From 3b6af914e1e89b1b86cdad73f10ec1ba260b9d6a Mon Sep 17 00:00:00 2001 From: ThomasV Date: Tue, 11 Sep 2018 14:57:59 +0200 Subject: [PATCH] add multiplexing capability to NotificationSession, simplify interface --- electrum/interface.py | 77 ++++++++++++++++++++-------------------- electrum/synchronizer.py | 4 +-- 2 files changed, 40 insertions(+), 41 deletions(-) diff --git a/electrum/interface.py b/electrum/interface.py index 64e773842..bac6f59f6 100644 --- a/electrum/interface.py +++ b/electrum/interface.py @@ -47,22 +47,21 @@ from . import constants class NotificationSession(ClientSession): - def __init__(self, scripthash, header, *args, **kwargs): + def __init__(self, *args, **kwargs): super(NotificationSession, self).__init__(*args, **kwargs) - # queues: - self.scripthash = scripthash - self.header = header + self.subscriptions = {} + self.cache = {} async def handle_request(self, request): # note: if server sends malformed request and we raise, the superclass # will catch the exception, count errors, and at some point disconnect if isinstance(request, Notification): - if request.method == 'blockchain.scripthash.subscribe' and self.scripthash is not None: - scripthash, status = request.args - await self.scripthash.put((scripthash, status)) - elif request.method == 'blockchain.headers.subscribe' and self.header is not None: - deser = deserialize_header(bfh(request.args[0]['hex']), request.args[0]['height']) - await self.header.put(deser) + params, result = request.args[:-1], request.args[-1] + key = request.method + repr(params) + if key in self.subscriptions: + self.cache[key] = result + for queue in self.subscriptions[key]: + await queue.put(request.args) else: assert False, request.method @@ -73,6 +72,17 @@ class NotificationSession(ClientSession): super().send_request(*args, **kwargs), timeout) + async def subscribe(self, method, params, queue): + key = method + repr(params) + if key in self.subscriptions: + self.subscriptions[key].append(queue) + result = self.cache[key] + else: + result = await self.send_request(method, params) + self.subscriptions[key] = [queue] + self.cache[key] = result + await queue.put(params + [result]) + # FIXME this is often raised inside a TaskGroup, but then it's not silent :( class GracefulDisconnect(AIOSafeSilentException): pass @@ -122,7 +132,6 @@ class Interface(PrintError): self.tip_header = None self.tip = 0 - self.tip_lock = asyncio.Lock() # TODO combine? self.fut = asyncio.get_event_loop().create_task(self.run()) @@ -280,7 +289,7 @@ class Interface(PrintError): async def open_session(self, sslc, exit_early): header_queue = asyncio.Queue() - self.session = NotificationSession(None, header_queue, self.host, self.port, ssl=sslc, proxy=self.proxy) + self.session = NotificationSession(self.host, self.port, ssl=sslc, proxy=self.proxy) async with self.session as session: try: ver = await session.send_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION]) @@ -289,14 +298,10 @@ class Interface(PrintError): if exit_early: return self.print_error(ver, self.host) - subscription_res = await session.send_request('blockchain.headers.subscribe') - self.tip_header = blockchain.deserialize_header(bfh(subscription_res['hex']), subscription_res['height']) - self.tip = subscription_res['height'] - self.mark_ready() - copy_header_queue = asyncio.Queue() + await session.subscribe('blockchain.headers.subscribe', [], header_queue) async with self.group as group: - await group.spawn(self.run_fetch_blocks(subscription_res, copy_header_queue)) - await group.spawn(self.subscribe_to_headers(header_queue, copy_header_queue)) + await group.spawn(self.ping()) + await group.spawn(self.run_fetch_blocks(header_queue)) await group.spawn(self.monitor_connection()) # NOTE: group.__aexit__ will be called here; this is needed to notice exceptions in the group! @@ -306,33 +311,29 @@ class Interface(PrintError): if not self.session or self.session.is_closing(): raise GracefulDisconnect('server closed session') - async def subscribe_to_headers(self, header_queue, copy_header_queue): + async def ping(self): while True: - try: - new_header = await asyncio.wait_for(header_queue.get(), 300) - async with self.tip_lock: - self.tip_header = new_header - self.tip = new_header['block_height'] - await copy_header_queue.put(new_header) - except concurrent.futures.TimeoutError: - await self.session.send_request('server.ping', timeout=10) + await asyncio.sleep(300) + await self.session.send_request('server.ping', timeout=10) def close(self): self.fut.cancel() asyncio.get_event_loop().create_task(self.group.cancel_remaining()) - async def run_fetch_blocks(self, sub_reply, replies): - if self.tip < constants.net.max_checkpoint(): - raise GracefulDisconnect('server tip below max checkpoint') - - async with self.network.bhi_lock: - height = self.blockchain.height()+1 - await replies.put(blockchain.deserialize_header(bfh(sub_reply['hex']), sub_reply['height'])) - + async def run_fetch_blocks(self, header_queue): while True: self.network.notify('updated') - item = await replies.get() - async with self.network.bhi_lock, self.tip_lock: + item = await header_queue.get() + item = item[0] + height = item['height'] + item = blockchain.deserialize_header(bfh(item['hex']), item['height']) + self.tip_header = item + self.tip = height + if self.tip < constants.net.max_checkpoint(): + raise GracefulDisconnect('server tip below max checkpoint') + if not self.ready.done(): + self.mark_ready() + async with self.network.bhi_lock: if self.blockchain.height() < item['block_height']-1: _, height = await self.sync_until(height, None) if self.blockchain.height() >= height and self.blockchain.check_header(item): diff --git a/electrum/synchronizer.py b/electrum/synchronizer.py index ba4611d64..8182f8321 100644 --- a/electrum/synchronizer.py +++ b/electrum/synchronizer.py @@ -141,9 +141,7 @@ class Synchronizer(PrintError): async def subscribe_to_address(self, addr): h = address_to_scripthash(addr) self.scripthash_to_address[h] = addr - self.session.scripthash = self.status_queue - status = await self.session.send_request('blockchain.scripthash.subscribe', [h]) - await self.status_queue.put((h, status)) + await self.session.subscribe('blockchain.scripthash.subscribe', [h], self.status_queue) self.requested_addrs.remove(addr) async def send_subscriptions(self, interface):