You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
449 lines
16 KiB
449 lines
16 KiB
# Copyright (c) 2016, Neil Booth
|
|
#
|
|
# All rights reserved.
|
|
#
|
|
# See the file "LICENCE" for information about the copyright
|
|
# and warranty status of this software.
|
|
|
|
'''Class for handling JSON RPC 2.0 connections, server or client.'''
|
|
|
|
import asyncio
|
|
import json
|
|
import numbers
|
|
import time
|
|
|
|
from lib.util import LoggedClass
|
|
|
|
|
|
class RequestBase(object):
|
|
'''An object that represents a queued request.'''
|
|
|
|
def __init__(self, remaining):
|
|
self.remaining = remaining
|
|
|
|
class SingleRequest(RequestBase):
|
|
'''An object that represents a single request.'''
|
|
|
|
def __init__(self, payload):
|
|
super().__init__(1)
|
|
self.payload = payload
|
|
|
|
async def process(self, session):
|
|
'''Asynchronously handle the JSON request.'''
|
|
self.remaining = 0
|
|
binary = await session.process_single_payload(self.payload)
|
|
if binary:
|
|
session._send_bytes(binary)
|
|
|
|
def __str__(self):
|
|
return str(self.payload)
|
|
|
|
|
|
class BatchRequest(RequestBase):
|
|
'''An object that represents a batch request and its processing state.
|
|
|
|
Batches are processed in chunks.
|
|
'''
|
|
|
|
def __init__(self, payload):
|
|
super().__init__(len(payload))
|
|
self.payload = payload
|
|
self.parts = []
|
|
|
|
async def process(self, session):
|
|
'''Asynchronously handle the JSON batch according to the JSON 2.0
|
|
spec.'''
|
|
target = max(self.remaining - 4, 0)
|
|
while self.remaining > target:
|
|
item = self.payload[len(self.payload) - self.remaining]
|
|
self.remaining -= 1
|
|
part = await session.process_single_payload(item)
|
|
if part:
|
|
self.parts.append(part)
|
|
|
|
total_len = sum(len(part) + 2 for part in self.parts)
|
|
session.check_oversized_request(total_len)
|
|
|
|
if not self.remaining:
|
|
if self.parts:
|
|
binary = b'[' + b', '.join(self.parts) + b']'
|
|
session._send_bytes(binary)
|
|
|
|
def __str__(self):
|
|
return str(self.payload)
|
|
|
|
|
|
class JSONRPC(asyncio.Protocol, LoggedClass):
|
|
'''Manages a JSONRPC connection.
|
|
|
|
Assumes JSON messages are newline-separated and that newlines
|
|
cannot appear in the JSON other than to separate lines. Incoming
|
|
requests are passed to enqueue_request(), which should arrange for
|
|
their asynchronous handling via the request's process() method.
|
|
|
|
Derived classes may want to override connection_made() and
|
|
connection_lost() but should be sure to call the implementation in
|
|
this base class first. They will also want to implement some or
|
|
all of the asynchronous functions handle_notification(),
|
|
handle_response() and handle_request().
|
|
|
|
handle_request() returns the result to pass over the network, and
|
|
must raise an RPCError if there is an error.
|
|
handle_notification() and handle_response() should not return
|
|
anything or raise any exceptions. All three functions have
|
|
default "ignore" implementations supplied by this class.
|
|
'''
|
|
|
|
# See http://www.jsonrpc.org/specification
|
|
PARSE_ERROR = -32700
|
|
INVALID_REQUEST = -32600
|
|
METHOD_NOT_FOUND = -32601
|
|
INVALID_PARAMS = -32602
|
|
INTERNAL_ERROR = -32603
|
|
|
|
ID_TYPES = (type(None), str, numbers.Number)
|
|
NEXT_SESSION_ID = 0
|
|
|
|
class RPCError(Exception):
|
|
'''RPC handlers raise this error.'''
|
|
def __init__(self, msg, code=-1, **kw_args):
|
|
super().__init__(**kw_args)
|
|
self.msg = msg
|
|
self.code = code
|
|
|
|
@classmethod
|
|
def request_payload(cls, method, id_, params=None):
|
|
payload = {'jsonrpc': '2.0', 'id': id_, 'method': method}
|
|
if params:
|
|
payload['params'] = params
|
|
return payload
|
|
|
|
@classmethod
|
|
def response_payload(cls, result, id_):
|
|
# We should not respond to notifications
|
|
assert id_ is not None
|
|
return {'jsonrpc': '2.0', 'result': result, 'id': id_}
|
|
|
|
@classmethod
|
|
def notification_payload(cls, method, params=None):
|
|
return cls.request_payload(method, None, params)
|
|
|
|
@classmethod
|
|
def error_payload(cls, message, code, id_=None):
|
|
error = {'message': message, 'code': code}
|
|
return {'jsonrpc': '2.0', 'error': error, 'id': id_}
|
|
|
|
@classmethod
|
|
def payload_id(cls, payload):
|
|
return payload.get('id') if isinstance(payload, dict) else None
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.start = time.time()
|
|
self.stop = 0
|
|
self.last_recv = self.start
|
|
self.bandwidth_start = self.start
|
|
self.bandwidth_interval = 3600
|
|
self.bandwidth_used = 0
|
|
self.bandwidth_limit = 5000000
|
|
self.transport = None
|
|
self.pause = False
|
|
# Parts of an incomplete JSON line. We buffer them until
|
|
# getting a newline.
|
|
self.parts = []
|
|
# recv_count is JSON messages not calls to data_received()
|
|
self.recv_count = 0
|
|
self.recv_size = 0
|
|
self.send_count = 0
|
|
self.send_size = 0
|
|
self.error_count = 0
|
|
self.peer_info = None
|
|
# Sends longer than max_send are prevented, instead returning
|
|
# an oversized request error to other end of the network
|
|
# connection. The request causing it is logged. Values under
|
|
# 1000 are treated as 1000.
|
|
self.max_send = 0
|
|
# If buffered incoming data exceeds this the connection is closed
|
|
self.max_buffer_size = 1000000
|
|
self.anon_logs = False
|
|
self.id_ = JSONRPC.NEXT_SESSION_ID
|
|
JSONRPC.NEXT_SESSION_ID += 1
|
|
self.log_prefix = '[{:d}] '.format(self.id_)
|
|
self.log_me = False
|
|
|
|
def peername(self, *, for_log=True):
|
|
'''Return the peer name of this connection.'''
|
|
if not self.peer_info:
|
|
return 'unknown'
|
|
if for_log and self.anon_logs:
|
|
return 'xx.xx.xx.xx:xx'
|
|
return '{}:{}'.format(self.peer_info[0], self.peer_info[1])
|
|
|
|
def connection_made(self, transport):
|
|
'''Handle an incoming client connection.'''
|
|
self.transport = transport
|
|
self.peer_info = transport.get_extra_info('peername')
|
|
transport.set_write_buffer_limits(high=500000)
|
|
|
|
def connection_lost(self, exc):
|
|
'''Handle client disconnection.'''
|
|
pass
|
|
|
|
def pause_writing(self):
|
|
'''Called by asyncio when the write buffer is full.'''
|
|
self.log_info('pausing request processing whilst socket drains')
|
|
self.pause = True
|
|
|
|
def resume_writing(self):
|
|
'''Called by asyncio when the write buffer has room.'''
|
|
self.log_info('resuming request processing')
|
|
self.pause = False
|
|
|
|
def close_connection(self):
|
|
self.stop = time.time()
|
|
if self.transport:
|
|
self.transport.close()
|
|
|
|
def using_bandwidth(self, amount):
|
|
now = time.time()
|
|
# Reduce the recorded usage in proportion to the elapsed time
|
|
elapsed = now - self.bandwidth_start
|
|
self.bandwidth_start = now
|
|
refund = int(elapsed / self.bandwidth_interval * self.bandwidth_limit)
|
|
refund = min(refund, self.bandwidth_used)
|
|
self.bandwidth_used += amount - refund
|
|
|
|
def data_received(self, data):
|
|
'''Handle incoming data (synchronously).
|
|
|
|
Requests end in newline characters. Pass complete requests to
|
|
decode_message for handling.
|
|
'''
|
|
self.recv_size += len(data)
|
|
self.using_bandwidth(len(data))
|
|
|
|
# Close abusive connections where buffered data exceeds limit
|
|
buffer_size = len(data) + sum(len(part) for part in self.parts)
|
|
if buffer_size > self.max_buffer_size:
|
|
self.log_error('read buffer of {:,d} bytes exceeds {:,d} '
|
|
'byte limit, closing {}'
|
|
.format(buffer_size, self.max_buffer_size,
|
|
self.peername()))
|
|
self.close_connection()
|
|
|
|
# Do nothing if this connection is closing
|
|
if self.transport.is_closing():
|
|
return
|
|
|
|
while True:
|
|
npos = data.find(ord('\n'))
|
|
if npos == -1:
|
|
self.parts.append(data)
|
|
break
|
|
self.last_recv = time.time()
|
|
self.recv_count += 1
|
|
tail, data = data[:npos], data[npos + 1:]
|
|
parts, self.parts = self.parts, []
|
|
parts.append(tail)
|
|
self.decode_message(b''.join(parts))
|
|
|
|
def decode_message(self, message):
|
|
'''Decode a binary message and queue it for asynchronous handling.
|
|
|
|
Messages that cannot be decoded are logged and dropped.
|
|
'''
|
|
try:
|
|
message = message.decode()
|
|
except UnicodeDecodeError as e:
|
|
msg = 'cannot decode binary bytes: {}'.format(e)
|
|
self.send_json_error(msg, self.PARSE_ERROR, close=True)
|
|
return
|
|
|
|
try:
|
|
message = json.loads(message)
|
|
except json.JSONDecodeError as e:
|
|
msg = 'cannot decode JSON: {}'.format(e)
|
|
self.send_json_error(msg, self.PARSE_ERROR, close=True)
|
|
return
|
|
|
|
if isinstance(message, list):
|
|
# Batches must have at least one request.
|
|
if not message:
|
|
self.send_json_error('empty batch', self.INVALID_REQUEST)
|
|
return
|
|
request = BatchRequest(message)
|
|
else:
|
|
request = SingleRequest(message)
|
|
|
|
'''Queue the request for asynchronous handling.'''
|
|
self.enqueue_request(request)
|
|
if self.log_me:
|
|
self.log_info('queued {}'.format(message))
|
|
|
|
def encode_payload(self, payload):
|
|
try:
|
|
binary = json.dumps(payload).encode()
|
|
except TypeError:
|
|
msg = 'JSON encoding failure: {}'.format(payload)
|
|
self.log_error(msg)
|
|
return self.json_error(msg, self.INTERNAL_ERROR,
|
|
self.payload_id(payload))
|
|
|
|
self.check_oversized_request(len(binary))
|
|
self.send_count += 1
|
|
self.send_size += len(binary)
|
|
self.using_bandwidth(len(binary))
|
|
return binary
|
|
|
|
def _send_bytes(self, binary, close=False):
|
|
'''Send JSON text over the transport. Close it if close is True.'''
|
|
# Confirmed this happens, sometimes a lot
|
|
if self.transport.is_closing():
|
|
return
|
|
self.transport.write(binary)
|
|
self.transport.write(b'\n')
|
|
if close or self.error_count > 10:
|
|
self.close_connection()
|
|
|
|
def send_json_error(self, message, code, id_=None, close=False):
|
|
'''Send a JSON error and close the connection by default.'''
|
|
self._send_bytes(self.json_error_bytes(message, code, id_), close)
|
|
|
|
def encode_and_send_payload(self, payload):
|
|
'''Encode the payload and send it.'''
|
|
self._send_bytes(self.encode_payload(payload))
|
|
|
|
def json_notification_bytes(self, method, params):
|
|
'''Return the bytes of a json notification.'''
|
|
return self.encode_payload(self.notification_payload(method, params))
|
|
|
|
def json_request_bytes(self, method, id_, params=None):
|
|
'''Return the bytes of a JSON request.'''
|
|
return self.encode_payload(self.request_payload(method, id_, params))
|
|
|
|
def json_response_bytes(self, result, id_):
|
|
'''Return the bytes of a JSON response.'''
|
|
return self.encode_payload(self.response_payload(result, id_))
|
|
|
|
def json_error_bytes(self, message, code, id_=None):
|
|
'''Return the bytes of a JSON error.'''
|
|
self.error_count += 1
|
|
return self.encode_payload(self.error_payload(message, code, id_))
|
|
|
|
async def process_single_payload(self, payload):
|
|
'''Return the binary JSON result of a single JSON request, response or
|
|
notification.
|
|
|
|
The result is empty if nothing is to be sent.
|
|
'''
|
|
|
|
if not isinstance(payload, dict):
|
|
return self.json_error_bytes('request must be a dict',
|
|
self.INVALID_REQUEST)
|
|
|
|
try:
|
|
if not 'id' in payload:
|
|
return await self.process_json_notification(payload)
|
|
|
|
id_ = payload['id']
|
|
if not isinstance(id_, self.ID_TYPES):
|
|
return self.json_error_bytes('invalid id: {}'.format(id_),
|
|
self.INVALID_REQUEST)
|
|
|
|
if 'method' in payload:
|
|
return await self.process_json_request(payload)
|
|
|
|
return await self.process_json_response(payload)
|
|
except self.RPCError as e:
|
|
return self.json_error_bytes(e.msg, e.code,
|
|
self.payload_id(payload))
|
|
|
|
@classmethod
|
|
def method_and_params(cls, payload):
|
|
method = payload.get('method')
|
|
params = payload.get('params', [])
|
|
|
|
if not isinstance(method, str):
|
|
raise cls.RPCError('invalid method: {}'.format(method),
|
|
cls.INVALID_REQUEST)
|
|
|
|
if not isinstance(params, list):
|
|
raise cls.RPCError('params should be an array',
|
|
cls.INVALID_REQUEST)
|
|
|
|
return method, params
|
|
|
|
async def process_json_notification(self, payload):
|
|
try:
|
|
method, params = self.method_and_params(payload)
|
|
except self.RPCError:
|
|
pass
|
|
else:
|
|
await self.handle_notification(method, params)
|
|
return b''
|
|
|
|
async def process_json_request(self, payload):
|
|
method, params = self.method_and_params(payload)
|
|
result = await self.handle_request(method, params)
|
|
return self.json_response_bytes(result, payload['id'])
|
|
|
|
async def process_json_response(self, payload):
|
|
# Only one of result and error should exist; we go with 'error'
|
|
# if both are supplied.
|
|
if 'error' in payload:
|
|
await self.handle_response(None, payload['error'], payload['id'])
|
|
elif 'result' in payload:
|
|
await self.handle_response(payload['result'], None, payload['id'])
|
|
return b''
|
|
|
|
def check_oversized_request(self, total_len):
|
|
if total_len > max(1000, self.max_send):
|
|
raise self.RPCError('request too large', self.INVALID_REQUEST)
|
|
|
|
def raise_unknown_method(self, method):
|
|
'''Respond to a request with an unknown method.'''
|
|
raise self.RPCError("unknown method: '{}'".format(method),
|
|
self.METHOD_NOT_FOUND)
|
|
|
|
# Common parameter verification routines
|
|
@classmethod
|
|
def param_to_non_negative_integer(cls, param):
|
|
'''Return param if it is or can be converted to a non-negative
|
|
integer, otherwise raise an RPCError.'''
|
|
try:
|
|
param = int(param)
|
|
if param >= 0:
|
|
return param
|
|
except ValueError:
|
|
pass
|
|
|
|
raise cls.RPCError('param {} should be a non-negative integer'
|
|
.format(param))
|
|
|
|
@classmethod
|
|
def params_to_non_negative_integer(cls, params):
|
|
if len(params) == 1:
|
|
return cls.param_to_non_negative_integer(params[0])
|
|
raise cls.RPCError('params {} should contain one non-negative integer'
|
|
.format(params))
|
|
|
|
@classmethod
|
|
def require_empty_params(cls, params):
|
|
if params:
|
|
raise cls.RPCError('params {} should be empty'.format(params))
|
|
|
|
|
|
# --- derived classes are intended to override these functions
|
|
def enqueue_request(self, request):
|
|
'''Enqueue a request for later asynchronous processing.'''
|
|
raise NotImplementedError
|
|
|
|
async def handle_notification(self, method, params):
|
|
'''Handle a notification.'''
|
|
|
|
async def handle_request(self, method, params):
|
|
'''Handle a request.'''
|
|
return None
|
|
|
|
async def handle_response(self, result, error, id_):
|
|
'''Handle a response.'''
|
|
|