Browse Source

Clean up param verification code

master
Neil Booth 8 years ago
parent
commit
b116040365
  1. 39
      lib/jsonrpc.py
  2. 61
      server/protocol.py

39
lib/jsonrpc.py

@ -304,17 +304,18 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
return await self.json_response(message)
def method_and_params(self, message):
@classmethod
def method_and_params(cls, message):
method = message.get('method')
params = message.get('params', [])
if not isinstance(method, str):
raise self.RPCError('invalid method: {}'.format(method),
self.INVALID_REQUEST)
raise cls.RPCError('invalid method: {}'.format(method),
cls.INVALID_REQUEST)
if not isinstance(params, list):
raise self.RPCError('params should be an array',
self.INVALID_REQUEST)
raise cls.RPCError('params should be an array',
cls.INVALID_REQUEST)
return method, params
@ -349,6 +350,34 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
raise self.RPCError("unknown method: '{}'".format(method),
self.METHOD_NOT_FOUND)
# Common parameter verification routines
@classmethod
def param_to_non_negative_integer(cls, param):
'''Return param if it is or can be converted to a non-negative
integer, otherwise raise an RPCError.'''
try:
param = int(param)
if param >= 0:
return param
except ValueError:
pass
raise cls.RPCError('param {} should be a non-negative integer'
.format(param))
@classmethod
def params_to_non_negative_integer(cls, params):
if len(params) == 1:
return cls.param_to_non_negative_integer(params[0])
raise cls.RPCError('params {} should contain one non-negative integer'
.format(params))
@classmethod
def require_empty_params(cls, params):
if params:
raise cls.RPCError('params {} should be empty'.format(params))
# --- derived classes are intended to override these functions
async def handle_notification(self, method, params):
'''Handle a notification.'''

61
server/protocol.py

@ -623,7 +623,7 @@ class Session(JSONRPC):
except DaemonError as e:
raise self.RPCError('daemon error: {}'.format(e))
def tx_hash_from_param(self, param):
def param_to_tx_hash(self, param):
'''Raise an RPCError if the parameter is not a valid transaction
hash.'''
if isinstance(param, str) and len(param) == 64:
@ -635,43 +635,20 @@ class Session(JSONRPC):
raise self.RPCError('parameter should be a transaction hash: {}'
.format(param))
def hash168_from_param(self, param):
def param_to_hash168(self, param):
if isinstance(param, str):
try:
return self.coin.address_to_hash168(param)
except:
pass
raise self.RPCError('parameter should be a valid address: {}'
.format(param))
def non_negative_integer_from_param(self, param):
try:
param = int(param)
except ValueError:
pass
else:
if param >= 0:
return param
raise self.RPCError('param {} is not a valid address'.format(param))
raise self.RPCError('param should be a non-negative integer: {}'
.format(param))
def extract_hash168(self, params):
def params_to_hash168(self, params):
if len(params) == 1:
return self.hash168_from_param(params[0])
raise self.RPCError('params should contain a single address: {}'
return self.param_to_hash168(params[0])
raise self.RPCError('params {} should contain a single address'
.format(params))
def extract_non_negative_integer(self, params):
if len(params) == 1:
return self.non_negative_integer_from_param(params[0])
raise self.RPCError('params should contain a non-negative integer: {}'
.format(params))
def require_empty_params(self, params):
if params:
raise self.RPCError('params should be empty: {}'.format(params))
class ElectrumX(Session):
'''A TCP server that handles incoming Electrum connections.'''
@ -837,27 +814,27 @@ class ElectrumX(Session):
# --- blockchain commands
async def address_get_balance(self, params):
hash168 = self.extract_hash168(params)
hash168 = self.params_to_hash168(params)
return await self.get_balance(hash168)
async def address_get_history(self, params):
hash168 = self.extract_hash168(params)
hash168 = self.params_to_hash168(params)
return await self.get_history(hash168)
async def address_get_mempool(self, params):
hash168 = self.extract_hash168(params)
hash168 = self.params_to_hash168(params)
return self.unconfirmed_history(hash168)
async def address_get_proof(self, params):
hash168 = self.extract_hash168(params)
hash168 = self.params_to_hash168(params)
raise self.RPCError('get_proof is not yet implemented')
async def address_listunspent(self, params):
hash168 = self.extract_hash168(params)
hash168 = self.params_to_hash168(params)
return await self.list_unspent(hash168)
async def address_subscribe(self, params):
hash168 = self.extract_hash168(params)
hash168 = self.params_to_hash168(params)
if len(self.hash168s) >= self.max_subs:
raise self.RPCError('your address subscription limit {:,d} reached'
.format(self.max_subs))
@ -868,11 +845,11 @@ class ElectrumX(Session):
return result
async def block_get_chunk(self, params):
index = self.extract_non_negative_integer(params)
index = self.params_to_non_negative_integer(params)
return self.get_chunk(index)
async def block_get_header(self, params):
height = self.extract_non_negative_integer(params)
height = self.params_to_non_negative_integer(params)
return self.electrum_header(height)
async def estimatefee(self, params):
@ -929,15 +906,15 @@ class ElectrumX(Session):
# For some reason Electrum passes a height. Don't require it
# in anticipation it might be dropped in the future.
if 1 <= len(params) <= 2:
tx_hash = self.tx_hash_from_param(params[0])
tx_hash = self.param_to_tx_hash(params[0])
return await self.daemon_request('getrawtransaction', tx_hash)
raise self.RPCError('params wrong length: {}'.format(params))
async def transaction_get_merkle(self, params):
if len(params) == 2:
tx_hash = self.tx_hash_from_param(params[0])
height = self.non_negative_integer_from_param(params[1])
tx_hash = self.param_to_tx_hash(params[0])
height = self.param_to_non_negative_integer(params[1])
return await self.tx_merkle(tx_hash, height)
raise self.RPCError('params should contain a transaction hash '
@ -945,8 +922,8 @@ class ElectrumX(Session):
async def utxo_get_address(self, params):
if len(params) == 2:
tx_hash = self.tx_hash_from_param(params[0])
index = self.non_negative_integer_from_param(params[1])
tx_hash = self.param_to_tx_hash(params[0])
index = self.param_to_non_negative_integer(params[1])
tx_hash = hex_str_to_hash(tx_hash)
hash168 = self.bp.get_utxo_hash168(tx_hash, index)
if hash168:

Loading…
Cancel
Save