diff --git a/RELEASE-NOTES b/RELEASE-NOTES index 78c83b1..2075d39 100644 --- a/RELEASE-NOTES +++ b/RELEASE-NOTES @@ -1,3 +1,12 @@ +version 0.8.5 +------------- + +- rework of JSON layer to better handle batch requests. This is + preparatory work for improved DoS resistance. + +I'm aware recent versions don't sync efficiently; please use 0.8.0 to sync +until I find time to fix it. + version 0.8.4 ------------- diff --git a/lib/jsonrpc.py b/lib/jsonrpc.py index d3d7390..7b5b396 100644 --- a/lib/jsonrpc.py +++ b/lib/jsonrpc.py @@ -33,6 +33,9 @@ def json_request_payload(method, id_, params=None): def json_notification_payload(method, params=None): return json_request_payload(method, None, params) +def json_payload_id(payload): + return payload.get('id') if isinstance(payload, dict) else None + class JSONRPC(asyncio.Protocol, LoggedClass): '''Manages a JSONRPC connection. @@ -179,16 +182,14 @@ class JSONRPC(asyncio.Protocol, LoggedClass): message = message.decode() except UnicodeDecodeError as e: msg = 'cannot decode binary bytes: {}'.format(e) - self.send_json_error(msg, self.PARSE_ERROR) - self.transport.close() + self.send_json_error(msg, self.INVALID_REQUEST) return try: message = json.loads(message) except json.JSONDecodeError as e: msg = 'cannot decode JSON: {}'.format(e) - self.send_json_error(msg, self.PARSE_ERROR) - self.transport.close() + self.send_json_error(msg, self.INVALID_REQUEST) return '''Queue the request for asynchronous handling.''' @@ -196,55 +197,98 @@ class JSONRPC(asyncio.Protocol, LoggedClass): if self.log_me: self.log_info('queued {}'.format(message)) - def send_json_notification(self, method, params): - '''Create a json notification.''' - self.send_json(json_notification_payload(method, params)) + def encode_payload(self, payload): + try: + text = (json.dumps(payload) + '\n').encode() + except TypeError: + msg = 'JSON encoding failure: {}'.format(payload) + self.log_error(msg) + return self.json_error(msg, self.INTERNAL_ERROR, + json_payload_id(payload)) + + self.check_oversized_request(len(text)) + if 'error' in payload: + self.error_count += 1 + self.send_count += 1 + self.send_size += len(text) + self.using_bandwidth(len(text)) + return text + + def send_text(self, text, close): + '''Send JSON text over the transport. Close it if close is True.''' + # Confirmed this happens, sometimes a lot + if self.transport.is_closing(): + return + self.transport.write(text) + if close: + self.transport.close() - def send_json_request(self, method, id_, params=None): - '''Send a JSON request.''' - self.send_json(json_request_payload(method, id_, params)) + def send_json_error(self, message, code, id_=None, close=True): + '''Send a JSON error and close the connection by default.''' + self.send_text(self.json_error_text(message, code, id_), close) - def send_json_response(self, result, id_): - '''Send a JSON result.''' - self.send_json(json_response_payload(result, id_)) + def encode_and_send_payload(self, payload): + '''Encode the payload and send it.''' + self.send_text(self.encode_payload(payload), False) - def send_json_error(self, message, code, id_=None): - '''Send a JSON error.''' - self.send_json(json_error_payload(message, code, id_)) - self.error_count += 1 - # Close abusive clients - if self.error_count >= 10: - self.transport.close() + def json_notification_text(self, method, params): + '''Return the text of a json notification.''' + return self.encode_payload(json_notification_payload(method, params)) - def send_json(self, payload): - '''Send a JSON payload.''' - # Confirmed this happens, sometimes a lot - if self.transport.is_closing(): - return + def json_request_text(self, method, id_, params=None): + '''Return the text of a JSON request.''' + return self.encode_payload(json_request_payload(method, params)) - id_ = payload.get('id') if isinstance(payload, dict) else None - try: - data = (json.dumps(payload) + '\n').encode() - except TypeError: - msg = 'JSON encoding failure: {}'.format(payload) - self.log_error(msg) - self.send_json_error(msg, self.INTERNAL_ERROR, id_) - else: - if len(data) > max(1000, self.max_send): - self.send_json_error('request too large', - self.INVALID_REQUEST, id_) - raise self.LargeRequestError - else: - self.send_count += 1 - self.send_size += len(data) - self.using_bandwidth(len(data)) - self.transport.write(data) + def json_response_text(self, result, id_): + '''Return the text of a JSON response.''' + return self.encode_payload(json_response_payload(result, id_)) + + def json_error_text(self, message, code, id_=None): + '''Return the text of a JSON error.''' + return self.encode_payload(json_error_payload(message, code, id_)) - async def handle_message(self, message): + async def handle_message(self, payload): '''Asynchronously handle a JSON request or response. Handles batches according to the JSON 2.0 spec. ''' + try: + if isinstance(payload, list): + text = await self.process_json_batch(payload) + else: + text = await self.process_single_json(payload) + except self.RPCError as e: + text = self.json_error_text(e.msg, e.code, + json_payload_id(payload)) + + if text: + self.send_text(text, self.error_count > 10) + + async def process_json_batch(self, batch): + '''Return the text response to a JSON batch request.''' + # Batches must have at least one request. + if not batch: + return self.json_error_text('empty batch', self.INVALID_REQUEST) + + # PYTHON 3.6: use asynchronous comprehensions when supported + parts = [] + total_len = 0 + for item in batch: + part = await self.process_single_json(item) + if part: + parts.append(part) + total_len += len(part) + 2 + self.check_oversized_request(total_len) + if parts: + return '{' + ', '.join(parts) + '}' + return '' + + async def process_single_json(self, payload): + '''Return the JSON result of a single JSON request, response or + notification. + + Return None if the request is a notification or a response. + ''' # Throttle high-bandwidth connections by delaying processing # their requests. Delay more the higher the excessive usage. excess = self.bandwidth_used - self.bandwidth_limit @@ -255,100 +299,98 @@ class JSONRPC(asyncio.Protocol, LoggedClass): .format(self.bandwidth_used, secs)) await asyncio.sleep(secs) - if isinstance(message, list): - payload = await self.batch_payload(message) - else: - payload = await self.single_payload(message) - - if payload: - try: - self.send_json(payload) - except self.LargeRequestError: - self.log_warning('blocked large request {}'.format(message)) - - async def batch_payload(self, batch): - '''Return the JSON payload corresponding to a batch JSON request.''' - # Batches must have at least one request. - if not batch: - return json_error_payload('empty request list', - self.INVALID_REQUEST) - - # PYTHON 3.6: use asynchronous comprehensions when supported - payload = [] - for message in batch: - message_payload = await self.single_payload(message) - if message_payload: - payload.append(message_payload) - return payload - - async def single_payload(self, message): - '''Return the JSON payload corresponding to a single JSON request, - response or notification. - - Return None if the request is a notification or a response. - ''' - if not isinstance(message, dict): - return json_error_payload('request must be a dict', - self.INVALID_REQUEST) + if not isinstance(payload, dict): + return self.json_error_text('request must be a dict', + self.INVALID_REQUEST) - if not 'id' in message: - return await self.json_notification(message) + if not 'id' in payload: + return await self.process_json_notification(payload) - id_ = message['id'] + id_ = payload['id'] if not isinstance(id_, self.ID_TYPES): - return json_error_payload('invalid id: {}'.format(id_), - self.INVALID_REQUEST) + return self.json_error_text('invalid id: {}'.format(id_), + self.INVALID_REQUEST) - if 'method' in message: - return await self.json_request(message) + if 'method' in payload: + return await self.process_json_request(payload) - return await self.json_response(message) + return await self.process_json_response(payload) - def method_and_params(self, message): - method = message.get('method') - params = message.get('params', []) + @classmethod + def method_and_params(cls, payload): + method = payload.get('method') + params = payload.get('params', []) if not isinstance(method, str): - raise self.RPCError('invalid method: {}'.format(method), - self.INVALID_REQUEST) + raise cls.RPCError('invalid method: {}'.format(method), + cls.INVALID_REQUEST) if not isinstance(params, list): - raise self.RPCError('params should be an array', - self.INVALID_REQUEST) + raise cls.RPCError('params should be an array', + cls.INVALID_REQUEST) return method, params - async def json_notification(self, message): + async def process_json_notification(self, payload): try: - method, params = self.method_and_params(message) + method, params = self.method_and_params(payload) except self.RPCError: pass else: await self.handle_notification(method, params) - return None + return '' - async def json_request(self, message): - try: - method, params = self.method_and_params(message) - result = await self.handle_request(method, params) - return json_response_payload(result, message['id']) - except self.RPCError as e: - return json_error_payload(e.msg, e.code, message['id']) + async def process_json_request(self, payload): + method, params = self.method_and_params(payload) + result = await self.handle_request(method, params) + return self.json_response_text(result, payload['id']) - async def json_response(self, message): + async def process_json_response(self, payload): # Only one of result and error should exist; we go with 'error' # if both are supplied. - if 'error' in message: - await self.handle_response(None, message['error'], message['id']) - elif 'result' in message: - await self.handle_response(message['result'], None, message['id']) - return None + if 'error' in payload: + await self.handle_response(None, payload['error'], payload['id']) + elif 'result' in payload: + await self.handle_response(payload['result'], None, payload['id']) + return '' + + def check_oversized_request(self, total_len): + if total_len > max(1000, self.max_send): + raise self.RPCError('request too large', self.INVALID_REQUEST) def raise_unknown_method(self, method): '''Respond to a request with an unknown method.''' raise self.RPCError("unknown method: '{}'".format(method), self.METHOD_NOT_FOUND) + # Common parameter verification routines + @classmethod + def param_to_non_negative_integer(cls, param): + '''Return param if it is or can be converted to a non-negative + integer, otherwise raise an RPCError.''' + try: + param = int(param) + if param >= 0: + return param + except ValueError: + pass + + raise cls.RPCError('param {} should be a non-negative integer' + .format(param)) + + @classmethod + def params_to_non_negative_integer(cls, params): + if len(params) == 1: + return cls.param_to_non_negative_integer(params[0]) + raise cls.RPCError('params {} should contain one non-negative integer' + .format(params)) + + @classmethod + def require_empty_params(cls, params): + if params: + raise cls.RPCError('params {} should be empty'.format(params)) + + # --- derived classes are intended to override these functions async def handle_notification(self, method, params): '''Handle a notification.''' diff --git a/server/protocol.py b/server/protocol.py index 70d5014..07f5aa5 100644 --- a/server/protocol.py +++ b/server/protocol.py @@ -419,7 +419,6 @@ class ServerManager(util.LoggedClass): cutoff = now - self.env.session_timeout stale = [session for session in self.sessions if session.last_recv < cutoff - and session.client != 'all_seeing_eye' and not session.is_closing()] for session in stale: self.close_session(session) @@ -623,7 +622,7 @@ class Session(JSONRPC): except DaemonError as e: raise self.RPCError('daemon error: {}'.format(e)) - def tx_hash_from_param(self, param): + def param_to_tx_hash(self, param): '''Raise an RPCError if the parameter is not a valid transaction hash.''' if isinstance(param, str) and len(param) == 64: @@ -635,43 +634,20 @@ class Session(JSONRPC): raise self.RPCError('parameter should be a transaction hash: {}' .format(param)) - def hash168_from_param(self, param): + def param_to_hash168(self, param): if isinstance(param, str): try: return self.coin.address_to_hash168(param) except: pass - raise self.RPCError('parameter should be a valid address: {}' - .format(param)) - - def non_negative_integer_from_param(self, param): - try: - param = int(param) - except ValueError: - pass - else: - if param >= 0: - return param + raise self.RPCError('param {} is not a valid address'.format(param)) - raise self.RPCError('param should be a non-negative integer: {}' - .format(param)) - - def extract_hash168(self, params): + def params_to_hash168(self, params): if len(params) == 1: - return self.hash168_from_param(params[0]) - raise self.RPCError('params should contain a single address: {}' + return self.param_to_hash168(params[0]) + raise self.RPCError('params {} should contain a single address' .format(params)) - def extract_non_negative_integer(self, params): - if len(params) == 1: - return self.non_negative_integer_from_param(params[0]) - raise self.RPCError('params should contain a non-negative integer: {}' - .format(params)) - - def require_empty_params(self, params): - if params: - raise self.RPCError('params should be empty: {}'.format(params)) - class ElectrumX(Session): '''A TCP server that handles incoming Electrum connections.''' @@ -715,14 +691,14 @@ class ElectrumX(Session): 'blockchain.headers.subscribe', (self.electrum_header(height), ), ) - self.send_json(cache[key]) + self.encode_and_send_payload(cache[key]) if self.subscribe_height: payload = json_notification_payload( 'blockchain.numblocks.subscribe', (height, ), ) - self.send_json(payload) + self.encode_and_send_payload(payload) hash168_to_address = self.coin.hash168_to_address matches = self.hash168s.intersection(touched) @@ -731,7 +707,7 @@ class ElectrumX(Session): status = await self.address_status(hash168) payload = json_notification_payload( 'blockchain.address.subscribe', (address, status)) - self.send_json(payload) + self.encode_and_send_payload(payload) if matches: self.log_info('notified of {:,d} addresses'.format(len(matches))) @@ -837,27 +813,27 @@ class ElectrumX(Session): # --- blockchain commands async def address_get_balance(self, params): - hash168 = self.extract_hash168(params) + hash168 = self.params_to_hash168(params) return await self.get_balance(hash168) async def address_get_history(self, params): - hash168 = self.extract_hash168(params) + hash168 = self.params_to_hash168(params) return await self.get_history(hash168) async def address_get_mempool(self, params): - hash168 = self.extract_hash168(params) + hash168 = self.params_to_hash168(params) return self.unconfirmed_history(hash168) async def address_get_proof(self, params): - hash168 = self.extract_hash168(params) + hash168 = self.params_to_hash168(params) raise self.RPCError('get_proof is not yet implemented') async def address_listunspent(self, params): - hash168 = self.extract_hash168(params) + hash168 = self.params_to_hash168(params) return await self.list_unspent(hash168) async def address_subscribe(self, params): - hash168 = self.extract_hash168(params) + hash168 = self.params_to_hash168(params) if len(self.hash168s) >= self.max_subs: raise self.RPCError('your address subscription limit {:,d} reached' .format(self.max_subs)) @@ -868,11 +844,11 @@ class ElectrumX(Session): return result async def block_get_chunk(self, params): - index = self.extract_non_negative_integer(params) + index = self.params_to_non_negative_integer(params) return self.get_chunk(index) async def block_get_header(self, params): - height = self.extract_non_negative_integer(params) + height = self.params_to_non_negative_integer(params) return self.electrum_header(height) async def estimatefee(self, params): @@ -929,15 +905,15 @@ class ElectrumX(Session): # For some reason Electrum passes a height. Don't require it # in anticipation it might be dropped in the future. if 1 <= len(params) <= 2: - tx_hash = self.tx_hash_from_param(params[0]) + tx_hash = self.param_to_tx_hash(params[0]) return await self.daemon_request('getrawtransaction', tx_hash) raise self.RPCError('params wrong length: {}'.format(params)) async def transaction_get_merkle(self, params): if len(params) == 2: - tx_hash = self.tx_hash_from_param(params[0]) - height = self.non_negative_integer_from_param(params[1]) + tx_hash = self.param_to_tx_hash(params[0]) + height = self.param_to_non_negative_integer(params[1]) return await self.tx_merkle(tx_hash, height) raise self.RPCError('params should contain a transaction hash ' @@ -945,8 +921,8 @@ class ElectrumX(Session): async def utxo_get_address(self, params): if len(params) == 2: - tx_hash = self.tx_hash_from_param(params[0]) - index = self.non_negative_integer_from_param(params[1]) + tx_hash = self.param_to_tx_hash(params[0]) + index = self.param_to_non_negative_integer(params[1]) tx_hash = hex_str_to_hash(tx_hash) hash168 = self.bp.get_utxo_hash168(tx_hash, index) if hash168: diff --git a/server/version.py b/server/version.py index 23e3031..845ec4c 100644 --- a/server/version.py +++ b/server/version.py @@ -1 +1 @@ -VERSION = "ElectrumX 0.8.4" +VERSION = "ElectrumX 0.8.5"