diff --git a/server/block_processor.py b/server/block_processor.py index d9fd11f..9373bc9 100644 --- a/server/block_processor.py +++ b/server/block_processor.py @@ -208,8 +208,7 @@ class BlockProcessor(server.db.DB): await self._wait_for_update() except asyncio.CancelledError: self.on_cancel() - # This lets the asyncio subsystem process futures cancellations - await asyncio.sleep(0) + await self.wait_shutdown() def on_cancel(self): '''Called when the main loop is cancelled. @@ -219,6 +218,10 @@ class BlockProcessor(server.db.DB): future.cancel() self.flush(True) + async def wait_shutdown(self): + '''Wait for shutdown to complete cleanly, and return.''' + await asyncio.sleep(0) + async def _wait_for_update(self): '''Wait for the prefetcher to deliver blocks or a mempool update. diff --git a/server/protocol.py b/server/protocol.py index d539e86..07c4073 100644 --- a/server/protocol.py +++ b/server/protocol.py @@ -54,6 +54,11 @@ class BlockServer(BlockProcessor): self.server_mgr.stop() super().on_cancel() + async def wait_shutdown(self): + '''Wait for shutdown to complete cleanly, and return.''' + await self.server_mgr.wait_shutdown() + await super().wait_shutdown() + def mempool_transactions(self, hash168): '''Generate (hex_hash, tx_fee, unconfirmed) tuples for mempool entries for the hash168. @@ -141,7 +146,7 @@ class MemPool(LoggedClass): for n, (hex_hash, tx) in enumerate(new_txs.items()): # Yield to process e.g. signals - if n % 100 == 0: + if n % 20 == 0: await asyncio.sleep(0) txout_pairs = [txout_pair(txout) for txout in tx.outputs] self.txs[hex_hash] = (None, txout_pairs, None) @@ -162,8 +167,7 @@ class MemPool(LoggedClass): # Now add the inputs for n, (hex_hash, tx) in enumerate(new_txs.items()): # Yield to process e.g. signals - if n % 10 == 0: - await asyncio.sleep(0) + await asyncio.sleep(0) if initial and time.time() > next_log: next_log = time.time() + 20 @@ -248,7 +252,7 @@ class ServerManager(LoggedClass): self.sessions = {} self.max_subs = env.max_subs self.subscription_count = 0 - self.futures = [] # At present just the IRC future, if any + self.irc_future = None self.logger.info('max subscriptions across all sessions: {:,d}' .format(self.max_subs)) self.logger.info('max subscriptions per session: {:,d}' @@ -263,8 +267,6 @@ class ServerManager(LoggedClass): host, port = args[:2] try: self.servers.append(await server) - except asyncio.CancelledError: - raise except Exception as e: self.logger.error('{} server failed to listen on {}:{:d} :{}' .format(kind, host, port, e)) @@ -294,7 +296,7 @@ class ServerManager(LoggedClass): if env.irc: self.logger.info('starting IRC coroutine') - self.futures.append(asyncio.ensure_future(self.irc.start())) + self.irc_future = asyncio.ensure_future(self.irc.start()) else: self.logger.info('IRC disabled') @@ -308,24 +310,42 @@ class ServerManager(LoggedClass): def stop(self): '''Close listening servers.''' + self.logger.info('cleanly closing client sessions, please wait...') for server in self.servers: server.close() + if self.irc_future: + self.irc_future.cancel() + for session in self.sessions: + session.transport.close() + + async def wait_shutdown(self): + # Wait for servers to close + for server in self.servers: + await server.wait_closed() + # Just in case a connection came in + await asyncio.sleep(0) self.servers = [] - for future in self.futures: - future.cancel() - self.futures = [] - sessions = list(self.sessions.keys()) # A copy - for session in sessions: - self.remove_session(session) + self.logger.info('server listening sockets closed') + limit = time.time() + 10 + while self.sessions and time.time() < limit: + self.logger.info('{:,d} sessions remaining' + .format(len(self.sessions))) + await asyncio.sleep(2) + if self.sessions: + self.logger.info('forcibly closing {:,d} stragglers' + .format(len(self.sessions))) + for future in self.sessions.values(): + future.cancel() + await asyncio.sleep(0) def add_session(self, session): + assert self.servers assert session not in self.sessions coro = session.serve_requests() self.sessions[session] = asyncio.ensure_future(coro) def remove_session(self, session): - if isinstance(session, ElectrumX): - self.subscription_count -= len(session.hash168s) + self.subscription_count -= session.sub_count() future = self.sessions.pop(session) future.cancel() @@ -346,12 +366,6 @@ class ServerManager(LoggedClass): async def rpc_getinfo(self, params): '''The RPC 'getinfo' call.''' - # FIXME: remove later - indep_count = sum(len(session.hash168s) for session in self.sessions - if isinstance(session, ElectrumX)) - if indep_count != self.subscription_count: - self.logger.error('sub count {:,d} but session total {:,d}' - .format(self.subscription_count, indep_count)) return { 'blocks': self.bp.height, 'peers': len(self.irc.peers),