Browse Source

lnworker: try to fail pending HTLCs when shutting down

This is most useful when receiving MPP where there is a non-trivial chance
that we have received some HTLCs for a payment but not all, and the user
closes the program. We try to fail them and wait for the fails to get
ACKed, with a timeout of course.
patch-4
SomberNight 4 years ago
parent
commit
cb78f73ed0
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 49
      electrum/lnhtlc.py
  2. 45
      electrum/lnpeer.py
  3. 33
      electrum/lnworker.py
  4. 1
      electrum/tests/test_lnpeer.py

49
electrum/lnhtlc.py

@ -359,6 +359,55 @@ class HTLCManager:
return False return False
return ctns[ctx_owner] <= self.ctn_oldest_unrevoked(ctx_owner) return ctns[ctx_owner] <= self.ctn_oldest_unrevoked(ctx_owner)
@with_lock
def is_htlc_irrevocably_removed_yet(
self,
*,
ctx_owner: HTLCOwner = None,
htlc_proposer: HTLCOwner,
htlc_id: int,
) -> bool:
"""Returns whether the removal of an htlc was irrevocably committed to `ctx_owner's` ctx.
The removal can either be a fulfill/settle or a fail; they are not distinguished.
If `ctx_owner` is None, both parties' ctxs are checked.
"""
in_local = self._is_htlc_irrevocably_removed_yet(
ctx_owner=LOCAL, htlc_proposer=htlc_proposer, htlc_id=htlc_id)
in_remote = self._is_htlc_irrevocably_removed_yet(
ctx_owner=REMOTE, htlc_proposer=htlc_proposer, htlc_id=htlc_id)
if ctx_owner is None:
return in_local and in_remote
elif ctx_owner == LOCAL:
return in_local
elif ctx_owner == REMOTE:
return in_remote
else:
raise Exception(f"unexpected ctx_owner: {ctx_owner!r}")
@with_lock
def _is_htlc_irrevocably_removed_yet(
self,
*,
ctx_owner: HTLCOwner,
htlc_proposer: HTLCOwner,
htlc_id: int,
) -> bool:
htlc_id = int(htlc_id)
if htlc_id >= self.get_next_htlc_id(htlc_proposer):
return False
if htlc_id in self.log[htlc_proposer]['settles']:
ctn_of_settle = self.log[htlc_proposer]['settles'][htlc_id][ctx_owner]
else:
ctn_of_settle = None
if htlc_id in self.log[htlc_proposer]['fails']:
ctn_of_fail = self.log[htlc_proposer]['fails'][htlc_id][ctx_owner]
else:
ctn_of_fail = None
ctn_of_rm = ctn_of_settle or ctn_of_fail or None
if ctn_of_rm is None:
return False
return ctn_of_rm <= self.ctn_oldest_unrevoked(ctx_owner)
@with_lock @with_lock
def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction, def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction,
ctn: int = None) -> Dict[int, UpdateAddHtlc]: ctn: int = None) -> Dict[int, UpdateAddHtlc]:

45
electrum/lnpeer.py

@ -9,11 +9,12 @@ from collections import OrderedDict, defaultdict
import asyncio import asyncio
import os import os
import time import time
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set
from datetime import datetime from datetime import datetime
import functools import functools
import aiorpcx import aiorpcx
from aiorpcx import TaskGroup
from .crypto import sha256, sha256d from .crypto import sha256, sha256d
from . import bitcoin, util from . import bitcoin, util
@ -74,6 +75,7 @@ class Peer(Logger):
self._sent_init = False # type: bool self._sent_init = False # type: bool
self._received_init = False # type: bool self._received_init = False # type: bool
self.initialized = asyncio.Future() self.initialized = asyncio.Future()
self.got_disconnected = asyncio.Event()
self.querying = asyncio.Event() self.querying = asyncio.Event()
self.transport = transport self.transport = transport
self.pubkey = pubkey # remote pubkey self.pubkey = pubkey # remote pubkey
@ -98,6 +100,11 @@ class Peer(Logger):
self.orphan_channel_updates = OrderedDict() self.orphan_channel_updates = OrderedDict()
Logger.__init__(self) Logger.__init__(self)
self.taskgroup = SilentTaskGroup() self.taskgroup = SilentTaskGroup()
# HTLCs offered by REMOTE, that we started removing but are still active:
self.received_htlcs_pending_removal = set() # type: Set[Tuple[Channel, int]]
self.received_htlc_removed_event = asyncio.Event()
self._htlc_switch_iterstart_event = asyncio.Event()
self._htlc_switch_iterdone_event = asyncio.Event()
def send_message(self, message_name: str, **kwargs): def send_message(self, message_name: str, **kwargs):
assert type(message_name) is str assert type(message_name) is str
@ -492,6 +499,7 @@ class Peer(Logger):
except: except:
pass pass
self.lnworker.peer_closed(self) self.lnworker.peer_closed(self)
self.got_disconnected.set()
def is_static_remotekey(self): def is_static_remotekey(self):
return self.features.supports(LnFeatures.OPTION_STATIC_REMOTEKEY_OPT) return self.features.supports(LnFeatures.OPTION_STATIC_REMOTEKEY_OPT)
@ -1575,6 +1583,7 @@ class Peer(Logger):
self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}" assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id) assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id)
self.received_htlcs_pending_removal.add((chan, htlc_id))
chan.settle_htlc(preimage, htlc_id) chan.settle_htlc(preimage, htlc_id)
self.send_message( self.send_message(
"update_fulfill_htlc", "update_fulfill_htlc",
@ -1585,6 +1594,7 @@ class Peer(Logger):
def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes): def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes):
self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.") self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.")
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}" assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
self.received_htlcs_pending_removal.add((chan, htlc_id))
chan.fail_htlc(htlc_id) chan.fail_htlc(htlc_id)
self.send_message( self.send_message(
"update_fail_htlc", "update_fail_htlc",
@ -1596,9 +1606,10 @@ class Peer(Logger):
def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionRoutingFailure): def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionRoutingFailure):
self.logger.info(f"fail_malformed_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.") self.logger.info(f"fail_malformed_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.")
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}" assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
chan.fail_htlc(htlc_id)
if not (reason.code & OnionFailureCodeMetaFlag.BADONION and len(reason.data) == 32): if not (reason.code & OnionFailureCodeMetaFlag.BADONION and len(reason.data) == 32):
raise Exception(f"unexpected reason when sending 'update_fail_malformed_htlc': {reason!r}") raise Exception(f"unexpected reason when sending 'update_fail_malformed_htlc': {reason!r}")
self.received_htlcs_pending_removal.add((chan, htlc_id))
chan.fail_htlc(htlc_id)
self.send_message( self.send_message(
"update_fail_malformed_htlc", "update_fail_malformed_htlc",
channel_id=chan.channel_id, channel_id=chan.channel_id,
@ -1800,8 +1811,13 @@ class Peer(Logger):
async def htlc_switch(self): async def htlc_switch(self):
await self.initialized await self.initialized
while True: while True:
await asyncio.sleep(0.1) self._htlc_switch_iterdone_event.set()
self._htlc_switch_iterdone_event.clear()
await asyncio.sleep(0.1) # TODO maybe make this partly event-driven
self._htlc_switch_iterstart_event.set()
self._htlc_switch_iterstart_event.clear()
self.ping_if_required() self.ping_if_required()
self._maybe_cleanup_received_htlcs_pending_removal()
for chan_id, chan in self.channels.items(): for chan_id, chan in self.channels.items():
if not chan.can_send_ctx_updates(): if not chan.can_send_ctx_updates():
continue continue
@ -1853,6 +1869,29 @@ class Peer(Logger):
for htlc_id in done: for htlc_id in done:
unfulfilled.pop(htlc_id) unfulfilled.pop(htlc_id)
def _maybe_cleanup_received_htlcs_pending_removal(self) -> None:
done = set()
for chan, htlc_id in self.received_htlcs_pending_removal:
if chan.hm.is_htlc_irrevocably_removed_yet(htlc_proposer=REMOTE, htlc_id=htlc_id):
done.add((chan, htlc_id))
if done:
for key in done:
self.received_htlcs_pending_removal.remove(key)
self.received_htlc_removed_event.set()
self.received_htlc_removed_event.clear()
async def wait_one_htlc_switch_iteration(self) -> None:
"""Waits until the HTLC switch does a full iteration or the peer disconnects,
whichever happens first.
"""
async def htlc_switch_iteration():
await self._htlc_switch_iterstart_event.wait()
await self._htlc_switch_iterdone_event.wait()
async with TaskGroup(wait=any) as group:
await group.spawn(htlc_switch_iteration())
await group.spawn(self.got_disconnected.wait())
async def process_unfulfilled_htlc( async def process_unfulfilled_htlc(
self, *, self, *,
chan: Channel, chan: Channel,

33
electrum/lnworker.py

@ -22,7 +22,7 @@ import urllib.parse
import dns.resolver import dns.resolver
import dns.exception import dns.exception
from aiorpcx import run_in_thread, TaskGroup, NetAddress from aiorpcx import run_in_thread, TaskGroup, NetAddress, ignore_after
from . import constants, util from . import constants, util
from . import keystore from . import keystore
@ -195,6 +195,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
self.features = features self.features = features
self.network = None # type: Optional[Network] self.network = None # type: Optional[Network]
self.config = None # type: Optional[SimpleConfig] self.config = None # type: Optional[SimpleConfig]
self.stopping_soon = False # whether we are being shut down
util.register_callback(self.on_proxy_changed, ['proxy_set']) util.register_callback(self.on_proxy_changed, ['proxy_set'])
@ -268,6 +269,8 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
async def _maintain_connectivity(self): async def _maintain_connectivity(self):
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)
if self.stopping_soon:
return
now = time.time() now = time.time()
if len(self._peers) >= NUM_PEERS_TARGET: if len(self._peers) >= NUM_PEERS_TARGET:
continue continue
@ -707,10 +710,32 @@ class LNWallet(LNWorker):
asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop) asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
async def stop(self): async def stop(self):
self.stopping_soon = True
if self.listen_server: # stop accepting new peers
self.listen_server.close()
async with ignore_after(3):
await self.wait_for_received_pending_htlcs_to_get_removed()
await super().stop() await super().stop()
await self.lnwatcher.stop() await self.lnwatcher.stop()
self.lnwatcher = None self.lnwatcher = None
async def wait_for_received_pending_htlcs_to_get_removed(self):
assert self.stopping_soon is True
# We try to fail pending MPP HTLCs, and wait a bit for them to get removed.
# Note: even without MPP, if we just failed/fulfilled an HTLC, it is good
# to wait a bit for it to become irrevocably removed.
# Note: we don't wait for *all htlcs* to get removed, only for those
# that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed
async with TaskGroup() as group:
for peer in self.peers.values():
await group.spawn(peer.wait_one_htlc_switch_iteration())
while True:
if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()):
break
async with TaskGroup(wait=any) as group:
for peer in self.peers.values():
await group.spawn(peer.received_htlc_removed_event.wait())
def peer_closed(self, peer): def peer_closed(self, peer):
for chan in self.channels_for_peer(peer.pubkey).values(): for chan in self.channels_for_peer(peer.pubkey).values():
chan.peer_state = PeerState.DISCONNECTED chan.peer_state = PeerState.DISCONNECTED
@ -1635,7 +1660,9 @@ class LNWallet(LNWorker):
if not is_accepted and not is_expired: if not is_accepted and not is_expired:
total = sum([_htlc.amount_msat for scid, _htlc in htlc_set]) total = sum([_htlc.amount_msat for scid, _htlc in htlc_set])
first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set]) first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set])
if time.time() - first_timestamp > self.MPP_EXPIRY: if self.stopping_soon:
is_expired = True # try to time out pending HTLCs before shutting down
elif time.time() - first_timestamp > self.MPP_EXPIRY:
is_expired = True is_expired = True
elif total == expected_msat: elif total == expected_msat:
is_accepted = True is_accepted = True
@ -1897,6 +1924,8 @@ class LNWallet(LNWorker):
async def reestablish_peers_and_channels(self): async def reestablish_peers_and_channels(self):
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)
if self.stopping_soon:
return
for chan in self.channels.values(): for chan in self.channels.values():
if chan.is_closed(): if chan.is_closed():
continue continue

1
electrum/tests/test_lnpeer.py

@ -147,6 +147,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
self.trampoline_forwarding_failures = {} self.trampoline_forwarding_failures = {}
self.inflight_payments = set() self.inflight_payments = set()
self.preimages = {} self.preimages = {}
self.stopping_soon = False
def get_invoice_status(self, key): def get_invoice_status(self, key):
pass pass

Loading…
Cancel
Save