|
|
@ -10,7 +10,7 @@ from concurrent import futures |
|
|
|
import unittest |
|
|
|
from typing import Iterable, NamedTuple, Tuple, List |
|
|
|
|
|
|
|
from aiorpcx import TaskGroup |
|
|
|
from aiorpcx import TaskGroup, timeout_after, TaskTimeout |
|
|
|
|
|
|
|
from electrum import bitcoin |
|
|
|
from electrum import constants |
|
|
@ -113,7 +113,8 @@ class MockWallet: |
|
|
|
|
|
|
|
|
|
|
|
class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): |
|
|
|
MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1 |
|
|
|
MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1 |
|
|
|
TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0 |
|
|
|
|
|
|
|
def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name): |
|
|
|
self.name = name |
|
|
@ -121,6 +122,9 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): |
|
|
|
NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1) |
|
|
|
self.node_keypair = local_keypair |
|
|
|
self.network = MockNetwork(tx_queue) |
|
|
|
self.taskgroup = TaskGroup() |
|
|
|
self.lnwatcher = None |
|
|
|
self.listen_server = None |
|
|
|
self._channels = {chan.channel_id: chan for chan in chans} |
|
|
|
self.payments = {} |
|
|
|
self.logs = defaultdict(list) |
|
|
@ -184,6 +188,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): |
|
|
|
return self.name |
|
|
|
|
|
|
|
async def stop(self): |
|
|
|
await LNWallet.stop(self) |
|
|
|
if self.channel_db: |
|
|
|
self.channel_db.stop() |
|
|
|
await self.channel_db.stopped_event.wait() |
|
|
@ -216,6 +221,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): |
|
|
|
_calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice |
|
|
|
handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc |
|
|
|
is_trampoline_peer = LNWallet.is_trampoline_peer |
|
|
|
wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed |
|
|
|
on_proxy_changed = LNWallet.on_proxy_changed |
|
|
|
|
|
|
|
|
|
|
|
class MockTransport: |
|
|
@ -291,13 +298,9 @@ class SquareGraph(NamedTuple): |
|
|
|
def all_lnworkers(self) -> Iterable[MockLNWallet]: |
|
|
|
return self.w_a, self.w_b, self.w_c, self.w_d |
|
|
|
|
|
|
|
async def stop_and_cleanup(self): |
|
|
|
async with TaskGroup() as group: |
|
|
|
for lnworker in self.all_lnworkers(): |
|
|
|
await group.spawn(lnworker.stop()) |
|
|
|
|
|
|
|
|
|
|
|
class PaymentDone(Exception): pass |
|
|
|
class TestSuccess(Exception): pass |
|
|
|
|
|
|
|
|
|
|
|
class TestPeer(ElectrumTestCase): |
|
|
@ -837,6 +840,50 @@ class TestPeer(ElectrumTestCase): |
|
|
|
graph = self.prepare_chans_and_peers_in_square() |
|
|
|
self._run_mpp(graph, {'alice_uses_trampoline':True, 'attempts':1}, {'alice_uses_trampoline':True, 'attempts':3}) |
|
|
|
|
|
|
|
@needs_test_with_all_chacha20_implementations |
|
|
|
def test_fail_pending_htlcs_on_shutdown(self): |
|
|
|
"""Alice tries to pay Dave via MPP. Dave receives some HTLCs but not all. |
|
|
|
Dave shuts down (stops wallet). |
|
|
|
We test if Dave fails the pending HTLCs during shutdown. |
|
|
|
""" |
|
|
|
graph = self.prepare_chans_and_peers_in_square() |
|
|
|
self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL)) |
|
|
|
self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL)) |
|
|
|
amount_to_pay = 600_000_000_000 |
|
|
|
peers = graph.all_peers() |
|
|
|
graph.w_d.MPP_EXPIRY = 120 |
|
|
|
graph.w_d.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3 |
|
|
|
async def pay(): |
|
|
|
graph.w_d.features |= LnFeatures.BASIC_MPP_OPT |
|
|
|
graph.w_b.enable_htlc_forwarding.clear() # Bob will hold forwarded HTLCs |
|
|
|
assert graph.w_a.network.channel_db is not None |
|
|
|
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay) |
|
|
|
try: |
|
|
|
async with timeout_after(0.5): |
|
|
|
result, log = await graph.w_a.pay_invoice(pay_req, attempts=1) |
|
|
|
except TaskTimeout: |
|
|
|
# by now Dave hopefully received some HTLCs: |
|
|
|
self.assertTrue(len(graph.chan_dc.hm.htlcs(LOCAL)) > 0) |
|
|
|
self.assertTrue(len(graph.chan_dc.hm.htlcs(REMOTE)) > 0) |
|
|
|
else: |
|
|
|
self.fail(f"pay_invoice finished but was not supposed to. result={result}") |
|
|
|
await graph.w_d.stop() |
|
|
|
# Dave is supposed to have failed the pending incomplete MPP HTLCs |
|
|
|
self.assertEqual(0, len(graph.chan_dc.hm.htlcs(LOCAL))) |
|
|
|
self.assertEqual(0, len(graph.chan_dc.hm.htlcs(REMOTE))) |
|
|
|
raise TestSuccess() |
|
|
|
|
|
|
|
async def f(): |
|
|
|
async with TaskGroup() as group: |
|
|
|
for peer in peers: |
|
|
|
await group.spawn(peer._message_loop()) |
|
|
|
await group.spawn(peer.htlc_switch()) |
|
|
|
await asyncio.sleep(0.2) |
|
|
|
await group.spawn(pay()) |
|
|
|
|
|
|
|
with self.assertRaises(TestSuccess): |
|
|
|
run(f()) |
|
|
|
|
|
|
|
@needs_test_with_all_chacha20_implementations |
|
|
|
def test_close(self): |
|
|
|
alice_channel, bob_channel = create_test_channels() |
|
|
|