diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 5dcf4972e..1435389dc 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -112,6 +112,7 @@ class Peer(Logger): self._htlc_switch_iterstart_event = asyncio.Event() self._htlc_switch_iterdone_event = asyncio.Event() self._received_revack_event = asyncio.Event() + self.downstream_htlc_resolved_event = asyncio.Event() def send_message(self, message_name: str, **kwargs): assert type(message_name) is str @@ -1198,16 +1199,17 @@ class Peer(Logger): chan.receive_fail_htlc(htlc_id, error_bytes=reason) # TODO handle exc and maybe fail channel (e.g. bad htlc_id) self.maybe_send_commitment(chan) - def maybe_send_commitment(self, chan: Channel): + def maybe_send_commitment(self, chan: Channel) -> bool: # REMOTE should revoke first before we can sign a new ctx if chan.hm.is_revack_pending(REMOTE): - return + return False # if there are no changes, we will not (and must not) send a new commitment if not chan.has_pending_changes(REMOTE): - return + return False self.logger.info(f'send_commitment. chan {chan.short_channel_id}. ctn: {chan.get_next_ctn(REMOTE)}.') sig_64, htlc_sigs = chan.sign_next_commitment() self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs)) + return True def pay(self, *, route: 'LNPaymentRoute', @@ -1424,6 +1426,7 @@ class Peer(Logger): except BaseException as e: self.logger.info(f"failed to forward htlc: error sending message. {e}") raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message) + next_peer.maybe_send_commitment(next_chan) return next_chan_scid, next_htlc.htlc_id def maybe_forward_trampoline( @@ -1845,11 +1848,14 @@ class Peer(Logger): self._htlc_switch_iterdone_event.set() self._htlc_switch_iterdone_event.clear() # We poll every 0.1 sec to check if there is work to do, - # or we can be woken up when receiving a revack. - # TODO when forwarding, we should also be woken up when there are - # certain events with the downstream peer + # or we can also be triggered via events. + # When forwarding an HTLC originating from this peer (the upstream), + # we can get triggered for events that happen on the downstream peer. + # TODO: trampoline forwarding relies on the polling async with ignore_after(0.1): - await self._received_revack_event.wait() + async with TaskGroup(wait=any) as group: + await group.spawn(self._received_revack_event.wait()) + await group.spawn(self.downstream_htlc_resolved_event.wait()) self._htlc_switch_iterstart_event.set() self._htlc_switch_iterstart_event.clear() self.ping_if_required() @@ -1861,6 +1867,8 @@ class Peer(Logger): done = set() unfulfilled = chan.unfulfilled_htlcs for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_info) in unfulfilled.items(): + if forwarding_info: + self.lnworker.downstream_htlc_to_upstream_peer_map[forwarding_info] = self.pubkey if not chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id): continue htlc = chan.hm.get_htlc_by_id(REMOTE, htlc_id) @@ -1886,6 +1894,7 @@ class Peer(Logger): error_bytes = construct_onion_error(e, onion_packet, our_onion_private_key=self.privkey) if fw_info: unfulfilled[htlc_id] = local_ctn, remote_ctn, onion_packet_hex, fw_info + self.lnworker.downstream_htlc_to_upstream_peer_map[fw_info] = self.pubkey elif preimage or error_reason or error_bytes: if preimage: if not self.lnworker.enable_htlc_settle: @@ -1904,7 +1913,10 @@ class Peer(Logger): done.add(htlc_id) # cleanup for htlc_id in done: - unfulfilled.pop(htlc_id) + local_ctn, remote_ctn, onion_packet_hex, forwarding_info = unfulfilled.pop(htlc_id) + if forwarding_info: + self.lnworker.downstream_htlc_to_upstream_peer_map.pop(forwarding_info, None) + self.maybe_send_commitment(chan) def _maybe_cleanup_received_htlcs_pending_removal(self) -> None: done = set() diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 6b1bde00d..efc0b4391 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -647,6 +647,8 @@ class LNWallet(LNWorker): self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT) self.trampoline_forwarding_failures = {} # todo: should be persisted + # map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys + self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes] def has_deterministic_node_id(self): return bool(self.db.get('lightning_xprv')) @@ -1847,8 +1849,23 @@ class LNWallet(LNWorker): info = info._replace(status=status) self.save_payment_info(info) + def _on_maybe_forwarded_htlc_resolved(self, chan: Channel, htlc_id: int) -> None: + """Called when an HTLC we offered on chan gets irrevocably fulfilled or failed. + If we find this was a forwarded HTLC, the upstream peer is notified. + """ + fw_info = chan.short_channel_id.hex(), htlc_id + upstream_peer_pubkey = self.downstream_htlc_to_upstream_peer_map.get(fw_info) + if not upstream_peer_pubkey: + return + upstream_peer = self.peers.get(upstream_peer_pubkey) + if not upstream_peer: + return + upstream_peer.downstream_htlc_resolved_event.set() + upstream_peer.downstream_htlc_resolved_event.clear() + def htlc_fulfilled(self, chan: Channel, payment_hash: bytes, htlc_id: int): util.trigger_callback('htlc_fulfilled', payment_hash, chan, htlc_id) + self._on_maybe_forwarded_htlc_resolved(chan=chan, htlc_id=htlc_id) q = self.sent_htlcs.get(payment_hash) if q: route, payment_secret, amount_msat, bucket_msat, amount_receiver_msat = self.sent_htlcs_routes[(payment_hash, chan.short_channel_id, htlc_id)] @@ -1871,6 +1888,7 @@ class LNWallet(LNWorker): failure_message: Optional['OnionRoutingFailure']): util.trigger_callback('htlc_failed', payment_hash, chan, htlc_id) + self._on_maybe_forwarded_htlc_resolved(chan=chan, htlc_id=htlc_id) q = self.sent_htlcs.get(payment_hash) if q: # detect if it is part of a bucket diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index a3a766efb..772a3af72 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -153,6 +153,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): self.inflight_payments = set() self.preimages = {} self.stopping_soon = False + self.downstream_htlc_to_upstream_peer_map = {} self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}") @@ -241,6 +242,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): on_proxy_changed = LNWallet.on_proxy_changed _decode_channel_update_msg = LNWallet._decode_channel_update_msg _handle_chanupd_from_failed_htlc = LNWallet._handle_chanupd_from_failed_htlc + _on_maybe_forwarded_htlc_resolved = LNWallet._on_maybe_forwarded_htlc_resolved class MockTransport: