diff --git a/server/protocol.py b/server/protocol.py index 4efad5b..53c9572 100644 --- a/server/protocol.py +++ b/server/protocol.py @@ -12,6 +12,7 @@ import asyncio import codecs import json import traceback +from collections import namedtuple from functools import partial from server.daemon import DaemonError @@ -29,45 +30,67 @@ def json_notification(method, params): return {'id': None, 'method': method, 'params': params} -class JSONRPC(asyncio.Protocol, LoggedClass): - '''Base class that manages a JSONRPC connection.''' - SESSIONS = set() - # Queue for aynchronous job processing. - JOBS = None +AsyncTask = namedtuple('AsyncTask', 'session job') + +class SessionManager(LoggedClass): def __init__(self): super().__init__() - self.parts = [] - self.send_count = 0 - self.send_size = 0 - self.error_count = 0 - self.init_jobs() - - @classmethod - def init_jobs(cls): - if not cls.JOBS: - cls.JOBS = asyncio.Queue() - asyncio.ensure_future(cls.run_jobs()) - - @classmethod - async def run_jobs(cls): - '''Asynchronously run through the job queue.''' + self.sessions = set() + self.tasks = asyncio.Queue() + self.current_task = None + asyncio.ensure_future(self.run_tasks()) + + def add_session(self, session): + assert session not in self.sessions + self.sessions.add(session) + + def remove_session(self, session): + self.sessions.remove(session) + if self.current_task and session == self.current_task.session: + self.logger.info('cancelling running task') + self.current_task.cancel() + + def add_task(self, session, job): + assert session in self.sessions + task = asyncio.ensure_future(job) + self.tasks.put_nowait(AsyncTask(session, task)) + + async def run_tasks(self): + '''Asynchronously run through the task queue.''' while True: - job = await cls.JOBS.get() + task = await self.tasks.get() try: - await job + if task.session in self.sessions: + self.current_task = task + await task.job + else: + task.job.cancel() except asyncio.CancelledError: - raise + self.logger.info('cancelled task noted') except Exception: # Getting here should probably be considered a bug and fixed traceback.print_exc() + finally: + self.current_task = None + + +class JSONRPC(asyncio.Protocol, LoggedClass): + '''Base class that manages a JSONRPC connection.''' + + def __init__(self): + super().__init__() + self.parts = [] + self.send_count = 0 + self.send_size = 0 + self.error_count = 0 def connection_made(self, transport): '''Handle an incoming client connection.''' self.transport = transport self.peername = transport.get_extra_info('peername') self.logger.info('connection from {}'.format(self.peername)) - self.SESSIONS.add(self) + self.SESSION_MGR.add_session(self) def connection_lost(self, exc): '''Handle client disconnection.''' @@ -75,7 +98,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass): 'Sent {:,d} bytes in {:,d} messages {:,d} errors' .format(self.peername, self.send_size, self.send_count, self.error_count)) - self.SESSIONS.remove(self) + self.SESSION_MGR.remove_session(self) def data_received(self, data): '''Handle incoming data (synchronously). @@ -100,7 +123,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass): except Exception as e: self.logger.info('error decoding JSON message: {}'.format(e)) else: - self.JOBS.put_nowait(self.request_handler(message)) + self.SESSION_MGR.add_task(self, self.request_handler(message)) async def request_handler(self, request): '''Called asynchronously.''' @@ -113,13 +136,21 @@ class JSONRPC(asyncio.Protocol, LoggedClass): self.error_count += 1 error = {'code': 1, 'message': e.args[0]} payload = {'id': request.get('id'), 'error': error, 'result': result} - self.json_send(payload) + if not self.json_send(payload): + # Let asyncio call connection_lost() so we stop this + # session's tasks + await asyncio.sleep(0) def json_send(self, payload): + if self.transport.is_closing(): + self.logger.info('connection closing, not writing') + return False + data = (json.dumps(payload) + '\n').encode() self.transport.write(data) self.send_count += 1 self.send_size += len(data) + return True def rpc_handler(self, method, params): handler = None @@ -193,6 +224,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass): cls.BLOCK_PROCESSOR = block_processor cls.DAEMON = daemon cls.COIN = coin + cls.SESSION_MGR = SessionManager() @classmethod def height(cls): @@ -240,7 +272,8 @@ class ElectrumX(JSONRPC): @classmethod def watched_address_count(cls): - return sum(len(session.hash168s) for session in self.SESSIONS + sessions = self.SESSION_MGR.sessions + return sum(len(session.hash168s) for session in session if isinstance(session, cls)) @classmethod @@ -257,7 +290,7 @@ class ElectrumX(JSONRPC): ) hash168_to_address = cls.COIN.hash168_to_address - for session in cls.SESSIONS: + for session in cls.SESSION_MGR.sessions: if height != session.notified_height: session.notified_height = height if session.subscribe_headers: @@ -519,7 +552,7 @@ class LocalRPC(JSONRPC): return { 'blocks': self.height(), 'peers': len(ElectrumX.irc_peers()), - 'sessions': len(self.SESSIONS), + 'sessions': len(self.SESSION_MGR.sessions), 'watched': ElectrumX.watched_address_count(), 'cached': 0, } @@ -528,7 +561,7 @@ class LocalRPC(JSONRPC): return [] async def numsessions(self, params): - return len(self.SESSIONS) + return len(self.SESSION_MGR.sessions) async def peers(self, params): return tuple(ElectrumX.irc_peers().keys())