diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index ef2f88ffa..2cb9cc0f0 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -624,6 +624,7 @@ class Channel(Logger): assert type(whose) is HTLCOwner initial = self.config[whose].initial_msat + # TODO slow. -- and 'balance' is called from a decent number of places (e.g. 'make_commitment') for direction, htlc in self.hm.all_settled_htlcs_ever(ctx_owner, ctn): # note: could "simplify" to (whose * ctx_owner == direction * SENT) if whose == ctx_owner: diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py index b7b151f4e..40b564367 100644 --- a/electrum/lnhtlc.py +++ b/electrum/lnhtlc.py @@ -13,10 +13,10 @@ class HTLCManager: if len(log) == 0: initial = { - 'adds': {}, - 'locked_in': {}, - 'settles': {}, - 'fails': {}, + 'adds': {}, # "side who offered htlc" -> htlc_id -> htlc + 'locked_in': {}, # "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn + 'settles': {}, # "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn + 'fails': {}, # "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn 'fee_updates': {}, # "side who initiated fee update" -> action -> list of FeeUpdates 'revack_pending': False, 'next_htlc_id': 0, @@ -36,6 +36,7 @@ class HTLCManager: if not log[sub]['fee_updates']: log[sub]['fee_updates'][0] = FeeUpdate(rate=initial_feerate, ctn_local=0, ctn_remote=0) self.log = log + self._init_maybe_active_htlc_ids() def ctn_latest(self, sub: HTLCOwner) -> int: """Return the ctn for the latest (newest that has a valid sig) ctx of sub""" @@ -73,6 +74,7 @@ class HTLCManager: self.log[LOCAL]['adds'][htlc_id] = htlc self.log[LOCAL]['locked_in'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE)+1} self.log[LOCAL]['next_htlc_id'] += 1 + self._maybe_active_htlc_ids[LOCAL].add(htlc_id) return htlc def recv_htlc(self, htlc: UpdateAddHtlc) -> None: @@ -83,6 +85,7 @@ class HTLCManager: self.log[REMOTE]['adds'][htlc_id] = htlc self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL)+1, REMOTE: None} self.log[REMOTE]['next_htlc_id'] += 1 + self._maybe_active_htlc_ids[REMOTE].add(htlc_id) def send_settle(self, htlc_id: int) -> None: self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1} @@ -130,13 +133,17 @@ class HTLCManager: self.log[LOCAL]['ctn'] += 1 self._set_revack_pending(LOCAL, False) # htlcs - for ctns in self.log[REMOTE]['locked_in'].values(): + for htlc_id in self._maybe_active_htlc_ids[REMOTE]: + ctns = self.log[REMOTE]['locked_in'][htlc_id] if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL): ctns[REMOTE] = self.ctn_latest(REMOTE) + 1 for log_action in ('settles', 'fails'): - for ctns in self.log[LOCAL][log_action].values(): + for htlc_id in self._maybe_active_htlc_ids[LOCAL]: + ctns = self.log[LOCAL][log_action].get(htlc_id, None) + if ctns is None: continue if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL): ctns[REMOTE] = self.ctn_latest(REMOTE) + 1 + self._update_maybe_active_htlc_ids() # fee updates for k, fee_update in list(self.log[REMOTE]['fee_updates'].items()): if fee_update.ctn_remote is None and fee_update.ctn_local <= self.ctn_latest(LOCAL): @@ -146,13 +153,17 @@ class HTLCManager: self.log[REMOTE]['ctn'] += 1 self._set_revack_pending(REMOTE, False) # htlcs - for ctns in self.log[LOCAL]['locked_in'].values(): + for htlc_id in self._maybe_active_htlc_ids[LOCAL]: + ctns = self.log[LOCAL]['locked_in'][htlc_id] if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE): ctns[LOCAL] = self.ctn_latest(LOCAL) + 1 for log_action in ('settles', 'fails'): - for ctns in self.log[REMOTE][log_action].values(): + for htlc_id in self._maybe_active_htlc_ids[REMOTE]: + ctns = self.log[REMOTE][log_action].get(htlc_id, None) + if ctns is None: continue if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE): ctns[LOCAL] = self.ctn_latest(LOCAL) + 1 + self._update_maybe_active_htlc_ids() # fee updates for k, fee_update in list(self.log[LOCAL]['fee_updates'].items()): if fee_update.ctn_local is None and fee_update.ctn_remote <= self.ctn_latest(REMOTE): @@ -161,6 +172,32 @@ class HTLCManager: # no need to keep local update raw msgs anymore, they have just been ACKed. self.log['unacked_local_updates2'].pop(self.log[REMOTE]['ctn'], None) + def _update_maybe_active_htlc_ids(self) -> None: + # Loosely, we want a set that contains the htlcs that are + # not "removed and revoked from all ctxs of both parties". + # It is guaranteed that those htlcs are in the set, but older htlcs might be there too: + # there is a sanity margin of 1 ctn -- this relaxes the care needed re order of method calls. + sanity_margin = 1 + for htlc_proposer in (LOCAL, REMOTE): + for log_action in ('settles', 'fails'): + for htlc_id in list(self._maybe_active_htlc_ids[htlc_proposer]): + ctns = self.log[htlc_proposer][log_action].get(htlc_id, None) + if ctns is None: continue + if (ctns[LOCAL] is not None + and ctns[LOCAL] <= self.ctn_oldest_unrevoked(LOCAL) - sanity_margin + and ctns[REMOTE] is not None + and ctns[REMOTE] <= self.ctn_oldest_unrevoked(REMOTE) - sanity_margin): + self._maybe_active_htlc_ids[htlc_proposer].remove(htlc_id) + + def _init_maybe_active_htlc_ids(self): + self._maybe_active_htlc_ids = {LOCAL: set(), REMOTE: set()} # first idx is "side who offered htlc" + # add all htlcs + for htlc_proposer in (LOCAL, REMOTE): + for htlc_id in self.log[htlc_proposer]['adds']: + self._maybe_active_htlc_ids[htlc_proposer].add(htlc_id) + # remove old htlcs + self._update_maybe_active_htlc_ids() + def discard_unsigned_remote_updates(self): """Discard updates sent by the remote, that the remote itself did not yet sign (i.e. there was no corresponding commitment_signed msg) @@ -170,6 +207,7 @@ class HTLCManager: if ctns[LOCAL] > self.ctn_latest(LOCAL): del self.log[REMOTE]['locked_in'][htlc_id] del self.log[REMOTE]['adds'][htlc_id] + self._maybe_active_htlc_ids[REMOTE].discard(htlc_id) if self.log[REMOTE]['locked_in']: self.log[REMOTE]['next_htlc_id'] = max([int(x) for x in self.log[REMOTE]['locked_in'].keys()]) + 1 else: @@ -222,7 +260,12 @@ class HTLCManager: party = subject if direction == SENT else subject.inverted() settles = self.log[party]['settles'] fails = self.log[party]['fails'] - for htlc_id, ctns in self.log[party]['locked_in'].items(): + if ctn >= self.ctn_oldest_unrevoked(subject): + considered_htlc_ids = self._maybe_active_htlc_ids[party] + else: # ctn is too old; need to consider full log (slow...) + considered_htlc_ids = self.log[party]['locked_in'] + for htlc_id in considered_htlc_ids: + ctns = self.log[party]['locked_in'][htlc_id] if ctns[subject] is not None and ctns[subject] <= ctn: not_settled = htlc_id not in settles or settles[htlc_id][subject] is None or settles[htlc_id][subject] > ctn not_failed = htlc_id not in fails or fails[htlc_id][subject] is None or fails[htlc_id][subject] > ctn @@ -290,32 +333,50 @@ class HTLCManager: received = [(RECEIVED, x) for x in self.all_settled_htlcs_ever_by_direction(subject, RECEIVED, ctn)] return sent + received - def received_in_ctn(self, ctn: int) -> Sequence[UpdateAddHtlc]: + def _get_htlcs_that_got_removed_exactly_at_ctn( + self, ctn: int, *, ctx_owner: HTLCOwner, htlc_proposer: HTLCOwner, log_action: str, + ) -> Sequence[UpdateAddHtlc]: + if ctn >= self.ctn_oldest_unrevoked(ctx_owner): + considered_htlc_ids = self._maybe_active_htlc_ids[htlc_proposer] + else: # ctn is too old; need to consider full log (slow...) + considered_htlc_ids = self.log[htlc_proposer][log_action] + htlcs = [] + for htlc_id in considered_htlc_ids: + ctns = self.log[htlc_proposer][log_action].get(htlc_id, None) + if ctns is None: continue + if ctns[ctx_owner] == ctn: + htlcs.append(self.log[htlc_proposer]['adds'][htlc_id]) + return htlcs + + def received_in_ctn(self, local_ctn: int) -> Sequence[UpdateAddHtlc]: """ received htlcs that became fulfilled when we send a revocation. - we check only local, because they are commited in the remote ctx first. + we check only local, because they are committed in the remote ctx first. """ - return [self.log[REMOTE]['adds'][htlc_id] - for htlc_id, ctns in self.log[REMOTE]['settles'].items() - if ctns[LOCAL] == ctn] + return self._get_htlcs_that_got_removed_exactly_at_ctn(local_ctn, + ctx_owner=LOCAL, + htlc_proposer=REMOTE, + log_action='settles') - def sent_in_ctn(self, ctn: int) -> Sequence[UpdateAddHtlc]: + def sent_in_ctn(self, remote_ctn: int) -> Sequence[UpdateAddHtlc]: """ sent htlcs that became fulfilled when we received a revocation - we check only remote, because they are commited in the local ctx first. + we check only remote, because they are committed in the local ctx first. """ - return [self.log[LOCAL]['adds'][htlc_id] - for htlc_id, ctns in self.log[LOCAL]['settles'].items() - if ctns[REMOTE] == ctn] + return self._get_htlcs_that_got_removed_exactly_at_ctn(remote_ctn, + ctx_owner=REMOTE, + htlc_proposer=LOCAL, + log_action='settles') - def failed_in_ctn(self, ctn: int) -> Sequence[UpdateAddHtlc]: + def failed_in_ctn(self, remote_ctn: int) -> Sequence[UpdateAddHtlc]: """ sent htlcs that became failed when we received a revocation - we check only remote, because they are commited in the local ctx first. + we check only remote, because they are committed in the local ctx first. """ - return [self.log[LOCAL]['adds'][htlc_id] - for htlc_id, ctns in self.log[LOCAL]['fails'].items() - if ctns[REMOTE] == ctn] + return self._get_htlcs_that_got_removed_exactly_at_ctn(remote_ctn, + ctx_owner=REMOTE, + htlc_proposer=LOCAL, + log_action='fails') ##### Queries re Fees: diff --git a/electrum/lnworker.py b/electrum/lnworker.py index f9cdea190..456c34765 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -124,7 +124,7 @@ FALLBACK_NODE_LIST_MAINNET = [ class PaymentInfo(NamedTuple): payment_hash: bytes - amount: int + amount: int # in satoshis direction: int status: int @@ -934,6 +934,7 @@ class LNWallet(LNWorker): success = payment_attempt_log.success if success: break + self.logger.debug(f'payment attempts log for RHASH {key}: {repr(log)}') self.network.trigger_callback('invoice_status', key) return success diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 9b3fce90b..682cf298d 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -7,6 +7,9 @@ from collections import defaultdict import logging import concurrent from concurrent import futures +import unittest + +from aiorpcx import TaskGroup from electrum import constants from electrum.network import Network @@ -18,13 +21,13 @@ from electrum.util import bh2u, create_and_start_event_loop from electrum.lnpeer import Peer from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving -from electrum.lnutil import PaymentFailure, LnLocalFeatures +from electrum.lnutil import PaymentFailure, LnLocalFeatures, HTLCOwner from electrum.lnchannel import channel_states, peer_states, Channel from electrum.lnrouter import LNPathFinder from electrum.channel_db import ChannelDB from electrum.lnworker import LNWallet, NoPathFound from electrum.lnmsg import encode_msg, decode_msg -from electrum.logging import console_stderr_handler +from electrum.logging import console_stderr_handler, Logger from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID from .test_lnchannel import create_test_channels @@ -81,8 +84,9 @@ class MockWallet: def is_lightning_backup(self): return False -class MockLNWallet: +class MockLNWallet(Logger): def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue): + Logger.__init__(self) self.remote_keypair = remote_keypair self.node_keypair = local_keypair self.network = MockNetwork(tx_queue) @@ -216,9 +220,11 @@ class TestPeer(ElectrumTestCase): return p1, p2, w1, w2, q1, q2 @staticmethod - def prepare_invoice(w2 # receiver - ): - amount_sat = 100000 + def prepare_invoice( + w2, # receiver + *, + amount_sat=100_000, + ): amount_btc = amount_sat/Decimal(COIN) payment_preimage = os.urandom(32) RHASH = sha256(payment_preimage) @@ -300,6 +306,35 @@ class TestPeer(ElectrumTestCase): with self.assertRaises(concurrent.futures.CancelledError): run(f()) + @unittest.skip("too expensive") + #@needs_test_with_all_chacha20_implementations + def test_payments_stresstest(self): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + alice_init_balance_msat = alice_channel.balance(HTLCOwner.LOCAL) + bob_init_balance_msat = bob_channel.balance(HTLCOwner.LOCAL) + num_payments = 1000 + #pay_reqs1 = [self.prepare_invoice(w1, amount_sat=1) for i in range(num_payments)] + pay_reqs2 = [self.prepare_invoice(w2, amount_sat=1) for i in range(num_payments)] + max_htlcs_in_flight = asyncio.Semaphore(5) + async def single_payment(pay_req): + async with max_htlcs_in_flight: + await w1._pay(pay_req) + async def many_payments(): + async with TaskGroup() as group: + for pay_req in pay_reqs2: + await group.spawn(single_payment(pay_req)) + gath.cancel() + gath = asyncio.gather(many_payments(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) + async def f(): + await gath + with self.assertRaises(concurrent.futures.CancelledError): + run(f()) + self.assertEqual(alice_init_balance_msat - num_payments * 1000, alice_channel.balance(HTLCOwner.LOCAL)) + self.assertEqual(alice_init_balance_msat - num_payments * 1000, bob_channel.balance(HTLCOwner.REMOTE)) + self.assertEqual(bob_init_balance_msat + num_payments * 1000, bob_channel.balance(HTLCOwner.LOCAL)) + self.assertEqual(bob_init_balance_msat + num_payments * 1000, alice_channel.balance(HTLCOwner.REMOTE)) + @needs_test_with_all_chacha20_implementations def test_close(self): alice_channel, bob_channel = create_test_channels() @@ -352,5 +387,6 @@ class TestPeer(ElectrumTestCase): with self.assertRaises(PaymentFailure): run(f()) + def run(coro): return asyncio.run_coroutine_threadsafe(coro, loop=asyncio.get_event_loop()).result()