diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index fee1f51a7..306708a46 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -1473,7 +1473,7 @@ class Peer(Logger): chan: Channel, htlc: UpdateAddHtlc, processed_onion: ProcessedOnionPacket, - is_trampoline: bool = False) -> Tuple[Optional[bytes], Optional[bytes]]: + is_trampoline: bool = False) -> Tuple[Optional[bytes], Optional[OnionPacket]]: """As a final recipient of an HTLC, decide if we should fulfill it. Return (preimage, trampoline_onion_packet) with at most a single element not None diff --git a/electrum/lnworker.py b/electrum/lnworker.py index a1b126ed2..6bce0e068 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -577,9 +577,9 @@ class LNWallet(LNWorker): lnwatcher: Optional['LNWalletWatcher'] def __init__(self, wallet: 'Abstract_Wallet', xprv): - Logger.__init__(self) self.wallet = wallet self.db = wallet.db + Logger.__init__(self) LNWorker.__init__(self, xprv, LNWALLET_FEATURES) self.config = wallet.config self.lnwatcher = None @@ -621,6 +621,9 @@ class LNWallet(LNWorker): def get_channel_by_id(self, channel_id: bytes) -> Optional[Channel]: return self._channels.get(channel_id, None) + def diagnostic_name(self): + return self.wallet.diagnostic_name() + @ignore_exceptions @log_exceptions async def sync_with_local_watchtower(self): diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 13d2c0742..6ca920c8c 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -30,10 +30,11 @@ from electrum.channel_db import ChannelDB from electrum.lnworker import LNWallet, NoPathFound from electrum.lnmsg import encode_msg, decode_msg from electrum.logging import console_stderr_handler, Logger -from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID +from electrum.lnworker import PaymentInfo, RECEIVED from electrum.lnonion import OnionFailureCode from electrum.lnutil import ChannelBlackList, derive_payment_secret_from_payment_preimage from electrum.lnutil import LOCAL, REMOTE +from electrum.invoices import PR_PAID, PR_UNPAID from .test_lnchannel import create_test_channels from .test_bitcoin import needs_test_with_all_chacha20_implementations @@ -112,7 +113,8 @@ class MockWallet: class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): - def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue): + def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name): + self.name = name Logger.__init__(self) NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1) self.node_keypair = local_keypair @@ -173,6 +175,9 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): def save_channel(self, chan): print("Ignoring channel save") + def diagnostic_name(self): + return self.name + get_payments = LNWallet.get_payments get_payment_info = LNWallet.get_payment_info save_payment_info = LNWallet.save_payment_info @@ -298,8 +303,8 @@ class TestPeer(ElectrumTestCase): bob_channel.node_id = k1.pubkey t1, t2 = transport_pair(k1, k2, alice_channel.name, bob_channel.name) q1, q2 = asyncio.Queue(), asyncio.Queue() - w1 = MockLNWallet(local_keypair=k1, chans=[alice_channel], tx_queue=q1) - w2 = MockLNWallet(local_keypair=k2, chans=[bob_channel], tx_queue=q2) + w1 = MockLNWallet(local_keypair=k1, chans=[alice_channel], tx_queue=q1, name=bob_channel.name) + w2 = MockLNWallet(local_keypair=k2, chans=[bob_channel], tx_queue=q2, name=alice_channel.name) p1 = Peer(w1, k2.pubkey, t1) p2 = Peer(w2, k1.pubkey, t2) w1._peers[p1.pubkey] = p1 @@ -324,10 +329,10 @@ class TestPeer(ElectrumTestCase): trans_bd, trans_db = transport_pair(key_b, key_d, chan_bd.name, chan_db.name) trans_cd, trans_dc = transport_pair(key_c, key_d, chan_cd.name, chan_dc.name) txq_a, txq_b, txq_c, txq_d = [asyncio.Queue() for i in range(4)] - w_a = MockLNWallet(local_keypair=key_a, chans=[chan_ab, chan_ac], tx_queue=txq_a) - w_b = MockLNWallet(local_keypair=key_b, chans=[chan_ba, chan_bd], tx_queue=txq_b) - w_c = MockLNWallet(local_keypair=key_c, chans=[chan_ca, chan_cd], tx_queue=txq_c) - w_d = MockLNWallet(local_keypair=key_d, chans=[chan_db, chan_dc], tx_queue=txq_d) + w_a = MockLNWallet(local_keypair=key_a, chans=[chan_ab, chan_ac], tx_queue=txq_a, name="alice") + w_b = MockLNWallet(local_keypair=key_b, chans=[chan_ba, chan_bd], tx_queue=txq_b, name="bob") + w_c = MockLNWallet(local_keypair=key_c, chans=[chan_ca, chan_cd], tx_queue=txq_c, name="carol") + w_d = MockLNWallet(local_keypair=key_d, chans=[chan_db, chan_dc], tx_queue=txq_d, name="dave") peer_ab = Peer(w_a, key_b.pubkey, trans_ab) peer_ac = Peer(w_a, key_c.pubkey, trans_ac) peer_ba = Peer(w_b, key_a.pubkey, trans_ba) @@ -489,11 +494,14 @@ class TestPeer(ElectrumTestCase): @needs_test_with_all_chacha20_implementations def test_payment(self): + """Alice pays Bob a single HTLC via direct channel.""" alice_channel, bob_channel = create_test_channels() p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) - async def pay(pay_req): + async def pay(lnaddr, pay_req): + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) result, log = await w1.pay_invoice(pay_req) self.assertTrue(result) + self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) raise PaymentDone() async def f(): async with TaskGroup() as group: @@ -503,7 +511,9 @@ class TestPeer(ElectrumTestCase): await group.spawn(p2.htlc_switch()) await asyncio.sleep(0.01) lnaddr, pay_req = await self.prepare_invoice(w2) - await group.spawn(pay(pay_req)) + invoice_features = lnaddr.get_features() + self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) + await group.spawn(pay(lnaddr, pay_req)) with self.assertRaises(PaymentDone): run(f()) @@ -614,9 +624,11 @@ class TestPeer(ElectrumTestCase): def test_payment_multihop(self): graph = self.prepare_chans_and_peers_in_square() peers = graph.all_peers() - async def pay(pay_req): + async def pay(lnaddr, pay_req): + self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) result, log = await graph.w_a.pay_invoice(pay_req) self.assertTrue(result) + self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) raise PaymentDone() async def f(): async with TaskGroup() as group: @@ -625,7 +637,7 @@ class TestPeer(ElectrumTestCase): await group.spawn(peer.htlc_switch()) await asyncio.sleep(0.2) lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) - await group.spawn(pay(pay_req)) + await group.spawn(pay(lnaddr, pay_req)) with self.assertRaises(PaymentDone): run(f()) @@ -679,9 +691,11 @@ class TestPeer(ElectrumTestCase): graph.w_b.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True) graph.w_c.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True) peers = graph.all_peers() - async def pay(pay_req): + async def pay(lnaddr, pay_req): + self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) result, log = await graph.w_a.pay_invoice(pay_req) self.assertFalse(result) + self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code) raise PaymentDone() async def f(): @@ -691,7 +705,7 @@ class TestPeer(ElectrumTestCase): await group.spawn(peer.htlc_switch()) await asyncio.sleep(0.2) lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) - await group.spawn(pay(pay_req)) + await group.spawn(pay(lnaddr, pay_req)) with self.assertRaises(PaymentDone): run(f()) @@ -702,12 +716,14 @@ class TestPeer(ElectrumTestCase): graph = self.prepare_chans_and_peers_in_square() graph.w_c.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True) peers = graph.all_peers() - async def pay(pay_req): + async def pay(lnaddr, pay_req): self.assertEqual(500000000000, graph.chan_ab.balance(LOCAL)) self.assertEqual(500000000000, graph.chan_db.balance(LOCAL)) + self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) result, log = await graph.w_a.pay_invoice(pay_req, attempts=2) self.assertEqual(2, len(log)) self.assertTrue(result) + self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) self.assertEqual([graph.chan_ac.short_channel_id, graph.chan_cd.short_channel_id], [edge.short_channel_id for edge in log[0].route]) self.assertEqual([graph.chan_ab.short_channel_id, graph.chan_bd.short_channel_id], @@ -726,7 +742,7 @@ class TestPeer(ElectrumTestCase): lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) invoice_features = lnaddr.get_features() self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) - await group.spawn(pay(pay_req)) + await group.spawn(pay(lnaddr, pay_req)) with self.assertRaises(PaymentDone): run(f()) @@ -737,8 +753,10 @@ class TestPeer(ElectrumTestCase): peers = graph.all_peers() async def pay(): lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay) + self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) result, log = await graph.w_a.pay_invoice(pay_req, attempts=attempts) if result: + self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) raise PaymentDone() else: raise NoPathFound()