Browse Source

Merge branch 'session_mgr' into develop

master
Neil Booth 8 years ago
parent
commit
61e8e3ccad
  1. 95
      server/protocol.py

95
server/protocol.py

@ -12,6 +12,7 @@ import asyncio
import codecs
import json
import traceback
from collections import namedtuple
from functools import partial
from server.daemon import DaemonError
@ -29,45 +30,67 @@ def json_notification(method, params):
return {'id': None, 'method': method, 'params': params}
class JSONRPC(asyncio.Protocol, LoggedClass):
'''Base class that manages a JSONRPC connection.'''
SESSIONS = set()
# Queue for aynchronous job processing.
JOBS = None
AsyncTask = namedtuple('AsyncTask', 'session job')
class SessionManager(LoggedClass):
def __init__(self):
super().__init__()
self.parts = []
self.send_count = 0
self.send_size = 0
self.error_count = 0
self.init_jobs()
@classmethod
def init_jobs(cls):
if not cls.JOBS:
cls.JOBS = asyncio.Queue()
asyncio.ensure_future(cls.run_jobs())
@classmethod
async def run_jobs(cls):
'''Asynchronously run through the job queue.'''
self.sessions = set()
self.tasks = asyncio.Queue()
self.current_task = None
asyncio.ensure_future(self.run_tasks())
def add_session(self, session):
assert session not in self.sessions
self.sessions.add(session)
def remove_session(self, session):
self.sessions.remove(session)
if self.current_task and session == self.current_task.session:
self.logger.info('cancelling running task')
self.current_task.cancel()
def add_task(self, session, job):
assert session in self.sessions
task = asyncio.ensure_future(job)
self.tasks.put_nowait(AsyncTask(session, task))
async def run_tasks(self):
'''Asynchronously run through the task queue.'''
while True:
job = await cls.JOBS.get()
task = await self.tasks.get()
try:
await job
if task.session in self.sessions:
self.current_task = task
await task.job
else:
task.job.cancel()
except asyncio.CancelledError:
raise
self.logger.info('cancelled task noted')
except Exception:
# Getting here should probably be considered a bug and fixed
traceback.print_exc()
finally:
self.current_task = None
class JSONRPC(asyncio.Protocol, LoggedClass):
'''Base class that manages a JSONRPC connection.'''
def __init__(self):
super().__init__()
self.parts = []
self.send_count = 0
self.send_size = 0
self.error_count = 0
def connection_made(self, transport):
'''Handle an incoming client connection.'''
self.transport = transport
self.peername = transport.get_extra_info('peername')
self.logger.info('connection from {}'.format(self.peername))
self.SESSIONS.add(self)
self.SESSION_MGR.add_session(self)
def connection_lost(self, exc):
'''Handle client disconnection.'''
@ -75,7 +98,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
'Sent {:,d} bytes in {:,d} messages {:,d} errors'
.format(self.peername, self.send_size,
self.send_count, self.error_count))
self.SESSIONS.remove(self)
self.SESSION_MGR.remove_session(self)
def data_received(self, data):
'''Handle incoming data (synchronously).
@ -100,7 +123,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
except Exception as e:
self.logger.info('error decoding JSON message: {}'.format(e))
else:
self.JOBS.put_nowait(self.request_handler(message))
self.SESSION_MGR.add_task(self, self.request_handler(message))
async def request_handler(self, request):
'''Called asynchronously.'''
@ -113,13 +136,21 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
self.error_count += 1
error = {'code': 1, 'message': e.args[0]}
payload = {'id': request.get('id'), 'error': error, 'result': result}
self.json_send(payload)
if not self.json_send(payload):
# Let asyncio call connection_lost() so we stop this
# session's tasks
await asyncio.sleep(0)
def json_send(self, payload):
if self.transport.is_closing():
self.logger.info('connection closing, not writing')
return False
data = (json.dumps(payload) + '\n').encode()
self.transport.write(data)
self.send_count += 1
self.send_size += len(data)
return True
def rpc_handler(self, method, params):
handler = None
@ -193,6 +224,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
cls.BLOCK_PROCESSOR = block_processor
cls.DAEMON = daemon
cls.COIN = coin
cls.SESSION_MGR = SessionManager()
@classmethod
def height(cls):
@ -240,7 +272,8 @@ class ElectrumX(JSONRPC):
@classmethod
def watched_address_count(cls):
return sum(len(session.hash168s) for session in self.SESSIONS
sessions = self.SESSION_MGR.sessions
return sum(len(session.hash168s) for session in session
if isinstance(session, cls))
@classmethod
@ -257,7 +290,7 @@ class ElectrumX(JSONRPC):
)
hash168_to_address = cls.COIN.hash168_to_address
for session in cls.SESSIONS:
for session in cls.SESSION_MGR.sessions:
if height != session.notified_height:
session.notified_height = height
if session.subscribe_headers:
@ -519,7 +552,7 @@ class LocalRPC(JSONRPC):
return {
'blocks': self.height(),
'peers': len(ElectrumX.irc_peers()),
'sessions': len(self.SESSIONS),
'sessions': len(self.SESSION_MGR.sessions),
'watched': ElectrumX.watched_address_count(),
'cached': 0,
}
@ -528,7 +561,7 @@ class LocalRPC(JSONRPC):
return []
async def numsessions(self, params):
return len(self.SESSIONS)
return len(self.SESSION_MGR.sessions)
async def peers(self, params):
return tuple(ElectrumX.irc_peers().keys())

Loading…
Cancel
Save