Browse Source

Merge pull request #7426 from bitromortac/2107-trampoline-test

flexible lnpeer test graphs and spp trampoline tests
patch-4
ghost43 3 years ago
committed by GitHub
parent
commit
47132790c1
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 7
      electrum/lnpeer.py
  2. 513
      electrum/tests/test_lnpeer.py
  3. 28
      electrum/trampoline.py

7
electrum/lnpeer.py

@ -1443,7 +1443,12 @@ class Peer(Logger):
payload = trampoline_onion.hop_data.payload payload = trampoline_onion.hop_data.payload
payment_hash = htlc.payment_hash payment_hash = htlc.payment_hash
payment_secret = os.urandom(32) payment_data = payload.get('payment_data')
if payment_data: # legacy case
payment_secret = payment_data['payment_secret']
else:
payment_secret = os.urandom(32)
try: try:
outgoing_node_id = payload["outgoing_node_id"]["outgoing_node_id"] outgoing_node_id = payload["outgoing_node_id"]["outgoing_node_id"]
amt_to_forward = payload["amt_to_forward"]["amt_to_forward"] amt_to_forward = payload["amt_to_forward"]["amt_to_forward"]

513
electrum/tests/test_lnpeer.py

@ -289,38 +289,62 @@ class PeerInTests(Peer):
DELAY_INC_MSG_PROCESSING_SLEEP = 0 # disable rate-limiting DELAY_INC_MSG_PROCESSING_SLEEP = 0 # disable rate-limiting
class SquareGraph(NamedTuple): high_fee_channel = {
# A 'local_balance_msat': 10 * bitcoin.COIN * 1000 // 2,
# high fee / \ low fee 'remote_balance_msat': 10 * bitcoin.COIN * 1000 // 2,
# B C 'local_base_fee_msat': 500_000,
# high fee \ / low fee 'local_fee_rate_millionths': 500,
# D 'remote_base_fee_msat': 500_000,
w_a: MockLNWallet 'remote_fee_rate_millionths': 500,
w_b: MockLNWallet }
w_c: MockLNWallet
w_d: MockLNWallet low_fee_channel = {
peer_ab: Peer 'local_balance_msat': 10 * bitcoin.COIN * 1000 // 2,
peer_ac: Peer 'remote_balance_msat': 10 * bitcoin.COIN * 1000 // 2,
peer_ba: Peer 'local_base_fee_msat': 1_000,
peer_bd: Peer 'local_fee_rate_millionths': 1,
peer_ca: Peer 'remote_base_fee_msat': 1_000,
peer_cd: Peer 'remote_fee_rate_millionths': 1,
peer_db: Peer }
peer_dc: Peer
chan_ab: Channel GRAPH_DEFINITIONS = {
chan_ac: Channel 'square_graph': {
chan_ba: Channel 'alice': {
chan_bd: Channel 'channels': {
chan_ca: Channel # we should use copies of channel definitions if
chan_cd: Channel # we want to independently alter them in a test
chan_db: Channel 'bob': high_fee_channel.copy(),
chan_dc: Channel 'carol': low_fee_channel.copy(),
},
def all_peers(self) -> Iterable[Peer]: },
return self.peer_ab, self.peer_ac, self.peer_ba, self.peer_bd, self.peer_ca, self.peer_cd, self.peer_db, self.peer_dc 'bob': {
'channels': {
def all_lnworkers(self) -> Iterable[MockLNWallet]: 'dave': high_fee_channel.copy(),
return self.w_a, self.w_b, self.w_c, self.w_d },
'config': {
'lightning_forward_payments': True,
'lightning_forward_trampoline_payments': True,
},
},
'carol': {
'channels': {
'dave': low_fee_channel.copy(),
},
'config': {
'lightning_forward_payments': True,
'lightning_forward_trampoline_payments': True,
},
},
'dave': {
},
}
}
class Graph(NamedTuple):
workers: Dict[str, MockLNWallet]
peers: Dict[Tuple[str, str], Peer]
channels: Dict[Tuple[str, str], Channel]
class PaymentDone(Exception): pass class PaymentDone(Exception): pass
@ -373,115 +397,71 @@ class TestPeer(TestCaseForTestnet):
p2.mark_open(bob_channel) p2.mark_open(bob_channel)
return p1, p2, w1, w2, q1, q2 return p1, p2, w1, w2, q1, q2
def prepare_chans_and_peers_in_square(self, funds_distribution: Dict[str, Tuple[int, int]]=None) -> SquareGraph: def prepare_chans_and_peers_in_graph(self, graph_definition) -> Graph:
if not funds_distribution: keys = {k: keypair() for k in graph_definition}
funds_distribution = {} txs_queues = {k: asyncio.Queue() for k in graph_definition}
key_a, key_b, key_c, key_d = [keypair() for i in range(4)] channels = {} # type: Dict[Tuple[str, str], Channel]
local_balance, remote_balance = funds_distribution.get('ab') or (None, None) transports = {}
chan_ab, chan_ba = create_test_channels( workers = {} # type: Dict[str, MockLNWallet]
alice_name="alice", bob_name="bob", peers = {}
alice_pubkey=key_a.pubkey, bob_pubkey=key_b.pubkey,
local_msat=local_balance, # create channels
remote_msat=remote_balance, for a, definition in graph_definition.items():
) for b, channel_def in definition.get('channels', {}).items():
local_balance, remote_balance = funds_distribution.get('ac') or (None, None) channel_ab, channel_ba = create_test_channels(
chan_ac, chan_ca = create_test_channels( alice_name=a,
alice_name="alice", bob_name="carol", bob_name=b,
alice_pubkey=key_a.pubkey, bob_pubkey=key_c.pubkey, alice_pubkey=keys[a].pubkey,
local_msat=local_balance, bob_pubkey=keys[b].pubkey,
remote_msat=remote_balance, local_msat=channel_def['local_balance_msat'],
) remote_msat=channel_def['remote_balance_msat'],
local_balance, remote_balance = funds_distribution.get('bd') or (None, None) )
chan_bd, chan_db = create_test_channels( channels[(a, b)], channels[(b, a)] = channel_ab, channel_ba
alice_name="bob", bob_name="dave", transport_ab, transport_ba = transport_pair(keys[a], keys[b], channel_ab.name, channel_ba.name)
alice_pubkey=key_b.pubkey, bob_pubkey=key_d.pubkey, transports[(a, b)], transports[(b, a)] = transport_ab, transport_ba
local_msat=local_balance, # set fees
remote_msat=remote_balance, channel_ab.forwarding_fee_proportional_millionths = channel_def['local_fee_rate_millionths']
) channel_ab.forwarding_fee_base_msat = channel_def['local_base_fee_msat']
local_balance, remote_balance = funds_distribution.get('cd') or (None, None) channel_ba.forwarding_fee_proportional_millionths = channel_def['remote_fee_rate_millionths']
chan_cd, chan_dc = create_test_channels( channel_ba.forwarding_fee_base_msat = channel_def['remote_base_fee_msat']
alice_name="carol", bob_name="dave",
alice_pubkey=key_c.pubkey, bob_pubkey=key_d.pubkey, # create workers and peers
local_msat=local_balance, for a, definition in graph_definition.items():
remote_msat=remote_balance, channels_of_node = [c for k, c in channels.items() if k[0] == a]
) workers[a] = MockLNWallet(local_keypair=keys[a], chans=channels_of_node, tx_queue=txs_queues[a], name=a)
trans_ab, trans_ba = transport_pair(key_a, key_b, chan_ab.name, chan_ba.name) self._lnworkers_created.extend(list(workers.values()))
trans_ac, trans_ca = transport_pair(key_a, key_c, chan_ac.name, chan_ca.name)
trans_bd, trans_db = transport_pair(key_b, key_d, chan_bd.name, chan_db.name) # create peers
trans_cd, trans_dc = transport_pair(key_c, key_d, chan_cd.name, chan_dc.name) for ab in channels.keys():
txq_a, txq_b, txq_c, txq_d = [asyncio.Queue() for i in range(4)] peers[ab] = Peer(workers[ab[0]], keys[ab[1]].pubkey, transports[ab])
w_a = MockLNWallet(local_keypair=key_a, chans=[chan_ab, chan_ac], tx_queue=txq_a, name="alice")
w_b = MockLNWallet(local_keypair=key_b, chans=[chan_ba, chan_bd], tx_queue=txq_b, name="bob") # add peers to workers
w_c = MockLNWallet(local_keypair=key_c, chans=[chan_ca, chan_cd], tx_queue=txq_c, name="carol") for a, w in workers.items():
w_d = MockLNWallet(local_keypair=key_d, chans=[chan_db, chan_dc], tx_queue=txq_d, name="dave") for ab, peer_ab in peers.items():
self._lnworkers_created.extend([w_a, w_b, w_c, w_d]) if ab[0] == a:
peer_ab = PeerInTests(w_a, key_b.pubkey, trans_ab) w._peers[peer_ab.pubkey] = peer_ab
peer_ac = PeerInTests(w_a, key_c.pubkey, trans_ac)
peer_ba = PeerInTests(w_b, key_a.pubkey, trans_ba) # set forwarding properties
peer_bd = PeerInTests(w_b, key_d.pubkey, trans_bd) for a, definition in graph_definition.items():
peer_ca = PeerInTests(w_c, key_a.pubkey, trans_ca) for property in definition.get('config', {}).items():
peer_cd = PeerInTests(w_c, key_d.pubkey, trans_cd) workers[a].network.config.set_key(*property)
peer_db = PeerInTests(w_d, key_b.pubkey, trans_db)
peer_dc = PeerInTests(w_d, key_c.pubkey, trans_dc)
w_a._peers[peer_ab.pubkey] = peer_ab
w_a._peers[peer_ac.pubkey] = peer_ac
w_b._peers[peer_ba.pubkey] = peer_ba
w_b._peers[peer_bd.pubkey] = peer_bd
w_c._peers[peer_ca.pubkey] = peer_ca
w_c._peers[peer_cd.pubkey] = peer_cd
w_d._peers[peer_db.pubkey] = peer_db
w_d._peers[peer_dc.pubkey] = peer_dc
w_b.network.config.set_key('lightning_forward_payments', True)
w_c.network.config.set_key('lightning_forward_payments', True)
w_b.network.config.set_key('lightning_forward_trampoline_payments', True)
w_c.network.config.set_key('lightning_forward_trampoline_payments', True)
# forwarding fees, etc
chan_ab.forwarding_fee_proportional_millionths *= 500
chan_ab.forwarding_fee_base_msat *= 500
chan_ba.forwarding_fee_proportional_millionths *= 500
chan_ba.forwarding_fee_base_msat *= 500
chan_bd.forwarding_fee_proportional_millionths *= 500
chan_bd.forwarding_fee_base_msat *= 500
chan_db.forwarding_fee_proportional_millionths *= 500
chan_db.forwarding_fee_base_msat *= 500
# mark_open won't work if state is already OPEN. # mark_open won't work if state is already OPEN.
# so set it to FUNDED # so set it to FUNDED
for chan in [chan_ab, chan_ac, chan_ba, chan_bd, chan_ca, chan_cd, chan_db, chan_dc]: for channel_ab in channels.values():
chan._state = ChannelState.FUNDED channel_ab._state = ChannelState.FUNDED
# this populates the channel graph: # this populates the channel graph:
peer_ab.mark_open(chan_ab) for ab, peer_ab in peers.items():
peer_ac.mark_open(chan_ac) peer_ab.mark_open(channels[ab])
peer_ba.mark_open(chan_ba)
peer_bd.mark_open(chan_bd) graph = Graph(
peer_ca.mark_open(chan_ca) workers=workers,
peer_cd.mark_open(chan_cd) peers=peers,
peer_db.mark_open(chan_db) channels=channels,
peer_dc.mark_open(chan_dc)
graph = SquareGraph(
w_a=w_a,
w_b=w_b,
w_c=w_c,
w_d=w_d,
peer_ab=peer_ab,
peer_ac=peer_ac,
peer_ba=peer_ba,
peer_bd=peer_bd,
peer_ca=peer_ca,
peer_cd=peer_cd,
peer_db=peer_db,
peer_dc=peer_dc,
chan_ab=chan_ab,
chan_ac=chan_ac,
chan_ba=chan_ba,
chan_bd=chan_bd,
chan_ca=chan_ca,
chan_cd=chan_cd,
chan_db=chan_db,
chan_dc=chan_dc,
) )
for a in workers:
print(f"{a} -> pubkey {keys[a].pubkey}")
return graph return graph
@staticmethod @staticmethod
@ -707,13 +687,13 @@ class TestPeer(TestCaseForTestnet):
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
def test_payment_multihop(self): def test_payment_multihop(self):
graph = self.prepare_chans_and_peers_in_square() graph = self.prepare_chans_and_peers_in_graph(GRAPH_DEFINITIONS['square_graph'])
peers = graph.all_peers() peers = graph.peers.values()
async def pay(lnaddr, pay_req): async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
result, log = await graph.w_a.pay_invoice(pay_req) result, log = await graph.workers['alice'].pay_invoice(pay_req)
self.assertTrue(result) self.assertTrue(result)
self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
raise PaymentDone() raise PaymentDone()
async def f(): async def f():
async with TaskGroup() as group: async with TaskGroup() as group:
@ -721,39 +701,39 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop()) await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch()) await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
await group.spawn(pay(lnaddr, pay_req)) await group.spawn(pay(lnaddr, pay_req))
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
run(f()) run(f())
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
def test_payment_multihop_with_preselected_path(self): def test_payment_multihop_with_preselected_path(self):
graph = self.prepare_chans_and_peers_in_square() graph = self.prepare_chans_and_peers_in_graph(GRAPH_DEFINITIONS['square_graph'])
peers = graph.all_peers() peers = graph.peers.values()
async def pay(pay_req): async def pay(pay_req):
with self.subTest(msg="bad path: edges do not chain together"): with self.subTest(msg="bad path: edges do not chain together"):
path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey, path = [PathEdge(start_node=graph.workers['alice'].node_keypair.pubkey,
end_node=graph.w_c.node_keypair.pubkey, end_node=graph.workers['carol'].node_keypair.pubkey,
short_channel_id=graph.chan_ab.short_channel_id), short_channel_id=graph.channels[('alice', 'bob')].short_channel_id),
PathEdge(start_node=graph.w_b.node_keypair.pubkey, PathEdge(start_node=graph.workers['bob'].node_keypair.pubkey,
end_node=graph.w_d.node_keypair.pubkey, end_node=graph.workers['dave'].node_keypair.pubkey,
short_channel_id=graph.chan_bd.short_channel_id)] short_channel_id=graph.channels['bob', 'dave'].short_channel_id)]
with self.assertRaises(LNPathInconsistent): with self.assertRaises(LNPathInconsistent):
await graph.w_a.pay_invoice(pay_req, full_path=path) await graph.workers['alice'].pay_invoice(pay_req, full_path=path)
with self.subTest(msg="bad path: last node id differs from invoice pubkey"): with self.subTest(msg="bad path: last node id differs from invoice pubkey"):
path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey, path = [PathEdge(start_node=graph.workers['alice'].node_keypair.pubkey,
end_node=graph.w_b.node_keypair.pubkey, end_node=graph.workers['bob'].node_keypair.pubkey,
short_channel_id=graph.chan_ab.short_channel_id)] short_channel_id=graph.channels[('alice', 'bob')].short_channel_id)]
with self.assertRaises(LNPathInconsistent): with self.assertRaises(LNPathInconsistent):
await graph.w_a.pay_invoice(pay_req, full_path=path) await graph.workers['alice'].pay_invoice(pay_req, full_path=path)
with self.subTest(msg="good path"): with self.subTest(msg="good path"):
path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey, path = [PathEdge(start_node=graph.workers['alice'].node_keypair.pubkey,
end_node=graph.w_b.node_keypair.pubkey, end_node=graph.workers['bob'].node_keypair.pubkey,
short_channel_id=graph.chan_ab.short_channel_id), short_channel_id=graph.channels[('alice', 'bob')].short_channel_id),
PathEdge(start_node=graph.w_b.node_keypair.pubkey, PathEdge(start_node=graph.workers['bob'].node_keypair.pubkey,
end_node=graph.w_d.node_keypair.pubkey, end_node=graph.workers['dave'].node_keypair.pubkey,
short_channel_id=graph.chan_bd.short_channel_id)] short_channel_id=graph.channels['bob', 'dave'].short_channel_id)]
result, log = await graph.w_a.pay_invoice(pay_req, full_path=path) result, log = await graph.workers['alice'].pay_invoice(pay_req, full_path=path)
self.assertTrue(result) self.assertTrue(result)
self.assertEqual( self.assertEqual(
[edge.short_channel_id for edge in path], [edge.short_channel_id for edge in path],
@ -765,22 +745,22 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop()) await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch()) await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
await group.spawn(pay(pay_req)) await group.spawn(pay(pay_req))
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
run(f()) run(f())
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
def test_payment_multihop_temp_node_failure(self): def test_payment_multihop_temp_node_failure(self):
graph = self.prepare_chans_and_peers_in_square() graph = self.prepare_chans_and_peers_in_graph(GRAPH_DEFINITIONS['square_graph'])
graph.w_b.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True) graph.workers['bob'].network.config.set_key('test_fail_htlcs_with_temp_node_failure', True)
graph.w_c.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True) graph.workers['carol'].network.config.set_key('test_fail_htlcs_with_temp_node_failure', True)
peers = graph.all_peers() peers = graph.peers.values()
async def pay(lnaddr, pay_req): async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
result, log = await graph.w_a.pay_invoice(pay_req) result, log = await graph.workers['alice'].pay_invoice(pay_req)
self.assertFalse(result) self.assertFalse(result)
self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code) self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code)
raise PaymentDone() raise PaymentDone()
async def f(): async def f():
@ -789,7 +769,7 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop()) await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch()) await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
await group.spawn(pay(lnaddr, pay_req)) await group.spawn(pay(lnaddr, pay_req))
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
run(f()) run(f())
@ -798,25 +778,25 @@ class TestPeer(TestCaseForTestnet):
def test_payment_multihop_route_around_failure(self): def test_payment_multihop_route_around_failure(self):
# Alice will pay Dave. Alice first tries A->C->D route, due to lower fees, but Carol # Alice will pay Dave. Alice first tries A->C->D route, due to lower fees, but Carol
# will fail the htlc and get blacklisted. Alice will then try A->B->D and succeed. # will fail the htlc and get blacklisted. Alice will then try A->B->D and succeed.
graph = self.prepare_chans_and_peers_in_square() graph = self.prepare_chans_and_peers_in_graph(GRAPH_DEFINITIONS['square_graph'])
graph.w_c.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True) graph.workers['carol'].network.config.set_key('test_fail_htlcs_with_temp_node_failure', True)
peers = graph.all_peers() peers = graph.peers.values()
async def pay(lnaddr, pay_req): async def pay(lnaddr, pay_req):
self.assertEqual(500000000000, graph.chan_ab.balance(LOCAL)) self.assertEqual(500000000000, graph.channels[('alice', 'bob')].balance(LOCAL))
self.assertEqual(500000000000, graph.chan_db.balance(LOCAL)) self.assertEqual(500000000000, graph.channels[('dave', 'bob')].balance(LOCAL))
self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
result, log = await graph.w_a.pay_invoice(pay_req, attempts=2) result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=2)
self.assertEqual(2, len(log)) self.assertEqual(2, len(log))
self.assertTrue(result) self.assertTrue(result)
self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual([graph.chan_ac.short_channel_id, graph.chan_cd.short_channel_id], self.assertEqual([graph.channels[('alice', 'carol')].short_channel_id, graph.channels[('carol', 'dave')].short_channel_id],
[edge.short_channel_id for edge in log[0].route]) [edge.short_channel_id for edge in log[0].route])
self.assertEqual([graph.chan_ab.short_channel_id, graph.chan_bd.short_channel_id], self.assertEqual([graph.channels[('alice', 'bob')].short_channel_id, graph.channels[('bob', 'dave')].short_channel_id],
[edge.short_channel_id for edge in log[1].route]) [edge.short_channel_id for edge in log[1].route])
self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code) self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code)
self.assertEqual(499899450000, graph.chan_ab.balance(LOCAL)) self.assertEqual(499899450000, graph.channels[('alice', 'bob')].balance(LOCAL))
await asyncio.sleep(0.2) # wait for COMMITMENT_SIGNED / REVACK msgs to update balance await asyncio.sleep(0.2) # wait for COMMITMENT_SIGNED / REVACK msgs to update balance
self.assertEqual(500100000000, graph.chan_db.balance(LOCAL)) self.assertEqual(500100000000, graph.channels[('dave', 'bob')].balance(LOCAL))
raise PaymentDone() raise PaymentDone()
async def f(): async def f():
async with TaskGroup() as group: async with TaskGroup() as group:
@ -824,7 +804,7 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop()) await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch()) await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
invoice_features = lnaddr.get_features() invoice_features = lnaddr.get_features()
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
await group.spawn(pay(lnaddr, pay_req)) await group.spawn(pay(lnaddr, pay_req))
@ -834,43 +814,47 @@ class TestPeer(TestCaseForTestnet):
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
def test_payment_with_temp_channel_failure_and_liquidty_hints(self): def test_payment_with_temp_channel_failure_and_liquidty_hints(self):
# prepare channels such that a temporary channel failure happens at c->d # prepare channels such that a temporary channel failure happens at c->d
funds_distribution = { graph_definition = GRAPH_DEFINITIONS['square_graph'].copy()
'ac': (200_000_000, 200_000_000), # low fees graph_definition['alice']['channels']['carol']['local_balance_msat'] = 200_000_000
'cd': (50_000_000, 200_000_000), # low fees graph_definition['alice']['channels']['carol']['remote_balance_msat'] = 200_000_000
'ab': (200_000_000, 200_000_000), # high fees graph_definition['carol']['channels']['dave']['local_balance_msat'] = 50_000_000
'bd': (200_000_000, 200_000_000), # high fees graph_definition['carol']['channels']['dave']['remote_balance_msat'] = 200_000_000
} graph_definition['alice']['channels']['bob']['local_balance_msat'] = 200_000_000
graph_definition['alice']['channels']['bob']['remote_balance_msat'] = 200_000_000
graph_definition['bob']['channels']['dave']['local_balance_msat'] = 200_000_000
graph_definition['bob']['channels']['dave']['remote_balance_msat'] = 200_000_000
graph = self.prepare_chans_and_peers_in_graph(graph_definition)
# the payment happens in two attempts: # the payment happens in two attempts:
# 1. along a->c->d due to low fees with temp channel failure: # 1. along a->c->d due to low fees with temp channel failure:
# with chanupd: ORPHANED, private channel update # with chanupd: ORPHANED, private channel update
# c->d gets a liquidity hint and gets blocked # c->d gets a liquidity hint and gets blocked
# 2. along a->b->d with success # 2. along a->b->d with success
amount_to_pay = 100_000_000 amount_to_pay = 100_000_000
graph = self.prepare_chans_and_peers_in_square(funds_distribution) peers = graph.peers.values()
peers = graph.all_peers()
async def pay(lnaddr, pay_req): async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
result, log = await graph.w_a.pay_invoice(pay_req, attempts=3) result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=3)
self.assertTrue(result) self.assertTrue(result)
self.assertEqual(2, len(log)) self.assertEqual(2, len(log))
self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, log[0].failure_msg.code) self.assertEqual(OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, log[0].failure_msg.code)
liquidity_hints = graph.w_a.network.path_finder.liquidity_hints liquidity_hints = graph.workers['alice'].network.path_finder.liquidity_hints
pubkey_a = graph.w_a.node_keypair.pubkey pubkey_a = graph.workers['alice'].node_keypair.pubkey
pubkey_b = graph.w_b.node_keypair.pubkey pubkey_b = graph.workers['bob'].node_keypair.pubkey
pubkey_c = graph.w_c.node_keypair.pubkey pubkey_c = graph.workers['carol'].node_keypair.pubkey
pubkey_d = graph.w_d.node_keypair.pubkey pubkey_d = graph.workers['dave'].node_keypair.pubkey
# check liquidity hints for failing route: # check liquidity hints for failing route:
hint_ac = liquidity_hints.get_hint(graph.chan_ac.short_channel_id) hint_ac = liquidity_hints.get_hint(graph.channels[('alice', 'carol')].short_channel_id)
hint_cd = liquidity_hints.get_hint(graph.chan_cd.short_channel_id) hint_cd = liquidity_hints.get_hint(graph.channels[('carol', 'dave')].short_channel_id)
self.assertEqual(amount_to_pay, hint_ac.can_send(pubkey_a < pubkey_c)) self.assertEqual(amount_to_pay, hint_ac.can_send(pubkey_a < pubkey_c))
self.assertEqual(None, hint_ac.cannot_send(pubkey_a < pubkey_c)) self.assertEqual(None, hint_ac.cannot_send(pubkey_a < pubkey_c))
self.assertEqual(None, hint_cd.can_send(pubkey_c < pubkey_d)) self.assertEqual(None, hint_cd.can_send(pubkey_c < pubkey_d))
self.assertEqual(amount_to_pay, hint_cd.cannot_send(pubkey_c < pubkey_d)) self.assertEqual(amount_to_pay, hint_cd.cannot_send(pubkey_c < pubkey_d))
# check liquidity hints for successful route: # check liquidity hints for successful route:
hint_ab = liquidity_hints.get_hint(graph.chan_ab.short_channel_id) hint_ab = liquidity_hints.get_hint(graph.channels[('alice', 'bob')].short_channel_id)
hint_bd = liquidity_hints.get_hint(graph.chan_bd.short_channel_id) hint_bd = liquidity_hints.get_hint(graph.channels[('bob', 'dave')].short_channel_id)
self.assertEqual(amount_to_pay, hint_ab.can_send(pubkey_a < pubkey_b)) self.assertEqual(amount_to_pay, hint_ab.can_send(pubkey_a < pubkey_b))
self.assertEqual(None, hint_ab.cannot_send(pubkey_a < pubkey_b)) self.assertEqual(None, hint_ab.cannot_send(pubkey_a < pubkey_b))
self.assertEqual(amount_to_pay, hint_bd.can_send(pubkey_b < pubkey_d)) self.assertEqual(amount_to_pay, hint_bd.can_send(pubkey_b < pubkey_d))
@ -883,17 +867,17 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop()) await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch()) await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, amount_msat=amount_to_pay, include_routing_hints=True) lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], amount_msat=amount_to_pay, include_routing_hints=True)
await group.spawn(pay(lnaddr, pay_req)) await group.spawn(pay(lnaddr, pay_req))
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
run(f()) run(f())
def _run_mpp(self, graph, fail_kwargs, success_kwargs): def _run_mpp(self, graph, fail_kwargs, success_kwargs):
"""Tests a multipart payment scenario for failing and successful cases.""" """Tests a multipart payment scenario for failing and successful cases."""
self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL)) self.assertEqual(500_000_000_000, graph.channels[('alice', 'bob')].balance(LOCAL))
self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL)) self.assertEqual(500_000_000_000, graph.channels[('alice', 'carol')].balance(LOCAL))
amount_to_pay = 600_000_000_000 amount_to_pay = 600_000_000_000
peers = graph.all_peers() peers = graph.peers.values()
async def pay( async def pay(
attempts=1, attempts=1,
alice_uses_trampoline=False, alice_uses_trampoline=False,
@ -901,25 +885,25 @@ class TestPeer(TestCaseForTestnet):
mpp_invoice=True mpp_invoice=True
): ):
if mpp_invoice: if mpp_invoice:
graph.w_d.features |= LnFeatures.BASIC_MPP_OPT graph.workers['dave'].features |= LnFeatures.BASIC_MPP_OPT
if not bob_forwarding: if not bob_forwarding:
graph.w_b.enable_htlc_forwarding = False graph.workers['bob'].enable_htlc_forwarding = False
if alice_uses_trampoline: if alice_uses_trampoline:
if graph.w_a.network.channel_db: if graph.workers['alice'].network.channel_db:
graph.w_a.network.channel_db.stop() graph.workers['alice'].network.channel_db.stop()
await graph.w_a.network.channel_db.stopped_event.wait() await graph.workers['alice'].network.channel_db.stopped_event.wait()
graph.w_a.network.channel_db = None graph.workers['alice'].network.channel_db = None
else: else:
assert graph.w_a.network.channel_db is not None assert graph.workers['alice'].network.channel_db is not None
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay) lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True, amount_msat=amount_to_pay)
self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
result, log = await graph.w_a.pay_invoice(pay_req, attempts=attempts) result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=attempts)
if not bob_forwarding: if not bob_forwarding:
# reset to previous state, sleep 2s so that the second htlc can time out # reset to previous state, sleep 2s so that the second htlc can time out
graph.w_b.enable_htlc_forwarding = True graph.workers['bob'].enable_htlc_forwarding = True
await asyncio.sleep(2) await asyncio.sleep(2)
if result: if result:
self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
raise PaymentDone() raise PaymentDone()
else: else:
raise NoPathFound() raise NoPathFound()
@ -939,21 +923,66 @@ class TestPeer(TestCaseForTestnet):
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
def test_payment_multipart_with_timeout(self): def test_payment_multipart_with_timeout(self):
graph = self.prepare_chans_and_peers_in_square() graph = self.prepare_chans_and_peers_in_graph(GRAPH_DEFINITIONS['square_graph'])
self._run_mpp(graph, {'bob_forwarding': False}, {'bob_forwarding': True}) self._run_mpp(graph, {'bob_forwarding': False}, {'bob_forwarding': True})
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
def test_payment_multipart(self): def test_payment_multipart(self):
graph = self.prepare_chans_and_peers_in_square() graph = self.prepare_chans_and_peers_in_graph(GRAPH_DEFINITIONS['square_graph'])
self._run_mpp(graph, {'mpp_invoice': False}, {'mpp_invoice': True}) self._run_mpp(graph, {'mpp_invoice': False}, {'mpp_invoice': True})
@needs_test_with_all_chacha20_implementations
def test_payment_trampoline(self):
async def turn_on_trampoline_alice():
if graph.workers['alice'].network.channel_db:
graph.workers['alice'].network.channel_db.stop()
await graph.workers['alice'].network.channel_db.stopped_event.wait()
graph.workers['alice'].network.channel_db = None
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=10)
self.assertTrue(result)
self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
raise PaymentDone()
async def f():
await turn_on_trampoline_alice()
async with TaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
await group.spawn(pay(lnaddr, pay_req))
for is_legacy in (True, False):
graph_definition = GRAPH_DEFINITIONS['square_graph'].copy()
# insert a channel from bob to carol for faster tests,
# otherwise will fail randomly
graph_definition['bob']['channels']['carol'] = high_fee_channel
graph = self.prepare_chans_and_peers_in_graph(graph_definition)
peers = graph.peers.values()
if is_legacy:
# turn off trampoline features
graph.workers['dave'].features = graph.workers['dave'].features ^ LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT
# declare routing nodes as trampoline nodes
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
graph.workers['bob'].name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.workers['bob'].node_keypair.pubkey),
graph.workers['carol'].name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.workers['carol'].node_keypair.pubkey),
}
with self.assertRaises(PaymentDone):
run(f())
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
def test_payment_multipart_trampoline(self): def test_payment_multipart_trampoline(self):
# single attempt will fail with insufficient trampoline fee # single attempt will fail with insufficient trampoline fee
graph = self.prepare_chans_and_peers_in_square() graph = self.prepare_chans_and_peers_in_graph(GRAPH_DEFINITIONS['square_graph'])
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
graph.w_b.name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.w_b.node_keypair.pubkey), graph.workers['bob'].name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.workers['bob'].node_keypair.pubkey),
graph.w_c.name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.w_c.node_keypair.pubkey), graph.workers['carol'].name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.workers['carol'].node_keypair.pubkey),
} }
try: try:
self._run_mpp( self._run_mpp(
@ -969,31 +998,31 @@ class TestPeer(TestCaseForTestnet):
Dave shuts down (stops wallet). Dave shuts down (stops wallet).
We test if Dave fails the pending HTLCs during shutdown. We test if Dave fails the pending HTLCs during shutdown.
""" """
graph = self.prepare_chans_and_peers_in_square() graph = self.prepare_chans_and_peers_in_graph(GRAPH_DEFINITIONS['square_graph'])
self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL)) self.assertEqual(500_000_000_000, graph.channels[('alice', 'bob')].balance(LOCAL))
self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL)) self.assertEqual(500_000_000_000, graph.channels[('alice', 'carol')].balance(LOCAL))
amount_to_pay = 600_000_000_000 amount_to_pay = 600_000_000_000
peers = graph.all_peers() peers = graph.peers.values()
graph.w_d.MPP_EXPIRY = 120 graph.workers['dave'].MPP_EXPIRY = 120
graph.w_d.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3 graph.workers['dave'].TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3
async def pay(): async def pay():
graph.w_d.features |= LnFeatures.BASIC_MPP_OPT graph.workers['dave'].features |= LnFeatures.BASIC_MPP_OPT
graph.w_b.enable_htlc_forwarding = False # Bob will hold forwarded HTLCs graph.workers['bob'].enable_htlc_forwarding = False # Bob will hold forwarded HTLCs
assert graph.w_a.network.channel_db is not None assert graph.workers['alice'].network.channel_db is not None
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay) lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True, amount_msat=amount_to_pay)
try: try:
async with timeout_after(1.0): async with timeout_after(0.5):
result, log = await graph.w_a.pay_invoice(pay_req, attempts=1) result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=1)
except TaskTimeout: except TaskTimeout:
# by now Dave hopefully received some HTLCs: # by now Dave hopefully received some HTLCs:
self.assertTrue(len(graph.chan_dc.hm.htlcs(LOCAL)) > 0) self.assertTrue(len(graph.channels[('dave', 'carol')].hm.htlcs(LOCAL)) > 0)
self.assertTrue(len(graph.chan_dc.hm.htlcs(REMOTE)) > 0) self.assertTrue(len(graph.channels[('dave', 'carol')].hm.htlcs(REMOTE)) > 0)
else: else:
self.fail(f"pay_invoice finished but was not supposed to. result={result}") self.fail(f"pay_invoice finished but was not supposed to. result={result}")
await graph.w_d.stop() await graph.workers['dave'].stop()
# Dave is supposed to have failed the pending incomplete MPP HTLCs # Dave is supposed to have failed the pending incomplete MPP HTLCs
self.assertEqual(0, len(graph.chan_dc.hm.htlcs(LOCAL))) self.assertEqual(0, len(graph.channels[('dave', 'carol')].hm.htlcs(LOCAL)))
self.assertEqual(0, len(graph.chan_dc.hm.htlcs(REMOTE))) self.assertEqual(0, len(graph.channels[('dave', 'carol')].hm.htlcs(REMOTE)))
raise SuccessfulTest() raise SuccessfulTest()
async def f(): async def f():

28
electrum/trampoline.py

@ -186,16 +186,22 @@ def create_trampoline_route(
route[-1].outgoing_node_id = invoice_pubkey route[-1].outgoing_node_id = invoice_pubkey
else: # end-to-end trampoline else: # end-to-end trampoline
if r_tag_chosen_for_e2e_trampoline: if r_tag_chosen_for_e2e_trampoline:
pubkey, scid, feebase, feerate, cltv = r_tag_chosen_for_e2e_trampoline pubkey = r_tag_chosen_for_e2e_trampoline[0]
if route[-1].end_node != pubkey: if route[-1].end_node != pubkey:
route.append( # We don't use the forwarding policy from the route hint, which
TrampolineEdge( # is only valid for legacy forwarding. Trampoline forwarders require
start_node=route[-1].end_node, # higher fees and cltv deltas.
end_node=pubkey, trampoline_fee_level = trampoline_fee_levels[pubkey]
fee_base_msat=feebase, if trampoline_fee_level < len(TRAMPOLINE_FEES):
fee_proportional_millionths=feerate, fee_policy = TRAMPOLINE_FEES[trampoline_fee_level]
cltv_expiry_delta=cltv, route.append(
node_features=trampoline_features)) TrampolineEdge(
start_node=route[-1].end_node,
end_node=pubkey,
fee_base_msat=fee_policy['fee_base_msat'],
fee_proportional_millionths=fee_policy['fee_proportional_millionths'],
cltv_expiry_delta=fee_policy['cltv_expiry_delta'],
node_features=trampoline_features))
# Final edge (not part of the route if payment is legacy, but eclair requires an encrypted blob) # Final edge (not part of the route if payment is legacy, but eclair requires an encrypted blob)
route.append( route.append(
@ -241,7 +247,7 @@ def create_trampoline_onion(*, route, amount_msat, final_cltv, total_msat, payme
# only for final # only for final
if i == num_hops - 1: if i == num_hops - 1:
payload["payment_data"] = { payload["payment_data"] = {
"payment_secret":payment_secret, "payment_secret": payment_secret,
"total_msat": total_msat "total_msat": total_msat
} }
# legacy # legacy
@ -249,7 +255,7 @@ def create_trampoline_onion(*, route, amount_msat, final_cltv, total_msat, payme
payload["invoice_features"] = {"invoice_features":route_edge.invoice_features} payload["invoice_features"] = {"invoice_features":route_edge.invoice_features}
payload["invoice_routing_info"] = {"invoice_routing_info":route_edge.invoice_routing_info} payload["invoice_routing_info"] = {"invoice_routing_info":route_edge.invoice_routing_info}
payload["payment_data"] = { payload["payment_data"] = {
"payment_secret":payment_secret, "payment_secret": payment_secret,
"total_msat": total_msat "total_msat": total_msat
} }
_logger.info(f'payload {i} {payload}') _logger.info(f'payload {i} {payload}')

Loading…
Cancel
Save