Browse Source

tests: add test for prev

patch-4
SomberNight 4 years ago
parent
commit
2487a3fa90
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 10
      electrum/lnworker.py
  2. 61
      electrum/tests/test_lnpeer.py

10
electrum/lnworker.py

@ -578,6 +578,7 @@ class LNWallet(LNWorker):
lnwatcher: Optional['LNWalletWatcher']
MPP_EXPIRY = 120
TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3 # seconds
def __init__(self, wallet: 'Abstract_Wallet', xprv):
self.wallet = wallet
@ -713,11 +714,12 @@ class LNWallet(LNWorker):
self.stopping_soon = True
if self.listen_server: # stop accepting new peers
self.listen_server.close()
async with ignore_after(3):
async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS):
await self.wait_for_received_pending_htlcs_to_get_removed()
await super().stop()
await self.lnwatcher.stop()
self.lnwatcher = None
await LNWorker.stop(self)
if self.lnwatcher:
await self.lnwatcher.stop()
self.lnwatcher = None
async def wait_for_received_pending_htlcs_to_get_removed(self):
assert self.stopping_soon is True

61
electrum/tests/test_lnpeer.py

@ -10,7 +10,7 @@ from concurrent import futures
import unittest
from typing import Iterable, NamedTuple, Tuple, List
from aiorpcx import TaskGroup
from aiorpcx import TaskGroup, timeout_after, TaskTimeout
from electrum import bitcoin
from electrum import constants
@ -113,7 +113,8 @@ class MockWallet:
class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0
def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name):
self.name = name
@ -121,6 +122,9 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
self.node_keypair = local_keypair
self.network = MockNetwork(tx_queue)
self.taskgroup = TaskGroup()
self.lnwatcher = None
self.listen_server = None
self._channels = {chan.channel_id: chan for chan in chans}
self.payments = {}
self.logs = defaultdict(list)
@ -184,6 +188,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
return self.name
async def stop(self):
await LNWallet.stop(self)
if self.channel_db:
self.channel_db.stop()
await self.channel_db.stopped_event.wait()
@ -216,6 +221,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
_calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice
handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc
is_trampoline_peer = LNWallet.is_trampoline_peer
wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed
on_proxy_changed = LNWallet.on_proxy_changed
class MockTransport:
@ -291,13 +298,9 @@ class SquareGraph(NamedTuple):
def all_lnworkers(self) -> Iterable[MockLNWallet]:
return self.w_a, self.w_b, self.w_c, self.w_d
async def stop_and_cleanup(self):
async with TaskGroup() as group:
for lnworker in self.all_lnworkers():
await group.spawn(lnworker.stop())
class PaymentDone(Exception): pass
class TestSuccess(Exception): pass
class TestPeer(ElectrumTestCase):
@ -837,6 +840,50 @@ class TestPeer(ElectrumTestCase):
graph = self.prepare_chans_and_peers_in_square()
self._run_mpp(graph, {'alice_uses_trampoline':True, 'attempts':1}, {'alice_uses_trampoline':True, 'attempts':3})
@needs_test_with_all_chacha20_implementations
def test_fail_pending_htlcs_on_shutdown(self):
"""Alice tries to pay Dave via MPP. Dave receives some HTLCs but not all.
Dave shuts down (stops wallet).
We test if Dave fails the pending HTLCs during shutdown.
"""
graph = self.prepare_chans_and_peers_in_square()
self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL))
self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL))
amount_to_pay = 600_000_000_000
peers = graph.all_peers()
graph.w_d.MPP_EXPIRY = 120
graph.w_d.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3
async def pay():
graph.w_d.features |= LnFeatures.BASIC_MPP_OPT
graph.w_b.enable_htlc_forwarding.clear() # Bob will hold forwarded HTLCs
assert graph.w_a.network.channel_db is not None
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay)
try:
async with timeout_after(0.5):
result, log = await graph.w_a.pay_invoice(pay_req, attempts=1)
except TaskTimeout:
# by now Dave hopefully received some HTLCs:
self.assertTrue(len(graph.chan_dc.hm.htlcs(LOCAL)) > 0)
self.assertTrue(len(graph.chan_dc.hm.htlcs(REMOTE)) > 0)
else:
self.fail(f"pay_invoice finished but was not supposed to. result={result}")
await graph.w_d.stop()
# Dave is supposed to have failed the pending incomplete MPP HTLCs
self.assertEqual(0, len(graph.chan_dc.hm.htlcs(LOCAL)))
self.assertEqual(0, len(graph.chan_dc.hm.htlcs(REMOTE)))
raise TestSuccess()
async def f():
async with TaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2)
await group.spawn(pay())
with self.assertRaises(TestSuccess):
run(f())
@needs_test_with_all_chacha20_implementations
def test_close(self):
alice_channel, bob_channel = create_test_channels()

Loading…
Cancel
Save