diff --git a/electrumx_rpc.py b/electrumx_rpc.py index abaa330..7cc2fbb 100755 --- a/electrumx_rpc.py +++ b/electrumx_rpc.py @@ -33,6 +33,8 @@ class RPCClient(JSONRPC): async def send_and_wait(self, method, params, timeout=None): # Raise incoming buffer size - presumably connection is trusted self.max_buffer_size = 5000000 + if params: + params = [params] payload = self.request_payload(method, id_=method, params=params) self.encode_and_send_payload(payload) diff --git a/lib/jsonrpc.py b/lib/jsonrpc.py index ccf46e1..91cb405 100644 --- a/lib/jsonrpc.py +++ b/lib/jsonrpc.py @@ -8,19 +8,30 @@ '''Class for handling JSON RPC 2.0 connections, server or client.''' import asyncio +import inspect import json import numbers import time +import traceback from lib.util import LoggedClass +class RPCError(Exception): + '''RPC handlers raise this error.''' + def __init__(self, msg, code=-1, **kw_args): + super().__init__(**kw_args) + self.msg = msg + self.code = code + + class RequestBase(object): '''An object that represents a queued request.''' def __init__(self, remaining): self.remaining = remaining + class SingleRequest(RequestBase): '''An object that represents a single request.''' @@ -62,7 +73,8 @@ class BatchRequest(RequestBase): self.parts.append(part) total_len = sum(len(part) + 2 for part in self.parts) - session.check_oversized_request(total_len) + if session.is_oversized_request(total_len): + raise RPCError('request too large', JSONRPC.INVALID_REQUEST) if not self.remaining: if self.parts: @@ -83,34 +95,31 @@ class JSONRPC(asyncio.Protocol, LoggedClass): Derived classes may want to override connection_made() and connection_lost() but should be sure to call the implementation in - this base class first. They will also want to implement some or - all of the asynchronous functions handle_notification(), - handle_response() and handle_request(). - - handle_request() returns the result to pass over the network, and - must raise an RPCError if there is an error. - handle_notification() and handle_response() should not return - anything or raise any exceptions. All three functions have - default "ignore" implementations supplied by this class. + this base class first. They may also want to implement the asynchronous + function handle_response() which by default does nothing. + + The functions request_handler() and notification_handler() are + passed an RPC method name, and should return an asynchronous + function to call to handle it. The functions' docstrings are used + for help, and the arguments are what can be used as JSONRPC 2.0 + named arguments (and thus become part of the external interface). + If the method is unknown return None. + + Request handlers should return a Python object to return to the + caller, or raise an RPCError on error. Notification handlers + should not return a value or raise any exceptions. ''' # See http://www.jsonrpc.org/specification PARSE_ERROR = -32700 INVALID_REQUEST = -32600 METHOD_NOT_FOUND = -32601 - INVALID_PARAMS = -32602 + INVALID_ARGS = -32602 INTERNAL_ERROR = -32603 ID_TYPES = (type(None), str, numbers.Number) NEXT_SESSION_ID = 0 - class RPCError(Exception): - '''RPC handlers raise this error.''' - def __init__(self, msg, code=-1, **kw_args): - super().__init__(**kw_args) - self.msg = msg - self.code = code - @classmethod def request_payload(cls, method, id_, params=None): payload = {'jsonrpc': '2.0', 'id': id_, 'method': method} @@ -120,8 +129,6 @@ class JSONRPC(asyncio.Protocol, LoggedClass): @classmethod def response_payload(cls, result, id_): - # We should not respond to notifications - assert id_ is not None return {'jsonrpc': '2.0', 'result': result, 'id': id_} @classmethod @@ -133,9 +140,29 @@ class JSONRPC(asyncio.Protocol, LoggedClass): error = {'message': message, 'code': code} return {'jsonrpc': '2.0', 'error': error, 'id': id_} + @classmethod + def check_payload_id(cls, payload): + '''Extract and return the ID from the payload. + + Raises an RPCError if it is missing or invalid.''' + if not 'id' in payload: + raise RPCError('missing id', JSONRPC.INVALID_REQUEST) + + id_ = payload['id'] + if not isinstance(id_, JSONRPC.ID_TYPES): + raise RPCError('invalid id: {}'.format(id_), + JSONRPC.INVALID_REQUEST) + return id_ + @classmethod def payload_id(cls, payload): - return payload.get('id') if isinstance(payload, dict) else None + '''Extract and return the ID from the payload. + + Returns None if it is missing or invalid.''' + try: + return cls.check_payload_id(payload) + except RPCError: + return None def __init__(self): super().__init__() @@ -157,6 +184,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass): self.send_count = 0 self.send_size = 0 self.error_count = 0 + self.close_after_send = False self.peer_info = None # Sends longer than max_send are prevented, instead returning # an oversized request error to other end of the network @@ -260,20 +288,20 @@ class JSONRPC(asyncio.Protocol, LoggedClass): message = message.decode() except UnicodeDecodeError as e: msg = 'cannot decode binary bytes: {}'.format(e) - self.send_json_error(msg, self.PARSE_ERROR, close=True) + self.send_json_error(msg, JSONRPC.PARSE_ERROR) return try: message = json.loads(message) except json.JSONDecodeError as e: msg = 'cannot decode JSON: {}'.format(e) - self.send_json_error(msg, self.PARSE_ERROR, close=True) + self.send_json_error(msg, JSONRPC.PARSE_ERROR) return if isinstance(message, list): - # Batches must have at least one request. + # Batches must have at least one object. if not message: - self.send_json_error('empty batch', self.INVALID_REQUEST) + self.send_json_error('empty batch', JSONRPC.INVALID_REQUEST) return request = BatchRequest(message) else: @@ -284,35 +312,43 @@ class JSONRPC(asyncio.Protocol, LoggedClass): if self.log_me: self.log_info('queued {}'.format(message)) + def send_json_error(self, message, code, id_=None): + '''Send a JSON error.''' + self._send_bytes(self.json_error_bytes(message, code, id_)) + def encode_payload(self, payload): + assert isinstance(payload, dict) + try: binary = json.dumps(payload).encode() except TypeError: msg = 'JSON encoding failure: {}'.format(payload) self.log_error(msg) - return self.send_json_error(msg, self.INTERNAL_ERROR, - self.payload_id(payload)) + binary = self.json_error_bytes(msg, JSONRPC.INTERNAL_ERROR, + payload.get('id')) - self.check_oversized_request(len(binary)) + if self.is_oversized_request(len(binary)): + binary = self.json_error_bytes('request too large', + JSONRPC.INVALID_REQUEST, + payload.get('id')) self.send_count += 1 self.send_size += len(binary) self.using_bandwidth(len(binary)) return binary - def _send_bytes(self, binary, close=False): + def is_oversized_request(self, total_len): + return total_len > max(1000, self.max_send) + + def _send_bytes(self, binary): '''Send JSON text over the transport. Close it if close is True.''' # Confirmed this happens, sometimes a lot if self.transport.is_closing(): return self.transport.write(binary) self.transport.write(b'\n') - if close or self.error_count > 10: + if self.close_after_send: self.close_connection() - def send_json_error(self, message, code, id_=None, close=False): - '''Send a JSON error and close the connection by default.''' - self._send_bytes(self.json_error_bytes(message, code, id_), close) - def encode_and_send_payload(self, payload): '''Encode the payload and send it.''' self._send_bytes(self.encode_payload(payload)) @@ -330,124 +366,134 @@ class JSONRPC(asyncio.Protocol, LoggedClass): return self.encode_payload(self.response_payload(result, id_)) def json_error_bytes(self, message, code, id_=None): - '''Return the bytes of a JSON error.''' + '''Return the bytes of a JSON error. + + Flag the connection to close on a fatal error or too many errors.''' self.error_count += 1 + if (code in (JSONRPC.PARSE_ERROR, JSONRPC.INVALID_REQUEST) + or self.error_count > 10): + self.close_after_send = True return self.encode_payload(self.error_payload(message, code, id_)) async def process_single_payload(self, payload): - '''Return the binary JSON result of a single JSON request, response or - notification. - - The result is empty if nothing is to be sent. - ''' + '''Handle a single JSON request, notification or response. + If it is a request, return the binary response, oterhwise None.''' if not isinstance(payload, dict): return self.json_error_bytes('request must be a dict', - self.INVALID_REQUEST) - - try: - if not 'id' in payload: - return await self.process_json_notification(payload) - - id_ = payload['id'] - if not isinstance(id_, self.ID_TYPES): - return self.json_error_bytes('invalid id: {}'.format(id_), - self.INVALID_REQUEST) + JSONRPC.INVALID_REQUEST) + + # Requests and notifications must have a method. + # Notifications are distinguished by having no 'id'. + if 'method' in payload: + if 'id' in payload: + return await self.process_single_request(payload) + else: + await self.process_single_notification(payload) + else: + await self.process_single_response(payload) - if 'method' in payload: - return await self.process_json_request(payload) + return None - return await self.process_json_response(payload) - except self.RPCError as e: + async def process_single_request(self, payload): + '''Handle a single JSON request and return the binary response.''' + try: + result = await self.handle_payload(payload, self.request_handler) + return self.json_response_bytes(result, payload['id']) + except RPCError as e: return self.json_error_bytes(e.msg, e.code, self.payload_id(payload)) + except Exception: + self.log_error(traceback.format_exc()) + return self.json_error_bytes('internal error processing request', + JSONRPC.INTERNAL_ERROR, + self.payload_id(payload)) - @classmethod - def method_and_params(cls, payload): + async def process_single_notification(self, payload): + '''Handle a single JSON notification.''' + try: + await self.handle_payload(payload, self.notification_handler) + except RPCError: + pass + except Exception: + self.log_error(traceback.format_exc()) + + async def process_single_response(self, payload): + '''Handle a single JSON response.''' + try: + id_ = self.check_payload_id(payload) + # Only one of result and error should exist + if 'error' in payload: + error = payload['error'] + if (not 'result' in payload and isinstance(error, dict) + and 'code' in error and 'message' in error): + await self.handle_response(None, error, id_) + elif 'result' in payload: + await self.handle_response(payload['result'], None, id_) + except RPCError: + pass + except Exception: + self.log_error(traceback.format_exc()) + + async def handle_payload(self, payload, get_handler): + '''Handle a request or notification payload given the handlers.''' + # An argument is the value passed to a function parameter... + args = payload.get('params', []) method = payload.get('method') - params = payload.get('params', []) if not isinstance(method, str): - raise cls.RPCError('invalid method: {}'.format(method), - cls.INVALID_REQUEST) + raise RPCError("invalid method: '{}'".format(method), + JSONRPC.INVALID_REQUEST) - if not isinstance(params, list): - raise cls.RPCError('params should be an array', - cls.INVALID_REQUEST) + handler = get_handler(method) + if not handler: + raise RPCError("unknown method: '{}'".format(method), + JSONRPC.METHOD_NOT_FOUND) - return method, params + if not isinstance(args, (list, dict)): + raise RPCError('arguments should be an array or a dict', + JSONRPC.INVALID_REQUEST) - async def process_json_notification(self, payload): - try: - method, params = self.method_and_params(payload) - except self.RPCError: - pass - else: - await self.handle_notification(method, params) - return b'' - - async def process_json_request(self, payload): - method, params = self.method_and_params(payload) - result = await self.handle_request(method, params) - return self.json_response_bytes(result, payload['id']) - - async def process_json_response(self, payload): - # Only one of result and error should exist; we go with 'error' - # if both are supplied. - if 'error' in payload: - await self.handle_response(None, payload['error'], payload['id']) - elif 'result' in payload: - await self.handle_response(payload['result'], None, payload['id']) - return b'' - - def check_oversized_request(self, total_len): - if total_len > max(1000, self.max_send): - raise self.RPCError('request too large', self.INVALID_REQUEST) - - def raise_unknown_method(self, method): - '''Respond to a request with an unknown method.''' - raise self.RPCError("unknown method: '{}'".format(method), - self.METHOD_NOT_FOUND) - - # Common parameter verification routines - @classmethod - def param_to_non_negative_integer(cls, param): - '''Return param if it is or can be converted to a non-negative - integer, otherwise raise an RPCError.''' - try: - param = int(param) - if param >= 0: - return param - except ValueError: - pass + params = inspect.signature(handler).parameters + names = list(params) + min_args = sum(p.default is p.empty for p in params.values()) - raise cls.RPCError('param {} should be a non-negative integer' - .format(param)) + if len(args) < min_args: + raise RPCError('too few arguments: expected {:d} got {:d}' + .format(min_args, len(args)), JSONRPC.INVALID_ARGS) - @classmethod - def params_to_non_negative_integer(cls, params): - if len(params) == 1: - return cls.param_to_non_negative_integer(params[0]) - raise cls.RPCError('params {} should contain one non-negative integer' - .format(params)) + if len(args) > len(params): + raise RPCError('too many arguments: expected {:d} got {:d}' + .format(len(params), len(args)), + JSONRPC.INVALID_ARGS) - @classmethod - def require_empty_params(cls, params): - if params: - raise cls.RPCError('params {} should be empty'.format(params)) + if isinstance(args, list): + kw_args = {name: arg for name, arg in zip(names, args)} + else: + kw_args = args + bad_names = ['<{}>'.format(name) for name in args + if name not in names] + if bad_names: + raise RPCError('invalid parameter names: {}' + .format(', '.join(bad_names))) + return await handler(**kw_args) # --- derived classes are intended to override these functions def enqueue_request(self, request): '''Enqueue a request for later asynchronous processing.''' raise NotImplementedError - async def handle_notification(self, method, params): - '''Handle a notification.''' + async def handle_response(self, result, error, id_): + '''Handle a JSON response. + + Should not raise an exception. Return values are ignored. + ''' - async def handle_request(self, method, params): - '''Handle a request.''' + def notification_handler(self, method): + '''Return the async handler for the given notification method.''' return None - async def handle_response(self, result, error, id_): - '''Handle a response.''' + def request_handler(self, method): + '''Return the async handler for the given request method.''' + return None diff --git a/server/controller.py b/server/controller.py index d1c4462..73585d3 100644 --- a/server/controller.py +++ b/server/controller.py @@ -6,6 +6,7 @@ # and warranty status of this software. import asyncio +import codecs import json import os import ssl @@ -16,12 +17,14 @@ from functools import partial import pylru -from lib.jsonrpc import JSONRPC, RequestBase +from lib.jsonrpc import JSONRPC, RPCError, RequestBase +from lib.hash import sha256, double_sha256, hash_to_str, hex_str_to_hash import lib.util as util from server.block_processor import BlockProcessor from server.irc import IRC from server.session import LocalRPC, ElectrumX from server.mempool import MemPool +from server.version import VERSION class Controller(util.LoggedClass): @@ -48,7 +51,9 @@ class Controller(util.LoggedClass): super().__init__() self.loop = asyncio.get_event_loop() self.start = time.time() + self.coin = env.coin self.bp = BlockProcessor(env) + self.daemon = self.bp.daemon self.mempool = MemPool(self.bp) self.irc = IRC(env) self.env = env @@ -69,10 +74,27 @@ class Controller(util.LoggedClass): self.queue = asyncio.PriorityQueue() self.delayed_sessions = [] self.next_queue_id = 0 - self.height = 0 + self.cache_height = 0 self.futures = [] env.max_send = max(350000, env.max_send) self.setup_bands() + # Set up the RPC request handlers + cmds = 'disconnect getinfo groups log peers reorg sessions'.split() + self.rpc_handlers = {cmd: getattr(self, 'rpc_' + cmd) for cmd in cmds} + # Set up the ElectrumX request handlers + rpcs = [ + ('blockchain', + 'address.get_balance address.get_history address.get_mempool ' + 'address.get_proof address.listunspent ' + 'block.get_header block.get_chunk estimatefee relayfee ' + 'transaction.get transaction.get_merkle utxo.get_address'), + ('server', + 'banner donation_address peers.subscribe version'), + ] + self.electrumx_handlers = {'.'.join([prefix, suffix]): + getattr(self, suffix.replace('.', '_')) + for prefix, suffixes in rpcs + for suffix in suffixes.split()} async def mempool_transactions(self, hashX): '''Generate (hex_hash, tx_fee, unconfirmed) tuples for mempool @@ -167,7 +189,7 @@ class Controller(util.LoggedClass): await session.serve_requests() async def main_loop(self): - '''Server manager main loop.''' + '''Controller main loop.''' def add_future(coro): self.futures.append(asyncio.ensure_future(coro)) @@ -259,8 +281,8 @@ class Controller(util.LoggedClass): hc = self.history_cache for hashX in set(hc).intersection(touched): del hc[hashX] - if self.bp.db_height != self.height: - self.height = self.bp.db_height + if self.bp.db_height != self.cache_height: + self.cache_height = self.bp.db_height self.header_cache.clear() for session in self.sessions: @@ -280,32 +302,14 @@ class Controller(util.LoggedClass): def electrum_header(self, height): '''Return the binary header at the given height.''' if not 0 <= height <= self.bp.db_height: - raise JSONRPC.RPCError('height {:,d} out of range'.format(height)) + raise RPCError('height {:,d} out of range'.format(height)) if height in self.header_cache: return self.header_cache[height] header = self.bp.read_headers(height, 1) - header = self.env.coin.electrum_header(header, height) + header = self.coin.electrum_header(header, height) self.header_cache[height] = header return header - async def async_get_history(self, hashX): - '''Get history asynchronously to reduce latency.''' - if hashX in self.history_cache: - return self.history_cache[hashX] - - def job(): - # History DoS limit. Each element of history is about 99 - # bytes when encoded as JSON. This limits resource usage - # on bloated history requests, and uses a smaller divisor - # so large requests are logged before refusing them. - limit = self.env.max_send // 97 - return list(self.bp.get_history(hashX, limit=limit)) - - loop = asyncio.get_event_loop() - history = await loop.run_in_executor(None, job) - self.history_cache[hashX] = history - return history - async def shutdown(self): '''Call to shutdown everything. Returns when done.''' self.state = self.SHUTTING_DOWN @@ -400,15 +404,6 @@ class Controller(util.LoggedClass): self.sessions[session] = new_gid self.groups[new_gid] = sessions - def new_subscription(self): - if self.subscription_count >= self.max_subs: - raise JSONRPC.RPCError('server subscription limit {:,d} reached' - .format(self.max_subs)) - self.subscription_count += 1 - - def irc_peers(self): - return self.irc.peers - def session_count(self): '''The number of connections that we've sent something to.''' return len(self.sessions) @@ -416,7 +411,7 @@ class Controller(util.LoggedClass): def server_summary(self): '''A one-line summary of server state.''' return { - 'daemon_height': self.bp.daemon.cached_height(), + 'daemon_height': self.daemon.cached_height(), 'db_height': self.bp.db_height, 'closing': len([s for s in self.sessions if s.is_closing()]), 'errors': sum(s.error_count for s in self.sessions), @@ -522,49 +517,360 @@ class Controller(util.LoggedClass): now - session.start) for session in sessions] - def lookup_session(self, param): + def lookup_session(self, session_id): try: - id_ = int(param) + session_id = int(session_id) except: pass else: for session in self.sessions: - if session.id_ == id_: + if session.id_ == session_id: return session return None - def for_each_session(self, params, operation): + def for_each_session(self, session_ids, operation): + if not isinstance(session_ids, list): + raise RPCError('expected a list of session IDs') + result = [] - for param in params: - session = self.lookup_session(param) + for session_id in session_ids: + session = self.lookup_session(session_id) if session: result.append(operation(session)) else: - result.append('unknown session: {}'.format(param)) + result.append('unknown session: {}'.format(session_id)) return result - async def rpc_disconnect(self, params): - return self.for_each_session(params, self.close_session) + # Local RPC command handlers + + async def rpc_disconnect(self, session_ids): + '''Disconnect sesssions. - async def rpc_log(self, params): - return self.for_each_session(params, self.toggle_logging) + session_ids: array of session IDs + ''' + return self.for_each_session(session_ids, self.close_session) + + async def rpc_log(self, session_ids): + '''Toggle logging of sesssions. + + session_ids: array of session IDs + ''' + return self.for_each_session(session_ids, self.toggle_logging) - async def rpc_getinfo(self, params): + async def rpc_getinfo(self): + '''Return summary information about the server process.''' return self.server_summary() - async def rpc_groups(self, params): + async def rpc_groups(self): + '''Return statistics about the session groups.''' return self.group_data() - async def rpc_sessions(self, params): + async def rpc_sessions(self): + '''Return statistics about connected sessions.''' return self.session_data(for_log=False) - async def rpc_peers(self, params): + async def rpc_peers(self): + '''Return a list of server peers, currently taken from IRC.''' return self.irc.peers - async def rpc_reorg(self, params): - '''Force a reorg of the given number of blocks, 3 by default.''' - count = 3 - if params: - count = JSONRPC.params_to_non_negative_integer(params) + async def rpc_reorg(self, count=3): + '''Force a reorg of the given number of blocks. + + count: number of blocks to reorg (default 3) + ''' + count = self.non_negative_integer(count) if not self.bp.force_chain_reorg(count): - raise JSONRPC.RPCError('still catching up with daemon') + raise RPCError('still catching up with daemon') + return 'scheduled a reorg of {:,d} blocks'.format(count) + + # Helpers for RPC "blockchain" command handlers + + def address_to_hashX(self, address): + if isinstance(address, str): + try: + return self.coin.address_to_hashX(address) + except: + pass + raise RPCError('{} is not a valid address'.format(address)) + + def to_tx_hash(self, value): + '''Raise an RPCError if the value is not a valid transaction + hash.''' + if isinstance(value, str) and len(value) == 64: + try: + bytes.fromhex(value) + return value + except ValueError: + pass + raise RPCError('{} should be a transaction hash'.format(value)) + + def non_negative_integer(self, value): + '''Return param value it is or can be converted to a non-negative + integer, otherwise raise an RPCError.''' + try: + value = int(value) + if value >= 0: + return value + except ValueError: + pass + raise RPCError('{} should be a non-negative integer'.format(value)) + + async def daemon_request(self, method, *args): + '''Catch a DaemonError and convert it to an RPCError.''' + try: + return await getattr(self.daemon, method)(*args) + except DaemonError as e: + raise RPCError('daemon error: {}'.format(e)) + + async def new_subscription(self, address): + if self.subscription_count >= self.max_subs: + raise RPCError('server subscription limit {:,d} reached' + .format(self.max_subs)) + self.subscription_count += 1 + hashX = self.address_to_hashX(address) + status = await self.address_status(hashX) + return hashX, status + + async def tx_merkle(self, tx_hash, height): + '''tx_hash is a hex string.''' + hex_hashes = await self.daemon_request('block_hex_hashes', height, 1) + block = await self.daemon_request('deserialised_block', hex_hashes[0]) + tx_hashes = block['tx'] + try: + pos = tx_hashes.index(tx_hash) + except ValueError: + raise RPCError('tx hash {} not in block {} at height {:,d}' + .format(tx_hash, hex_hashes[0], height)) + + idx = pos + hashes = [hex_str_to_hash(txh) for txh in tx_hashes] + merkle_branch = [] + while len(hashes) > 1: + if len(hashes) & 1: + hashes.append(hashes[-1]) + idx = idx - 1 if (idx & 1) else idx + 1 + merkle_branch.append(hash_to_str(hashes[idx])) + idx //= 2 + hashes = [double_sha256(hashes[n] + hashes[n + 1]) + for n in range(0, len(hashes), 2)] + + return {"block_height": height, "merkle": merkle_branch, "pos": pos} + + async def get_balance(self, hashX): + utxos = await self.get_utxos(hashX) + confirmed = sum(utxo.value for utxo in utxos) + unconfirmed = self.mempool_value(hashX) + return {'confirmed': confirmed, 'unconfirmed': unconfirmed} + + async def unconfirmed_history(self, hashX): + # Note unconfirmed history is unordered in electrum-server + # Height is -1 if unconfirmed txins, otherwise 0 + mempool = await self.mempool_transactions(hashX) + return [{'tx_hash': tx_hash, 'height': -unconfirmed, 'fee': fee} + for tx_hash, fee, unconfirmed in mempool] + + async def get_history(self, hashX): + '''Get history asynchronously to reduce latency.''' + if hashX in self.history_cache: + return self.history_cache[hashX] + + def job(): + # History DoS limit. Each element of history is about 99 + # bytes when encoded as JSON. This limits resource usage + # on bloated history requests, and uses a smaller divisor + # so large requests are logged before refusing them. + limit = self.env.max_send // 97 + return list(self.bp.get_history(hashX, limit=limit)) + + loop = asyncio.get_event_loop() + history = await loop.run_in_executor(None, job) + self.history_cache[hashX] = history + return history + + async def confirmed_and_unconfirmed_history(self, hashX): + # Note history is ordered but unconfirmed is unordered in e-s + history = await self.get_history(hashX) + conf = [{'tx_hash': hash_to_str(tx_hash), 'height': height} + for tx_hash, height in history] + return conf + await self.unconfirmed_history(hashX) + + async def address_status(self, hashX): + '''Returns status as 32 bytes.''' + # Note history is ordered and mempool unordered in electrum-server + # For mempool, height is -1 if unconfirmed txins, otherwise 0 + history = await self.get_history(hashX) + mempool = await self.mempool_transactions(hashX) + + status = ''.join('{}:{:d}:'.format(hash_to_str(tx_hash), height) + for tx_hash, height in history) + status += ''.join('{}:{:d}:'.format(hex_hash, -unconfirmed) + for hex_hash, tx_fee, unconfirmed in mempool) + if status: + return sha256(status.encode()).hex() + return None + + async def get_utxos(self, hashX): + '''Get UTXOs asynchronously to reduce latency.''' + def job(): + return list(self.bp.get_utxos(hashX, limit=None)) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, job) + + def get_chunk(self, index): + '''Return header chunk as hex. Index is a non-negative integer.''' + chunk_size = self.coin.CHUNK_SIZE + next_height = self.bp.db_height + 1 + start_height = min(index * chunk_size, next_height) + count = min(next_height - start_height, chunk_size) + return self.bp.read_headers(start_height, count).hex() + + # Client RPC "blockchain" command handlers + + async def address_get_balance(self, address): + '''Return the confirmed and unconfirmed balance of an address.''' + hashX = self.address_to_hashX(address) + return await self.get_balance(hashX) + + async def address_get_history(self, address): + '''Return the confirmed and unconfirmed history of an address.''' + hashX = self.address_to_hashX(address) + return await self.confirmed_and_unconfirmed_history(hashX) + + async def address_get_mempool(self, address): + '''Return the mempool transactions touching an address.''' + hashX = self.address_to_hashX(address) + return await self.unconfirmed_history(hashX) + + async def address_get_proof(self, address): + '''Return the UTXO proof of an address.''' + hashX = self.address_to_hashX(address) + raise RPCError('address.get_proof is not yet implemented') + + async def address_listunspent(self, address): + '''Return the list of UTXOs of an address.''' + hashX = self.address_to_hashX(address) + return [{'tx_hash': hash_to_str(utxo.tx_hash), 'tx_pos': utxo.tx_pos, + 'height': utxo.height, 'value': utxo.value} + for utxo in sorted(await self.get_utxos(hashX))] + + async def block_get_chunk(self, index): + '''Return a chunk of block headers. + + index: the chunk index''' + index = self.non_negative_integer(index) + return self.get_chunk(index) + + async def block_get_header(self, height): + '''The deserialized header at a given height. + + height: the header's height''' + height = self.non_negative_integer(height) + return self.electrum_header(height) + + async def estimatefee(self, number): + '''The estimated transaction fee per kilobyte to be paid for a + transaction to be included within a certain number of blocks. + + number: the number of blocks + ''' + number = self.non_negative_integer(number) + return await self.daemon_request('estimatefee', [number]) + + async def relayfee(self): + '''The minimum fee a low-priority tx must pay in order to be accepted + to the daemon's memory pool.''' + return await self.daemon_request('relayfee') + + async def transaction_get(self, tx_hash, height=None): + '''Return the serialized raw transaction given its hash + + tx_hash: the transaction hash as a hexadecimal string + height: ignored, do not use + ''' + # For some reason Electrum passes a height. We don't require + # it in anticipation it might be dropped in the future. + tx_hash = self.to_tx_hash(tx_hash) + return await self.daemon_request('getrawtransaction', tx_hash) + + async def transaction_get_merkle(self, tx_hash, height): + '''Return the markle tree to a confirmed transaction given its hash + and height. + + tx_hash: the transaction hash as a hexadecimal string + height: the height of the block it is in + ''' + tx_hash = self.to_tx_hash(tx_hash) + height = self.non_negative_integer(height) + return await self.tx_merkle(tx_hash, height) + + async def utxo_get_address(self, tx_hash, index): + '''Returns the address sent to in a UTXO, or null if the UTXO + cannot be found. + + tx_hash: the transaction hash of the UTXO + index: the index of the UTXO in the transaction''' + # Used only for electrum client command-line requests. We no + # longer index by address, so need to request the raw + # transaction. So it works for any TXO not just UTXOs. + tx_hash = self.to_tx_hash(tx_hash) + index = self.non_negative_integer(index) + raw_tx = await self.daemon_request('getrawtransaction', tx_hash) + if not raw_tx: + return None + raw_tx = bytes.fromhex(raw_tx) + deserializer = self.coin.deserializer() + tx, tx_hash = deserializer(raw_tx).read_tx() + if index >= len(tx.outputs): + return None + return self.coin.address_from_script(tx.outputs[index].pk_script) + + # Client RPC "server" command handlers + + async def banner(self): + '''Return the server banner text.''' + banner = 'Welcome to Electrum!' + if self.env.banner_file: + try: + with codecs.open(self.env.banner_file, 'r', 'utf-8') as f: + banner = f.read() + except Exception as e: + self.log_error('reading banner file {}: {}' + .format(self.env.banner_file, e)) + else: + network_info = await self.daemon_request('getnetworkinfo') + version = network_info['version'] + major, minor = divmod(version, 1000000) + minor, revision = divmod(minor, 10000) + revision //= 100 + version = '{:d}.{:d}.{:d}'.format(major, minor, revision) + for pair in [ + ('$VERSION', VERSION), + ('$DAEMON_VERSION', version), + ('$DAEMON_SUBVERSION', network_info['subversion']), + ('$DONATION_ADDRESS', self.env.donation_address), + ]: + banner = banner.replace(*pair) + + return banner + + async def donation_address(self): + '''Return the donation address as a string, empty if there is none.''' + return self.env.donation_address + + async def peers_subscribe(self): + '''Returns the server peers as a list of (ip, host, ports) tuples. + + Despite the name this is not currently treated as a subscription.''' + return list(self.irc.peers.values()) + + async def version(self, client_name=None, protocol_version=None): + '''Returns the server version as a string. + + client_name: a string identifying the client + protocol_version: the protocol version spoken by the client + ''' + if client_name: + self.client = str(client_name)[:15] + if protocol_version is not None: + self.protocol_version = protocol_version + return VERSION diff --git a/server/session.py b/server/session.py index 2f5d30a..d646c4f 100644 --- a/server/session.py +++ b/server/session.py @@ -9,13 +9,10 @@ import asyncio -import codecs import traceback -from lib.hash import sha256, double_sha256, hash_to_str, hex_str_to_hash -from lib.jsonrpc import JSONRPC +from lib.jsonrpc import JSONRPC, RPCError from server.daemon import DaemonError -from server.version import VERSION class Session(JSONRPC): @@ -26,13 +23,12 @@ class Session(JSONRPC): long-running requests should yield. ''' - def __init__(self, manager, bp, env, kind): + def __init__(self, controller, bp, env, kind): super().__init__() - self.manager = manager + self.controller = controller self.bp = bp self.env = env self.daemon = bp.daemon - self.coin = bp.coin self.kind = kind self.client = 'unknown' self.anon_logs = env.anon_logs @@ -53,7 +49,7 @@ class Session(JSONRPC): status += 'C' if self.log_me: status += 'L' - status += str(self.manager.session_priority(self)) + status += str(self.controller.session_priority(self)) return status def requests_remaining(self): @@ -63,7 +59,7 @@ class Session(JSONRPC): '''Add a request to the session's list.''' self.requests.append(request) if len(self.requests) == 1: - self.manager.enqueue_session(self) + self.controller.enqueue_session(self) async def serve_requests(self): '''Serve requests in batches.''' @@ -90,68 +86,27 @@ class Session(JSONRPC): self.requests = [req for req in self.requests if req.remaining and not req in errs] if self.requests: - self.manager.enqueue_session(self) + self.controller.enqueue_session(self) def connection_made(self, transport): '''Handle an incoming client connection.''' super().connection_made(transport) - self.manager.add_session(self) + self.controller.add_session(self) def connection_lost(self, exc): '''Handle client disconnection.''' super().connection_lost(exc) - if (self.pause or self.manager.is_deprioritized(self) + if (self.pause or self.controller.is_deprioritized(self) or self.send_size >= 1024*1024 or self.error_count): self.log_info('disconnected. Sent {:,d} bytes in {:,d} messages ' '{:,d} errors' .format(self.send_size, self.send_count, self.error_count)) - self.manager.remove_session(self) - - async def handle_request(self, method, params): - '''Handle a request.''' - handler = self.handlers.get(method) - if not handler: - self.raise_unknown_method(method) - - return await handler(params) + self.controller.remove_session(self) def sub_count(self): return 0 - async def daemon_request(self, method, *args): - '''Catch a DaemonError and convert it to an RPCError.''' - try: - return await getattr(self.daemon, method)(*args) - except DaemonError as e: - raise self.RPCError('daemon error: {}'.format(e)) - - def param_to_tx_hash(self, param): - '''Raise an RPCError if the parameter is not a valid transaction - hash.''' - if isinstance(param, str) and len(param) == 64: - try: - bytes.fromhex(param) - return param - except ValueError: - pass - raise self.RPCError('parameter should be a transaction hash: {}' - .format(param)) - - def param_to_hashX(self, param): - if isinstance(param, str): - try: - return self.coin.address_to_hashX(param) - except: - pass - raise self.RPCError('param {} is not a valid address'.format(param)) - - def params_to_hashX(self, params): - if len(params) == 1: - return self.param_to_hashX(params[0]) - raise self.RPCError('params {} should contain a single address' - .format(params)) - class ElectrumX(Session): '''A TCP server that handles incoming Electrum connections.''' @@ -163,20 +118,12 @@ class ElectrumX(Session): self.notified_height = None self.max_subs = self.env.max_session_subs self.hashX_subs = {} - rpcs = [ - ('blockchain', - 'address.get_balance address.get_history address.get_mempool ' - 'address.get_proof address.listunspent address.subscribe ' - 'block.get_header block.get_chunk estimatefee headers.subscribe ' - 'numblocks.subscribe relayfee transaction.broadcast ' - 'transaction.get transaction.get_merkle utxo.get_address'), - ('server', - 'banner donation_address peers.subscribe version'), - ] - self.handlers = {'.'.join([prefix, suffix]): - getattr(self, suffix.replace('.', '_')) - for prefix, suffixes in rpcs - for suffix in suffixes.split()} + self.electrumx_handlers = { + 'blockchain.address.subscribe': self.address_subscribe, + 'blockchain.headers.subscribe': self.headers_subscribe, + 'blockchain.numblocks.subscribe': self.numblocks_subscribe, + 'blockchain.transaction.broadcast': self.transaction_broadcast, + } def sub_count(self): return len(self.hashX_subs) @@ -191,7 +138,7 @@ class ElectrumX(Session): if self.subscribe_headers: payload = self.notification_payload( 'blockchain.headers.subscribe', - (self.manager.electrum_header(height), ), + (self.controller.electrum_header(height), ), ) self.encode_and_send_payload(payload) @@ -205,7 +152,7 @@ class ElectrumX(Session): matches = touched.intersection(self.hashX_subs) for hashX in matches: address = self.hashX_subs[hashX] - status = await self.address_status(hashX) + status = await self.controller.address_status(hashX) payload = self.notification_payload( 'blockchain.address.subscribe', (address, status)) self.encode_and_send_payload(payload) @@ -219,162 +166,44 @@ class ElectrumX(Session): def current_electrum_header(self): '''Used as response to a headers subscription request.''' - return self.manager.electrum_header(self.height()) - - async def address_status(self, hashX): - '''Returns status as 32 bytes.''' - # Note history is ordered and mempool unordered in electrum-server - # For mempool, height is -1 if unconfirmed txins, otherwise 0 - history = await self.manager.async_get_history(hashX) - mempool = await self.manager.mempool_transactions(hashX) - - status = ''.join('{}:{:d}:'.format(hash_to_str(tx_hash), height) - for tx_hash, height in history) - status += ''.join('{}:{:d}:'.format(hex_hash, -unconfirmed) - for hex_hash, tx_fee, unconfirmed in mempool) - if status: - return sha256(status.encode()).hex() - return None - - async def tx_merkle(self, tx_hash, height): - '''tx_hash is a hex string.''' - hex_hashes = await self.daemon_request('block_hex_hashes', height, 1) - block = await self.daemon_request('deserialised_block', hex_hashes[0]) - tx_hashes = block['tx'] - try: - pos = tx_hashes.index(tx_hash) - except ValueError: - raise self.RPCError('tx hash {} not in block {} at height {:,d}' - .format(tx_hash, hex_hashes[0], height)) - - idx = pos - hashes = [hex_str_to_hash(txh) for txh in tx_hashes] - merkle_branch = [] - while len(hashes) > 1: - if len(hashes) & 1: - hashes.append(hashes[-1]) - idx = idx - 1 if (idx & 1) else idx + 1 - merkle_branch.append(hash_to_str(hashes[idx])) - idx //= 2 - hashes = [double_sha256(hashes[n] + hashes[n + 1]) - for n in range(0, len(hashes), 2)] - - return {"block_height": height, "merkle": merkle_branch, "pos": pos} - - async def unconfirmed_history(self, hashX): - # Note unconfirmed history is unordered in electrum-server - # Height is -1 if unconfirmed txins, otherwise 0 - mempool = await self.manager.mempool_transactions(hashX) - return [{'tx_hash': tx_hash, 'height': -unconfirmed, 'fee': fee} - for tx_hash, fee, unconfirmed in mempool] - - async def get_history(self, hashX): - # Note history is ordered but unconfirmed is unordered in e-s - history = await self.manager.async_get_history(hashX) - conf = [{'tx_hash': hash_to_str(tx_hash), 'height': height} - for tx_hash, height in history] - - return conf + await self.unconfirmed_history(hashX) - - def get_chunk(self, index): - '''Return header chunk as hex. Index is a non-negative integer.''' - chunk_size = self.coin.CHUNK_SIZE - next_height = self.height() + 1 - start_height = min(index * chunk_size, next_height) - count = min(next_height - start_height, chunk_size) - return self.bp.read_headers(start_height, count).hex() - - async def get_utxos(self, hashX): - '''Get UTXOs asynchronously to reduce latency.''' - def job(): - return list(self.bp.get_utxos(hashX, limit=None)) - loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, job) - - async def get_balance(self, hashX): - utxos = await self.get_utxos(hashX) - confirmed = sum(utxo.value for utxo in utxos) - unconfirmed = self.manager.mempool_value(hashX) - return {'confirmed': confirmed, 'unconfirmed': unconfirmed} - - async def list_unspent(self, hashX): - return [{'tx_hash': hash_to_str(utxo.tx_hash), 'tx_pos': utxo.tx_pos, - 'height': utxo.height, 'value': utxo.value} - for utxo in sorted(await self.get_utxos(hashX))] - - # --- blockchain commands - - async def address_get_balance(self, params): - hashX = self.params_to_hashX(params) - return await self.get_balance(hashX) - - async def address_get_history(self, params): - hashX = self.params_to_hashX(params) - return await self.get_history(hashX) - - async def address_get_mempool(self, params): - hashX = self.params_to_hashX(params) - return await self.unconfirmed_history(hashX) - - async def address_get_proof(self, params): - hashX = self.params_to_hashX(params) - raise self.RPCError('get_proof is not yet implemented') - - async def address_listunspent(self, params): - hashX = self.params_to_hashX(params) - return await self.list_unspent(hashX) - - async def address_subscribe(self, params): - hashX = self.params_to_hashX(params) - if len(self.hashX_subs) >= self.max_subs: - raise self.RPCError('your address subscription limit {:,d} reached' - .format(self.max_subs)) - result = await self.address_status(hashX) - # add_subscription can raise so call it before adding - self.manager.new_subscription() - self.hashX_subs[hashX] = params[0] - return result - - async def block_get_chunk(self, params): - index = self.params_to_non_negative_integer(params) - return self.get_chunk(index) - - async def block_get_header(self, params): - height = self.params_to_non_negative_integer(params) - return self.manager.electrum_header(height) - - async def estimatefee(self, params): - return await self.daemon_request('estimatefee', params) - - async def headers_subscribe(self, params): - self.require_empty_params(params) + return self.controller.electrum_header(self.height()) + + async def headers_subscribe(self): + '''Subscribe to get headers of new blocks.''' self.subscribe_headers = True return self.current_electrum_header() - async def numblocks_subscribe(self, params): - self.require_empty_params(params) + async def numblocks_subscribe(self): + '''Subscribe to get height of new blocks.''' self.subscribe_height = True return self.height() - async def relayfee(self, params): - '''The minimum fee a low-priority tx must pay in order to be accepted - to the daemon's memory pool.''' - self.require_empty_params(params) - return await self.daemon_request('relayfee') + async def address_subscribe(self, address): + '''Subscribe to an address. - async def transaction_broadcast(self, params): - '''Pass through the parameters to the daemon. + address: the address to subscribe to''' + # First check our limit. + if len(self.hashX_subs) >= self.max_subs: + raise RPCError('your address subscription limit {:,d} reached' + .format(self.max_subs)) + # Now let the controller check its limit + hashX, status = await self.controller.new_subscription(address) + self.hashX_subs[hashX] = address + return status - An ugly API: current Electrum clients only pass the raw - transaction in hex and expect error messages to be returned in - the result field. And the server shouldn't be doing the client's - user interface job here. - ''' + async def transaction_broadcast(self, raw_tx): + '''Broadcast a raw transaction to the network. + + raw_tx: the raw transaction as a hexadecimal string''' + # An ugly API: current Electrum clients only pass the raw + # transaction in hex and expect error messages to be returned in + # the result field. And the server shouldn't be doing the client's + # user interface job here. try: - tx_hash = await self.daemon.sendrawtransaction(params) + tx_hash = await self.daemon.sendrawtransaction([raw_tx]) self.txs_sent += 1 self.log_info('sent tx: {}'.format(tx_hash)) - self.manager.sent_tx(tx_hash) + self.controller.sent_tx(tx_hash) return tx_hash except DaemonError as e: error = e.args[0] @@ -390,105 +219,15 @@ class ElectrumX(Session): return ( 'The transaction was rejected by network rules. ({})\n[{}]' - .format(message, params[0]) + .format(message, raw_tx) ) - async def transaction_get(self, params): - '''Return the serialized raw transaction.''' - # For some reason Electrum passes a height. Don't require it - # in anticipation it might be dropped in the future. - if 1 <= len(params) <= 2: - tx_hash = self.param_to_tx_hash(params[0]) - return await self.daemon_request('getrawtransaction', tx_hash) - - raise self.RPCError('params wrong length: {}'.format(params)) - - async def transaction_get_merkle(self, params): - if len(params) == 2: - tx_hash = self.param_to_tx_hash(params[0]) - height = self.param_to_non_negative_integer(params[1]) - return await self.tx_merkle(tx_hash, height) - - raise self.RPCError('params should contain a transaction hash ' - 'and height') - - async def utxo_get_address(self, params): - '''Returns the address for a TXO. - - Used only for electrum client command-line requests. We no - longer index by address, so need to request the raw - transaction. So it works for any TXO not just UTXOs. - ''' - if len(params) == 2: - tx_hash = self.param_to_tx_hash(params[0]) - index = self.param_to_non_negative_integer(params[1]) - raw_tx = await self.daemon_request('getrawtransaction', tx_hash) - if not raw_tx: - return None - raw_tx = bytes.fromhex(raw_tx) - deserializer = self.coin.deserializer() - tx, tx_hash = deserializer(raw_tx).read_tx() - if index >= len(tx.outputs): - return None - return self.coin.address_from_script(tx.outputs[index].pk_script) - - raise self.RPCError('params should contain a transaction hash ' - 'and index') - - # --- server commands - - async def banner(self, params): - '''Return the server banner.''' - self.require_empty_params(params) - banner = 'Welcome to Electrum!' - if self.env.banner_file: - try: - with codecs.open(self.env.banner_file, 'r', 'utf-8') as f: - banner = f.read() - except Exception as e: - self.log_error('reading banner file {}: {}' - .format(self.env.banner_file, e)) - else: - network_info = await self.daemon.getnetworkinfo() - version = network_info['version'] - major, minor = divmod(version, 1000000) - minor, revision = divmod(minor, 10000) - revision //= 100 - version = '{:d}.{:d}.{:d}'.format(major, minor, revision) - for pair in [ - ('$VERSION', VERSION), - ('$DAEMON_VERSION', version), - ('$DAEMON_SUBVERSION', network_info['subversion']), - ('$DONATION_ADDRESS', self.env.donation_address), - ]: - banner = banner.replace(*pair) - - return banner - - async def donation_address(self, params): - '''Return the donation address as a string. - - If none is specified return the empty string. - ''' - self.require_empty_params(params) - return self.env.donation_address - - async def peers_subscribe(self, params): - '''Returns the peer (ip, host, ports) tuples. - - Despite the name electrum-server does not treat this as a - subscription. - ''' - self.require_empty_params(params) - return list(self.manager.irc_peers().values()) - - async def version(self, params): - '''Return the server version as a string.''' - if params: - self.client = str(params[0])[:15] - if len(params) > 1: - self.protocol_version = params[1] - return VERSION + def request_handler(self, method): + '''Return the async handler for the given request method.''' + handler = self.electrumx_handlers.get(method) + if not handler: + handler = self.controller.electrumx_handlers.get(method) + return handler class LocalRPC(Session): @@ -496,8 +235,9 @@ class LocalRPC(Session): def __init__(self, *args): super().__init__(*args) - cmds = 'disconnect getinfo groups log peers reorg sessions'.split() - self.handlers = {cmd: getattr(self.manager, 'rpc_{}'.format(cmd)) - for cmd in cmds} self.client = 'RPC' self.max_send = 5000000 + + def request_handler(self, method): + '''Return the async handler for the given request method.''' + return self.controller.rpc_handlers.get(method)