From b116040365652123367995908451af2dedda6697 Mon Sep 17 00:00:00 2001 From: Neil Booth Date: Sun, 4 Dec 2016 16:59:25 +0900 Subject: [PATCH] Clean up param verification code --- lib/jsonrpc.py | 39 +++++++++++++++++++++++++---- server/protocol.py | 61 +++++++++++++++------------------------------- 2 files changed, 53 insertions(+), 47 deletions(-) diff --git a/lib/jsonrpc.py b/lib/jsonrpc.py index d3d7390..512fcb2 100644 --- a/lib/jsonrpc.py +++ b/lib/jsonrpc.py @@ -304,17 +304,18 @@ class JSONRPC(asyncio.Protocol, LoggedClass): return await self.json_response(message) - def method_and_params(self, message): + @classmethod + def method_and_params(cls, message): method = message.get('method') params = message.get('params', []) if not isinstance(method, str): - raise self.RPCError('invalid method: {}'.format(method), - self.INVALID_REQUEST) + raise cls.RPCError('invalid method: {}'.format(method), + cls.INVALID_REQUEST) if not isinstance(params, list): - raise self.RPCError('params should be an array', - self.INVALID_REQUEST) + raise cls.RPCError('params should be an array', + cls.INVALID_REQUEST) return method, params @@ -349,6 +350,34 @@ class JSONRPC(asyncio.Protocol, LoggedClass): raise self.RPCError("unknown method: '{}'".format(method), self.METHOD_NOT_FOUND) + # Common parameter verification routines + @classmethod + def param_to_non_negative_integer(cls, param): + '''Return param if it is or can be converted to a non-negative + integer, otherwise raise an RPCError.''' + try: + param = int(param) + if param >= 0: + return param + except ValueError: + pass + + raise cls.RPCError('param {} should be a non-negative integer' + .format(param)) + + @classmethod + def params_to_non_negative_integer(cls, params): + if len(params) == 1: + return cls.param_to_non_negative_integer(params[0]) + raise cls.RPCError('params {} should contain one non-negative integer' + .format(params)) + + @classmethod + def require_empty_params(cls, params): + if params: + raise cls.RPCError('params {} should be empty'.format(params)) + + # --- derived classes are intended to override these functions async def handle_notification(self, method, params): '''Handle a notification.''' diff --git a/server/protocol.py b/server/protocol.py index 70d5014..de07679 100644 --- a/server/protocol.py +++ b/server/protocol.py @@ -623,7 +623,7 @@ class Session(JSONRPC): except DaemonError as e: raise self.RPCError('daemon error: {}'.format(e)) - def tx_hash_from_param(self, param): + def param_to_tx_hash(self, param): '''Raise an RPCError if the parameter is not a valid transaction hash.''' if isinstance(param, str) and len(param) == 64: @@ -635,43 +635,20 @@ class Session(JSONRPC): raise self.RPCError('parameter should be a transaction hash: {}' .format(param)) - def hash168_from_param(self, param): + def param_to_hash168(self, param): if isinstance(param, str): try: return self.coin.address_to_hash168(param) except: pass - raise self.RPCError('parameter should be a valid address: {}' - .format(param)) - - def non_negative_integer_from_param(self, param): - try: - param = int(param) - except ValueError: - pass - else: - if param >= 0: - return param + raise self.RPCError('param {} is not a valid address'.format(param)) - raise self.RPCError('param should be a non-negative integer: {}' - .format(param)) - - def extract_hash168(self, params): + def params_to_hash168(self, params): if len(params) == 1: - return self.hash168_from_param(params[0]) - raise self.RPCError('params should contain a single address: {}' + return self.param_to_hash168(params[0]) + raise self.RPCError('params {} should contain a single address' .format(params)) - def extract_non_negative_integer(self, params): - if len(params) == 1: - return self.non_negative_integer_from_param(params[0]) - raise self.RPCError('params should contain a non-negative integer: {}' - .format(params)) - - def require_empty_params(self, params): - if params: - raise self.RPCError('params should be empty: {}'.format(params)) - class ElectrumX(Session): '''A TCP server that handles incoming Electrum connections.''' @@ -837,27 +814,27 @@ class ElectrumX(Session): # --- blockchain commands async def address_get_balance(self, params): - hash168 = self.extract_hash168(params) + hash168 = self.params_to_hash168(params) return await self.get_balance(hash168) async def address_get_history(self, params): - hash168 = self.extract_hash168(params) + hash168 = self.params_to_hash168(params) return await self.get_history(hash168) async def address_get_mempool(self, params): - hash168 = self.extract_hash168(params) + hash168 = self.params_to_hash168(params) return self.unconfirmed_history(hash168) async def address_get_proof(self, params): - hash168 = self.extract_hash168(params) + hash168 = self.params_to_hash168(params) raise self.RPCError('get_proof is not yet implemented') async def address_listunspent(self, params): - hash168 = self.extract_hash168(params) + hash168 = self.params_to_hash168(params) return await self.list_unspent(hash168) async def address_subscribe(self, params): - hash168 = self.extract_hash168(params) + hash168 = self.params_to_hash168(params) if len(self.hash168s) >= self.max_subs: raise self.RPCError('your address subscription limit {:,d} reached' .format(self.max_subs)) @@ -868,11 +845,11 @@ class ElectrumX(Session): return result async def block_get_chunk(self, params): - index = self.extract_non_negative_integer(params) + index = self.params_to_non_negative_integer(params) return self.get_chunk(index) async def block_get_header(self, params): - height = self.extract_non_negative_integer(params) + height = self.params_to_non_negative_integer(params) return self.electrum_header(height) async def estimatefee(self, params): @@ -929,15 +906,15 @@ class ElectrumX(Session): # For some reason Electrum passes a height. Don't require it # in anticipation it might be dropped in the future. if 1 <= len(params) <= 2: - tx_hash = self.tx_hash_from_param(params[0]) + tx_hash = self.param_to_tx_hash(params[0]) return await self.daemon_request('getrawtransaction', tx_hash) raise self.RPCError('params wrong length: {}'.format(params)) async def transaction_get_merkle(self, params): if len(params) == 2: - tx_hash = self.tx_hash_from_param(params[0]) - height = self.non_negative_integer_from_param(params[1]) + tx_hash = self.param_to_tx_hash(params[0]) + height = self.param_to_non_negative_integer(params[1]) return await self.tx_merkle(tx_hash, height) raise self.RPCError('params should contain a transaction hash ' @@ -945,8 +922,8 @@ class ElectrumX(Session): async def utxo_get_address(self, params): if len(params) == 2: - tx_hash = self.tx_hash_from_param(params[0]) - index = self.non_negative_integer_from_param(params[1]) + tx_hash = self.param_to_tx_hash(params[0]) + index = self.param_to_non_negative_integer(params[1]) tx_hash = hex_str_to_hash(tx_hash) hash168 = self.bp.get_utxo_hash168(tx_hash, index) if hash168: