diff --git a/electrum/sql_db.py b/electrum/sql_db.py index 4d40beaec..c6928a87f 100644 --- a/electrum/sql_db.py +++ b/electrum/sql_db.py @@ -67,8 +67,8 @@ class SqlDB(Logger): self.conn.commit() self.conn.close() - self.asyncio_loop.call_soon_threadsafe(self.stopped_event.set) self.logger.info("SQL thread terminated") + self.asyncio_loop.call_soon_threadsafe(self.stopped_event.set) def create_database(self): raise NotImplementedError() diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index a14906069..dd4cdcb96 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -8,7 +8,7 @@ import logging import concurrent from concurrent import futures import unittest -from typing import Iterable, NamedTuple, Tuple +from typing import Iterable, NamedTuple, Tuple, List from aiorpcx import TaskGroup @@ -182,6 +182,11 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): def diagnostic_name(self): return self.name + async def stop(self): + if self.channel_db: + self.channel_db.stop() + await self.channel_db.stopped_event.wait() + get_payments = LNWallet.get_payments get_payment_info = LNWallet.get_payment_info save_payment_info = LNWallet.save_payment_info @@ -282,6 +287,14 @@ class SquareGraph(NamedTuple): def all_peers(self) -> Iterable[Peer]: return self.peer_ab, self.peer_ac, self.peer_ba, self.peer_bd, self.peer_ca, self.peer_cd, self.peer_db, self.peer_dc + def all_lnworkers(self) -> Iterable[MockLNWallet]: + return self.w_a, self.w_b, self.w_c, self.w_d + + async def stop_and_cleanup(self): + async with TaskGroup() as group: + for lnworker in self.all_lnworkers(): + await group.spawn(lnworker.stop()) + class PaymentDone(Exception): pass @@ -296,11 +309,19 @@ class TestPeer(ElectrumTestCase): def setUp(self): super().setUp() self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop() + self._lnworkers_created = [] # type: List[MockLNWallet] def tearDown(self): - super().tearDown() + async def cleanup_lnworkers(): + async with TaskGroup() as group: + for lnworker in self._lnworkers_created: + await group.spawn(lnworker.stop()) + self._lnworkers_created.clear() + run(cleanup_lnworkers()) + self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1) self._loop_thread.join(timeout=1) + super().tearDown() def prepare_peers(self, alice_channel, bob_channel): k1, k2 = keypair(), keypair() @@ -310,6 +331,7 @@ class TestPeer(ElectrumTestCase): q1, q2 = asyncio.Queue(), asyncio.Queue() w1 = MockLNWallet(local_keypair=k1, chans=[alice_channel], tx_queue=q1, name=bob_channel.name) w2 = MockLNWallet(local_keypair=k2, chans=[bob_channel], tx_queue=q2, name=alice_channel.name) + self._lnworkers_created.extend([w1, w2]) p1 = Peer(w1, k2.pubkey, t1) p2 = Peer(w2, k1.pubkey, t2) w1._peers[p1.pubkey] = p1 @@ -338,6 +360,7 @@ class TestPeer(ElectrumTestCase): w_b = MockLNWallet(local_keypair=key_b, chans=[chan_ba, chan_bd], tx_queue=txq_b, name="bob") w_c = MockLNWallet(local_keypair=key_c, chans=[chan_ca, chan_cd], tx_queue=txq_c, name="carol") w_d = MockLNWallet(local_keypair=key_d, chans=[chan_db, chan_dc], tx_queue=txq_d, name="dave") + self._lnworkers_created.extend([w_a, w_b, w_c, w_d]) peer_ab = Peer(w_a, key_b.pubkey, trans_ab) peer_ac = Peer(w_a, key_c.pubkey, trans_ac) peer_ba = Peer(w_b, key_a.pubkey, trans_ba) @@ -799,6 +822,7 @@ class TestPeer(ElectrumTestCase): graph = self.prepare_chans_and_peers_in_square() graph.w_d.features |= LnFeatures.BASIC_MPP_OPT graph.w_a.network.channel_db.stop() + run(graph.w_a.network.channel_db.stopped_event.wait()) graph.w_a.network.channel_db = None # Note: single attempt will fail with insufficient trampoline fee with self.assertRaises(NoPathFound): diff --git a/electrum/tests/test_lnrouter.py b/electrum/tests/test_lnrouter.py index fa99e6015..50bcd6ec7 100644 --- a/electrum/tests/test_lnrouter.py +++ b/electrum/tests/test_lnrouter.py @@ -94,10 +94,8 @@ class Test_LNRouter(TestCaseForTestnet): self.assertEqual(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', route[0].node_id) self.assertEqual(bfh('0000000000000003'), route[0].short_channel_id) - # need to duplicate tear_down here, as we also need to wait for the sql thread to stop - self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1) - self._loop_thread.join(timeout=1) - cdb.sql_thread.join(timeout=1) + cdb.stop() + asyncio.run_coroutine_threadsafe(cdb.stopped_event.wait(), self.asyncio_loop).result() @needs_test_with_all_chacha20_implementations def test_new_onion_packet_legacy(self):