Browse Source

test_lnpeer: add some multi-hop payment unit tests

master
SomberNight 5 years ago
parent
commit
cc4029c335
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 4
      electrum/lnpeer.py
  2. 1
      electrum/lnworker.py
  3. 184
      electrum/tests/test_lnpeer.py

4
electrum/lnpeer.py

@ -1510,7 +1510,9 @@ class Peer(Logger):
self.logger.info(f"error processing onion packet: {e!r}") self.logger.info(f"error processing onion packet: {e!r}")
error_reason = OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') error_reason = OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
else: else:
if processed_onion.are_we_final: if self.lnworker._fail_htlcs_with_temp_node_failure:
error_reason = OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
elif processed_onion.are_we_final:
preimage, error_reason = self.maybe_fulfill_htlc( preimage, error_reason = self.maybe_fulfill_htlc(
chan=chan, chan=chan,
htlc=htlc, htlc=htlc,

1
electrum/lnworker.py

@ -494,6 +494,7 @@ class LNWallet(LNWorker):
# used in tests # used in tests
self.enable_htlc_settle = asyncio.Event() self.enable_htlc_settle = asyncio.Event()
self.enable_htlc_settle.set() self.enable_htlc_settle.set()
self._fail_htlcs_with_temp_node_failure = False
# note: accessing channels (besides simple lookup) needs self.lock! # note: accessing channels (besides simple lookup) needs self.lock!
self._channels = {} # type: Dict[bytes, Channel] self._channels = {} # type: Dict[bytes, Channel]

184
electrum/tests/test_lnpeer.py

@ -8,7 +8,7 @@ import logging
import concurrent import concurrent
from concurrent import futures from concurrent import futures
import unittest import unittest
from typing import Iterable from typing import Iterable, NamedTuple
from aiorpcx import TaskGroup from aiorpcx import TaskGroup
@ -24,12 +24,13 @@ from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner
from electrum.lnchannel import ChannelState, PeerState, Channel from electrum.lnchannel import ChannelState, PeerState, Channel
from electrum.lnrouter import LNPathFinder from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent
from electrum.channel_db import ChannelDB from electrum.channel_db import ChannelDB
from electrum.lnworker import LNWallet, NoPathFound from electrum.lnworker import LNWallet, NoPathFound
from electrum.lnmsg import encode_msg, decode_msg from electrum.lnmsg import encode_msg, decode_msg
from electrum.logging import console_stderr_handler, Logger from electrum.logging import console_stderr_handler, Logger
from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID
from electrum.lnonion import OnionFailureCode
from .test_lnchannel import create_test_channels from .test_lnchannel import create_test_channels
from .test_bitcoin import needs_test_with_all_chacha20_implementations from .test_bitcoin import needs_test_with_all_chacha20_implementations
@ -117,6 +118,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
# used in tests # used in tests
self.enable_htlc_settle = asyncio.Event() self.enable_htlc_settle = asyncio.Event()
self.enable_htlc_settle.set() self.enable_htlc_settle.set()
self._fail_htlcs_with_temp_node_failure = False
def get_invoice_status(self, key): def get_invoice_status(self, key):
pass pass
@ -212,6 +214,37 @@ def transport_pair(k1, k2, name1, name2):
return t1, t2 return t1, t2
class DiamondGraph(NamedTuple):
# A
# / \
# B C
# \ /
# D
w_a: MockLNWallet
w_b: MockLNWallet
w_c: MockLNWallet
w_d: MockLNWallet
peer_ab: Peer
peer_ac: Peer
peer_ba: Peer
peer_bd: Peer
peer_ca: Peer
peer_cd: Peer
peer_db: Peer
peer_dc: Peer
chan_ab: Channel
chan_ac: Channel
chan_ba: Channel
chan_bd: Channel
chan_ca: Channel
chan_cd: Channel
chan_db: Channel
chan_dc: Channel
def all_peers(self) -> Iterable[Peer]:
return self.peer_ab, self.peer_ac, self.peer_ba, self.peer_bd, self.peer_ca, self.peer_cd, self.peer_db, self.peer_dc
class PaymentDone(Exception): pass class PaymentDone(Exception): pass
@ -252,6 +285,77 @@ class TestPeer(ElectrumTestCase):
p2.mark_open(bob_channel) p2.mark_open(bob_channel)
return p1, p2, w1, w2, q1, q2 return p1, p2, w1, w2, q1, q2
def prepare_chans_and_peers_in_diamond(self) -> DiamondGraph:
key_a, key_b, key_c, key_d = [keypair() for i in range(4)]
chan_ab, chan_ba = create_test_channels(alice_name="alice", bob_name="bob", alice_pubkey=key_a.pubkey, bob_pubkey=key_b.pubkey)
chan_ac, chan_ca = create_test_channels(alice_name="alice", bob_name="carol", alice_pubkey=key_a.pubkey, bob_pubkey=key_c.pubkey)
chan_bd, chan_db = create_test_channels(alice_name="bob", bob_name="dave", alice_pubkey=key_b.pubkey, bob_pubkey=key_d.pubkey)
chan_cd, chan_dc = create_test_channels(alice_name="carol", bob_name="dave", alice_pubkey=key_c.pubkey, bob_pubkey=key_d.pubkey)
trans_ab, trans_ba = transport_pair(key_a, key_b, chan_ab.name, chan_ba.name)
trans_ac, trans_ca = transport_pair(key_a, key_c, chan_ac.name, chan_ca.name)
trans_bd, trans_db = transport_pair(key_b, key_d, chan_bd.name, chan_db.name)
trans_cd, trans_dc = transport_pair(key_c, key_d, chan_cd.name, chan_dc.name)
txq_a, txq_b, txq_c, txq_d = [asyncio.Queue() for i in range(4)]
w_a = MockLNWallet(local_keypair=key_a, chans=[chan_ab, chan_ac], tx_queue=txq_a)
w_b = MockLNWallet(local_keypair=key_b, chans=[chan_ba, chan_bd], tx_queue=txq_b)
w_c = MockLNWallet(local_keypair=key_c, chans=[chan_ca, chan_cd], tx_queue=txq_c)
w_d = MockLNWallet(local_keypair=key_d, chans=[chan_db, chan_dc], tx_queue=txq_d)
peer_ab = Peer(w_a, key_b.pubkey, trans_ab)
peer_ac = Peer(w_a, key_c.pubkey, trans_ac)
peer_ba = Peer(w_b, key_a.pubkey, trans_ba)
peer_bd = Peer(w_b, key_d.pubkey, trans_bd)
peer_ca = Peer(w_c, key_a.pubkey, trans_ca)
peer_cd = Peer(w_c, key_d.pubkey, trans_cd)
peer_db = Peer(w_d, key_b.pubkey, trans_db)
peer_dc = Peer(w_d, key_c.pubkey, trans_dc)
w_a._peers[peer_ab.pubkey] = peer_ab
w_a._peers[peer_ac.pubkey] = peer_ac
w_b._peers[peer_ba.pubkey] = peer_ba
w_b._peers[peer_bd.pubkey] = peer_bd
w_c._peers[peer_ca.pubkey] = peer_ca
w_c._peers[peer_cd.pubkey] = peer_cd
w_d._peers[peer_db.pubkey] = peer_db
w_d._peers[peer_dc.pubkey] = peer_dc
w_b.network.config.set_key('lightning_forward_payments', True)
w_c.network.config.set_key('lightning_forward_payments', True)
# mark_open won't work if state is already OPEN.
# so set it to FUNDED
for chan in [chan_ab, chan_ac, chan_ba, chan_bd, chan_ca, chan_cd, chan_db, chan_dc]:
chan._state = ChannelState.FUNDED
# this populates the channel graph:
peer_ab.mark_open(chan_ab)
peer_ac.mark_open(chan_ac)
peer_ba.mark_open(chan_ba)
peer_bd.mark_open(chan_bd)
peer_ca.mark_open(chan_ca)
peer_cd.mark_open(chan_cd)
peer_db.mark_open(chan_db)
peer_dc.mark_open(chan_dc)
return DiamondGraph(
w_a=w_a,
w_b=w_b,
w_c=w_c,
w_d=w_d,
peer_ab=peer_ab,
peer_ac=peer_ac,
peer_ba=peer_ba,
peer_bd=peer_bd,
peer_ca=peer_ca,
peer_cd=peer_cd,
peer_db=peer_db,
peer_dc=peer_dc,
chan_ab=chan_ab,
chan_ac=chan_ac,
chan_ba=chan_ba,
chan_bd=chan_bd,
chan_ca=chan_ca,
chan_cd=chan_cd,
chan_db=chan_db,
chan_dc=chan_dc,
)
@staticmethod @staticmethod
async def prepare_invoice( async def prepare_invoice(
w2: MockLNWallet, # receiver w2: MockLNWallet, # receiver
@ -382,6 +486,82 @@ class TestPeer(ElectrumTestCase):
self.assertEqual(bob_init_balance_msat + num_payments * payment_value_sat * 1000, bob_channel.balance(HTLCOwner.LOCAL)) self.assertEqual(bob_init_balance_msat + num_payments * payment_value_sat * 1000, bob_channel.balance(HTLCOwner.LOCAL))
self.assertEqual(bob_init_balance_msat + num_payments * payment_value_sat * 1000, alice_channel.balance(HTLCOwner.REMOTE)) self.assertEqual(bob_init_balance_msat + num_payments * payment_value_sat * 1000, alice_channel.balance(HTLCOwner.REMOTE))
@needs_test_with_all_chacha20_implementations
def test_payment_multihop(self):
graph = self.prepare_chans_and_peers_in_diamond()
peers = graph.all_peers()
async def pay(pay_req):
result, log = await graph.w_a._pay(pay_req)
self.assertTrue(result)
raise PaymentDone()
async def f():
async with TaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2)
pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True)
await group.spawn(pay(pay_req))
with self.assertRaises(PaymentDone):
run(f())
@needs_test_with_all_chacha20_implementations
def test_payment_multihop_with_preselected_path(self):
graph = self.prepare_chans_and_peers_in_diamond()
peers = graph.all_peers()
async def pay(pay_req):
with self.subTest(msg="bad path: edges do not chain together"):
path = [PathEdge(node_id=graph.w_c.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id),
PathEdge(node_id=graph.w_d.node_keypair.pubkey, short_channel_id=graph.chan_bd.short_channel_id)]
result, log = await graph.w_a._pay(pay_req, full_path=path)
self.assertFalse(result)
self.assertTrue(isinstance(log[0].exception, LNPathInconsistent))
with self.subTest(msg="bad path: last node id differs from invoice pubkey"):
path = [PathEdge(node_id=graph.w_b.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id)]
result, log = await graph.w_a._pay(pay_req, full_path=path)
self.assertFalse(result)
self.assertTrue(isinstance(log[0].exception, LNPathInconsistent))
with self.subTest(msg="good path"):
path = [PathEdge(node_id=graph.w_b.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id),
PathEdge(node_id=graph.w_d.node_keypair.pubkey, short_channel_id=graph.chan_bd.short_channel_id)]
result, log = await graph.w_a._pay(pay_req, full_path=path)
self.assertTrue(result)
self.assertEqual([edge.short_channel_id for edge in path],
[edge.short_channel_id for edge in log[0].route])
raise PaymentDone()
async def f():
async with TaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2)
pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True)
await group.spawn(pay(pay_req))
with self.assertRaises(PaymentDone):
run(f())
@needs_test_with_all_chacha20_implementations
def test_payment_multihop_temp_node_failure(self):
graph = self.prepare_chans_and_peers_in_diamond()
graph.w_b._fail_htlcs_with_temp_node_failure = True
graph.w_c._fail_htlcs_with_temp_node_failure = True
peers = graph.all_peers()
async def pay(pay_req):
result, log = await graph.w_a._pay(pay_req)
self.assertFalse(result)
self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_details.failure_msg.code)
raise PaymentDone()
async def f():
async with TaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2)
pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True)
await group.spawn(pay(pay_req))
with self.assertRaises(PaymentDone):
run(f())
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
def test_close(self): def test_close(self):
alice_channel, bob_channel = create_test_channels() alice_channel, bob_channel = create_test_channels()

Loading…
Cancel
Save