From e40db63bebabb2b5cde067081a5dabdf64957d01 Mon Sep 17 00:00:00 2001 From: Neil Booth Date: Thu, 8 Dec 2016 00:29:46 +0900 Subject: [PATCH] Queue requests, which have a process method. --- electrumx_rpc.py | 4 +- lib/jsonrpc.py | 148 ++++++++++++++++++++++++--------------------- server/protocol.py | 21 +++---- 3 files changed, 92 insertions(+), 81 deletions(-) diff --git a/electrumx_rpc.py b/electrumx_rpc.py index 3ec6062..8de320d 100755 --- a/electrumx_rpc.py +++ b/electrumx_rpc.py @@ -31,12 +31,12 @@ class RPCClient(JSONRPC): future = asyncio.ensure_future(self.messages.get()) for f in asyncio.as_completed([future], timeout=timeout): try: - message = await f + request = await f except asyncio.TimeoutError: future.cancel() print('request timed out after {}s'.format(timeout)) else: - await self.handle_message(message) + await request.process() async def handle_response(self, result, error, method): if result and method == 'sessions': diff --git a/lib/jsonrpc.py b/lib/jsonrpc.py index 5eb51ee..0496dee 100644 --- a/lib/jsonrpc.py +++ b/lib/jsonrpc.py @@ -21,7 +21,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass): Assumes JSON messages are newline-separated and that newlines cannot appear in the JSON other than to separate lines. Incoming messages are queued on the messages queue for later asynchronous - processing, and should be passed to the handle_message() function. + processing, and should be passed to the handle_request() function. Derived classes may want to override connection_made() and connection_lost() but should be sure to call the implementation in @@ -53,8 +53,47 @@ class JSONRPC(asyncio.Protocol, LoggedClass): self.msg = msg self.code = code - class LargeRequestError(Exception): - '''Raised if a large request was prevented from being sent.''' + class SingleRequest(object): + '''An object that represents a single request.''' + def __init__(self, session, payload): + self.payload = payload + self.session = session + + async def process(self): + '''Asynchronously handle the JSON request.''' + binary = await self.session.process_single_payload(self.payload) + if binary: + self.session._send_bytes(binary) + + class BatchRequest(object): + '''An object that represents a batch request and its processing + state.''' + def __init__(self, session, payload): + self.session = session + self.payload = payload + self.done = 0 + self.parts = [] + + async def process(self): + '''Asynchronously handle the JSON batch according to the JSON 2.0 + spec.''' + if not self.payload: + raise JSONRPC.RPCError('empty batch', self.INVALID_REQUEST) + for n in range(self.session.batch_limit): + if self.done >= len(self.payload): + if self.parts: + binary = b'[' + b', '.join(self.parts) + b']' + self.session._send_bytes(binary) + return + item = self.payload[self.done] + part = await self.session.process_single_payload(item) + if part: + self.parts.append(part) + self.done += 1 + + # Re-enqueue to continue the rest later + self.session.enqueue_request(self) + return b'' @classmethod def request_payload(cls, method, id_, params=None): @@ -90,6 +129,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass): self.bandwidth_interval = 3600 self.bandwidth_used = 0 self.bandwidth_limit = 5000000 + self.batch_limit = 4 self.transport = None # Parts of an incomplete JSON line. We buffer them until # getting a newline. @@ -184,18 +224,27 @@ class JSONRPC(asyncio.Protocol, LoggedClass): message = message.decode() except UnicodeDecodeError as e: msg = 'cannot decode binary bytes: {}'.format(e) - self.send_json_error(msg, self.PARSE_ERROR) + self.send_json_error(msg, self.PARSE_ERROR, close=True) return try: message = json.loads(message) except json.JSONDecodeError as e: msg = 'cannot decode JSON: {}'.format(e) - self.send_json_error(msg, self.PARSE_ERROR) + self.send_json_error(msg, self.PARSE_ERROR, close=True) return + if isinstance(message, list): + # Batches must have at least one request. + if not message: + self.send_json_error('empty batch', self.INVALID_REQUEST) + return + request = self.BatchRequest(self, message) + else: + request = self.SingleRequest(self, message) + '''Queue the request for asynchronous handling.''' - self.messages.put_nowait(message) + self.enqueue_request(request) if self.log_me: self.log_info('queued {}'.format(message)) @@ -214,23 +263,23 @@ class JSONRPC(asyncio.Protocol, LoggedClass): self.using_bandwidth(len(binary)) return binary - def _send_bytes(self, text, close): + def _send_bytes(self, binary, close=False): '''Send JSON text over the transport. Close it if close is True.''' # Confirmed this happens, sometimes a lot if self.transport.is_closing(): return - self.transport.write(text) + self.transport.write(binary) self.transport.write(b'\n') - if close: + if close or self.error_count > 10: self.transport.close() - def send_json_error(self, message, code, id_=None, close=True): + def send_json_error(self, message, code, id_=None, close=False): '''Send a JSON error and close the connection by default.''' self._send_bytes(self.json_error_bytes(message, code, id_), close) def encode_and_send_payload(self, payload): '''Encode the payload and send it.''' - self._send_bytes(self.encode_payload(payload), False) + self._send_bytes(self.encode_payload(payload)) def json_notification_bytes(self, method, params): '''Return the bytes of a json notification.''' @@ -249,74 +298,33 @@ class JSONRPC(asyncio.Protocol, LoggedClass): self.error_count += 1 return self.encode_payload(self.error_payload(message, code, id_)) - async def handle_message(self, payload): - '''Asynchronously handle a JSON request or response. - - Handles batches according to the JSON 2.0 spec. - ''' - try: - if isinstance(payload, list): - binary = await self.process_json_batch(payload) - else: - binary = await self.process_single_json(payload) - except self.RPCError as e: - binary = self.json_error_bytes(e.msg, e.code, - self.payload_id(payload)) - - if binary: - self._send_bytes(binary, self.error_count > 10) - - async def process_json_batch(self, batch): - '''Return the text response to a JSON batch request.''' - # Batches must have at least one request. - if not batch: - return self.json_error_bytes('empty batch', self.INVALID_REQUEST) - - # PYTHON 3.6: use asynchronous comprehensions when supported - parts = [] - total_len = 0 - for item in batch: - part = await self.process_single_json(item) - if part: - parts.append(part) - total_len += len(part) + 2 - self.check_oversized_request(total_len) - if parts: - return b'[' + b', '.join(parts) + b']' - return b'' - - async def process_single_json(self, payload): - '''Return the JSON result of a single JSON request, response or + async def process_single_payload(self, payload): + '''Return the binary JSON result of a single JSON request, response or notification. - Return None if the request is a notification or a response. + The result is empty if nothing is to be sent. ''' - # Throttle high-bandwidth connections by delaying processing - # their requests. Delay more the higher the excessive usage. - excess = self.bandwidth_used - self.bandwidth_limit - if excess > 0: - secs = 1 + excess // self.bandwidth_limit - self.log_warning('high bandwidth use of {:,d} bytes, ' - 'sleeping {:d}s' - .format(self.bandwidth_used, secs)) - await asyncio.sleep(secs) if not isinstance(payload, dict): return self.json_error_bytes('request must be a dict', self.INVALID_REQUEST) - if not 'id' in payload: - return await self.process_json_notification(payload) + try: + if not 'id' in payload: + return await self.process_json_notification(payload) - id_ = payload['id'] - if not isinstance(id_, self.ID_TYPES): - return self.json_error_bytes('invalid id: {}'.format(id_), - self.INVALID_REQUEST) + id_ = payload['id'] + if not isinstance(id_, self.ID_TYPES): + return self.json_error_bytes('invalid id: {}'.format(id_), + self.INVALID_REQUEST) - if 'method' in payload: - return await self.process_json_request(payload) + if 'method' in payload: + return await self.process_json_request(payload) - return await self.process_json_response(payload) + return await self.process_json_response(payload) + except self.RPCError as e: + return self.json_error_bytes(e.msg, e.code, + self.payload_id(payload)) @classmethod def method_and_params(cls, payload): @@ -394,6 +402,10 @@ class JSONRPC(asyncio.Protocol, LoggedClass): # --- derived classes are intended to override these functions + def enqueue_request(self, request): + '''Enqueue a request for later asynchronous processing.''' + self.messages.put_nowait(request) + async def handle_notification(self, method, params): '''Handle a notification.''' diff --git a/server/protocol.py b/server/protocol.py index 7864ff8..2711ee9 100644 --- a/server/protocol.py +++ b/server/protocol.py @@ -316,6 +316,10 @@ class ServerManager(util.LoggedClass): sslc.load_cert_chain(env.ssl_certfile, keyfile=env.ssl_keyfile) await self.start_server('SSL', env.host, env.ssl_port, ssl=sslc) + class NotificationRequest(object): + def __init__(self, fn_call): + self.process = fn_call + def notify(self, touched): '''Notify sessions about height changes and touched addresses.''' # Remove invalidated history cache @@ -325,9 +329,9 @@ class ServerManager(util.LoggedClass): cache = {} for session in self.sessions: if isinstance(session, ElectrumX): - # Use a tuple to distinguish from JSON - triple = (self.bp.db_height, touched, cache) - session.messages.put_nowait(triple) + fn_call = partial(session.notify, self.bp.db_height, touched, + cache) + session.enqueue_request(self.NotificationRequest(fn_call)) # Periodically log sessions if self.env.log_sessions and time.time() > self.next_log_sessions: data = self.session_data(for_log=True) @@ -597,19 +601,14 @@ class Session(JSONRPC): async def serve_requests(self): '''Asynchronously run through the task queue.''' while True: - await asyncio.sleep(0) - message = await self.messages.get() + request = await self.messages.get() try: - # Height / mempool notification? - if isinstance(message, tuple): - await self.notify(*message) - else: - await self.handle_message(message) + await request.process() except asyncio.CancelledError: break except Exception: # Getting here should probably be considered a bug and fixed - self.log_error('error handling request {}'.format(message)) + self.log_error('error handling request {}'.format(request)) traceback.print_exc() def sub_count(self):