|
@ -8,7 +8,7 @@ import logging |
|
|
import concurrent |
|
|
import concurrent |
|
|
from concurrent import futures |
|
|
from concurrent import futures |
|
|
import unittest |
|
|
import unittest |
|
|
from typing import Iterable, NamedTuple, Tuple, List |
|
|
from typing import Iterable, NamedTuple, Tuple, List, Dict |
|
|
|
|
|
|
|
|
from aiorpcx import TaskGroup, timeout_after, TaskTimeout |
|
|
from aiorpcx import TaskGroup, timeout_after, TaskTimeout |
|
|
|
|
|
|
|
@ -221,6 +221,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): |
|
|
is_trampoline_peer = LNWallet.is_trampoline_peer |
|
|
is_trampoline_peer = LNWallet.is_trampoline_peer |
|
|
wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed |
|
|
wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed |
|
|
on_proxy_changed = LNWallet.on_proxy_changed |
|
|
on_proxy_changed = LNWallet.on_proxy_changed |
|
|
|
|
|
_decode_channel_update_msg = LNWallet._decode_channel_update_msg |
|
|
|
|
|
_handle_chanupd_from_failed_htlc = LNWallet._handle_chanupd_from_failed_htlc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MockTransport: |
|
|
class MockTransport: |
|
@ -347,12 +349,38 @@ class TestPeer(TestCaseForTestnet): |
|
|
p2.mark_open(bob_channel) |
|
|
p2.mark_open(bob_channel) |
|
|
return p1, p2, w1, w2, q1, q2 |
|
|
return p1, p2, w1, w2, q1, q2 |
|
|
|
|
|
|
|
|
def prepare_chans_and_peers_in_square(self) -> SquareGraph: |
|
|
def prepare_chans_and_peers_in_square(self, funds_distribution: Dict[str, Tuple[int, int]]=None) -> SquareGraph: |
|
|
|
|
|
if not funds_distribution: |
|
|
|
|
|
funds_distribution = {} |
|
|
key_a, key_b, key_c, key_d = [keypair() for i in range(4)] |
|
|
key_a, key_b, key_c, key_d = [keypair() for i in range(4)] |
|
|
chan_ab, chan_ba = create_test_channels(alice_name="alice", bob_name="bob", alice_pubkey=key_a.pubkey, bob_pubkey=key_b.pubkey) |
|
|
local_balance, remote_balance = funds_distribution.get('ab') or (None, None) |
|
|
chan_ac, chan_ca = create_test_channels(alice_name="alice", bob_name="carol", alice_pubkey=key_a.pubkey, bob_pubkey=key_c.pubkey) |
|
|
chan_ab, chan_ba = create_test_channels( |
|
|
chan_bd, chan_db = create_test_channels(alice_name="bob", bob_name="dave", alice_pubkey=key_b.pubkey, bob_pubkey=key_d.pubkey) |
|
|
alice_name="alice", bob_name="bob", |
|
|
chan_cd, chan_dc = create_test_channels(alice_name="carol", bob_name="dave", alice_pubkey=key_c.pubkey, bob_pubkey=key_d.pubkey) |
|
|
alice_pubkey=key_a.pubkey, bob_pubkey=key_b.pubkey, |
|
|
|
|
|
local_msat=local_balance, |
|
|
|
|
|
remote_msat=remote_balance, |
|
|
|
|
|
) |
|
|
|
|
|
local_balance, remote_balance = funds_distribution.get('ac') or (None, None) |
|
|
|
|
|
chan_ac, chan_ca = create_test_channels( |
|
|
|
|
|
alice_name="alice", bob_name="carol", |
|
|
|
|
|
alice_pubkey=key_a.pubkey, bob_pubkey=key_c.pubkey, |
|
|
|
|
|
local_msat=local_balance, |
|
|
|
|
|
remote_msat=remote_balance, |
|
|
|
|
|
) |
|
|
|
|
|
local_balance, remote_balance = funds_distribution.get('bd') or (None, None) |
|
|
|
|
|
chan_bd, chan_db = create_test_channels( |
|
|
|
|
|
alice_name="bob", bob_name="dave", |
|
|
|
|
|
alice_pubkey=key_b.pubkey, bob_pubkey=key_d.pubkey, |
|
|
|
|
|
local_msat=local_balance, |
|
|
|
|
|
remote_msat=remote_balance, |
|
|
|
|
|
) |
|
|
|
|
|
local_balance, remote_balance = funds_distribution.get('cd') or (None, None) |
|
|
|
|
|
chan_cd, chan_dc = create_test_channels( |
|
|
|
|
|
alice_name="carol", bob_name="dave", |
|
|
|
|
|
alice_pubkey=key_c.pubkey, bob_pubkey=key_d.pubkey, |
|
|
|
|
|
local_msat=local_balance, |
|
|
|
|
|
remote_msat=remote_balance, |
|
|
|
|
|
) |
|
|
trans_ab, trans_ba = transport_pair(key_a, key_b, chan_ab.name, chan_ba.name) |
|
|
trans_ab, trans_ba = transport_pair(key_a, key_b, chan_ab.name, chan_ba.name) |
|
|
trans_ac, trans_ca = transport_pair(key_a, key_c, chan_ac.name, chan_ca.name) |
|
|
trans_ac, trans_ca = transport_pair(key_a, key_c, chan_ac.name, chan_ca.name) |
|
|
trans_bd, trans_db = transport_pair(key_b, key_d, chan_bd.name, chan_db.name) |
|
|
trans_bd, trans_db = transport_pair(key_b, key_d, chan_bd.name, chan_db.name) |
|
@ -776,6 +804,43 @@ class TestPeer(TestCaseForTestnet): |
|
|
with self.assertRaises(PaymentDone): |
|
|
with self.assertRaises(PaymentDone): |
|
|
run(f()) |
|
|
run(f()) |
|
|
|
|
|
|
|
|
|
|
|
@needs_test_with_all_chacha20_implementations |
|
|
|
|
|
def test_payment_with_temp_channel_failure(self): |
|
|
|
|
|
# prepare channels such that a temporary channel failure happens at c->d |
|
|
|
|
|
funds_distribution = { |
|
|
|
|
|
'ac': (200_000_000, 200_000_000), # low fees |
|
|
|
|
|
'cd': (50_000_000, 200_000_000), # low fees |
|
|
|
|
|
'ab': (200_000_000, 200_000_000), # high fees |
|
|
|
|
|
'bd': (200_000_000, 200_000_000), # high fees |
|
|
|
|
|
} |
|
|
|
|
|
# the payment happens in three attempts: |
|
|
|
|
|
# 1. along ac->cd due to low fees with temp channel failure: |
|
|
|
|
|
# with chanupd: ORPHANED, private channel update |
|
|
|
|
|
# 2. along ac->cd with temp channel failure: |
|
|
|
|
|
# with chanupd: ORPHANED, private channel update, but already received, channel gets blacklisted |
|
|
|
|
|
# 3. along ab->bd with success |
|
|
|
|
|
amount_to_pay = 100_000_000 |
|
|
|
|
|
graph = self.prepare_chans_and_peers_in_square(funds_distribution) |
|
|
|
|
|
peers = graph.all_peers() |
|
|
|
|
|
async def pay(lnaddr, pay_req): |
|
|
|
|
|
self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) |
|
|
|
|
|
result, log = await graph.w_a.pay_invoice(pay_req, attempts=3) |
|
|
|
|
|
self.assertTrue(result) |
|
|
|
|
|
self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) |
|
|
|
|
|
self.assertEqual(OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, log[0].failure_msg.code) |
|
|
|
|
|
self.assertEqual(OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, log[1].failure_msg.code) |
|
|
|
|
|
raise PaymentDone() |
|
|
|
|
|
async def f(): |
|
|
|
|
|
async with TaskGroup() as group: |
|
|
|
|
|
for peer in peers: |
|
|
|
|
|
await group.spawn(peer._message_loop()) |
|
|
|
|
|
await group.spawn(peer.htlc_switch()) |
|
|
|
|
|
await asyncio.sleep(0.2) |
|
|
|
|
|
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, amount_msat=amount_to_pay, include_routing_hints=True) |
|
|
|
|
|
await group.spawn(pay(lnaddr, pay_req)) |
|
|
|
|
|
with self.assertRaises(PaymentDone): |
|
|
|
|
|
run(f()) |
|
|
|
|
|
|
|
|
def _run_mpp(self, graph, kwargs1, kwargs2): |
|
|
def _run_mpp(self, graph, kwargs1, kwargs2): |
|
|
self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL)) |
|
|
self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL)) |
|
|
self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL)) |
|
|
self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL)) |
|
|