diff --git a/electrum/address_synchronizer.py b/electrum/address_synchronizer.py index e7e262dc7..cd58feedf 100644 --- a/electrum/address_synchronizer.py +++ b/electrum/address_synchronizer.py @@ -56,11 +56,9 @@ class AddressSynchronizer(PrintError): def __init__(self, storage): self.storage = storage self.network = None - # verifier (SPV) and synchronizer are started in start_threads - self.synchronizer = None - self.verifier = None - self.sync_restart_lock = asyncio.Lock() - self.group = None + # verifier (SPV) and synchronizer are started in start_network + self.synchronizer = None # type: Synchronizer + self.verifier = None # type: SPV # locks: if you need to take multiple ones, acquire them in the order they are defined here! self.lock = threading.RLock() self.transaction_lock = threading.RLock() @@ -143,45 +141,20 @@ class AddressSynchronizer(PrintError): # add it in case it was previously unconfirmed self.add_unverified_tx(tx_hash, tx_height) - @aiosafe - async def on_default_server_changed(self, event): - async with self.sync_restart_lock: - self.stop_threads(write_to_disk=False) - await self._start_threads() - def start_network(self, network): self.network = network if self.network is not None: - self.network.register_callback(self.on_default_server_changed, ['default_server_changed']) - asyncio.run_coroutine_threadsafe(self._start_threads(), network.asyncio_loop) - - async def _start_threads(self): - interface = self.network.interface - if interface is None: - return # we should get called again soon - - self.verifier = SPV(self.network, self) - self.synchronizer = synchronizer = Synchronizer(self) - assert self.group is None, 'group already exists' - self.group = SilentTaskGroup() - - async def job(): - async with self.group as group: - await group.spawn(self.verifier.main(group)) - await group.spawn(self.synchronizer.send_subscriptions(group)) - await group.spawn(self.synchronizer.handle_status(group)) - await group.spawn(self.synchronizer.main()) - # we are being cancelled now - interface.session.unsubscribe(synchronizer.status_queue) - await interface.group.spawn(job) + self.synchronizer = Synchronizer(self) + self.verifier = SPV(self.network, self) def stop_threads(self, write_to_disk=True): if self.network: - self.synchronizer = None - self.verifier = None - if self.group: - asyncio.run_coroutine_threadsafe(self.group.cancel_remaining(), self.network.asyncio_loop) - self.group = None + if self.synchronizer: + asyncio.run_coroutine_threadsafe(self.synchronizer.stop(), self.network.asyncio_loop) + self.synchronizer = None + if self.verifier: + asyncio.run_coroutine_threadsafe(self.verifier.stop(), self.network.asyncio_loop) + self.verifier = None self.storage.put('stored_height', self.get_local_height()) if write_to_disk: self.save_transactions() diff --git a/electrum/commands.py b/electrum/commands.py index 03f7d33a3..9ce251f7a 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -40,7 +40,7 @@ from .bitcoin import is_address, hash_160, COIN, TYPE_ADDRESS from .i18n import _ from .transaction import Transaction, multisig_script, TxOutput from .paymentrequest import PR_PAID, PR_UNPAID, PR_UNKNOWN, PR_EXPIRED -from .plugin import run_hook +from .synchronizer import Notifier known_commands = {} @@ -635,21 +635,11 @@ class Commands: self.wallet.remove_payment_request(k, self.config) @command('n') - def notify(self, address, URL): + def notify(self, address: str, URL: str): """Watch an address. Every time the address changes, a http POST is sent to the URL.""" - raise NotImplementedError() # TODO this method is currently broken - def callback(x): - import urllib.request - headers = {'content-type':'application/json'} - data = {'address':address, 'status':x.get('result')} - serialized_data = util.to_bytes(json.dumps(data)) - try: - req = urllib.request.Request(URL, serialized_data, headers) - response_stream = urllib.request.urlopen(req, timeout=5) - util.print_error('Got Response for %s' % address) - except BaseException as e: - util.print_error(str(e)) - self.network.subscribe_to_addresses([address], callback) + if not hasattr(self, "_notifier"): + self._notifier = Notifier(self.network) + self.network.run_from_another_thread(self._notifier.start_watching_queue.put((address, URL))) return True @command('wn') diff --git a/electrum/interface.py b/electrum/interface.py index ef3711381..e1586a85b 100644 --- a/electrum/interface.py +++ b/electrum/interface.py @@ -28,7 +28,7 @@ import ssl import sys import traceback import asyncio -from typing import Tuple, Union +from typing import Tuple, Union, List from collections import defaultdict import aiorpcx @@ -57,7 +57,7 @@ class NotificationSession(ClientSession): # will catch the exception, count errors, and at some point disconnect if isinstance(request, Notification): params, result = request.args[:-1], request.args[-1] - key = self.get_index(request.method, params) + key = self.get_hashable_key_for_rpc_call(request.method, params) if key in self.subscriptions: self.cache[key] = result for queue in self.subscriptions[key]: @@ -78,10 +78,10 @@ class NotificationSession(ClientSession): except asyncio.TimeoutError as e: raise RequestTimedOut('request timed out: {}'.format(args)) from e - async def subscribe(self, method, params, queue): + async def subscribe(self, method: str, params: List, queue: asyncio.Queue): # note: until the cache is written for the first time, # each 'subscribe' call might make a request on the network. - key = self.get_index(method, params) + key = self.get_hashable_key_for_rpc_call(method, params) self.subscriptions[key].append(queue) if key in self.cache: result = self.cache[key] @@ -99,7 +99,7 @@ class NotificationSession(ClientSession): v.remove(queue) @classmethod - def get_index(cls, method, params): + def get_hashable_key_for_rpc_call(cls, method, params): """Hashable index for subscriptions and cache""" return str(method) + repr(params) @@ -141,7 +141,7 @@ class Interface(PrintError): self._requested_chunks = set() self.network = network self._set_proxy(proxy) - self.session = None + self.session = None # type: NotificationSession self.tip_header = None self.tip = 0 diff --git a/electrum/network.py b/electrum/network.py index 5eb7db4c9..d2d9a46ba 100644 --- a/electrum/network.py +++ b/electrum/network.py @@ -852,3 +852,54 @@ class Network(PrintError): await self.interface.group.spawn(self._request_fee_estimates, self.interface) await asyncio.sleep(0.1) + + +class NetworkJobOnDefaultServer(PrintError): + """An abstract base class for a job that runs on the main network + interface. Every time the main interface changes, the job is + restarted, and some of its internals are reset. + """ + def __init__(self, network: Network): + asyncio.set_event_loop(network.asyncio_loop) + self.network = network + self.interface = None # type: Interface + self._restart_lock = asyncio.Lock() + self._reset() + asyncio.run_coroutine_threadsafe(self._restart(), network.asyncio_loop) + network.register_callback(self._restart, ['default_server_changed']) + + def _reset(self): + """Initialise fields. Called every time the underlying + server connection changes. + """ + self.group = SilentTaskGroup() + + async def _start(self, interface): + self.interface = interface + await interface.group.spawn(self._start_tasks) + + async def _start_tasks(self): + """Start tasks in self.group. Called every time the underlying + server connection changes. + """ + raise NotImplementedError() # implemented by subclasses + + async def stop(self): + await self.group.cancel_remaining() + + @aiosafe + async def _restart(self, *args): + interface = self.network.interface + if interface is None: + return # we should get called again soon + + async with self._restart_lock: + await self.stop() + self._reset() + await self._start(interface) + + @property + def session(self): + s = self.interface.session + assert s is not None + return s diff --git a/electrum/synchronizer.py b/electrum/synchronizer.py index cb84f963f..12bbb18c8 100644 --- a/electrum/synchronizer.py +++ b/electrum/synchronizer.py @@ -24,12 +24,15 @@ # SOFTWARE. import asyncio import hashlib +from typing import Dict, List +from collections import defaultdict from aiorpcx import TaskGroup, run_in_thread from .transaction import Transaction -from .util import bh2u, PrintError +from .util import bh2u, make_aiohttp_session from .bitcoin import address_to_scripthash +from .network import NetworkJobOnDefaultServer def history_status(h): @@ -41,7 +44,68 @@ def history_status(h): return bh2u(hashlib.sha256(status.encode('ascii')).digest()) -class Synchronizer(PrintError): +class SynchronizerBase(NetworkJobOnDefaultServer): + """Subscribe over the network to a set of addresses, and monitor their statuses. + Every time a status changes, run a coroutine provided by the subclass. + """ + def __init__(self, network): + NetworkJobOnDefaultServer.__init__(self, network) + self.asyncio_loop = network.asyncio_loop + + def _reset(self): + super()._reset() + self.requested_addrs = set() + self.scripthash_to_address = {} + self._processed_some_notifications = False # so that we don't miss them + # Queues + self.add_queue = asyncio.Queue() + self.status_queue = asyncio.Queue() + + async def _start_tasks(self): + try: + async with self.group as group: + await group.spawn(self.send_subscriptions()) + await group.spawn(self.handle_status()) + await group.spawn(self.main()) + finally: + # we are being cancelled now + self.session.unsubscribe(self.status_queue) + + def add(self, addr): + asyncio.run_coroutine_threadsafe(self._add_address(addr), self.asyncio_loop) + + async def _add_address(self, addr): + if addr in self.requested_addrs: return + self.requested_addrs.add(addr) + await self.add_queue.put(addr) + + async def _on_address_status(self, addr, status): + """Handle the change of the status of an address.""" + raise NotImplementedError() # implemented by subclasses + + async def send_subscriptions(self): + async def subscribe_to_address(addr): + h = address_to_scripthash(addr) + self.scripthash_to_address[h] = addr + await self.session.subscribe('blockchain.scripthash.subscribe', [h], self.status_queue) + self.requested_addrs.remove(addr) + + while True: + addr = await self.add_queue.get() + await self.group.spawn(subscribe_to_address, addr) + + async def handle_status(self): + while True: + h, status = await self.status_queue.get() + addr = self.scripthash_to_address[h] + await self.group.spawn(self._on_address_status, addr, status) + self._processed_some_notifications = True + + async def main(self): + raise NotImplementedError() # implemented by subclasses + + +class Synchronizer(SynchronizerBase): '''The synchronizer keeps the wallet up-to-date with its set of addresses and their transactions. It subscribes over the network to wallet addresses, gets the wallet to generate new addresses @@ -51,16 +115,12 @@ class Synchronizer(PrintError): ''' def __init__(self, wallet): self.wallet = wallet - self.network = wallet.network - self.asyncio_loop = wallet.network.asyncio_loop + SynchronizerBase.__init__(self, wallet.network) + + def _reset(self): + super()._reset() self.requested_tx = {} self.requested_histories = {} - self.requested_addrs = set() - self.scripthash_to_address = {} - self._processed_some_notifications = False # so that we don't miss them - # Queues - self.add_queue = asyncio.Queue() - self.status_queue = asyncio.Queue() def diagnostic_name(self): return '{}:{}'.format(self.__class__.__name__, self.wallet.diagnostic_name()) @@ -70,14 +130,6 @@ class Synchronizer(PrintError): and not self.requested_histories and not self.requested_tx) - def add(self, addr): - asyncio.run_coroutine_threadsafe(self._add(addr), self.asyncio_loop) - - async def _add(self, addr): - if addr in self.requested_addrs: return - self.requested_addrs.add(addr) - await self.add_queue.put(addr) - async def _on_address_status(self, addr, status): history = self.wallet.history.get(addr, []) if history_status(history) == status: @@ -144,30 +196,6 @@ class Synchronizer(PrintError): # callbacks self.wallet.network.trigger_callback('new_transaction', self.wallet, tx) - async def send_subscriptions(self, group: TaskGroup): - async def subscribe_to_address(addr): - h = address_to_scripthash(addr) - self.scripthash_to_address[h] = addr - await self.session.subscribe('blockchain.scripthash.subscribe', [h], self.status_queue) - self.requested_addrs.remove(addr) - - while True: - addr = await self.add_queue.get() - await group.spawn(subscribe_to_address, addr) - - async def handle_status(self, group: TaskGroup): - while True: - h, status = await self.status_queue.get() - addr = self.scripthash_to_address[h] - await group.spawn(self._on_address_status, addr, status) - self._processed_some_notifications = True - - @property - def session(self): - s = self.wallet.network.interface.session - assert s is not None - return s - async def main(self): self.wallet.set_up_to_date(False) # request missing txns, if any @@ -178,7 +206,7 @@ class Synchronizer(PrintError): await self._request_missing_txs(history) # add addresses to bootstrap for addr in self.wallet.get_addresses(): - await self._add(addr) + await self._add_address(addr) # main loop while True: await asyncio.sleep(0.1) @@ -189,3 +217,37 @@ class Synchronizer(PrintError): self._processed_some_notifications = False self.wallet.set_up_to_date(up_to_date) self.wallet.network.trigger_callback('wallet_updated', self.wallet) + + +class Notifier(SynchronizerBase): + """Watch addresses. Every time the status of an address changes, + an HTTP POST is sent to the corresponding URL. + """ + def __init__(self, network): + SynchronizerBase.__init__(self, network) + self.watched_addresses = defaultdict(list) # type: Dict[str, List[str]] + self.start_watching_queue = asyncio.Queue() + + async def main(self): + # resend existing subscriptions if we were restarted + for addr in self.watched_addresses: + await self._add_address(addr) + # main loop + while True: + addr, url = await self.start_watching_queue.get() + self.watched_addresses[addr].append(url) + await self._add_address(addr) + + async def _on_address_status(self, addr, status): + self.print_error('new status for addr {}'.format(addr)) + headers = {'content-type': 'application/json'} + data = {'address': addr, 'status': status} + for url in self.watched_addresses[addr]: + try: + async with make_aiohttp_session(proxy=self.network.proxy, headers=headers) as session: + async with session.post(url, json=data, headers=headers) as resp: + await resp.text() + except Exception as e: + self.print_error(str(e)) + else: + self.print_error('Got Response for {}'.format(addr)) diff --git a/electrum/util.py b/electrum/util.py index 20af6e179..9ef3e5169 100644 --- a/electrum/util.py +++ b/electrum/util.py @@ -869,7 +869,12 @@ VerifiedTxInfo = NamedTuple("VerifiedTxInfo", [("height", int), ("txpos", int), ("header_hash", str)]) -def make_aiohttp_session(proxy): + +def make_aiohttp_session(proxy: dict, headers=None, timeout=None): + if headers is None: + headers = {'User-Agent': 'Electrum'} + if timeout is None: + timeout = aiohttp.ClientTimeout(total=10) if proxy: connector = SocksConnector( socks_ver=SocksVer.SOCKS5 if proxy['mode'] == 'socks5' else SocksVer.SOCKS4, @@ -879,9 +884,9 @@ def make_aiohttp_session(proxy): password=proxy.get('password', None), rdns=True ) - return aiohttp.ClientSession(headers={'User-Agent' : 'Electrum'}, timeout=aiohttp.ClientTimeout(total=10), connector=connector) + return aiohttp.ClientSession(headers=headers, timeout=timeout, connector=connector) else: - return aiohttp.ClientSession(headers={'User-Agent' : 'Electrum'}, timeout=aiohttp.ClientTimeout(total=10)) + return aiohttp.ClientSession(headers=headers, timeout=timeout) class SilentTaskGroup(TaskGroup): diff --git a/electrum/verifier.py b/electrum/verifier.py index cb6bdfa6d..d2c357597 100644 --- a/electrum/verifier.py +++ b/electrum/verifier.py @@ -25,14 +25,14 @@ import asyncio from typing import Sequence, Optional import aiorpcx -from aiorpcx import TaskGroup -from .util import PrintError, bh2u, VerifiedTxInfo +from .util import bh2u, VerifiedTxInfo from .bitcoin import Hash, hash_decode, hash_encode from .transaction import Transaction from .blockchain import hash_header from .interface import GracefulDisconnect from . import constants +from .network import NetworkJobOnDefaultServer class MerkleVerificationFailure(Exception): pass @@ -41,26 +41,33 @@ class MerkleRootMismatch(MerkleVerificationFailure): pass class InnerNodeOfSpvProofIsValidTx(MerkleVerificationFailure): pass -class SPV(PrintError): +class SPV(NetworkJobOnDefaultServer): """ Simple Payment Verification """ def __init__(self, network, wallet): + NetworkJobOnDefaultServer.__init__(self, network) self.wallet = wallet - self.network = network + + def _reset(self): + super()._reset() self.merkle_roots = {} # txid -> merkle root (once it has been verified) self.requested_merkle = set() # txid set of pending requests + async def _start_tasks(self): + async with self.group as group: + await group.spawn(self.main) + def diagnostic_name(self): return '{}:{}'.format(self.__class__.__name__, self.wallet.diagnostic_name()) - async def main(self, group: TaskGroup): + async def main(self): self.blockchain = self.network.blockchain() while True: await self._maybe_undo_verifications() - await self._request_proofs(group) + await self._request_proofs() await asyncio.sleep(0.1) - async def _request_proofs(self, group: TaskGroup): + async def _request_proofs(self): local_height = self.blockchain.height() unverified = self.wallet.get_unverified_txs() @@ -75,12 +82,12 @@ class SPV(PrintError): header = self.blockchain.read_header(tx_height) if header is None: if tx_height < constants.net.max_checkpoint(): - await group.spawn(self.network.request_chunk(tx_height, None, can_return_early=True)) + await self.group.spawn(self.network.request_chunk(tx_height, None, can_return_early=True)) continue # request now self.print_error('requested merkle', tx_hash) self.requested_merkle.add(tx_hash) - await group.spawn(self._request_and_verify_single_proof, tx_hash, tx_height) + await self.group.spawn(self._request_and_verify_single_proof, tx_hash, tx_height) async def _request_and_verify_single_proof(self, tx_hash, tx_height): try: diff --git a/electrum/websockets.py b/electrum/websockets.py index 4637f83e1..b4c380b16 100644 --- a/electrum/websockets.py +++ b/electrum/websockets.py @@ -22,44 +22,49 @@ # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import queue -import threading, os, json +import threading +import os +import json from collections import defaultdict +import asyncio +from typing import Dict, List +import traceback +import sys + try: from SimpleWebSocketServer import WebSocket, SimpleSSLWebSocketServer except ImportError: - import sys sys.exit("install SimpleWebSocketServer") -from . import util +from .util import PrintError from . import bitcoin +from .synchronizer import SynchronizerBase + +request_queue = asyncio.Queue() -request_queue = queue.Queue() -class ElectrumWebSocket(WebSocket): +class ElectrumWebSocket(WebSocket, PrintError): def handleMessage(self): assert self.data[0:3] == 'id:' - util.print_error("message received", self.data) + self.print_error("message received", self.data) request_id = self.data[3:] - request_queue.put((self, request_id)) + asyncio.run_coroutine_threadsafe( + request_queue.put((self, request_id)), asyncio.get_event_loop()) def handleConnected(self): - util.print_error("connected", self.address) + self.print_error("connected", self.address) def handleClose(self): - util.print_error("closed", self.address) + self.print_error("closed", self.address) - -class WsClientThread(util.DaemonThread): +class BalanceMonitor(SynchronizerBase): def __init__(self, config, network): - util.DaemonThread.__init__(self) - self.network = network + SynchronizerBase.__init__(self, network) self.config = config - self.response_queue = queue.Queue() - self.subscriptions = defaultdict(list) + self.expected_payments = defaultdict(list) # type: Dict[str, List[WebSocket, int]] def make_request(self, request_id): # read json file @@ -72,69 +77,47 @@ class WsClientThread(util.DaemonThread): amount = d.get('amount') return addr, amount - def reading_thread(self): - while self.is_running(): - try: - ws, request_id = request_queue.get() - except queue.Empty: - continue + async def main(self): + # resend existing subscriptions if we were restarted + for addr in self.expected_payments: + await self._add_address(addr) + # main loop + while True: + ws, request_id = await request_queue.get() try: addr, amount = self.make_request(request_id) - except: + except Exception: + traceback.print_exc(file=sys.stderr) continue - l = self.subscriptions.get(addr, []) - l.append((ws, amount)) - self.subscriptions[addr] = l - self.network.subscribe_to_addresses([addr], self.response_queue.put) - - def run(self): - threading.Thread(target=self.reading_thread).start() - while self.is_running(): - try: - r = self.response_queue.get(timeout=0.1) - except queue.Empty: - continue - util.print_error('response', r) - method = r.get('method') - result = r.get('result') - if result is None: - continue - if method == 'blockchain.scripthash.subscribe': - addr = r.get('params')[0] - scripthash = bitcoin.address_to_scripthash(addr) - self.network.get_balance_for_scripthash( - scripthash, self.response_queue.put) - elif method == 'blockchain.scripthash.get_balance': - scripthash = r.get('params')[0] - addr = self.network.h2addr.get(scripthash, None) - if addr is None: - util.print_error( - "can't find address for scripthash: %s" % scripthash) - l = self.subscriptions.get(addr, []) - for ws, amount in l: - if not ws.closed: - if sum(result.values()) >=amount: - ws.sendMessage('paid') + self.expected_payments[addr].append((ws, amount)) + await self._add_address(addr) + async def _on_address_status(self, addr, status): + self.print_error('new status for addr {}'.format(addr)) + sh = bitcoin.address_to_scripthash(addr) + balance = await self.network.get_balance_for_scripthash(sh) + for ws, amount in self.expected_payments[addr]: + if not ws.closed: + if sum(balance.values()) >= amount: + ws.sendMessage('paid') class WebSocketServer(threading.Thread): - def __init__(self, config, ns): + def __init__(self, config, network): threading.Thread.__init__(self) self.config = config - self.net_server = ns + self.network = network + asyncio.set_event_loop(network.asyncio_loop) self.daemon = True + self.balance_monitor = BalanceMonitor(self.config, self.network) + self.start() def run(self): - t = WsClientThread(self.config, self.net_server) - t.start() - + asyncio.set_event_loop(self.network.asyncio_loop) host = self.config.get('websocket_server') port = self.config.get('websocket_port', 9999) certfile = self.config.get('ssl_chain') keyfile = self.config.get('ssl_privkey') self.server = SimpleSSLWebSocketServer(host, port, ElectrumWebSocket, certfile, keyfile) self.server.serveforever() - - diff --git a/run_electrum b/run_electrum index 148c54ee4..9e9f1a9ca 100755 --- a/run_electrum +++ b/run_electrum @@ -438,7 +438,7 @@ if __name__ == '__main__': d = daemon.Daemon(config, fd) if config.get('websocket_server'): from electrum import websockets - websockets.WebSocketServer(config, d.network).start() + websockets.WebSocketServer(config, d.network) if config.get('requests_dir'): path = os.path.join(config.get('requests_dir'), 'index.html') if not os.path.exists(path):