Browse Source

Merge pull request #7099 from SomberNight/202103_fail_pending_htlcs_on_shutdown

fail pending htlcs on shutdown
patch-4
ThomasV 4 years ago
committed by GitHub
parent
commit
6004a04705
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 49
      electrum/lnhtlc.py
  2. 45
      electrum/lnpeer.py
  3. 41
      electrum/lnworker.py
  4. 62
      electrum/tests/test_lnpeer.py

49
electrum/lnhtlc.py

@ -359,6 +359,55 @@ class HTLCManager:
return False
return ctns[ctx_owner] <= self.ctn_oldest_unrevoked(ctx_owner)
@with_lock
def is_htlc_irrevocably_removed_yet(
self,
*,
ctx_owner: HTLCOwner = None,
htlc_proposer: HTLCOwner,
htlc_id: int,
) -> bool:
"""Returns whether the removal of an htlc was irrevocably committed to `ctx_owner's` ctx.
The removal can either be a fulfill/settle or a fail; they are not distinguished.
If `ctx_owner` is None, both parties' ctxs are checked.
"""
in_local = self._is_htlc_irrevocably_removed_yet(
ctx_owner=LOCAL, htlc_proposer=htlc_proposer, htlc_id=htlc_id)
in_remote = self._is_htlc_irrevocably_removed_yet(
ctx_owner=REMOTE, htlc_proposer=htlc_proposer, htlc_id=htlc_id)
if ctx_owner is None:
return in_local and in_remote
elif ctx_owner == LOCAL:
return in_local
elif ctx_owner == REMOTE:
return in_remote
else:
raise Exception(f"unexpected ctx_owner: {ctx_owner!r}")
@with_lock
def _is_htlc_irrevocably_removed_yet(
self,
*,
ctx_owner: HTLCOwner,
htlc_proposer: HTLCOwner,
htlc_id: int,
) -> bool:
htlc_id = int(htlc_id)
if htlc_id >= self.get_next_htlc_id(htlc_proposer):
return False
if htlc_id in self.log[htlc_proposer]['settles']:
ctn_of_settle = self.log[htlc_proposer]['settles'][htlc_id][ctx_owner]
else:
ctn_of_settle = None
if htlc_id in self.log[htlc_proposer]['fails']:
ctn_of_fail = self.log[htlc_proposer]['fails'][htlc_id][ctx_owner]
else:
ctn_of_fail = None
ctn_of_rm = ctn_of_settle or ctn_of_fail or None
if ctn_of_rm is None:
return False
return ctn_of_rm <= self.ctn_oldest_unrevoked(ctx_owner)
@with_lock
def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction,
ctn: int = None) -> Dict[int, UpdateAddHtlc]:

45
electrum/lnpeer.py

@ -9,11 +9,12 @@ from collections import OrderedDict, defaultdict
import asyncio
import os
import time
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set
from datetime import datetime
import functools
import aiorpcx
from aiorpcx import TaskGroup
from .crypto import sha256, sha256d
from . import bitcoin, util
@ -74,6 +75,7 @@ class Peer(Logger):
self._sent_init = False # type: bool
self._received_init = False # type: bool
self.initialized = asyncio.Future()
self.got_disconnected = asyncio.Event()
self.querying = asyncio.Event()
self.transport = transport
self.pubkey = pubkey # remote pubkey
@ -98,6 +100,11 @@ class Peer(Logger):
self.orphan_channel_updates = OrderedDict()
Logger.__init__(self)
self.taskgroup = SilentTaskGroup()
# HTLCs offered by REMOTE, that we started removing but are still active:
self.received_htlcs_pending_removal = set() # type: Set[Tuple[Channel, int]]
self.received_htlc_removed_event = asyncio.Event()
self._htlc_switch_iterstart_event = asyncio.Event()
self._htlc_switch_iterdone_event = asyncio.Event()
def send_message(self, message_name: str, **kwargs):
assert type(message_name) is str
@ -492,6 +499,7 @@ class Peer(Logger):
except:
pass
self.lnworker.peer_closed(self)
self.got_disconnected.set()
def is_static_remotekey(self):
return self.features.supports(LnFeatures.OPTION_STATIC_REMOTEKEY_OPT)
@ -1575,6 +1583,7 @@ class Peer(Logger):
self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id)
self.received_htlcs_pending_removal.add((chan, htlc_id))
chan.settle_htlc(preimage, htlc_id)
self.send_message(
"update_fulfill_htlc",
@ -1585,6 +1594,7 @@ class Peer(Logger):
def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes):
self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.")
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
self.received_htlcs_pending_removal.add((chan, htlc_id))
chan.fail_htlc(htlc_id)
self.send_message(
"update_fail_htlc",
@ -1596,9 +1606,10 @@ class Peer(Logger):
def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionRoutingFailure):
self.logger.info(f"fail_malformed_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.")
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
chan.fail_htlc(htlc_id)
if not (reason.code & OnionFailureCodeMetaFlag.BADONION and len(reason.data) == 32):
raise Exception(f"unexpected reason when sending 'update_fail_malformed_htlc': {reason!r}")
self.received_htlcs_pending_removal.add((chan, htlc_id))
chan.fail_htlc(htlc_id)
self.send_message(
"update_fail_malformed_htlc",
channel_id=chan.channel_id,
@ -1800,8 +1811,13 @@ class Peer(Logger):
async def htlc_switch(self):
await self.initialized
while True:
await asyncio.sleep(0.1)
self._htlc_switch_iterdone_event.set()
self._htlc_switch_iterdone_event.clear()
await asyncio.sleep(0.1) # TODO maybe make this partly event-driven
self._htlc_switch_iterstart_event.set()
self._htlc_switch_iterstart_event.clear()
self.ping_if_required()
self._maybe_cleanup_received_htlcs_pending_removal()
for chan_id, chan in self.channels.items():
if not chan.can_send_ctx_updates():
continue
@ -1853,6 +1869,29 @@ class Peer(Logger):
for htlc_id in done:
unfulfilled.pop(htlc_id)
def _maybe_cleanup_received_htlcs_pending_removal(self) -> None:
done = set()
for chan, htlc_id in self.received_htlcs_pending_removal:
if chan.hm.is_htlc_irrevocably_removed_yet(htlc_proposer=REMOTE, htlc_id=htlc_id):
done.add((chan, htlc_id))
if done:
for key in done:
self.received_htlcs_pending_removal.remove(key)
self.received_htlc_removed_event.set()
self.received_htlc_removed_event.clear()
async def wait_one_htlc_switch_iteration(self) -> None:
"""Waits until the HTLC switch does a full iteration or the peer disconnects,
whichever happens first.
"""
async def htlc_switch_iteration():
await self._htlc_switch_iterstart_event.wait()
await self._htlc_switch_iterdone_event.wait()
async with TaskGroup(wait=any) as group:
await group.spawn(htlc_switch_iteration())
await group.spawn(self.got_disconnected.wait())
async def process_unfulfilled_htlc(
self, *,
chan: Channel,

41
electrum/lnworker.py

@ -22,7 +22,7 @@ import urllib.parse
import dns.resolver
import dns.exception
from aiorpcx import run_in_thread, TaskGroup, NetAddress
from aiorpcx import run_in_thread, TaskGroup, NetAddress, ignore_after
from . import constants, util
from . import keystore
@ -195,6 +195,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
self.features = features
self.network = None # type: Optional[Network]
self.config = None # type: Optional[SimpleConfig]
self.stopping_soon = False # whether we are being shut down
util.register_callback(self.on_proxy_changed, ['proxy_set'])
@ -268,6 +269,8 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
async def _maintain_connectivity(self):
while True:
await asyncio.sleep(1)
if self.stopping_soon:
return
now = time.time()
if len(self._peers) >= NUM_PEERS_TARGET:
continue
@ -575,6 +578,7 @@ class LNWallet(LNWorker):
lnwatcher: Optional['LNWalletWatcher']
MPP_EXPIRY = 120
TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3 # seconds
def __init__(self, wallet: 'Abstract_Wallet', xprv):
self.wallet = wallet
@ -707,9 +711,32 @@ class LNWallet(LNWorker):
asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
async def stop(self):
await super().stop()
await self.lnwatcher.stop()
self.lnwatcher = None
self.stopping_soon = True
if self.listen_server: # stop accepting new peers
self.listen_server.close()
async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS):
await self.wait_for_received_pending_htlcs_to_get_removed()
await LNWorker.stop(self)
if self.lnwatcher:
await self.lnwatcher.stop()
self.lnwatcher = None
async def wait_for_received_pending_htlcs_to_get_removed(self):
assert self.stopping_soon is True
# We try to fail pending MPP HTLCs, and wait a bit for them to get removed.
# Note: even without MPP, if we just failed/fulfilled an HTLC, it is good
# to wait a bit for it to become irrevocably removed.
# Note: we don't wait for *all htlcs* to get removed, only for those
# that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed
async with TaskGroup() as group:
for peer in self.peers.values():
await group.spawn(peer.wait_one_htlc_switch_iteration())
while True:
if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()):
break
async with TaskGroup(wait=any) as group:
for peer in self.peers.values():
await group.spawn(peer.received_htlc_removed_event.wait())
def peer_closed(self, peer):
for chan in self.channels_for_peer(peer.pubkey).values():
@ -1635,7 +1662,9 @@ class LNWallet(LNWorker):
if not is_accepted and not is_expired:
total = sum([_htlc.amount_msat for scid, _htlc in htlc_set])
first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set])
if time.time() - first_timestamp > self.MPP_EXPIRY:
if self.stopping_soon:
is_expired = True # try to time out pending HTLCs before shutting down
elif time.time() - first_timestamp > self.MPP_EXPIRY:
is_expired = True
elif total == expected_msat:
is_accepted = True
@ -1897,6 +1926,8 @@ class LNWallet(LNWorker):
async def reestablish_peers_and_channels(self):
while True:
await asyncio.sleep(1)
if self.stopping_soon:
return
for chan in self.channels.values():
if chan.is_closed():
continue

62
electrum/tests/test_lnpeer.py

@ -10,7 +10,7 @@ from concurrent import futures
import unittest
from typing import Iterable, NamedTuple, Tuple, List
from aiorpcx import TaskGroup
from aiorpcx import TaskGroup, timeout_after, TaskTimeout
from electrum import bitcoin
from electrum import constants
@ -113,7 +113,8 @@ class MockWallet:
class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0
def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name):
self.name = name
@ -121,6 +122,9 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
self.node_keypair = local_keypair
self.network = MockNetwork(tx_queue)
self.taskgroup = TaskGroup()
self.lnwatcher = None
self.listen_server = None
self._channels = {chan.channel_id: chan for chan in chans}
self.payments = {}
self.logs = defaultdict(list)
@ -147,6 +151,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
self.trampoline_forwarding_failures = {}
self.inflight_payments = set()
self.preimages = {}
self.stopping_soon = False
def get_invoice_status(self, key):
pass
@ -183,6 +188,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
return self.name
async def stop(self):
await LNWallet.stop(self)
if self.channel_db:
self.channel_db.stop()
await self.channel_db.stopped_event.wait()
@ -215,6 +221,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
_calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice
handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc
is_trampoline_peer = LNWallet.is_trampoline_peer
wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed
on_proxy_changed = LNWallet.on_proxy_changed
class MockTransport:
@ -290,13 +298,9 @@ class SquareGraph(NamedTuple):
def all_lnworkers(self) -> Iterable[MockLNWallet]:
return self.w_a, self.w_b, self.w_c, self.w_d
async def stop_and_cleanup(self):
async with TaskGroup() as group:
for lnworker in self.all_lnworkers():
await group.spawn(lnworker.stop())
class PaymentDone(Exception): pass
class TestSuccess(Exception): pass
class TestPeer(ElectrumTestCase):
@ -836,6 +840,50 @@ class TestPeer(ElectrumTestCase):
graph = self.prepare_chans_and_peers_in_square()
self._run_mpp(graph, {'alice_uses_trampoline':True, 'attempts':1}, {'alice_uses_trampoline':True, 'attempts':3})
@needs_test_with_all_chacha20_implementations
def test_fail_pending_htlcs_on_shutdown(self):
"""Alice tries to pay Dave via MPP. Dave receives some HTLCs but not all.
Dave shuts down (stops wallet).
We test if Dave fails the pending HTLCs during shutdown.
"""
graph = self.prepare_chans_and_peers_in_square()
self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL))
self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL))
amount_to_pay = 600_000_000_000
peers = graph.all_peers()
graph.w_d.MPP_EXPIRY = 120
graph.w_d.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3
async def pay():
graph.w_d.features |= LnFeatures.BASIC_MPP_OPT
graph.w_b.enable_htlc_forwarding.clear() # Bob will hold forwarded HTLCs
assert graph.w_a.network.channel_db is not None
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay)
try:
async with timeout_after(0.5):
result, log = await graph.w_a.pay_invoice(pay_req, attempts=1)
except TaskTimeout:
# by now Dave hopefully received some HTLCs:
self.assertTrue(len(graph.chan_dc.hm.htlcs(LOCAL)) > 0)
self.assertTrue(len(graph.chan_dc.hm.htlcs(REMOTE)) > 0)
else:
self.fail(f"pay_invoice finished but was not supposed to. result={result}")
await graph.w_d.stop()
# Dave is supposed to have failed the pending incomplete MPP HTLCs
self.assertEqual(0, len(graph.chan_dc.hm.htlcs(LOCAL)))
self.assertEqual(0, len(graph.chan_dc.hm.htlcs(REMOTE)))
raise TestSuccess()
async def f():
async with TaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2)
await group.spawn(pay())
with self.assertRaises(TestSuccess):
run(f())
@needs_test_with_all_chacha20_implementations
def test_close(self):
alice_channel, bob_channel = create_test_channels()

Loading…
Cancel
Save