diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 42a07ea19..bb935f2f1 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -1360,34 +1360,36 @@ class Peer(Logger): raise OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'') outgoing_chan_upd = next_chan.get_outgoing_gossip_channel_update()[2:] outgoing_chan_upd_len = len(outgoing_chan_upd).to_bytes(2, byteorder="big") + outgoing_chan_upd_message = outgoing_chan_upd_len + outgoing_chan_upd if not next_chan.can_send_update_add_htlc(): self.logger.info(f"cannot forward htlc. next_chan {next_chan_scid} cannot send ctx updates. " f"chan state {next_chan.get_state()!r}, peer state: {next_chan.peer_state!r}") - data = outgoing_chan_upd_len + outgoing_chan_upd - raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=data) + raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message) + try: + next_amount_msat_htlc = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"] + except: + raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00') + if not next_chan.can_pay(next_amount_msat_htlc): + self.logger.info(f"cannot forward htlc due to transient errors (likely due to insufficient funds)") + raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message) try: next_cltv_expiry = processed_onion.hop_data.payload["outgoing_cltv_value"]["outgoing_cltv_value"] except: raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00') if htlc.cltv_expiry - next_cltv_expiry < next_chan.forwarding_cltv_expiry_delta: - data = htlc.cltv_expiry.to_bytes(4, byteorder="big") + outgoing_chan_upd_len + outgoing_chan_upd + data = htlc.cltv_expiry.to_bytes(4, byteorder="big") + outgoing_chan_upd_message raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_CLTV_EXPIRY, data=data) if htlc.cltv_expiry - lnutil.MIN_FINAL_CLTV_EXPIRY_ACCEPTED <= local_height \ or next_cltv_expiry <= local_height: - data = outgoing_chan_upd_len + outgoing_chan_upd - raise OnionRoutingFailure(code=OnionFailureCode.EXPIRY_TOO_SOON, data=data) + raise OnionRoutingFailure(code=OnionFailureCode.EXPIRY_TOO_SOON, data=outgoing_chan_upd_message) if max(htlc.cltv_expiry, next_cltv_expiry) > local_height + lnutil.NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE: raise OnionRoutingFailure(code=OnionFailureCode.EXPIRY_TOO_FAR, data=b'') - try: - next_amount_msat_htlc = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"] - except: - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00') forwarding_fees = fee_for_edge_msat( forwarded_amount_msat=next_amount_msat_htlc, fee_base_msat=next_chan.forwarding_fee_base_msat, fee_proportional_millionths=next_chan.forwarding_fee_proportional_millionths) if htlc.amount_msat - next_amount_msat_htlc < forwarding_fees: - data = next_amount_msat_htlc.to_bytes(8, byteorder="big") + outgoing_chan_upd_len + outgoing_chan_upd + data = next_amount_msat_htlc.to_bytes(8, byteorder="big") + outgoing_chan_upd_message raise OnionRoutingFailure(code=OnionFailureCode.FEE_INSUFFICIENT, data=data) self.logger.info(f'forwarding htlc to {next_chan.node_id}') next_htlc = UpdateAddHtlc( @@ -1409,8 +1411,7 @@ class Peer(Logger): ) except BaseException as e: self.logger.info(f"failed to forward htlc: error sending message. {e}") - data = outgoing_chan_upd_len + outgoing_chan_upd - raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=data) + raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message) return next_chan_scid, next_htlc.htlc_id def maybe_forward_trampoline( diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index cfe54323c..49f71386b 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -8,7 +8,7 @@ import logging import concurrent from concurrent import futures import unittest -from typing import Iterable, NamedTuple, Tuple, List +from typing import Iterable, NamedTuple, Tuple, List, Dict from aiorpcx import TaskGroup, timeout_after, TaskTimeout @@ -221,6 +221,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): is_trampoline_peer = LNWallet.is_trampoline_peer wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed on_proxy_changed = LNWallet.on_proxy_changed + _decode_channel_update_msg = LNWallet._decode_channel_update_msg + _handle_chanupd_from_failed_htlc = LNWallet._handle_chanupd_from_failed_htlc class MockTransport: @@ -347,12 +349,38 @@ class TestPeer(TestCaseForTestnet): p2.mark_open(bob_channel) return p1, p2, w1, w2, q1, q2 - def prepare_chans_and_peers_in_square(self) -> SquareGraph: + def prepare_chans_and_peers_in_square(self, funds_distribution: Dict[str, Tuple[int, int]]=None) -> SquareGraph: + if not funds_distribution: + funds_distribution = {} key_a, key_b, key_c, key_d = [keypair() for i in range(4)] - chan_ab, chan_ba = create_test_channels(alice_name="alice", bob_name="bob", alice_pubkey=key_a.pubkey, bob_pubkey=key_b.pubkey) - chan_ac, chan_ca = create_test_channels(alice_name="alice", bob_name="carol", alice_pubkey=key_a.pubkey, bob_pubkey=key_c.pubkey) - chan_bd, chan_db = create_test_channels(alice_name="bob", bob_name="dave", alice_pubkey=key_b.pubkey, bob_pubkey=key_d.pubkey) - chan_cd, chan_dc = create_test_channels(alice_name="carol", bob_name="dave", alice_pubkey=key_c.pubkey, bob_pubkey=key_d.pubkey) + local_balance, remote_balance = funds_distribution.get('ab') or (None, None) + chan_ab, chan_ba = create_test_channels( + alice_name="alice", bob_name="bob", + alice_pubkey=key_a.pubkey, bob_pubkey=key_b.pubkey, + local_msat=local_balance, + remote_msat=remote_balance, + ) + local_balance, remote_balance = funds_distribution.get('ac') or (None, None) + chan_ac, chan_ca = create_test_channels( + alice_name="alice", bob_name="carol", + alice_pubkey=key_a.pubkey, bob_pubkey=key_c.pubkey, + local_msat=local_balance, + remote_msat=remote_balance, + ) + local_balance, remote_balance = funds_distribution.get('bd') or (None, None) + chan_bd, chan_db = create_test_channels( + alice_name="bob", bob_name="dave", + alice_pubkey=key_b.pubkey, bob_pubkey=key_d.pubkey, + local_msat=local_balance, + remote_msat=remote_balance, + ) + local_balance, remote_balance = funds_distribution.get('cd') or (None, None) + chan_cd, chan_dc = create_test_channels( + alice_name="carol", bob_name="dave", + alice_pubkey=key_c.pubkey, bob_pubkey=key_d.pubkey, + local_msat=local_balance, + remote_msat=remote_balance, + ) trans_ab, trans_ba = transport_pair(key_a, key_b, chan_ab.name, chan_ba.name) trans_ac, trans_ca = transport_pair(key_a, key_c, chan_ac.name, chan_ca.name) trans_bd, trans_db = transport_pair(key_b, key_d, chan_bd.name, chan_db.name) @@ -776,6 +804,43 @@ class TestPeer(TestCaseForTestnet): with self.assertRaises(PaymentDone): run(f()) + @needs_test_with_all_chacha20_implementations + def test_payment_with_temp_channel_failure(self): + # prepare channels such that a temporary channel failure happens at c->d + funds_distribution = { + 'ac': (200_000_000, 200_000_000), # low fees + 'cd': (50_000_000, 200_000_000), # low fees + 'ab': (200_000_000, 200_000_000), # high fees + 'bd': (200_000_000, 200_000_000), # high fees + } + # the payment happens in three attempts: + # 1. along ac->cd due to low fees with temp channel failure: + # with chanupd: ORPHANED, private channel update + # 2. along ac->cd with temp channel failure: + # with chanupd: ORPHANED, private channel update, but already received, channel gets blacklisted + # 3. along ab->bd with success + amount_to_pay = 100_000_000 + graph = self.prepare_chans_and_peers_in_square(funds_distribution) + peers = graph.all_peers() + async def pay(lnaddr, pay_req): + self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) + result, log = await graph.w_a.pay_invoice(pay_req, attempts=3) + self.assertTrue(result) + self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, log[0].failure_msg.code) + self.assertEqual(OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, log[1].failure_msg.code) + raise PaymentDone() + async def f(): + async with TaskGroup() as group: + for peer in peers: + await group.spawn(peer._message_loop()) + await group.spawn(peer.htlc_switch()) + await asyncio.sleep(0.2) + lnaddr, pay_req = await self.prepare_invoice(graph.w_d, amount_msat=amount_to_pay, include_routing_hints=True) + await group.spawn(pay(lnaddr, pay_req)) + with self.assertRaises(PaymentDone): + run(f()) + def _run_mpp(self, graph, kwargs1, kwargs2): self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL)) self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL))