# Copyright (c) 2016, Neil Booth # # All rights reserved. # # See the file "LICENCE" for information about the copyright # and warranty status of this software. '''Class for handling JSON RPC 2.0 connections, server or client.''' import asyncio import json import numbers import time from lib.util import LoggedClass class RequestBase(object): '''An object that represents a queued request.''' def __init__(self, remaining): self.remaining = remaining class SingleRequest(RequestBase): '''An object that represents a single request.''' def __init__(self, payload): super().__init__(1) self.payload = payload async def process(self, session): '''Asynchronously handle the JSON request.''' self.remaining = 0 binary = await session.process_single_payload(self.payload) if binary: session._send_bytes(binary) def __str__(self): return str(self.payload) class BatchRequest(RequestBase): '''An object that represents a batch request and its processing state. Batches are processed in chunks. ''' def __init__(self, payload): super().__init__(len(payload)) self.payload = payload self.parts = [] async def process(self, session): '''Asynchronously handle the JSON batch according to the JSON 2.0 spec.''' target = max(self.remaining - 4, 0) while self.remaining > target: item = self.payload[len(self.payload) - self.remaining] self.remaining -= 1 part = await session.process_single_payload(item) if part: self.parts.append(part) total_len = sum(len(part) + 2 for part in self.parts) session.check_oversized_request(total_len) if not self.remaining: if self.parts: binary = b'[' + b', '.join(self.parts) + b']' session._send_bytes(binary) def __str__(self): return str(self.payload) class JSONRPC(asyncio.Protocol, LoggedClass): '''Manages a JSONRPC connection. Assumes JSON messages are newline-separated and that newlines cannot appear in the JSON other than to separate lines. Incoming requests are passed to enqueue_request(), which should arrange for their asynchronous handling via the request's process() method. Derived classes may want to override connection_made() and connection_lost() but should be sure to call the implementation in this base class first. They will also want to implement some or all of the asynchronous functions handle_notification(), handle_response() and handle_request(). handle_request() returns the result to pass over the network, and must raise an RPCError if there is an error. handle_notification() and handle_response() should not return anything or raise any exceptions. All three functions have default "ignore" implementations supplied by this class. ''' # See http://www.jsonrpc.org/specification PARSE_ERROR = -32700 INVALID_REQUEST = -32600 METHOD_NOT_FOUND = -32601 INVALID_PARAMS = -32602 INTERNAL_ERROR = -32603 ID_TYPES = (type(None), str, numbers.Number) NEXT_SESSION_ID = 0 class RPCError(Exception): '''RPC handlers raise this error.''' def __init__(self, msg, code=-1, **kw_args): super().__init__(**kw_args) self.msg = msg self.code = code @classmethod def request_payload(cls, method, id_, params=None): payload = {'jsonrpc': '2.0', 'id': id_, 'method': method} if params: payload['params'] = params return payload @classmethod def response_payload(cls, result, id_): # We should not respond to notifications assert id_ is not None return {'jsonrpc': '2.0', 'result': result, 'id': id_} @classmethod def notification_payload(cls, method, params=None): return cls.request_payload(method, None, params) @classmethod def error_payload(cls, message, code, id_=None): error = {'message': message, 'code': code} return {'jsonrpc': '2.0', 'error': error, 'id': id_} @classmethod def payload_id(cls, payload): return payload.get('id') if isinstance(payload, dict) else None def __init__(self): super().__init__() self.start = time.time() self.stop = 0 self.last_recv = self.start self.bandwidth_start = self.start self.bandwidth_interval = 3600 self.bandwidth_used = 0 self.bandwidth_limit = 5000000 self.transport = None self.pause = False # Parts of an incomplete JSON line. We buffer them until # getting a newline. self.parts = [] # recv_count is JSON messages not calls to data_received() self.recv_count = 0 self.recv_size = 0 self.send_count = 0 self.send_size = 0 self.error_count = 0 self.peer_info = None # Sends longer than max_send are prevented, instead returning # an oversized request error to other end of the network # connection. The request causing it is logged. Values under # 1000 are treated as 1000. self.max_send = 0 # If buffered incoming data exceeds this the connection is closed self.max_buffer_size = 1000000 self.anon_logs = False self.id_ = JSONRPC.NEXT_SESSION_ID JSONRPC.NEXT_SESSION_ID += 1 self.log_prefix = '[{:d}] '.format(self.id_) self.log_me = False def peername(self, *, for_log=True): '''Return the peer name of this connection.''' if not self.peer_info: return 'unknown' if for_log and self.anon_logs: return 'xx.xx.xx.xx:xx' if ':' in self.peer_info[0]: return '[{}]:{}'.format(self.peer_info[0], self.peer_info[1]) else: return '{}:{}'.format(self.peer_info[0], self.peer_info[1]) def connection_made(self, transport): '''Handle an incoming client connection.''' self.transport = transport self.peer_info = transport.get_extra_info('peername') transport.set_write_buffer_limits(high=500000) def connection_lost(self, exc): '''Handle client disconnection.''' pass def pause_writing(self): '''Called by asyncio when the write buffer is full.''' self.log_info('pausing request processing whilst socket drains') self.pause = True def resume_writing(self): '''Called by asyncio when the write buffer has room.''' self.log_info('resuming request processing') self.pause = False def close_connection(self): self.stop = time.time() if self.transport: self.transport.close() def using_bandwidth(self, amount): now = time.time() # Reduce the recorded usage in proportion to the elapsed time elapsed = now - self.bandwidth_start self.bandwidth_start = now refund = int(elapsed / self.bandwidth_interval * self.bandwidth_limit) refund = min(refund, self.bandwidth_used) self.bandwidth_used += amount - refund self.throttled = max(0, self.throttled - int(elapsed) // 60) def data_received(self, data): '''Handle incoming data (synchronously). Requests end in newline characters. Pass complete requests to decode_message for handling. ''' self.recv_size += len(data) self.using_bandwidth(len(data)) # Close abusive connections where buffered data exceeds limit buffer_size = len(data) + sum(len(part) for part in self.parts) if buffer_size > self.max_buffer_size: self.log_error('read buffer of {:,d} bytes exceeds {:,d} ' 'byte limit, closing {}' .format(buffer_size, self.max_buffer_size, self.peername())) self.close_connection() # Do nothing if this connection is closing if self.transport.is_closing(): return while True: npos = data.find(ord('\n')) if npos == -1: self.parts.append(data) break self.last_recv = time.time() self.recv_count += 1 tail, data = data[:npos], data[npos + 1:] parts, self.parts = self.parts, [] parts.append(tail) self.decode_message(b''.join(parts)) def decode_message(self, message): '''Decode a binary message and queue it for asynchronous handling. Messages that cannot be decoded are logged and dropped. ''' try: message = message.decode() except UnicodeDecodeError as e: msg = 'cannot decode binary bytes: {}'.format(e) self.send_json_error(msg, self.PARSE_ERROR, close=True) return try: message = json.loads(message) except json.JSONDecodeError as e: msg = 'cannot decode JSON: {}'.format(e) self.send_json_error(msg, self.PARSE_ERROR, close=True) return if isinstance(message, list): # Batches must have at least one request. if not message: self.send_json_error('empty batch', self.INVALID_REQUEST) return request = BatchRequest(message) else: request = SingleRequest(message) '''Queue the request for asynchronous handling.''' self.enqueue_request(request) if self.log_me: self.log_info('queued {}'.format(message)) def encode_payload(self, payload): try: binary = json.dumps(payload).encode() except TypeError: msg = 'JSON encoding failure: {}'.format(payload) self.log_error(msg) return self.json_error(msg, self.INTERNAL_ERROR, self.payload_id(payload)) self.check_oversized_request(len(binary)) self.send_count += 1 self.send_size += len(binary) self.using_bandwidth(len(binary)) return binary def _send_bytes(self, binary, close=False): '''Send JSON text over the transport. Close it if close is True.''' # Confirmed this happens, sometimes a lot if self.transport.is_closing(): return self.transport.write(binary) self.transport.write(b'\n') if close or self.error_count > 10: self.close_connection() def send_json_error(self, message, code, id_=None, close=False): '''Send a JSON error and close the connection by default.''' self._send_bytes(self.json_error_bytes(message, code, id_), close) def encode_and_send_payload(self, payload): '''Encode the payload and send it.''' self._send_bytes(self.encode_payload(payload)) def json_notification_bytes(self, method, params): '''Return the bytes of a json notification.''' return self.encode_payload(self.notification_payload(method, params)) def json_request_bytes(self, method, id_, params=None): '''Return the bytes of a JSON request.''' return self.encode_payload(self.request_payload(method, id_, params)) def json_response_bytes(self, result, id_): '''Return the bytes of a JSON response.''' return self.encode_payload(self.response_payload(result, id_)) def json_error_bytes(self, message, code, id_=None): '''Return the bytes of a JSON error.''' self.error_count += 1 return self.encode_payload(self.error_payload(message, code, id_)) async def process_single_payload(self, payload): '''Return the binary JSON result of a single JSON request, response or notification. The result is empty if nothing is to be sent. ''' if not isinstance(payload, dict): return self.json_error_bytes('request must be a dict', self.INVALID_REQUEST) try: if not 'id' in payload: return await self.process_json_notification(payload) id_ = payload['id'] if not isinstance(id_, self.ID_TYPES): return self.json_error_bytes('invalid id: {}'.format(id_), self.INVALID_REQUEST) if 'method' in payload: return await self.process_json_request(payload) return await self.process_json_response(payload) except self.RPCError as e: return self.json_error_bytes(e.msg, e.code, self.payload_id(payload)) @classmethod def method_and_params(cls, payload): method = payload.get('method') params = payload.get('params', []) if not isinstance(method, str): raise cls.RPCError('invalid method: {}'.format(method), cls.INVALID_REQUEST) if not isinstance(params, list): raise cls.RPCError('params should be an array', cls.INVALID_REQUEST) return method, params async def process_json_notification(self, payload): try: method, params = self.method_and_params(payload) except self.RPCError: pass else: await self.handle_notification(method, params) return b'' async def process_json_request(self, payload): method, params = self.method_and_params(payload) result = await self.handle_request(method, params) return self.json_response_bytes(result, payload['id']) async def process_json_response(self, payload): # Only one of result and error should exist; we go with 'error' # if both are supplied. if 'error' in payload: await self.handle_response(None, payload['error'], payload['id']) elif 'result' in payload: await self.handle_response(payload['result'], None, payload['id']) return b'' def check_oversized_request(self, total_len): if total_len > max(1000, self.max_send): raise self.RPCError('request too large', self.INVALID_REQUEST) def raise_unknown_method(self, method): '''Respond to a request with an unknown method.''' raise self.RPCError("unknown method: '{}'".format(method), self.METHOD_NOT_FOUND) # Common parameter verification routines @classmethod def param_to_non_negative_integer(cls, param): '''Return param if it is or can be converted to a non-negative integer, otherwise raise an RPCError.''' try: param = int(param) if param >= 0: return param except ValueError: pass raise cls.RPCError('param {} should be a non-negative integer' .format(param)) @classmethod def params_to_non_negative_integer(cls, params): if len(params) == 1: return cls.param_to_non_negative_integer(params[0]) raise cls.RPCError('params {} should contain one non-negative integer' .format(params)) @classmethod def require_empty_params(cls, params): if params: raise cls.RPCError('params {} should be empty'.format(params)) # --- derived classes are intended to override these functions def enqueue_request(self, request): '''Enqueue a request for later asynchronous processing.''' raise NotImplementedError async def handle_notification(self, method, params): '''Handle a notification.''' async def handle_request(self, method, params): '''Handle a request.''' return None async def handle_response(self, result, error, id_): '''Handle a response.'''