From 3c019c2f9c4d2fdefe52d84444632b54d421e140 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Tue, 9 Mar 2021 17:52:36 +0100 Subject: [PATCH] daemon/wallet/network: make stop() methods async --- electrum/address_synchronizer.py | 19 +++++++----- electrum/daemon.py | 43 +++++++++++++++++--------- electrum/gui/__init__.py | 6 ++++ electrum/gui/kivy/main_window.py | 3 +- electrum/gui/qt/settings_dialog.py | 3 +- electrum/interface.py | 2 +- electrum/lnwatcher.py | 4 +-- electrum/lnworker.py | 12 +++---- electrum/network.py | 31 +++++++++---------- electrum/sql_db.py | 3 ++ electrum/tests/test_storage_upgrade.py | 15 +++++++-- electrum/tests/test_wallet.py | 15 +++++++-- electrum/util.py | 10 +++--- electrum/wallet.py | 26 ++++++++++------ run_electrum | 1 - 15 files changed, 123 insertions(+), 70 deletions(-) diff --git a/electrum/address_synchronizer.py b/electrum/address_synchronizer.py index 63935aaf0..922372014 100644 --- a/electrum/address_synchronizer.py +++ b/electrum/address_synchronizer.py @@ -28,6 +28,8 @@ import itertools from collections import defaultdict from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, NamedTuple, Sequence, List +from aiorpcx import TaskGroup + from . import bitcoin, util from .bitcoin import COINBASE_MATURITY from .util import profiler, bfh, TxMinedInfo, UnrelatedTransactionException @@ -197,16 +199,19 @@ class AddressSynchronizer(Logger): def on_blockchain_updated(self, event, *args): self._get_addr_balance_cache = {} # invalidate cache - def stop(self): + async def stop(self): if self.network: - if self.synchronizer: - asyncio.run_coroutine_threadsafe(self.synchronizer.stop(), self.network.asyncio_loop) + try: + async with TaskGroup() as group: + if self.synchronizer: + await group.spawn(self.synchronizer.stop()) + if self.verifier: + await group.spawn(self.verifier.stop()) + finally: # even if we get cancelled self.synchronizer = None - if self.verifier: - asyncio.run_coroutine_threadsafe(self.verifier.stop(), self.network.asyncio_loop) self.verifier = None - util.unregister_callback(self.on_blockchain_updated) - self.db.put('stored_height', self.get_local_height()) + util.unregister_callback(self.on_blockchain_updated) + self.db.put('stored_height', self.get_local_height()) def add_address(self, address): if not self.db.get_addr_history(address): diff --git a/electrum/daemon.py b/electrum/daemon.py index 72a3c7d8a..1fa5466ea 100644 --- a/electrum/daemon.py +++ b/electrum/daemon.py @@ -29,7 +29,7 @@ import time import traceback import sys import threading -from typing import Dict, Optional, Tuple, Iterable, Callable, Union, Sequence, Mapping +from typing import Dict, Optional, Tuple, Iterable, Callable, Union, Sequence, Mapping, TYPE_CHECKING from base64 import b64decode, b64encode from collections import defaultdict import concurrent @@ -38,7 +38,7 @@ import json import aiohttp from aiohttp import web, client_exceptions -from aiorpcx import TaskGroup +from aiorpcx import TaskGroup, timeout_after, TaskTimeout from . import util from .network import Network @@ -53,6 +53,9 @@ from .simple_config import SimpleConfig from .exchange_rate import FxThread from .logging import get_logger, Logger +if TYPE_CHECKING: + from electrum import gui + _logger = get_logger(__name__) @@ -407,6 +410,7 @@ class PayServer(Logger): class Daemon(Logger): network: Optional[Network] + gui_object: Optional[Union['gui.qt.ElectrumGui', 'gui.kivy.ElectrumGui']] @profiler def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True): @@ -523,7 +527,8 @@ class Daemon(Logger): wallet = self._wallets.pop(path, None) if not wallet: return False - wallet.stop() + fut = asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop) + fut.result() return True def run_daemon(self): @@ -544,20 +549,28 @@ class Daemon(Logger): self.running = False def on_stop(self): + self.logger.info("on_stop() entered. initiating shutdown") if self.gui_object: self.gui_object.stop() - # stop network/wallets - for k, wallet in self._wallets.items(): - wallet.stop() - if self.network: - self.logger.info("shutting down network") - self.network.stop() - self.logger.info("stopping taskgroup") - fut = asyncio.run_coroutine_threadsafe(self.taskgroup.cancel_remaining(), self.asyncio_loop) - try: - fut.result(timeout=2) - except (concurrent.futures.TimeoutError, concurrent.futures.CancelledError, asyncio.CancelledError): - pass + + @log_exceptions + async def stop_async(): + self.logger.info("stopping all wallets") + async with TaskGroup() as group: + for k, wallet in self._wallets.items(): + await group.spawn(wallet.stop()) + self.logger.info("stopping network and taskgroup") + try: + async with timeout_after(2): + async with TaskGroup() as group: + if self.network: + await group.spawn(self.network.stop(full_shutdown=True)) + await group.spawn(self.taskgroup.cancel_remaining()) + except TaskTimeout: + pass + + fut = asyncio.run_coroutine_threadsafe(stop_async(), self.asyncio_loop) + fut.result() self.logger.info("removing lockfile") remove_lockfile(get_lockfile(self.config)) self.logger.info("stopped") diff --git a/electrum/gui/__init__.py b/electrum/gui/__init__.py index 02fe271cc..f0bef9a44 100644 --- a/electrum/gui/__init__.py +++ b/electrum/gui/__init__.py @@ -3,3 +3,9 @@ # The Wallet object is instantiated by the GUI # Notifications about network events are sent to the GUI by using network.register_callback() + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from . import qt + from . import kivy diff --git a/electrum/gui/kivy/main_window.py b/electrum/gui/kivy/main_window.py index 64ae76b06..de624a0bd 100644 --- a/electrum/gui/kivy/main_window.py +++ b/electrum/gui/kivy/main_window.py @@ -190,7 +190,8 @@ class ElectrumWindow(App, Logger): if self.use_gossip: self.network.start_gossip() else: - self.network.stop_gossip() + self.network.run_from_another_thread( + self.network.stop_gossip()) android_backups = BooleanProperty(False) def on_android_backups(self, instance, x): diff --git a/electrum/gui/qt/settings_dialog.py b/electrum/gui/qt/settings_dialog.py index 307930944..eb6ce9aa0 100644 --- a/electrum/gui/qt/settings_dialog.py +++ b/electrum/gui/qt/settings_dialog.py @@ -141,7 +141,8 @@ channels graph and compute payment path locally, instead of using trampoline pay if use_gossip: self.window.network.start_gossip() else: - self.window.network.stop_gossip() + self.window.network.run_from_another_thread( + self.window.network.stop_gossip()) util.trigger_callback('ln_gossip_sync_progress') # FIXME: update all wallet windows util.trigger_callback('channels_updated', self.wallet) diff --git a/electrum/interface.py b/electrum/interface.py index 0bf4f50b9..8e85733fd 100644 --- a/electrum/interface.py +++ b/electrum/interface.py @@ -695,7 +695,7 @@ class Interface(Logger): # We give up after a while and just abort the connection. # Note: specifically if the server is running Fulcrum, waiting seems hopeless, # the connection must be aborted (see https://github.com/cculianu/Fulcrum/issues/76) - force_after = 2 # seconds + force_after = 1 # seconds if self.session: await self.session.close(force_after=force_after) # monitor_connection will cancel tasks diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py index e0dc2e5f3..5bd788c51 100644 --- a/electrum/lnwatcher.py +++ b/electrum/lnwatcher.py @@ -147,8 +147,8 @@ class LNWatcher(AddressSynchronizer): # status gets populated when we run self.channel_status = {} - def stop(self): - super().stop() + async def stop(self): + await super().stop() util.unregister_callback(self.on_network_update) def get_channel_status(self, outpoint): diff --git a/electrum/lnworker.py b/electrum/lnworker.py index d10a9735d..ee839670b 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -311,11 +311,11 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]): self._add_peers_from_config() asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop) - def stop(self): + async def stop(self): if self.listen_server: - self.network.asyncio_loop.call_soon_threadsafe(self.listen_server.close) - asyncio.run_coroutine_threadsafe(self.taskgroup.cancel_remaining(), self.network.asyncio_loop) + self.listen_server.close() util.unregister_callback(self.on_proxy_changed) + await self.taskgroup.cancel_remaining() def _add_peers_from_config(self): peer_list = self.config.get('lightning_peers', []) @@ -704,9 +704,9 @@ class LNWallet(LNWorker): tg_coro = self.taskgroup.spawn(coro) asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop) - def stop(self): - super().stop() - self.lnwatcher.stop() + async def stop(self): + await super().stop() + await self.lnwatcher.stop() self.lnwatcher = None def peer_closed(self, peer): diff --git a/electrum/network.py b/electrum/network.py index 6ba39b01b..5f4c9c09f 100644 --- a/electrum/network.py +++ b/electrum/network.py @@ -252,6 +252,11 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): default_server: ServerAddr _recent_servers: List[ServerAddr] + channel_blacklist: 'ChannelBlackList' + channel_db: Optional['ChannelDB'] = None + lngossip: Optional['LNGossip'] = None + local_watchtower: Optional['WatchTower'] = None + def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None): global _INSTANCE assert _INSTANCE is None, "Network is a singleton!" @@ -344,9 +349,6 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): # lightning network self.channel_blacklist = ChannelBlackList() - self.channel_db = None # type: Optional[ChannelDB] - self.lngossip = None # type: Optional[LNGossip] - self.local_watchtower = None # type: Optional[WatchTower] if self.config.get('run_local_watchtower', False): from . import lnwatcher self.local_watchtower = lnwatcher.WatchTower(self) @@ -373,11 +375,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): self.lngossip = lnworker.LNGossip() self.lngossip.start_network(self) - def stop_gossip(self): + async def stop_gossip(self, *, full_shutdown: bool = False): if self.lngossip: - self.lngossip.stop() + await self.lngossip.stop() self.lngossip = None self.channel_db.stop() + if full_shutdown: + await self.channel_db.stopped_event.wait() self.channel_db = None def run_from_another_thread(self, coro, *, timeout=None): @@ -623,7 +627,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): self.auto_connect = net_params.auto_connect if self.proxy != proxy or self.oneserver != net_params.oneserver: # Restart the network defaulting to the given server - await self._stop() + await self.stop(full_shutdown=False) self.default_server = server await self._start() elif self.default_server != server: @@ -1217,13 +1221,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): asyncio.run_coroutine_threadsafe(self._start(), self.asyncio_loop) @log_exceptions - async def _stop(self, full_shutdown=False): + async def stop(self, *, full_shutdown: bool = True): self.logger.info("stopping network") try: # note: cancel_remaining ~cannot be cancelled, it suppresses CancelledError - await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=2) + await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=1) except (asyncio.TimeoutError, asyncio.CancelledError) as e: - self.logger.info(f"exc during main_taskgroup cancellation: {repr(e)}") + self.logger.info(f"exc during taskgroup cancellation: {repr(e)}") self.taskgroup = None self.interface = None self.interfaces = {} @@ -1231,13 +1235,8 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): self._closing_ifaces.clear() if not full_shutdown: util.trigger_callback('network_updated') - - def stop(self): - assert self._loop_thread != threading.current_thread(), 'must not be called from network thread' - fut = asyncio.run_coroutine_threadsafe(self._stop(full_shutdown=True), self.asyncio_loop) - try: - fut.result(timeout=2) - except (concurrent.futures.TimeoutError, concurrent.futures.CancelledError): pass + if full_shutdown: + await self.stop_gossip(full_shutdown=full_shutdown) async def _ensure_there_is_a_main_interface(self): if self.is_connected(): diff --git a/electrum/sql_db.py b/electrum/sql_db.py index b1ed5921f..4d40beaec 100644 --- a/electrum/sql_db.py +++ b/electrum/sql_db.py @@ -25,6 +25,7 @@ class SqlDB(Logger): Logger.__init__(self) self.asyncio_loop = asyncio_loop self.stopping = False + self.stopped_event = asyncio.Event() self.path = path test_read_write_permissions(path) self.commit_interval = commit_interval @@ -65,6 +66,8 @@ class SqlDB(Logger): # write self.conn.commit() self.conn.close() + + self.asyncio_loop.call_soon_threadsafe(self.stopped_event.set) self.logger.info("SQL thread terminated") def create_database(self): diff --git a/electrum/tests/test_storage_upgrade.py b/electrum/tests/test_storage_upgrade.py index 8fedb1241..7f5dbd61d 100644 --- a/electrum/tests/test_storage_upgrade.py +++ b/electrum/tests/test_storage_upgrade.py @@ -3,10 +3,12 @@ import tempfile import os import json from typing import Optional +import asyncio from electrum.wallet_db import WalletDB from electrum.wallet import Wallet from electrum import constants +from electrum import util from .test_wallet import WalletTestCase @@ -15,6 +17,15 @@ from .test_wallet import WalletTestCase # TODO hw wallet with client version 2.6.x (single-, and multiacc) class TestStorageUpgrade(WalletTestCase): + def setUp(self): + super().setUp() + self.asyncio_loop, self._stop_loop, self._loop_thread = util.create_and_start_event_loop() + + def tearDown(self): + super().tearDown() + self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1) + self._loop_thread.join(timeout=1) + def testnet_wallet(func): # note: it's ok to modify global network constants in subclasses of SequentialTestCase def wrapper(self, *args, **kwargs): @@ -281,7 +292,7 @@ class TestStorageUpgrade(WalletTestCase): # to simulate ks.opportunistically_fill_in_missing_info_from_device(): ks._root_fingerprint = "deadbeef" ks.is_requesting_to_be_rewritten_to_wallet_file = True - wallet.stop() + asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop).result() def test_upgrade_from_client_2_9_3_importedkeys_keystore_changes(self): # see #6401 @@ -292,7 +303,7 @@ class TestStorageUpgrade(WalletTestCase): ["p2wpkh:L1cgMEnShp73r9iCukoPE3MogLeueNYRD9JVsfT1zVHyPBR3KqBY"], password=None ) - wallet.stop() + asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop).result() @testnet_wallet def test_upgrade_from_client_3_3_8_xpub_with_realistic_history(self): diff --git a/electrum/tests/test_wallet.py b/electrum/tests/test_wallet.py index 10c213eec..916522e35 100644 --- a/electrum/tests/test_wallet.py +++ b/electrum/tests/test_wallet.py @@ -5,8 +5,9 @@ import os import json from decimal import Decimal import time - from io import StringIO +import asyncio + from electrum.storage import WalletStorage from electrum.wallet_db import FINAL_SEED_VERSION from electrum.wallet import (Abstract_Wallet, Standard_Wallet, create_new_wallet, @@ -16,6 +17,7 @@ from electrum.util import TxMinedInfo, InvalidPassword from electrum.bitcoin import COIN from electrum.wallet_db import WalletDB from electrum.simple_config import SimpleConfig +from electrum import util from . import ElectrumTestCase @@ -237,6 +239,15 @@ class TestCreateRestoreWallet(WalletTestCase): class TestWalletPassword(WalletTestCase): + def setUp(self): + super().setUp() + self.asyncio_loop, self._stop_loop, self._loop_thread = util.create_and_start_event_loop() + + def tearDown(self): + super().tearDown() + self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1) + self._loop_thread.join(timeout=1) + def test_update_password_of_imported_wallet(self): wallet_str = '{"addr_history":{"1364Js2VG66BwRdkaoxAaFtdPb1eQgn8Dr":[],"15CyDgLffJsJgQrhcyooFH4gnVDG82pUrA":[],"1Exet2BhHsFxKTwhnfdsBMkPYLGvobxuW6":[]},"addresses":{"change":[],"receiving":["1364Js2VG66BwRdkaoxAaFtdPb1eQgn8Dr","1Exet2BhHsFxKTwhnfdsBMkPYLGvobxuW6","15CyDgLffJsJgQrhcyooFH4gnVDG82pUrA"]},"keystore":{"keypairs":{"0344b1588589958b0bcab03435061539e9bcf54677c104904044e4f8901f4ebdf5":"L2sED74axVXC4H8szBJ4rQJrkfem7UMc6usLCPUoEWxDCFGUaGUM","0389508c13999d08ffae0f434a085f4185922d64765c0bff2f66e36ad7f745cc5f":"L3Gi6EQLvYw8gEEUckmqawkevfj9s8hxoQDFveQJGZHTfyWnbk1U","04575f52b82f159fa649d2a4c353eb7435f30206f0a6cb9674fbd659f45082c37d559ffd19bea9c0d3b7dcc07a7b79f4cffb76026d5d4dff35341efe99056e22d2":"5JyVyXU1LiRXATvRTQvR9Kp8Rx1X84j2x49iGkjSsXipydtByUq"},"type":"imported"},"pruned_txo":{},"seed_version":13,"stored_height":-1,"transactions":{},"tx_fees":{},"txi":{},"txo":{},"use_encryption":false,"verified_tx3":{},"wallet_type":"standard","winpos-qt":[100,100,840,405]}' db = WalletDB(wallet_str, manual_upgrades=False) @@ -273,7 +284,7 @@ class TestWalletPassword(WalletTestCase): db = WalletDB(wallet_str, manual_upgrades=False) storage = WalletStorage(self.wallet_path) wallet = Wallet(db, storage, config=self.config) - wallet.stop() + asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop).result() storage = WalletStorage(self.wallet_path) # if storage.is_encrypted(): diff --git a/electrum/util.py b/electrum/util.py index c002f6bdb..002c91beb 100644 --- a/electrum/util.py +++ b/electrum/util.py @@ -1205,11 +1205,9 @@ class NetworkJobOnDefaultServer(Logger, ABC): if taskgroup != self.taskgroup: raise asyncio.CancelledError() - async def stop(self): - unregister_callback(self._restart) - await self._stop() - - async def _stop(self): + async def stop(self, *, full_shutdown: bool = True): + if full_shutdown: + unregister_callback(self._restart) await self.taskgroup.cancel_remaining() @log_exceptions @@ -1219,7 +1217,7 @@ class NetworkJobOnDefaultServer(Logger, ABC): return # we should get called again soon async with self._restart_lock: - await self._stop() + await self.stop(full_shutdown=False) self._reset() await self._start(interface) diff --git a/electrum/wallet.py b/electrum/wallet.py index a8ed68a16..89d9dcc64 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -46,7 +46,7 @@ import itertools import threading import enum -from aiorpcx import TaskGroup +from aiorpcx import TaskGroup, timeout_after, TaskTimeout from .i18n import _ from .bip32 import BIP32Node, convert_bip32_intpath_to_strpath, convert_bip32_path_to_list_of_uint32 @@ -353,15 +353,21 @@ class Abstract_Wallet(AddressSynchronizer, ABC): ln_xprv = node.to_xprv() self.db.put('lightning_privkey2', ln_xprv) - def stop(self): - super().stop() - if any([ks.is_requesting_to_be_rewritten_to_wallet_file for ks in self.get_keystores()]): - self.save_keystore() - if self.network: - if self.lnworker: - self.lnworker.stop() - self.lnworker = None - self.save_db() + async def stop(self): + """Stop all networking and save DB to disk.""" + try: + async with timeout_after(5): + await super().stop() + if self.network: + if self.lnworker: + await self.lnworker.stop() + self.lnworker = None + except TaskTimeout: + pass + finally: # even if we get cancelled + if any([ks.is_requesting_to_be_rewritten_to_wallet_file for ks in self.get_keystores()]): + self.save_keystore() + self.save_db() def set_up_to_date(self, b): super().set_up_to_date(b) diff --git a/run_electrum b/run_electrum index d9b06d06c..7af9a3c89 100755 --- a/run_electrum +++ b/run_electrum @@ -345,7 +345,6 @@ def main(): print_stderr('unknown command:', uri) sys.exit(1) - # singleton config = SimpleConfig(config_options) if config.get('testnet'):