diff --git a/electrumx_server.py b/electrumx_server.py index 939bd2e..a0a505e 100755 --- a/electrumx_server.py +++ b/electrumx_server.py @@ -32,22 +32,21 @@ def main_loop(): def on_signal(signame): '''Call on receipt of a signal to cleanly shutdown.''' logging.warning('received {} signal, shutting down'.format(signame)) - for task in asyncio.Task.all_tasks(): - task.cancel() + future.cancel() + + server = BlockServer(Env()) + future = asyncio.ensure_future(server.main_loop()) # Install signal handlers for signame in ('SIGINT', 'SIGTERM'): loop.add_signal_handler(getattr(signal, signame), partial(on_signal, signame)) - server = BlockServer(Env()) - future = server.start() try: loop.run_until_complete(future) except asyncio.CancelledError: pass finally: - server.stop() loop.close() diff --git a/server/block_processor.py b/server/block_processor.py index 9b5a8e9..cc959c5 100644 --- a/server/block_processor.py +++ b/server/block_processor.py @@ -329,6 +329,7 @@ class BlockProcessor(server.db.DB): self.daemon.debug_set_height(self.height) self.mempool = MemPool(self) self.touched = set() + self.futures = [] # Meta self.utxo_MB = env.utxo_MB @@ -371,24 +372,30 @@ class BlockProcessor(server.db.DB): self.clean_db() - def start(self): - '''Returns a future that starts the block processor when awaited.''' - return asyncio.gather(self.main_loop(), - self.prefetcher.main_loop()) - async def main_loop(self): '''Main loop for block processing. Safely flushes the DB on clean shutdown. ''' + self.futures.append(asyncio.ensure_future(self.prefetcher.main_loop())) try: while True: await self._wait_for_update() await asyncio.sleep(0) # Yield except asyncio.CancelledError: - self.flush(True) + self.on_cancel() + # This lets the asyncio subsystem process futures cancellations + await asyncio.sleep(0) raise + def on_cancel(self): + '''Called when the main loop is cancelled. + + Intended to be overridden in derived classes.''' + for future in self.futures: + future.cancel() + self.flush(True) + async def _wait_for_update(self): '''Wait for the prefetcher to deliver blocks or a mempool update. diff --git a/server/protocol.py b/server/protocol.py index 05bb5f7..183a3da 100644 --- a/server/protocol.py +++ b/server/protocol.py @@ -42,9 +42,10 @@ class BlockServer(BlockProcessor): self.bs_caught_up = True self.server_mgr.notify(self.height, self.touched) - def stop(self): - '''Close the listening servers.''' + def on_cancel(self): + '''Called when the main loop is cancelled.''' self.server_mgr.stop() + super().on_cancel() class ServerManager(LoggedClass): @@ -58,9 +59,8 @@ class ServerManager(LoggedClass): self.env = env self.servers = [] self.irc = IRC(env) - self.sessions = set() - self.queue = asyncio.Queue() - self.current_task = None + self.sessions = {} + self.futures = [] # At present just the IRC future, if any async def start_server(self, kind, *args, **kw_args): loop = asyncio.get_event_loop() @@ -100,11 +100,9 @@ class ServerManager(LoggedClass): sslc.load_cert_chain(env.ssl_certfile, keyfile=env.ssl_keyfile) await self.start_server('SSL', env.host, env.ssl_port, ssl=sslc) - asyncio.ensure_future(self.run_tasks()) - if env.irc: self.logger.info('starting IRC coroutine') - asyncio.ensure_future(self.irc.start()) + self.futures.append(asyncio.ensure_future(self.irc.start())) else: self.logger.info('IRC disabled') @@ -115,48 +113,25 @@ class ServerManager(LoggedClass): ElectrumX.notify(sessions, height, touched) def stop(self): - '''Close the listening servers.''' + '''Close listening servers.''' for server in self.servers: server.close() + self.servers = [] + for future in self.futures: + future.cancel() + self.futures = [] + sessions = list(self.sessions.keys()) # A copy + for session in sessions: + self.remove_session(session) def add_session(self, session): assert session not in self.sessions - self.sessions.add(session) + coro = session.serve_requests() + self.sessions[session] = asyncio.ensure_future(coro) def remove_session(self, session): - self.sessions.remove(session) - if self.current_task and session == self.current_task.session: - self.current_task.task.cancel() - - def add_task(self, session, request): - assert session in self.sessions - self.queue.put_nowait((session, request)) - - async def run_tasks(self): - '''Asynchronously run through the task queue.''' - while True: - session, request = await self.queue.get() - if not session in self.sessions: - continue - coro = session.handle_json_request(request) - task = asyncio.ensure_future(coro) - try: - self.current_task = self.MgrTask(session, task) - start = time.time() - await task - secs = time.time() - start - if secs > 1: - self.logger.warning('slow request for {} took {:.1f}s: {}' - .format(session.peername(), secs, - request)) - except asyncio.CancelledError: - self.logger.info('running task cancelled') - except Exception: - # Getting here should probably be considered a bug and fixed - self.logger.error('error handling request {}'.format(request)) - traceback.print_exc() - finally: - self.current_task = None + future = self.sessions.pop(session) + future.cancel() def irc_peers(self): return self.irc.peers @@ -205,7 +180,12 @@ class ServerManager(LoggedClass): class Session(JSONRPC): - '''Base class of ElectrumX JSON session protocols.''' + '''Base class of ElectrumX JSON session protocols. + + Each session runs its tasks in asynchronous parallelism with other + sessions. To prevent some sessions blocking othersr, potentially + long-running requests should yield (not yet implemented). + ''' def __init__(self, manager, bp, env, kind): super().__init__() @@ -216,6 +196,8 @@ class Session(JSONRPC): self.coin = bp.coin self.kind = kind self.hash168s = set() + self.requests = asyncio.Queue() + self.current_task = None self.client = 'unknown' def connection_made(self, transport): @@ -240,7 +222,25 @@ class Session(JSONRPC): def on_json_request(self, request): '''Queue the request for asynchronous handling.''' - self.manager.add_task(self, request) + self.requests.put_nowait(request) + + async def serve_requests(self): + '''Asynchronously run through the task queue.''' + while True: + await asyncio.sleep(0) + request = await self.requests.get() + try: + start = time.time() + await self.handle_json_request(request) + secs = time.time() - start + if secs > 1: + self.logger.warning('slow request for {} took {:.1f}s: {}' + .format(session.peername(), secs, + request)) + except Exception: + # Getting here should probably be considered a bug and fixed + self.logger.error('error handling request {}'.format(request)) + traceback.print_exc() def peername(self, *, for_log=True): if not self.peer_info: