Browse Source

Merge pull request #7099 from SomberNight/202103_fail_pending_htlcs_on_shutdown

fail pending htlcs on shutdown
patch-4
ThomasV 4 years ago
committed by GitHub
parent
commit
6004a04705
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 49
      electrum/lnhtlc.py
  2. 45
      electrum/lnpeer.py
  3. 41
      electrum/lnworker.py
  4. 62
      electrum/tests/test_lnpeer.py

49
electrum/lnhtlc.py

@ -359,6 +359,55 @@ class HTLCManager:
return False return False
return ctns[ctx_owner] <= self.ctn_oldest_unrevoked(ctx_owner) return ctns[ctx_owner] <= self.ctn_oldest_unrevoked(ctx_owner)
@with_lock
def is_htlc_irrevocably_removed_yet(
self,
*,
ctx_owner: HTLCOwner = None,
htlc_proposer: HTLCOwner,
htlc_id: int,
) -> bool:
"""Returns whether the removal of an htlc was irrevocably committed to `ctx_owner's` ctx.
The removal can either be a fulfill/settle or a fail; they are not distinguished.
If `ctx_owner` is None, both parties' ctxs are checked.
"""
in_local = self._is_htlc_irrevocably_removed_yet(
ctx_owner=LOCAL, htlc_proposer=htlc_proposer, htlc_id=htlc_id)
in_remote = self._is_htlc_irrevocably_removed_yet(
ctx_owner=REMOTE, htlc_proposer=htlc_proposer, htlc_id=htlc_id)
if ctx_owner is None:
return in_local and in_remote
elif ctx_owner == LOCAL:
return in_local
elif ctx_owner == REMOTE:
return in_remote
else:
raise Exception(f"unexpected ctx_owner: {ctx_owner!r}")
@with_lock
def _is_htlc_irrevocably_removed_yet(
self,
*,
ctx_owner: HTLCOwner,
htlc_proposer: HTLCOwner,
htlc_id: int,
) -> bool:
htlc_id = int(htlc_id)
if htlc_id >= self.get_next_htlc_id(htlc_proposer):
return False
if htlc_id in self.log[htlc_proposer]['settles']:
ctn_of_settle = self.log[htlc_proposer]['settles'][htlc_id][ctx_owner]
else:
ctn_of_settle = None
if htlc_id in self.log[htlc_proposer]['fails']:
ctn_of_fail = self.log[htlc_proposer]['fails'][htlc_id][ctx_owner]
else:
ctn_of_fail = None
ctn_of_rm = ctn_of_settle or ctn_of_fail or None
if ctn_of_rm is None:
return False
return ctn_of_rm <= self.ctn_oldest_unrevoked(ctx_owner)
@with_lock @with_lock
def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction, def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction,
ctn: int = None) -> Dict[int, UpdateAddHtlc]: ctn: int = None) -> Dict[int, UpdateAddHtlc]:

45
electrum/lnpeer.py

@ -9,11 +9,12 @@ from collections import OrderedDict, defaultdict
import asyncio import asyncio
import os import os
import time import time
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set
from datetime import datetime from datetime import datetime
import functools import functools
import aiorpcx import aiorpcx
from aiorpcx import TaskGroup
from .crypto import sha256, sha256d from .crypto import sha256, sha256d
from . import bitcoin, util from . import bitcoin, util
@ -74,6 +75,7 @@ class Peer(Logger):
self._sent_init = False # type: bool self._sent_init = False # type: bool
self._received_init = False # type: bool self._received_init = False # type: bool
self.initialized = asyncio.Future() self.initialized = asyncio.Future()
self.got_disconnected = asyncio.Event()
self.querying = asyncio.Event() self.querying = asyncio.Event()
self.transport = transport self.transport = transport
self.pubkey = pubkey # remote pubkey self.pubkey = pubkey # remote pubkey
@ -98,6 +100,11 @@ class Peer(Logger):
self.orphan_channel_updates = OrderedDict() self.orphan_channel_updates = OrderedDict()
Logger.__init__(self) Logger.__init__(self)
self.taskgroup = SilentTaskGroup() self.taskgroup = SilentTaskGroup()
# HTLCs offered by REMOTE, that we started removing but are still active:
self.received_htlcs_pending_removal = set() # type: Set[Tuple[Channel, int]]
self.received_htlc_removed_event = asyncio.Event()
self._htlc_switch_iterstart_event = asyncio.Event()
self._htlc_switch_iterdone_event = asyncio.Event()
def send_message(self, message_name: str, **kwargs): def send_message(self, message_name: str, **kwargs):
assert type(message_name) is str assert type(message_name) is str
@ -492,6 +499,7 @@ class Peer(Logger):
except: except:
pass pass
self.lnworker.peer_closed(self) self.lnworker.peer_closed(self)
self.got_disconnected.set()
def is_static_remotekey(self): def is_static_remotekey(self):
return self.features.supports(LnFeatures.OPTION_STATIC_REMOTEKEY_OPT) return self.features.supports(LnFeatures.OPTION_STATIC_REMOTEKEY_OPT)
@ -1575,6 +1583,7 @@ class Peer(Logger):
self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}" assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id) assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id)
self.received_htlcs_pending_removal.add((chan, htlc_id))
chan.settle_htlc(preimage, htlc_id) chan.settle_htlc(preimage, htlc_id)
self.send_message( self.send_message(
"update_fulfill_htlc", "update_fulfill_htlc",
@ -1585,6 +1594,7 @@ class Peer(Logger):
def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes): def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes):
self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.") self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.")
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}" assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
self.received_htlcs_pending_removal.add((chan, htlc_id))
chan.fail_htlc(htlc_id) chan.fail_htlc(htlc_id)
self.send_message( self.send_message(
"update_fail_htlc", "update_fail_htlc",
@ -1596,9 +1606,10 @@ class Peer(Logger):
def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionRoutingFailure): def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionRoutingFailure):
self.logger.info(f"fail_malformed_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.") self.logger.info(f"fail_malformed_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.")
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}" assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
chan.fail_htlc(htlc_id)
if not (reason.code & OnionFailureCodeMetaFlag.BADONION and len(reason.data) == 32): if not (reason.code & OnionFailureCodeMetaFlag.BADONION and len(reason.data) == 32):
raise Exception(f"unexpected reason when sending 'update_fail_malformed_htlc': {reason!r}") raise Exception(f"unexpected reason when sending 'update_fail_malformed_htlc': {reason!r}")
self.received_htlcs_pending_removal.add((chan, htlc_id))
chan.fail_htlc(htlc_id)
self.send_message( self.send_message(
"update_fail_malformed_htlc", "update_fail_malformed_htlc",
channel_id=chan.channel_id, channel_id=chan.channel_id,
@ -1800,8 +1811,13 @@ class Peer(Logger):
async def htlc_switch(self): async def htlc_switch(self):
await self.initialized await self.initialized
while True: while True:
await asyncio.sleep(0.1) self._htlc_switch_iterdone_event.set()
self._htlc_switch_iterdone_event.clear()
await asyncio.sleep(0.1) # TODO maybe make this partly event-driven
self._htlc_switch_iterstart_event.set()
self._htlc_switch_iterstart_event.clear()
self.ping_if_required() self.ping_if_required()
self._maybe_cleanup_received_htlcs_pending_removal()
for chan_id, chan in self.channels.items(): for chan_id, chan in self.channels.items():
if not chan.can_send_ctx_updates(): if not chan.can_send_ctx_updates():
continue continue
@ -1853,6 +1869,29 @@ class Peer(Logger):
for htlc_id in done: for htlc_id in done:
unfulfilled.pop(htlc_id) unfulfilled.pop(htlc_id)
def _maybe_cleanup_received_htlcs_pending_removal(self) -> None:
done = set()
for chan, htlc_id in self.received_htlcs_pending_removal:
if chan.hm.is_htlc_irrevocably_removed_yet(htlc_proposer=REMOTE, htlc_id=htlc_id):
done.add((chan, htlc_id))
if done:
for key in done:
self.received_htlcs_pending_removal.remove(key)
self.received_htlc_removed_event.set()
self.received_htlc_removed_event.clear()
async def wait_one_htlc_switch_iteration(self) -> None:
"""Waits until the HTLC switch does a full iteration or the peer disconnects,
whichever happens first.
"""
async def htlc_switch_iteration():
await self._htlc_switch_iterstart_event.wait()
await self._htlc_switch_iterdone_event.wait()
async with TaskGroup(wait=any) as group:
await group.spawn(htlc_switch_iteration())
await group.spawn(self.got_disconnected.wait())
async def process_unfulfilled_htlc( async def process_unfulfilled_htlc(
self, *, self, *,
chan: Channel, chan: Channel,

41
electrum/lnworker.py

@ -22,7 +22,7 @@ import urllib.parse
import dns.resolver import dns.resolver
import dns.exception import dns.exception
from aiorpcx import run_in_thread, TaskGroup, NetAddress from aiorpcx import run_in_thread, TaskGroup, NetAddress, ignore_after
from . import constants, util from . import constants, util
from . import keystore from . import keystore
@ -195,6 +195,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
self.features = features self.features = features
self.network = None # type: Optional[Network] self.network = None # type: Optional[Network]
self.config = None # type: Optional[SimpleConfig] self.config = None # type: Optional[SimpleConfig]
self.stopping_soon = False # whether we are being shut down
util.register_callback(self.on_proxy_changed, ['proxy_set']) util.register_callback(self.on_proxy_changed, ['proxy_set'])
@ -268,6 +269,8 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
async def _maintain_connectivity(self): async def _maintain_connectivity(self):
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)
if self.stopping_soon:
return
now = time.time() now = time.time()
if len(self._peers) >= NUM_PEERS_TARGET: if len(self._peers) >= NUM_PEERS_TARGET:
continue continue
@ -575,6 +578,7 @@ class LNWallet(LNWorker):
lnwatcher: Optional['LNWalletWatcher'] lnwatcher: Optional['LNWalletWatcher']
MPP_EXPIRY = 120 MPP_EXPIRY = 120
TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3 # seconds
def __init__(self, wallet: 'Abstract_Wallet', xprv): def __init__(self, wallet: 'Abstract_Wallet', xprv):
self.wallet = wallet self.wallet = wallet
@ -707,9 +711,32 @@ class LNWallet(LNWorker):
asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop) asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
async def stop(self): async def stop(self):
await super().stop() self.stopping_soon = True
await self.lnwatcher.stop() if self.listen_server: # stop accepting new peers
self.lnwatcher = None self.listen_server.close()
async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS):
await self.wait_for_received_pending_htlcs_to_get_removed()
await LNWorker.stop(self)
if self.lnwatcher:
await self.lnwatcher.stop()
self.lnwatcher = None
async def wait_for_received_pending_htlcs_to_get_removed(self):
assert self.stopping_soon is True
# We try to fail pending MPP HTLCs, and wait a bit for them to get removed.
# Note: even without MPP, if we just failed/fulfilled an HTLC, it is good
# to wait a bit for it to become irrevocably removed.
# Note: we don't wait for *all htlcs* to get removed, only for those
# that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed
async with TaskGroup() as group:
for peer in self.peers.values():
await group.spawn(peer.wait_one_htlc_switch_iteration())
while True:
if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()):
break
async with TaskGroup(wait=any) as group:
for peer in self.peers.values():
await group.spawn(peer.received_htlc_removed_event.wait())
def peer_closed(self, peer): def peer_closed(self, peer):
for chan in self.channels_for_peer(peer.pubkey).values(): for chan in self.channels_for_peer(peer.pubkey).values():
@ -1635,7 +1662,9 @@ class LNWallet(LNWorker):
if not is_accepted and not is_expired: if not is_accepted and not is_expired:
total = sum([_htlc.amount_msat for scid, _htlc in htlc_set]) total = sum([_htlc.amount_msat for scid, _htlc in htlc_set])
first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set]) first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set])
if time.time() - first_timestamp > self.MPP_EXPIRY: if self.stopping_soon:
is_expired = True # try to time out pending HTLCs before shutting down
elif time.time() - first_timestamp > self.MPP_EXPIRY:
is_expired = True is_expired = True
elif total == expected_msat: elif total == expected_msat:
is_accepted = True is_accepted = True
@ -1897,6 +1926,8 @@ class LNWallet(LNWorker):
async def reestablish_peers_and_channels(self): async def reestablish_peers_and_channels(self):
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)
if self.stopping_soon:
return
for chan in self.channels.values(): for chan in self.channels.values():
if chan.is_closed(): if chan.is_closed():
continue continue

62
electrum/tests/test_lnpeer.py

@ -10,7 +10,7 @@ from concurrent import futures
import unittest import unittest
from typing import Iterable, NamedTuple, Tuple, List from typing import Iterable, NamedTuple, Tuple, List
from aiorpcx import TaskGroup from aiorpcx import TaskGroup, timeout_after, TaskTimeout
from electrum import bitcoin from electrum import bitcoin
from electrum import constants from electrum import constants
@ -113,7 +113,8 @@ class MockWallet:
class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1 MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0
def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name): def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name):
self.name = name self.name = name
@ -121,6 +122,9 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1) NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
self.node_keypair = local_keypair self.node_keypair = local_keypair
self.network = MockNetwork(tx_queue) self.network = MockNetwork(tx_queue)
self.taskgroup = TaskGroup()
self.lnwatcher = None
self.listen_server = None
self._channels = {chan.channel_id: chan for chan in chans} self._channels = {chan.channel_id: chan for chan in chans}
self.payments = {} self.payments = {}
self.logs = defaultdict(list) self.logs = defaultdict(list)
@ -147,6 +151,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
self.trampoline_forwarding_failures = {} self.trampoline_forwarding_failures = {}
self.inflight_payments = set() self.inflight_payments = set()
self.preimages = {} self.preimages = {}
self.stopping_soon = False
def get_invoice_status(self, key): def get_invoice_status(self, key):
pass pass
@ -183,6 +188,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
return self.name return self.name
async def stop(self): async def stop(self):
await LNWallet.stop(self)
if self.channel_db: if self.channel_db:
self.channel_db.stop() self.channel_db.stop()
await self.channel_db.stopped_event.wait() await self.channel_db.stopped_event.wait()
@ -215,6 +221,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
_calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice _calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice
handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc
is_trampoline_peer = LNWallet.is_trampoline_peer is_trampoline_peer = LNWallet.is_trampoline_peer
wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed
on_proxy_changed = LNWallet.on_proxy_changed
class MockTransport: class MockTransport:
@ -290,13 +298,9 @@ class SquareGraph(NamedTuple):
def all_lnworkers(self) -> Iterable[MockLNWallet]: def all_lnworkers(self) -> Iterable[MockLNWallet]:
return self.w_a, self.w_b, self.w_c, self.w_d return self.w_a, self.w_b, self.w_c, self.w_d
async def stop_and_cleanup(self):
async with TaskGroup() as group:
for lnworker in self.all_lnworkers():
await group.spawn(lnworker.stop())
class PaymentDone(Exception): pass class PaymentDone(Exception): pass
class TestSuccess(Exception): pass
class TestPeer(ElectrumTestCase): class TestPeer(ElectrumTestCase):
@ -836,6 +840,50 @@ class TestPeer(ElectrumTestCase):
graph = self.prepare_chans_and_peers_in_square() graph = self.prepare_chans_and_peers_in_square()
self._run_mpp(graph, {'alice_uses_trampoline':True, 'attempts':1}, {'alice_uses_trampoline':True, 'attempts':3}) self._run_mpp(graph, {'alice_uses_trampoline':True, 'attempts':1}, {'alice_uses_trampoline':True, 'attempts':3})
@needs_test_with_all_chacha20_implementations
def test_fail_pending_htlcs_on_shutdown(self):
"""Alice tries to pay Dave via MPP. Dave receives some HTLCs but not all.
Dave shuts down (stops wallet).
We test if Dave fails the pending HTLCs during shutdown.
"""
graph = self.prepare_chans_and_peers_in_square()
self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL))
self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL))
amount_to_pay = 600_000_000_000
peers = graph.all_peers()
graph.w_d.MPP_EXPIRY = 120
graph.w_d.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3
async def pay():
graph.w_d.features |= LnFeatures.BASIC_MPP_OPT
graph.w_b.enable_htlc_forwarding.clear() # Bob will hold forwarded HTLCs
assert graph.w_a.network.channel_db is not None
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay)
try:
async with timeout_after(0.5):
result, log = await graph.w_a.pay_invoice(pay_req, attempts=1)
except TaskTimeout:
# by now Dave hopefully received some HTLCs:
self.assertTrue(len(graph.chan_dc.hm.htlcs(LOCAL)) > 0)
self.assertTrue(len(graph.chan_dc.hm.htlcs(REMOTE)) > 0)
else:
self.fail(f"pay_invoice finished but was not supposed to. result={result}")
await graph.w_d.stop()
# Dave is supposed to have failed the pending incomplete MPP HTLCs
self.assertEqual(0, len(graph.chan_dc.hm.htlcs(LOCAL)))
self.assertEqual(0, len(graph.chan_dc.hm.htlcs(REMOTE)))
raise TestSuccess()
async def f():
async with TaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2)
await group.spawn(pay())
with self.assertRaises(TestSuccess):
run(f())
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
def test_close(self): def test_close(self):
alice_channel, bob_channel = create_test_channels() alice_channel, bob_channel = create_test_channels()

Loading…
Cancel
Save