Browse Source

Merge branch 'session_mgr' into develop

master
Neil Booth 8 years ago
parent
commit
61e8e3ccad
  1. 95
      server/protocol.py

95
server/protocol.py

@ -12,6 +12,7 @@ import asyncio
import codecs import codecs
import json import json
import traceback import traceback
from collections import namedtuple
from functools import partial from functools import partial
from server.daemon import DaemonError from server.daemon import DaemonError
@ -29,45 +30,67 @@ def json_notification(method, params):
return {'id': None, 'method': method, 'params': params} return {'id': None, 'method': method, 'params': params}
class JSONRPC(asyncio.Protocol, LoggedClass): AsyncTask = namedtuple('AsyncTask', 'session job')
'''Base class that manages a JSONRPC connection.'''
SESSIONS = set() class SessionManager(LoggedClass):
# Queue for aynchronous job processing.
JOBS = None
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.parts = [] self.sessions = set()
self.send_count = 0 self.tasks = asyncio.Queue()
self.send_size = 0 self.current_task = None
self.error_count = 0 asyncio.ensure_future(self.run_tasks())
self.init_jobs()
def add_session(self, session):
@classmethod assert session not in self.sessions
def init_jobs(cls): self.sessions.add(session)
if not cls.JOBS:
cls.JOBS = asyncio.Queue() def remove_session(self, session):
asyncio.ensure_future(cls.run_jobs()) self.sessions.remove(session)
if self.current_task and session == self.current_task.session:
@classmethod self.logger.info('cancelling running task')
async def run_jobs(cls): self.current_task.cancel()
'''Asynchronously run through the job queue.'''
def add_task(self, session, job):
assert session in self.sessions
task = asyncio.ensure_future(job)
self.tasks.put_nowait(AsyncTask(session, task))
async def run_tasks(self):
'''Asynchronously run through the task queue.'''
while True: while True:
job = await cls.JOBS.get() task = await self.tasks.get()
try: try:
await job if task.session in self.sessions:
self.current_task = task
await task.job
else:
task.job.cancel()
except asyncio.CancelledError: except asyncio.CancelledError:
raise self.logger.info('cancelled task noted')
except Exception: except Exception:
# Getting here should probably be considered a bug and fixed # Getting here should probably be considered a bug and fixed
traceback.print_exc() traceback.print_exc()
finally:
self.current_task = None
class JSONRPC(asyncio.Protocol, LoggedClass):
'''Base class that manages a JSONRPC connection.'''
def __init__(self):
super().__init__()
self.parts = []
self.send_count = 0
self.send_size = 0
self.error_count = 0
def connection_made(self, transport): def connection_made(self, transport):
'''Handle an incoming client connection.''' '''Handle an incoming client connection.'''
self.transport = transport self.transport = transport
self.peername = transport.get_extra_info('peername') self.peername = transport.get_extra_info('peername')
self.logger.info('connection from {}'.format(self.peername)) self.logger.info('connection from {}'.format(self.peername))
self.SESSIONS.add(self) self.SESSION_MGR.add_session(self)
def connection_lost(self, exc): def connection_lost(self, exc):
'''Handle client disconnection.''' '''Handle client disconnection.'''
@ -75,7 +98,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
'Sent {:,d} bytes in {:,d} messages {:,d} errors' 'Sent {:,d} bytes in {:,d} messages {:,d} errors'
.format(self.peername, self.send_size, .format(self.peername, self.send_size,
self.send_count, self.error_count)) self.send_count, self.error_count))
self.SESSIONS.remove(self) self.SESSION_MGR.remove_session(self)
def data_received(self, data): def data_received(self, data):
'''Handle incoming data (synchronously). '''Handle incoming data (synchronously).
@ -100,7 +123,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
except Exception as e: except Exception as e:
self.logger.info('error decoding JSON message: {}'.format(e)) self.logger.info('error decoding JSON message: {}'.format(e))
else: else:
self.JOBS.put_nowait(self.request_handler(message)) self.SESSION_MGR.add_task(self, self.request_handler(message))
async def request_handler(self, request): async def request_handler(self, request):
'''Called asynchronously.''' '''Called asynchronously.'''
@ -113,13 +136,21 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
self.error_count += 1 self.error_count += 1
error = {'code': 1, 'message': e.args[0]} error = {'code': 1, 'message': e.args[0]}
payload = {'id': request.get('id'), 'error': error, 'result': result} payload = {'id': request.get('id'), 'error': error, 'result': result}
self.json_send(payload) if not self.json_send(payload):
# Let asyncio call connection_lost() so we stop this
# session's tasks
await asyncio.sleep(0)
def json_send(self, payload): def json_send(self, payload):
if self.transport.is_closing():
self.logger.info('connection closing, not writing')
return False
data = (json.dumps(payload) + '\n').encode() data = (json.dumps(payload) + '\n').encode()
self.transport.write(data) self.transport.write(data)
self.send_count += 1 self.send_count += 1
self.send_size += len(data) self.send_size += len(data)
return True
def rpc_handler(self, method, params): def rpc_handler(self, method, params):
handler = None handler = None
@ -193,6 +224,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
cls.BLOCK_PROCESSOR = block_processor cls.BLOCK_PROCESSOR = block_processor
cls.DAEMON = daemon cls.DAEMON = daemon
cls.COIN = coin cls.COIN = coin
cls.SESSION_MGR = SessionManager()
@classmethod @classmethod
def height(cls): def height(cls):
@ -240,7 +272,8 @@ class ElectrumX(JSONRPC):
@classmethod @classmethod
def watched_address_count(cls): def watched_address_count(cls):
return sum(len(session.hash168s) for session in self.SESSIONS sessions = self.SESSION_MGR.sessions
return sum(len(session.hash168s) for session in session
if isinstance(session, cls)) if isinstance(session, cls))
@classmethod @classmethod
@ -257,7 +290,7 @@ class ElectrumX(JSONRPC):
) )
hash168_to_address = cls.COIN.hash168_to_address hash168_to_address = cls.COIN.hash168_to_address
for session in cls.SESSIONS: for session in cls.SESSION_MGR.sessions:
if height != session.notified_height: if height != session.notified_height:
session.notified_height = height session.notified_height = height
if session.subscribe_headers: if session.subscribe_headers:
@ -519,7 +552,7 @@ class LocalRPC(JSONRPC):
return { return {
'blocks': self.height(), 'blocks': self.height(),
'peers': len(ElectrumX.irc_peers()), 'peers': len(ElectrumX.irc_peers()),
'sessions': len(self.SESSIONS), 'sessions': len(self.SESSION_MGR.sessions),
'watched': ElectrumX.watched_address_count(), 'watched': ElectrumX.watched_address_count(),
'cached': 0, 'cached': 0,
} }
@ -528,7 +561,7 @@ class LocalRPC(JSONRPC):
return [] return []
async def numsessions(self, params): async def numsessions(self, params):
return len(self.SESSIONS) return len(self.SESSION_MGR.sessions)
async def peers(self, params): async def peers(self, params):
return tuple(ElectrumX.irc_peers().keys()) return tuple(ElectrumX.irc_peers().keys())

Loading…
Cancel
Save