diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py index a1eccd3ce..45b1cf739 100644 --- a/electrum/lnhtlc.py +++ b/electrum/lnhtlc.py @@ -359,6 +359,55 @@ class HTLCManager: return False return ctns[ctx_owner] <= self.ctn_oldest_unrevoked(ctx_owner) + @with_lock + def is_htlc_irrevocably_removed_yet( + self, + *, + ctx_owner: HTLCOwner = None, + htlc_proposer: HTLCOwner, + htlc_id: int, + ) -> bool: + """Returns whether the removal of an htlc was irrevocably committed to `ctx_owner's` ctx. + The removal can either be a fulfill/settle or a fail; they are not distinguished. + If `ctx_owner` is None, both parties' ctxs are checked. + """ + in_local = self._is_htlc_irrevocably_removed_yet( + ctx_owner=LOCAL, htlc_proposer=htlc_proposer, htlc_id=htlc_id) + in_remote = self._is_htlc_irrevocably_removed_yet( + ctx_owner=REMOTE, htlc_proposer=htlc_proposer, htlc_id=htlc_id) + if ctx_owner is None: + return in_local and in_remote + elif ctx_owner == LOCAL: + return in_local + elif ctx_owner == REMOTE: + return in_remote + else: + raise Exception(f"unexpected ctx_owner: {ctx_owner!r}") + + @with_lock + def _is_htlc_irrevocably_removed_yet( + self, + *, + ctx_owner: HTLCOwner, + htlc_proposer: HTLCOwner, + htlc_id: int, + ) -> bool: + htlc_id = int(htlc_id) + if htlc_id >= self.get_next_htlc_id(htlc_proposer): + return False + if htlc_id in self.log[htlc_proposer]['settles']: + ctn_of_settle = self.log[htlc_proposer]['settles'][htlc_id][ctx_owner] + else: + ctn_of_settle = None + if htlc_id in self.log[htlc_proposer]['fails']: + ctn_of_fail = self.log[htlc_proposer]['fails'][htlc_id][ctx_owner] + else: + ctn_of_fail = None + ctn_of_rm = ctn_of_settle or ctn_of_fail or None + if ctn_of_rm is None: + return False + return ctn_of_rm <= self.ctn_oldest_unrevoked(ctx_owner) + @with_lock def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction, ctn: int = None) -> Dict[int, UpdateAddHtlc]: diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index aa9c3976c..82a5e1d93 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -9,11 +9,12 @@ from collections import OrderedDict, defaultdict import asyncio import os import time -from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union +from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set from datetime import datetime import functools import aiorpcx +from aiorpcx import TaskGroup from .crypto import sha256, sha256d from . import bitcoin, util @@ -74,6 +75,7 @@ class Peer(Logger): self._sent_init = False # type: bool self._received_init = False # type: bool self.initialized = asyncio.Future() + self.got_disconnected = asyncio.Event() self.querying = asyncio.Event() self.transport = transport self.pubkey = pubkey # remote pubkey @@ -98,6 +100,11 @@ class Peer(Logger): self.orphan_channel_updates = OrderedDict() Logger.__init__(self) self.taskgroup = SilentTaskGroup() + # HTLCs offered by REMOTE, that we started removing but are still active: + self.received_htlcs_pending_removal = set() # type: Set[Tuple[Channel, int]] + self.received_htlc_removed_event = asyncio.Event() + self._htlc_switch_iterstart_event = asyncio.Event() + self._htlc_switch_iterdone_event = asyncio.Event() def send_message(self, message_name: str, **kwargs): assert type(message_name) is str @@ -492,6 +499,7 @@ class Peer(Logger): except: pass self.lnworker.peer_closed(self) + self.got_disconnected.set() def is_static_remotekey(self): return self.features.supports(LnFeatures.OPTION_STATIC_REMOTEKEY_OPT) @@ -1575,6 +1583,7 @@ class Peer(Logger): self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}" assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id) + self.received_htlcs_pending_removal.add((chan, htlc_id)) chan.settle_htlc(preimage, htlc_id) self.send_message( "update_fulfill_htlc", @@ -1585,6 +1594,7 @@ class Peer(Logger): def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes): self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.") assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}" + self.received_htlcs_pending_removal.add((chan, htlc_id)) chan.fail_htlc(htlc_id) self.send_message( "update_fail_htlc", @@ -1596,9 +1606,10 @@ class Peer(Logger): def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionRoutingFailure): self.logger.info(f"fail_malformed_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.") assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}" - chan.fail_htlc(htlc_id) if not (reason.code & OnionFailureCodeMetaFlag.BADONION and len(reason.data) == 32): raise Exception(f"unexpected reason when sending 'update_fail_malformed_htlc': {reason!r}") + self.received_htlcs_pending_removal.add((chan, htlc_id)) + chan.fail_htlc(htlc_id) self.send_message( "update_fail_malformed_htlc", channel_id=chan.channel_id, @@ -1800,8 +1811,13 @@ class Peer(Logger): async def htlc_switch(self): await self.initialized while True: - await asyncio.sleep(0.1) + self._htlc_switch_iterdone_event.set() + self._htlc_switch_iterdone_event.clear() + await asyncio.sleep(0.1) # TODO maybe make this partly event-driven + self._htlc_switch_iterstart_event.set() + self._htlc_switch_iterstart_event.clear() self.ping_if_required() + self._maybe_cleanup_received_htlcs_pending_removal() for chan_id, chan in self.channels.items(): if not chan.can_send_ctx_updates(): continue @@ -1853,6 +1869,29 @@ class Peer(Logger): for htlc_id in done: unfulfilled.pop(htlc_id) + def _maybe_cleanup_received_htlcs_pending_removal(self) -> None: + done = set() + for chan, htlc_id in self.received_htlcs_pending_removal: + if chan.hm.is_htlc_irrevocably_removed_yet(htlc_proposer=REMOTE, htlc_id=htlc_id): + done.add((chan, htlc_id)) + if done: + for key in done: + self.received_htlcs_pending_removal.remove(key) + self.received_htlc_removed_event.set() + self.received_htlc_removed_event.clear() + + async def wait_one_htlc_switch_iteration(self) -> None: + """Waits until the HTLC switch does a full iteration or the peer disconnects, + whichever happens first. + """ + async def htlc_switch_iteration(): + await self._htlc_switch_iterstart_event.wait() + await self._htlc_switch_iterdone_event.wait() + + async with TaskGroup(wait=any) as group: + await group.spawn(htlc_switch_iteration()) + await group.spawn(self.got_disconnected.wait()) + async def process_unfulfilled_htlc( self, *, chan: Channel, diff --git a/electrum/lnworker.py b/electrum/lnworker.py index b0254a305..5e155b49e 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -22,7 +22,7 @@ import urllib.parse import dns.resolver import dns.exception -from aiorpcx import run_in_thread, TaskGroup, NetAddress +from aiorpcx import run_in_thread, TaskGroup, NetAddress, ignore_after from . import constants, util from . import keystore @@ -195,6 +195,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]): self.features = features self.network = None # type: Optional[Network] self.config = None # type: Optional[SimpleConfig] + self.stopping_soon = False # whether we are being shut down util.register_callback(self.on_proxy_changed, ['proxy_set']) @@ -268,6 +269,8 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]): async def _maintain_connectivity(self): while True: await asyncio.sleep(1) + if self.stopping_soon: + return now = time.time() if len(self._peers) >= NUM_PEERS_TARGET: continue @@ -707,10 +710,32 @@ class LNWallet(LNWorker): asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop) async def stop(self): + self.stopping_soon = True + if self.listen_server: # stop accepting new peers + self.listen_server.close() + async with ignore_after(3): + await self.wait_for_received_pending_htlcs_to_get_removed() await super().stop() await self.lnwatcher.stop() self.lnwatcher = None + async def wait_for_received_pending_htlcs_to_get_removed(self): + assert self.stopping_soon is True + # We try to fail pending MPP HTLCs, and wait a bit for them to get removed. + # Note: even without MPP, if we just failed/fulfilled an HTLC, it is good + # to wait a bit for it to become irrevocably removed. + # Note: we don't wait for *all htlcs* to get removed, only for those + # that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed + async with TaskGroup() as group: + for peer in self.peers.values(): + await group.spawn(peer.wait_one_htlc_switch_iteration()) + while True: + if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()): + break + async with TaskGroup(wait=any) as group: + for peer in self.peers.values(): + await group.spawn(peer.received_htlc_removed_event.wait()) + def peer_closed(self, peer): for chan in self.channels_for_peer(peer.pubkey).values(): chan.peer_state = PeerState.DISCONNECTED @@ -1635,7 +1660,9 @@ class LNWallet(LNWorker): if not is_accepted and not is_expired: total = sum([_htlc.amount_msat for scid, _htlc in htlc_set]) first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set]) - if time.time() - first_timestamp > self.MPP_EXPIRY: + if self.stopping_soon: + is_expired = True # try to time out pending HTLCs before shutting down + elif time.time() - first_timestamp > self.MPP_EXPIRY: is_expired = True elif total == expected_msat: is_accepted = True @@ -1897,6 +1924,8 @@ class LNWallet(LNWorker): async def reestablish_peers_and_channels(self): while True: await asyncio.sleep(1) + if self.stopping_soon: + return for chan in self.channels.values(): if chan.is_closed(): continue diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index a284f8652..2c0005bd5 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -147,6 +147,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): self.trampoline_forwarding_failures = {} self.inflight_payments = set() self.preimages = {} + self.stopping_soon = False def get_invoice_status(self, key): pass