diff --git a/electrumx/server/session.py b/electrumx/server/session.py index 8fbd763..3cdf1ac 100644 --- a/electrumx/server/session.py +++ b/electrumx/server/session.py @@ -21,7 +21,8 @@ from functools import partial from aiorpcx import ( ServerSession, JSONRPCAutoDetect, JSONRPCConnection, - TaskGroup, handler_invocation, RPCError, Request, ignore_after + TaskGroup, handler_invocation, RPCError, Request, ignore_after, sleep, + Event ) import electrumx @@ -107,8 +108,6 @@ class SessionGroup(object): class SessionManager(object): '''Holds global state about all sessions.''' - CATCHING_UP, LISTENING, PAUSED, SHUTTING_DOWN = range(4) - def __init__(self, env, db, bp, daemon, mempool, notifications, shutdown_event): env.max_send = max(350000, env.max_send) @@ -122,11 +121,8 @@ class SessionManager(object): self.logger = util.class_logger(__name__, self.__class__.__name__) self.servers = {} self.sessions = set() - self.max_sessions = env.max_sessions - self.low_watermark = self.max_sessions * 19 // 20 self.max_subs = env.max_subs self.cur_group = SessionGroup(0) - self.state = self.CATCHING_UP self.txs_sent = 0 self.start_time = time.time() self.history_cache = pylru.lrucache(256) @@ -138,7 +134,8 @@ class SessionManager(object): self.mn_cache_height = 0 self.mn_cache = [] # Event triggered when electrumx is listening for incoming requests. - self.server_listening = asyncio.Event() + self.server_listening = Event() + self.session_event = Event() # Tell sessions about subscription changes notifications.add_callback(self._notify_sessions) @@ -179,8 +176,6 @@ class SessionManager(object): sslc = ssl.SSLContext(ssl.PROTOCOL_TLS) sslc.load_cert_chain(env.ssl_certfile, keyfile=env.ssl_keyfile) await self._start_server('SSL', host, env.ssl_port, ssl=sslc) - # Change state - self.state = self.LISTENING self.server_listening.set() async def _close_servers(self, kinds): @@ -194,21 +189,31 @@ class SessionManager(object): server.close() await server.wait_closed() - async def _restart_if_paused(self): + async def _manage_servers(self): + paused = False + max_sessions = self.env.max_sessions + low_watermark = max_sessions * 19 // 20 while True: - await asyncio.sleep(15) + await self.session_event.wait() + self.session_event.clear() + if not paused and len(self.sessions) >= max_sessions: + session.logger.info(f'maximum sessions {max_sessions:,d} ' + f'reached, stopping new connections until ' + f'count drops to {low_watermark:,d}') + await self._close_servers(['TCP', 'SSL']) + paused = True # Start listening for incoming connections if paused and # session count has fallen - if (self.state == self.PAUSED and - len(self.sessions) <= self.low_watermark): + if paused and len(self.sessions) <= low_watermark: await self._start_external_servers() + paused = False async def _log_sessions(self): '''Periodically log sessions.''' log_interval = self.env.log_sessions if log_interval: while True: - await asyncio.sleep(log_interval) + await sleep(log_interval) data = self._session_data(for_log=True) for line in text.sessions_lines(data): self.logger.info(line) @@ -250,7 +255,7 @@ class SessionManager(object): async def _clear_stale_sessions(self): '''Cut off sessions that haven't done anything for 10 minutes.''' while True: - await asyncio.sleep(60) + await sleep(60) stale_cutoff = time.time() - self.env.session_timeout stale_sessions = [session for session in self.sessions if session.last_recv < stale_cutoff] @@ -480,7 +485,7 @@ class SessionManager(object): await self._start_server('RPC', self.env.cs_host(for_rpc=True), self.env.rpc_port) await event.wait() - self.logger.info(f'max session count: {self.max_sessions:,d}') + self.logger.info(f'max session count: {self.env.max_sessions:,d}') self.logger.info(f'session timeout: ' f'{self.env.session_timeout:,d} seconds') self.logger.info('session bandwidth limit {:,d} bytes' @@ -501,10 +506,9 @@ class SessionManager(object): await group.spawn(self.peer_mgr.discover_peers()) await group.spawn(self._clear_stale_sessions()) await group.spawn(self._log_sessions()) - await group.spawn(self._restart_if_paused()) + await group.spawn(self._manage_servers()) finally: # Close servers and sessions - self.state = self.SHUTTING_DOWN await self._close_servers(list(self.servers.keys())) async with TaskGroup() as group: for session in list(self.sessions): @@ -568,14 +572,7 @@ class SessionManager(object): def add_session(self, session): self.sessions.add(session) - if (len(self.sessions) >= self.max_sessions - and self.state == self.LISTENING): - self.state = self.PAUSED - session.logger.info('maximum sessions {:,d} reached, stopping new ' - 'connections until count drops to {:,d}' - .format(self.max_sessions, self.low_watermark)) - loop = asyncio.get_event_loop() - loop.call_soon(self._close_servers(['TCP', 'SSL'])) + self.session_event.set() gid = int(session.start_time - self.start_time) // 900 if self.cur_group.gid != gid: self.cur_group = SessionGroup(gid) @@ -584,6 +581,7 @@ class SessionManager(object): def remove_session(self, session): '''Remove a session from our sessions list if there.''' self.sessions.remove(session) + self.session_event.set() def new_subscription(self): if self.subs_room <= 0: