Browse Source

tests: test payreq status after getting paid via LN

The test failures corresponding to single-part (non-MPP) payments expose a bug.

see 196b4c00a3/electrum/lnpeer.py (L1538-L1539)
`lnworker.add_received_htlc` is not called for single-part payments...
patch-4
SomberNight 4 years ago
parent
commit
a125cd5392
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 2
      electrum/lnpeer.py
  2. 5
      electrum/lnworker.py
  3. 50
      electrum/tests/test_lnpeer.py

2
electrum/lnpeer.py

@ -1473,7 +1473,7 @@ class Peer(Logger):
chan: Channel, chan: Channel,
htlc: UpdateAddHtlc, htlc: UpdateAddHtlc,
processed_onion: ProcessedOnionPacket, processed_onion: ProcessedOnionPacket,
is_trampoline: bool = False) -> Tuple[Optional[bytes], Optional[bytes]]: is_trampoline: bool = False) -> Tuple[Optional[bytes], Optional[OnionPacket]]:
"""As a final recipient of an HTLC, decide if we should fulfill it. """As a final recipient of an HTLC, decide if we should fulfill it.
Return (preimage, trampoline_onion_packet) with at most a single element not None Return (preimage, trampoline_onion_packet) with at most a single element not None

5
electrum/lnworker.py

@ -577,9 +577,9 @@ class LNWallet(LNWorker):
lnwatcher: Optional['LNWalletWatcher'] lnwatcher: Optional['LNWalletWatcher']
def __init__(self, wallet: 'Abstract_Wallet', xprv): def __init__(self, wallet: 'Abstract_Wallet', xprv):
Logger.__init__(self)
self.wallet = wallet self.wallet = wallet
self.db = wallet.db self.db = wallet.db
Logger.__init__(self)
LNWorker.__init__(self, xprv, LNWALLET_FEATURES) LNWorker.__init__(self, xprv, LNWALLET_FEATURES)
self.config = wallet.config self.config = wallet.config
self.lnwatcher = None self.lnwatcher = None
@ -621,6 +621,9 @@ class LNWallet(LNWorker):
def get_channel_by_id(self, channel_id: bytes) -> Optional[Channel]: def get_channel_by_id(self, channel_id: bytes) -> Optional[Channel]:
return self._channels.get(channel_id, None) return self._channels.get(channel_id, None)
def diagnostic_name(self):
return self.wallet.diagnostic_name()
@ignore_exceptions @ignore_exceptions
@log_exceptions @log_exceptions
async def sync_with_local_watchtower(self): async def sync_with_local_watchtower(self):

50
electrum/tests/test_lnpeer.py

@ -30,10 +30,11 @@ from electrum.channel_db import ChannelDB
from electrum.lnworker import LNWallet, NoPathFound from electrum.lnworker import LNWallet, NoPathFound
from electrum.lnmsg import encode_msg, decode_msg from electrum.lnmsg import encode_msg, decode_msg
from electrum.logging import console_stderr_handler, Logger from electrum.logging import console_stderr_handler, Logger
from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID from electrum.lnworker import PaymentInfo, RECEIVED
from electrum.lnonion import OnionFailureCode from electrum.lnonion import OnionFailureCode
from electrum.lnutil import ChannelBlackList, derive_payment_secret_from_payment_preimage from electrum.lnutil import ChannelBlackList, derive_payment_secret_from_payment_preimage
from electrum.lnutil import LOCAL, REMOTE from electrum.lnutil import LOCAL, REMOTE
from electrum.invoices import PR_PAID, PR_UNPAID
from .test_lnchannel import create_test_channels from .test_lnchannel import create_test_channels
from .test_bitcoin import needs_test_with_all_chacha20_implementations from .test_bitcoin import needs_test_with_all_chacha20_implementations
@ -112,7 +113,8 @@ class MockWallet:
class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue): def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name):
self.name = name
Logger.__init__(self) Logger.__init__(self)
NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1) NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
self.node_keypair = local_keypair self.node_keypair = local_keypair
@ -173,6 +175,9 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
def save_channel(self, chan): def save_channel(self, chan):
print("Ignoring channel save") print("Ignoring channel save")
def diagnostic_name(self):
return self.name
get_payments = LNWallet.get_payments get_payments = LNWallet.get_payments
get_payment_info = LNWallet.get_payment_info get_payment_info = LNWallet.get_payment_info
save_payment_info = LNWallet.save_payment_info save_payment_info = LNWallet.save_payment_info
@ -298,8 +303,8 @@ class TestPeer(ElectrumTestCase):
bob_channel.node_id = k1.pubkey bob_channel.node_id = k1.pubkey
t1, t2 = transport_pair(k1, k2, alice_channel.name, bob_channel.name) t1, t2 = transport_pair(k1, k2, alice_channel.name, bob_channel.name)
q1, q2 = asyncio.Queue(), asyncio.Queue() q1, q2 = asyncio.Queue(), asyncio.Queue()
w1 = MockLNWallet(local_keypair=k1, chans=[alice_channel], tx_queue=q1) w1 = MockLNWallet(local_keypair=k1, chans=[alice_channel], tx_queue=q1, name=bob_channel.name)
w2 = MockLNWallet(local_keypair=k2, chans=[bob_channel], tx_queue=q2) w2 = MockLNWallet(local_keypair=k2, chans=[bob_channel], tx_queue=q2, name=alice_channel.name)
p1 = Peer(w1, k2.pubkey, t1) p1 = Peer(w1, k2.pubkey, t1)
p2 = Peer(w2, k1.pubkey, t2) p2 = Peer(w2, k1.pubkey, t2)
w1._peers[p1.pubkey] = p1 w1._peers[p1.pubkey] = p1
@ -324,10 +329,10 @@ class TestPeer(ElectrumTestCase):
trans_bd, trans_db = transport_pair(key_b, key_d, chan_bd.name, chan_db.name) trans_bd, trans_db = transport_pair(key_b, key_d, chan_bd.name, chan_db.name)
trans_cd, trans_dc = transport_pair(key_c, key_d, chan_cd.name, chan_dc.name) trans_cd, trans_dc = transport_pair(key_c, key_d, chan_cd.name, chan_dc.name)
txq_a, txq_b, txq_c, txq_d = [asyncio.Queue() for i in range(4)] txq_a, txq_b, txq_c, txq_d = [asyncio.Queue() for i in range(4)]
w_a = MockLNWallet(local_keypair=key_a, chans=[chan_ab, chan_ac], tx_queue=txq_a) w_a = MockLNWallet(local_keypair=key_a, chans=[chan_ab, chan_ac], tx_queue=txq_a, name="alice")
w_b = MockLNWallet(local_keypair=key_b, chans=[chan_ba, chan_bd], tx_queue=txq_b) w_b = MockLNWallet(local_keypair=key_b, chans=[chan_ba, chan_bd], tx_queue=txq_b, name="bob")
w_c = MockLNWallet(local_keypair=key_c, chans=[chan_ca, chan_cd], tx_queue=txq_c) w_c = MockLNWallet(local_keypair=key_c, chans=[chan_ca, chan_cd], tx_queue=txq_c, name="carol")
w_d = MockLNWallet(local_keypair=key_d, chans=[chan_db, chan_dc], tx_queue=txq_d) w_d = MockLNWallet(local_keypair=key_d, chans=[chan_db, chan_dc], tx_queue=txq_d, name="dave")
peer_ab = Peer(w_a, key_b.pubkey, trans_ab) peer_ab = Peer(w_a, key_b.pubkey, trans_ab)
peer_ac = Peer(w_a, key_c.pubkey, trans_ac) peer_ac = Peer(w_a, key_c.pubkey, trans_ac)
peer_ba = Peer(w_b, key_a.pubkey, trans_ba) peer_ba = Peer(w_b, key_a.pubkey, trans_ba)
@ -489,11 +494,14 @@ class TestPeer(ElectrumTestCase):
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
def test_payment(self): def test_payment(self):
"""Alice pays Bob a single HTLC via direct channel."""
alice_channel, bob_channel = create_test_channels() alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
async def pay(pay_req): async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash))
result, log = await w1.pay_invoice(pay_req) result, log = await w1.pay_invoice(pay_req)
self.assertTrue(result) self.assertTrue(result)
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash))
raise PaymentDone() raise PaymentDone()
async def f(): async def f():
async with TaskGroup() as group: async with TaskGroup() as group:
@ -503,7 +511,9 @@ class TestPeer(ElectrumTestCase):
await group.spawn(p2.htlc_switch()) await group.spawn(p2.htlc_switch())
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
lnaddr, pay_req = await self.prepare_invoice(w2) lnaddr, pay_req = await self.prepare_invoice(w2)
await group.spawn(pay(pay_req)) invoice_features = lnaddr.get_features()
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
await group.spawn(pay(lnaddr, pay_req))
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
run(f()) run(f())
@ -614,9 +624,11 @@ class TestPeer(ElectrumTestCase):
def test_payment_multihop(self): def test_payment_multihop(self):
graph = self.prepare_chans_and_peers_in_square() graph = self.prepare_chans_and_peers_in_square()
peers = graph.all_peers() peers = graph.all_peers()
async def pay(pay_req): async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
result, log = await graph.w_a.pay_invoice(pay_req) result, log = await graph.w_a.pay_invoice(pay_req)
self.assertTrue(result) self.assertTrue(result)
self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
raise PaymentDone() raise PaymentDone()
async def f(): async def f():
async with TaskGroup() as group: async with TaskGroup() as group:
@ -625,7 +637,7 @@ class TestPeer(ElectrumTestCase):
await group.spawn(peer.htlc_switch()) await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True)
await group.spawn(pay(pay_req)) await group.spawn(pay(lnaddr, pay_req))
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
run(f()) run(f())
@ -679,9 +691,11 @@ class TestPeer(ElectrumTestCase):
graph.w_b.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True) graph.w_b.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True)
graph.w_c.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True) graph.w_c.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True)
peers = graph.all_peers() peers = graph.all_peers()
async def pay(pay_req): async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
result, log = await graph.w_a.pay_invoice(pay_req) result, log = await graph.w_a.pay_invoice(pay_req)
self.assertFalse(result) self.assertFalse(result)
self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code) self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code)
raise PaymentDone() raise PaymentDone()
async def f(): async def f():
@ -691,7 +705,7 @@ class TestPeer(ElectrumTestCase):
await group.spawn(peer.htlc_switch()) await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True)
await group.spawn(pay(pay_req)) await group.spawn(pay(lnaddr, pay_req))
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
run(f()) run(f())
@ -702,12 +716,14 @@ class TestPeer(ElectrumTestCase):
graph = self.prepare_chans_and_peers_in_square() graph = self.prepare_chans_and_peers_in_square()
graph.w_c.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True) graph.w_c.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True)
peers = graph.all_peers() peers = graph.all_peers()
async def pay(pay_req): async def pay(lnaddr, pay_req):
self.assertEqual(500000000000, graph.chan_ab.balance(LOCAL)) self.assertEqual(500000000000, graph.chan_ab.balance(LOCAL))
self.assertEqual(500000000000, graph.chan_db.balance(LOCAL)) self.assertEqual(500000000000, graph.chan_db.balance(LOCAL))
self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
result, log = await graph.w_a.pay_invoice(pay_req, attempts=2) result, log = await graph.w_a.pay_invoice(pay_req, attempts=2)
self.assertEqual(2, len(log)) self.assertEqual(2, len(log))
self.assertTrue(result) self.assertTrue(result)
self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
self.assertEqual([graph.chan_ac.short_channel_id, graph.chan_cd.short_channel_id], self.assertEqual([graph.chan_ac.short_channel_id, graph.chan_cd.short_channel_id],
[edge.short_channel_id for edge in log[0].route]) [edge.short_channel_id for edge in log[0].route])
self.assertEqual([graph.chan_ab.short_channel_id, graph.chan_bd.short_channel_id], self.assertEqual([graph.chan_ab.short_channel_id, graph.chan_bd.short_channel_id],
@ -726,7 +742,7 @@ class TestPeer(ElectrumTestCase):
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True)
invoice_features = lnaddr.get_features() invoice_features = lnaddr.get_features()
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
await group.spawn(pay(pay_req)) await group.spawn(pay(lnaddr, pay_req))
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
run(f()) run(f())
@ -737,8 +753,10 @@ class TestPeer(ElectrumTestCase):
peers = graph.all_peers() peers = graph.all_peers()
async def pay(): async def pay():
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay) lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay)
self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
result, log = await graph.w_a.pay_invoice(pay_req, attempts=attempts) result, log = await graph.w_a.pay_invoice(pay_req, attempts=attempts)
if result: if result:
self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
raise PaymentDone() raise PaymentDone()
else: else:
raise NoPathFound() raise NoPathFound()

Loading…
Cancel
Save