Browse Source

lnpeer: more forwarding is now event-driven

This should make unit tests less reliant on sleeps.
patch-4
SomberNight 3 years ago
parent
commit
56b03e2e8d
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 28
      electrum/lnpeer.py
  2. 18
      electrum/lnworker.py
  3. 2
      electrum/tests/test_lnpeer.py

28
electrum/lnpeer.py

@ -112,6 +112,7 @@ class Peer(Logger):
self._htlc_switch_iterstart_event = asyncio.Event() self._htlc_switch_iterstart_event = asyncio.Event()
self._htlc_switch_iterdone_event = asyncio.Event() self._htlc_switch_iterdone_event = asyncio.Event()
self._received_revack_event = asyncio.Event() self._received_revack_event = asyncio.Event()
self.downstream_htlc_resolved_event = asyncio.Event()
def send_message(self, message_name: str, **kwargs): def send_message(self, message_name: str, **kwargs):
assert type(message_name) is str assert type(message_name) is str
@ -1198,16 +1199,17 @@ class Peer(Logger):
chan.receive_fail_htlc(htlc_id, error_bytes=reason) # TODO handle exc and maybe fail channel (e.g. bad htlc_id) chan.receive_fail_htlc(htlc_id, error_bytes=reason) # TODO handle exc and maybe fail channel (e.g. bad htlc_id)
self.maybe_send_commitment(chan) self.maybe_send_commitment(chan)
def maybe_send_commitment(self, chan: Channel): def maybe_send_commitment(self, chan: Channel) -> bool:
# REMOTE should revoke first before we can sign a new ctx # REMOTE should revoke first before we can sign a new ctx
if chan.hm.is_revack_pending(REMOTE): if chan.hm.is_revack_pending(REMOTE):
return return False
# if there are no changes, we will not (and must not) send a new commitment # if there are no changes, we will not (and must not) send a new commitment
if not chan.has_pending_changes(REMOTE): if not chan.has_pending_changes(REMOTE):
return return False
self.logger.info(f'send_commitment. chan {chan.short_channel_id}. ctn: {chan.get_next_ctn(REMOTE)}.') self.logger.info(f'send_commitment. chan {chan.short_channel_id}. ctn: {chan.get_next_ctn(REMOTE)}.')
sig_64, htlc_sigs = chan.sign_next_commitment() sig_64, htlc_sigs = chan.sign_next_commitment()
self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs)) self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs))
return True
def pay(self, *, def pay(self, *,
route: 'LNPaymentRoute', route: 'LNPaymentRoute',
@ -1424,6 +1426,7 @@ class Peer(Logger):
except BaseException as e: except BaseException as e:
self.logger.info(f"failed to forward htlc: error sending message. {e}") self.logger.info(f"failed to forward htlc: error sending message. {e}")
raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message) raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message)
next_peer.maybe_send_commitment(next_chan)
return next_chan_scid, next_htlc.htlc_id return next_chan_scid, next_htlc.htlc_id
def maybe_forward_trampoline( def maybe_forward_trampoline(
@ -1845,11 +1848,14 @@ class Peer(Logger):
self._htlc_switch_iterdone_event.set() self._htlc_switch_iterdone_event.set()
self._htlc_switch_iterdone_event.clear() self._htlc_switch_iterdone_event.clear()
# We poll every 0.1 sec to check if there is work to do, # We poll every 0.1 sec to check if there is work to do,
# or we can be woken up when receiving a revack. # or we can also be triggered via events.
# TODO when forwarding, we should also be woken up when there are # When forwarding an HTLC originating from this peer (the upstream),
# certain events with the downstream peer # we can get triggered for events that happen on the downstream peer.
# TODO: trampoline forwarding relies on the polling
async with ignore_after(0.1): async with ignore_after(0.1):
await self._received_revack_event.wait() async with TaskGroup(wait=any) as group:
await group.spawn(self._received_revack_event.wait())
await group.spawn(self.downstream_htlc_resolved_event.wait())
self._htlc_switch_iterstart_event.set() self._htlc_switch_iterstart_event.set()
self._htlc_switch_iterstart_event.clear() self._htlc_switch_iterstart_event.clear()
self.ping_if_required() self.ping_if_required()
@ -1861,6 +1867,8 @@ class Peer(Logger):
done = set() done = set()
unfulfilled = chan.unfulfilled_htlcs unfulfilled = chan.unfulfilled_htlcs
for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_info) in unfulfilled.items(): for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_info) in unfulfilled.items():
if forwarding_info:
self.lnworker.downstream_htlc_to_upstream_peer_map[forwarding_info] = self.pubkey
if not chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id): if not chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id):
continue continue
htlc = chan.hm.get_htlc_by_id(REMOTE, htlc_id) htlc = chan.hm.get_htlc_by_id(REMOTE, htlc_id)
@ -1886,6 +1894,7 @@ class Peer(Logger):
error_bytes = construct_onion_error(e, onion_packet, our_onion_private_key=self.privkey) error_bytes = construct_onion_error(e, onion_packet, our_onion_private_key=self.privkey)
if fw_info: if fw_info:
unfulfilled[htlc_id] = local_ctn, remote_ctn, onion_packet_hex, fw_info unfulfilled[htlc_id] = local_ctn, remote_ctn, onion_packet_hex, fw_info
self.lnworker.downstream_htlc_to_upstream_peer_map[fw_info] = self.pubkey
elif preimage or error_reason or error_bytes: elif preimage or error_reason or error_bytes:
if preimage: if preimage:
if not self.lnworker.enable_htlc_settle: if not self.lnworker.enable_htlc_settle:
@ -1904,7 +1913,10 @@ class Peer(Logger):
done.add(htlc_id) done.add(htlc_id)
# cleanup # cleanup
for htlc_id in done: for htlc_id in done:
unfulfilled.pop(htlc_id) local_ctn, remote_ctn, onion_packet_hex, forwarding_info = unfulfilled.pop(htlc_id)
if forwarding_info:
self.lnworker.downstream_htlc_to_upstream_peer_map.pop(forwarding_info, None)
self.maybe_send_commitment(chan)
def _maybe_cleanup_received_htlcs_pending_removal(self) -> None: def _maybe_cleanup_received_htlcs_pending_removal(self) -> None:
done = set() done = set()

18
electrum/lnworker.py

@ -647,6 +647,8 @@ class LNWallet(LNWorker):
self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT) self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT)
self.trampoline_forwarding_failures = {} # todo: should be persisted self.trampoline_forwarding_failures = {} # todo: should be persisted
# map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys
self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes]
def has_deterministic_node_id(self): def has_deterministic_node_id(self):
return bool(self.db.get('lightning_xprv')) return bool(self.db.get('lightning_xprv'))
@ -1847,8 +1849,23 @@ class LNWallet(LNWorker):
info = info._replace(status=status) info = info._replace(status=status)
self.save_payment_info(info) self.save_payment_info(info)
def _on_maybe_forwarded_htlc_resolved(self, chan: Channel, htlc_id: int) -> None:
"""Called when an HTLC we offered on chan gets irrevocably fulfilled or failed.
If we find this was a forwarded HTLC, the upstream peer is notified.
"""
fw_info = chan.short_channel_id.hex(), htlc_id
upstream_peer_pubkey = self.downstream_htlc_to_upstream_peer_map.get(fw_info)
if not upstream_peer_pubkey:
return
upstream_peer = self.peers.get(upstream_peer_pubkey)
if not upstream_peer:
return
upstream_peer.downstream_htlc_resolved_event.set()
upstream_peer.downstream_htlc_resolved_event.clear()
def htlc_fulfilled(self, chan: Channel, payment_hash: bytes, htlc_id: int): def htlc_fulfilled(self, chan: Channel, payment_hash: bytes, htlc_id: int):
util.trigger_callback('htlc_fulfilled', payment_hash, chan, htlc_id) util.trigger_callback('htlc_fulfilled', payment_hash, chan, htlc_id)
self._on_maybe_forwarded_htlc_resolved(chan=chan, htlc_id=htlc_id)
q = self.sent_htlcs.get(payment_hash) q = self.sent_htlcs.get(payment_hash)
if q: if q:
route, payment_secret, amount_msat, bucket_msat, amount_receiver_msat = self.sent_htlcs_routes[(payment_hash, chan.short_channel_id, htlc_id)] route, payment_secret, amount_msat, bucket_msat, amount_receiver_msat = self.sent_htlcs_routes[(payment_hash, chan.short_channel_id, htlc_id)]
@ -1871,6 +1888,7 @@ class LNWallet(LNWorker):
failure_message: Optional['OnionRoutingFailure']): failure_message: Optional['OnionRoutingFailure']):
util.trigger_callback('htlc_failed', payment_hash, chan, htlc_id) util.trigger_callback('htlc_failed', payment_hash, chan, htlc_id)
self._on_maybe_forwarded_htlc_resolved(chan=chan, htlc_id=htlc_id)
q = self.sent_htlcs.get(payment_hash) q = self.sent_htlcs.get(payment_hash)
if q: if q:
# detect if it is part of a bucket # detect if it is part of a bucket

2
electrum/tests/test_lnpeer.py

@ -153,6 +153,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
self.inflight_payments = set() self.inflight_payments = set()
self.preimages = {} self.preimages = {}
self.stopping_soon = False self.stopping_soon = False
self.downstream_htlc_to_upstream_peer_map = {}
self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}") self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}")
@ -241,6 +242,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
on_proxy_changed = LNWallet.on_proxy_changed on_proxy_changed = LNWallet.on_proxy_changed
_decode_channel_update_msg = LNWallet._decode_channel_update_msg _decode_channel_update_msg = LNWallet._decode_channel_update_msg
_handle_chanupd_from_failed_htlc = LNWallet._handle_chanupd_from_failed_htlc _handle_chanupd_from_failed_htlc = LNWallet._handle_chanupd_from_failed_htlc
_on_maybe_forwarded_htlc_resolved = LNWallet._on_maybe_forwarded_htlc_resolved
class MockTransport: class MockTransport:

Loading…
Cancel
Save