diff --git a/lib/env_base.py b/lib/env_base.py index a94eabb..0a929a7 100644 --- a/lib/env_base.py +++ b/lib/env_base.py @@ -8,19 +8,18 @@ '''Class for server environment configuration and defaults.''' +import logging from os import environ -import lib.util as lib_util - -class EnvBase(lib_util.LoggedClass): +class EnvBase(object): '''Wraps environment configuration.''' class Error(Exception): pass def __init__(self): - super().__init__() + self.logger = logging.getLogger(self.__class__.__name__) self.allow_root = self.boolean('ALLOW_ROOT', False) self.host = self.default('HOST', 'localhost') self.rpc_host = self.default('RPC_HOST', 'localhost') diff --git a/lib/jsonrpc.py b/lib/jsonrpc.py deleted file mode 100644 index ffc1045..0000000 --- a/lib/jsonrpc.py +++ /dev/null @@ -1,806 +0,0 @@ -# Copyright (c) 2016-2017, Neil Booth -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -'''Classes for acting as a peer over a transport and speaking the JSON -RPC versions 1.0 and 2.0. - -JSONSessionBase can use an arbitrary transport. -JSONSession integrates asyncio.Protocol to provide the transport. -''' - -import asyncio -import collections -import inspect -import json -import numbers -import time -import traceback - -import lib.util as util - - -class RPCError(Exception): - '''RPC handlers raise this error.''' - def __init__(self, msg, code=-1, **kw_args): - super().__init__(**kw_args) - self.msg = msg - self.code = code - - -class JSONRPC(object): - '''Base class of JSON RPC versions.''' - - # See http://www.jsonrpc.org/specification - PARSE_ERROR = -32700 - INVALID_REQUEST = -32600 - METHOD_NOT_FOUND = -32601 - INVALID_ARGS = -32602 - INTERNAL_ERROR = -32603 - - # Codes for this library - INVALID_RESPONSE = -100 - ERROR_CODE_UNAVAILABLE = -101 - REQUEST_TIMEOUT = -102 - FATAL_ERROR = -103 - - ID_TYPES = (type(None), str, numbers.Number) - HAS_BATCHES = False - - @classmethod - def canonical_error(cls, error): - '''Convert an error to a JSON RPC 2.0 error. - - Handlers then only have a single form of error to deal with. - ''' - if isinstance(error, int): - error = {'code': error} - elif isinstance(error, str): - error = {'message': error} - elif not isinstance(error, dict): - error = {'data': error} - error['code'] = error.get('code', JSONRPC.ERROR_CODE_UNAVAILABLE) - error['message'] = error.get('message', 'error message unavailable') - return error - - @classmethod - def timeout_error(cls): - return {'message': 'request timed out', - 'code': JSONRPC.REQUEST_TIMEOUT} - - -class JSONRPCv1(JSONRPC): - '''JSON RPC version 1.0.''' - - @classmethod - def request_payload(cls, id_, method, params=None): - '''JSON v1 request payload. Params is mandatory.''' - return {'method': method, 'params': params or [], 'id': id_} - - @classmethod - def notification_payload(cls, method, params=None): - '''JSON v1 notification payload. Params and id are mandatory.''' - return {'method': method, 'params': params or [], 'id': None} - - @classmethod - def response_payload(cls, result, id_): - '''JSON v1 response payload. error is present and None.''' - return {'id': id_, 'result': result, 'error': None} - - @classmethod - def error_payload(cls, message, code, id_): - '''JSON v1 error payload. result is present and None.''' - return {'id': id_, 'result': None, - 'error': {'message': message, 'code': code}} - - @classmethod - def handle_response(cls, handler, payload): - '''JSON v1 response handler. Both 'error' and 'result' - should exist with exactly one being None. - - Unfortunately many 1.0 clients behave like 2.0, and just send - one or the other. - ''' - error = payload.get('error') - if error is None: - handler(payload.get('result'), None) - else: - handler(None, cls.canonical_error(error)) - - @classmethod - def is_request(cls, payload): - '''Returns True if the payload (which has a method) is a request. - False means it is a notification.''' - return payload.get('id') is not None - - -class JSONRPCv2(JSONRPC): - '''JSON RPC version 2.0.''' - - HAS_BATCHES = True - - @classmethod - def request_payload(cls, id_, method, params=None): - '''JSON v2 request payload. Params is optional.''' - payload = {'jsonrpc': '2.0', 'method': method, 'id': id_} - if params: - payload['params'] = params - return payload - - @classmethod - def notification_payload(cls, method, params=None): - '''JSON v2 notification payload. There must be no id.''' - payload = {'jsonrpc': '2.0', 'method': method} - if params: - payload['params'] = params - return payload - - @classmethod - def response_payload(cls, result, id_): - '''JSON v2 response payload. error is not present.''' - return {'jsonrpc': '2.0', 'id': id_, 'result': result} - - @classmethod - def error_payload(cls, message, code, id_): - '''JSON v2 error payload. result is not present.''' - return {'jsonrpc': '2.0', 'id': id_, - 'error': {'message': message, 'code': code}} - - @classmethod - def handle_response(cls, handler, payload): - '''JSON v2 response handler. Exactly one of 'error' and 'result' - must exist. Errors must have 'code' and 'message' members. - ''' - if 'error' in payload: - handler(None, cls.canonical_error(payload['error'])) - elif 'result' in payload: - handler(payload['result'], None) - else: - error = {'message': 'no error or result returned', - 'code': JSONRPC.INVALID_RESPONSE} - handler(None, cls.canonical_error(error)) - - @classmethod - def batch_size(cls, parts): - '''Return the size of a JSON batch from its parts.''' - return sum(len(part) for part in parts) + 2 * len(parts) - - @classmethod - def batch_bytes(cls, parts): - '''Return the bytes of a JSON batch from its parts.''' - if parts: - return b'[' + b', '.join(parts) + b']' - return b'' - - @classmethod - def is_request(cls, payload): - '''Returns True if the payload (which has a method) is a request. - False means it is a notification.''' - return 'id' in payload - - -class JSONRPCCompat(JSONRPC): - '''Intended to be used until receiving a response from the peer, at - which point detect_version should be used to choose which version - to use. - - Sends requests compatible with v1 and v2. Errors cannot be - compatible so v2 errors are sent. - - Does not send responses or notifications, nor handle responses. - - ''' - @classmethod - def request_payload(cls, id_, method, params=None): - '''JSON v2 request payload but with params present.''' - return {'jsonrpc': '2.0', 'id': id_, - 'method': method, 'params': params or []} - - @classmethod - def error_payload(cls, message, code, id_): - '''JSON v2 error payload. result is not present.''' - return {'jsonrpc': '2.0', 'id': id_, - 'error': {'message': message, 'code': code}} - - @classmethod - def detect_version(cls, payload): - '''Return a best guess at a version compatible with the received - payload. - - Return None if one cannot be determined. - ''' - def item_version(item): - if isinstance(item, dict): - version = item.get('jsonrpc') - if version is None: - return JSONRPCv1 - if version == '2.0': - return JSONRPCv2 - return None - - if isinstance(payload, list) and payload: - version = item_version(payload[0]) - # If a batch return at least JSONRPCv2 - if version in (JSONRPCv1, None): - version = JSONRPCv2 - else: - version = item_version(payload) - - return version - - -class JSONSessionBase(util.LoggedClass): - '''Acts as the application layer session, communicating via JSON RPC - over an underlying transport. - - Processes incoming and sends outgoing requests, notifications and - responses. Incoming messages are queued. When the queue goes - from empty - ''' - _next_session_id = 0 - _pending_reqs = {} # Outgoing requests waiting for a response - - @classmethod - def next_session_id(cls): - '''Return the next unique session ID.''' - session_id = cls._next_session_id - cls._next_session_id += 1 - return session_id - - def _pending_request_keys(self): - '''Return a generator of pending request keys for this session.''' - return [key for key in self._pending_reqs if key[0] is self] - - def has_pending_requests(self): - '''Return True if this session has pending requests.''' - return bool(self._pending_request_keys()) - - def pop_response_handler(self, msg_id): - '''Return the response handler for the given message ID.''' - return self._pending_reqs.pop((self, msg_id), (None, None))[0] - - def timeout_session(self): - '''Trigger timeouts for all of the session's pending requests.''' - self._timeout_requests(self._pending_request_keys()) - - @classmethod - def timeout_check(cls): - '''Trigger timeouts where necessary for all pending requests.''' - now = time.time() - keys = [key for key, value in cls._pending_reqs.items() - if value[1] < now] - cls._timeout_requests(keys) - - @classmethod - def _timeout_requests(cls, keys): - '''Trigger timeouts for the given lookup keys.''' - values = [cls._pending_reqs.pop(key) for key in keys] - handlers = [handler for handler, timeout in values] - timeout_error = JSONRPC.timeout_error() - for handler in handlers: - handler(None, timeout_error) - - def __init__(self, version=JSONRPCCompat): - super().__init__() - - # Parts of an incomplete JSON line. We buffer them until - # getting a newline. - self.parts = [] - self.version = version - self.log_me = False - self.session_id = None - # Count of incoming complete JSON requests and the time of the - # last one. A batch counts as just one here. - self.last_recv = time.time() - self.send_count = 0 - self.send_size = 0 - self.recv_size = 0 - self.recv_count = 0 - self.error_count = 0 - self.pause = False - # Handling of incoming items - self.items = collections.deque() - self.items_event = asyncio.Event() - self.batch_results = [] - # Handling of outgoing requests - self.next_request_id = 0 - # If buffered incoming data exceeds this the connection is closed - self.max_buffer_size = 1000000 - self.max_send = 50000 - self.close_after_send = False - - def pause_writing(self): - '''Transport calls when the send buffer is full.''' - self.log_info('pausing processing whilst socket drains') - self.pause = True - - def resume_writing(self): - '''Transport calls when the send buffer has room.''' - self.log_info('resuming processing') - self.pause = False - - def is_oversized(self, length, id_): - '''Return an error payload if the given outgoing message size is too - large, or False if not. - ''' - if self.max_send and length > max(1000, self.max_send): - msg = 'response too large (at least {:d} bytes)'.format(length) - return self.error_bytes(msg, JSONRPC.INVALID_REQUEST, id_) - return False - - def send_binary(self, binary): - '''Pass the bytes through to the transport. - - Close the connection if close_after_send is set. - ''' - if self.is_closing(): - return - self.using_bandwidth(len(binary)) - self.send_count += 1 - self.send_size += len(binary) - self.send_bytes(binary) - if self.close_after_send: - self.close_connection() - - def payload_id(self, payload): - '''Extract and return the ID from the payload. - - Returns None if it is missing or invalid.''' - try: - return self.check_payload_id(payload) - except RPCError: - return None - - def check_payload_id(self, payload): - '''Extract and return the ID from the payload. - - Raises an RPCError if it is missing or invalid.''' - if 'id' not in payload: - raise RPCError('missing id', JSONRPC.INVALID_REQUEST) - - id_ = payload['id'] - if not isinstance(id_, self.version.ID_TYPES): - raise RPCError('invalid id type {}'.format(type(id_)), - JSONRPC.INVALID_REQUEST) - return id_ - - def request_bytes(self, id_, method, params=None): - '''Return the bytes of a JSON request.''' - payload = self.version.request_payload(id_, method, params) - return self.encode_payload(payload) - - def notification_bytes(self, method, params=None): - payload = self.version.notification_payload(method, params) - return self.encode_payload(payload) - - def response_bytes(self, result, id_): - '''Return the bytes of a JSON response.''' - return self.encode_payload(self.version.response_payload(result, id_)) - - def error_bytes(self, message, code, id_=None): - '''Return the bytes of a JSON error. - - Flag the connection to close on a fatal error or too many errors.''' - version = self.version - self.error_count += 1 - if not self.close_after_send: - fatal_log = None - if code in (version.PARSE_ERROR, version.INVALID_REQUEST, - version.FATAL_ERROR): - fatal_log = message - elif self.error_count >= 10: - fatal_log = 'too many errors, last: {}'.format(message) - if fatal_log: - self.log_info(fatal_log) - self.close_after_send = True - return self.encode_payload(self.version.error_payload - (message, code, id_)) - - def encode_payload(self, payload): - '''Encode a Python object as binary bytes.''' - assert isinstance(payload, dict) - - id_ = payload.get('id') - try: - binary = json.dumps(payload).encode() - except TypeError: - msg = 'JSON encoding failure: {}'.format(payload) - self.log_error(msg) - binary = self.error_bytes(msg, JSONRPC.INTERNAL_ERROR, id_) - - error_bytes = self.is_oversized(len(binary), id_) - return error_bytes or binary - - def decode_message(self, payload): - '''Decode a binary message and pass it on to process_single_item or - process_batch as appropriate. - - Messages that cannot be decoded are logged and dropped. - ''' - try: - payload = payload.decode() - except UnicodeDecodeError as e: - msg = 'cannot decode message: {}'.format(e) - self.send_error(msg, JSONRPC.PARSE_ERROR) - return - - try: - payload = json.loads(payload) - except json.JSONDecodeError as e: - msg = 'cannot decode JSON: {}'.format(e) - self.send_error(msg, JSONRPC.PARSE_ERROR) - return - - if self.version is JSONRPCCompat: - # Attempt to detect peer's JSON RPC version - version = self.version.detect_version(payload) - if not version: - version = JSONRPCv2 - self.log_info('unable to detect JSON RPC version, using 2.0') - self.version = version - - # Batches must have at least one object. - if payload == [] and self.version.HAS_BATCHES: - self.send_error('empty batch', JSONRPC.INVALID_REQUEST) - return - - self.items.append(payload) - self.items_event.set() - - async def process_batch(self, batch, count): - '''Processes count items from the batch according to the JSON 2.0 - spec. - - If any remain, puts what is left of the batch back in the deque - and returns None. Otherwise returns the binary batch result.''' - results = self.batch_results - self.batch_results = [] - - for n in range(count): - item = batch.pop() - result = await self.process_single_item(item) - if result: - results.append(result) - - if not batch: - return self.version.batch_bytes(results) - - error_bytes = self.is_oversized(self.batch_size(results), None) - if error_bytes: - return error_bytes - - self.items.appendleft(item) - self.batch_results = results - return None - - async def process_single_item(self, payload): - '''Handle a single JSON request, notification or response. - - If it is a request, return the binary response, oterhwise None.''' - if self.log_me: - self.log_info('processing {}'.format(payload)) - - if not isinstance(payload, dict): - return self.error_bytes('request must be a dictionary', - JSONRPC.INVALID_REQUEST) - - try: - # Requests and notifications must have a method. - if 'method' in payload: - if self.version.is_request(payload): - return await self.process_single_request(payload) - else: - await self.process_single_notification(payload) - else: - self.process_single_response(payload) - - return None - except asyncio.CancelledError: - raise - except Exception: - self.log_error(traceback.format_exc()) - return self.error_bytes('internal error processing request', - JSONRPC.INTERNAL_ERROR, - self.payload_id(payload)) - - async def process_single_request(self, payload): - '''Handle a single JSON request and return the binary response.''' - try: - result = await self.handle_payload(payload, self.request_handler) - return self.response_bytes(result, payload['id']) - except RPCError as e: - return self.error_bytes(e.msg, e.code, self.payload_id(payload)) - except asyncio.CancelledError: - raise - except Exception: - self.log_error(traceback.format_exc()) - return self.error_bytes('internal error processing request', - JSONRPC.INTERNAL_ERROR, - self.payload_id(payload)) - - async def process_single_notification(self, payload): - '''Handle a single JSON notification.''' - try: - await self.handle_payload(payload, self.notification_handler) - except RPCError: - pass - except Exception: - self.log_error(traceback.format_exc()) - - def process_single_response(self, payload): - '''Handle a single JSON response.''' - try: - id_ = self.check_payload_id(payload) - handler = self.pop_response_handler(id_) - if handler: - self.version.handle_response(handler, payload) - else: - self.log_info('response for unsent id {}'.format(id_), - throttle=True) - except RPCError: - pass - except Exception: - self.log_error(traceback.format_exc()) - - async def handle_payload(self, payload, get_handler): - '''Handle a request or notification payload given the handlers.''' - # An argument is the value passed to a function parameter... - args = payload.get('params', []) - method = payload.get('method') - - if not isinstance(method, str): - raise RPCError("invalid method type {}".format(type(method)), - JSONRPC.INVALID_REQUEST) - - handler = get_handler(method) - if not handler: - raise RPCError("unknown method: '{}'".format(method), - JSONRPC.METHOD_NOT_FOUND) - - if not isinstance(args, (list, dict)): - raise RPCError('arguments should be an array or dictionary', - JSONRPC.INVALID_REQUEST) - - params = inspect.signature(handler).parameters - names = list(params) - min_args = sum(p.default is p.empty for p in params.values()) - - if len(args) < min_args: - raise RPCError('too few arguments to {}: expected {:d} got {:d}' - .format(method, min_args, len(args)), - JSONRPC.INVALID_ARGS) - - if len(args) > len(params): - raise RPCError('too many arguments to {}: expected {:d} got {:d}' - .format(method, len(params), len(args)), - JSONRPC.INVALID_ARGS) - - if isinstance(args, list): - kw_args = {name: arg for name, arg in zip(names, args)} - else: - kw_args = args - bad_names = ['<{}>'.format(name) for name in args - if name not in names] - if bad_names: - raise RPCError('invalid parameter names: {}' - .format(', '.join(bad_names))) - - if inspect.iscoroutinefunction(handler): - return await handler(**kw_args) - else: - return handler(**kw_args) - - # ---- External Interface ---- - - async def process_pending_items(self, limit=8): - '''Processes up to LIMIT pending items asynchronously.''' - while limit > 0 and self.items: - item = self.items.popleft() - if isinstance(item, list) and self.version.HAS_BATCHES: - count = min(limit, len(item)) - binary = await self.process_batch(item, count) - limit -= count - else: - binary = await self.process_single_item(item) - limit -= 1 - - if binary: - self.send_binary(binary) - - if not self.items: - self.items_event.clear() - - def count_pending_items(self): - '''Counts the number of pending items.''' - return sum(len(item) if isinstance(item, list) else 1 - for item in self.items) - - def connection_made(self): - '''Call when an incoming client connection is established.''' - self.session_id = self.next_session_id() - self.log_prefix = '[{:d}] '.format(self.session_id) - - def data_received(self, data): - '''Underlying transport calls this when new data comes in. - - Look for newline separators terminating full requests. - ''' - if self.is_closing(): - return - self.using_bandwidth(len(data)) - self.recv_size += len(data) - - # Close abusive connections where buffered data exceeds limit - buffer_size = len(data) + sum(len(part) for part in self.parts) - if buffer_size > self.max_buffer_size: - self.log_error('read buffer of {:,d} bytes over {:,d} byte limit' - .format(buffer_size, self.max_buffer_size)) - self.close_connection() - return - - while True: - npos = data.find(ord('\n')) - if npos == -1: - self.parts.append(data) - break - tail, data = data[:npos], data[npos + 1:] - parts, self.parts = self.parts, [] - parts.append(tail) - self.recv_count += 1 - self.last_recv = time.time() - self.decode_message(b''.join(parts)) - - def send_error(self, message, code, id_=None): - '''Send a JSON error.''' - self.send_binary(self.error_bytes(message, code, id_)) - - def send_request(self, handler, method, params=None, timeout=30): - '''Sends a request and arranges for handler to be called with the - response when it comes in. - - A call to request_timeout_check() will result in pending - responses that have been waiting more than timeout seconds to - call the handler with a REQUEST_TIMEOUT error. - ''' - id_ = self.next_request_id - self.next_request_id += 1 - self.send_binary(self.request_bytes(id_, method, params)) - self._pending_reqs[(self, id_)] = (handler, time.time() + timeout) - - def send_notification(self, method, params=None): - '''Send a notification.''' - self.send_binary(self.notification_bytes(method, params)) - - def send_notifications(self, mp_iterable): - '''Send an iterable of (method, params) notification pairs. - - A 1-tuple is also valid in which case there are no params.''' - if False and self.version.HAS_BATCHES: - parts = [self.notification_bytes(*pair) for pair in mp_iterable] - self.send_binary(self.version.batch_bytes(parts)) - else: - for pair in mp_iterable: - self.send_notification(*pair) - - # -- derived classes are intended to override these functions - - # Transport layer - - def is_closing(self): - '''Return True if the underlying transport is closing.''' - raise NotImplementedError - - def close_connection(self): - '''Close the connection.''' - raise NotImplementedError - - def send_bytes(self, binary): - '''Pass the bytes through to the underlying transport.''' - raise NotImplementedError - - # App layer - - def using_bandwidth(self, amount): - '''Called as bandwidth is consumed. - - Override to implement bandwidth management. - ''' - pass - - def notification_handler(self, method): - '''Return the handler for the given notification. - - The handler can be synchronous or asynchronous.''' - return None - - def request_handler(self, method): - '''Return the handler for the given request method. - - The handler can be synchronous or asynchronous.''' - return None - - -class JSONSession(JSONSessionBase, asyncio.Protocol): - '''A JSONSessionBase instance specialized for use with - asyncio.protocol to implement the transport layer. - - The app should await on items_event, which is set when unprocessed - incoming items remain and cleared when the queue is empty, and - then arrange to call process_pending_items asynchronously. - - Derived classes may want to override the request and notification - handlers. - ''' - - def __init__(self, version=JSONRPCCompat): - super().__init__(version=version) - self.transport = None - self.write_buffer_high = 500000 - - def peer_info(self): - '''Returns information about the peer.''' - try: - # get_extra_info can throw even if self.transport is not None - return self.transport.get_extra_info('peername') - except Exception: - return None - - def abort(self): - '''Cut the connection abruptly.''' - self.transport.abort() - - def connection_made(self, transport): - '''Handle an incoming client connection.''' - transport.set_write_buffer_limits(high=self.write_buffer_high) - self.transport = transport - super().connection_made() - - def connection_lost(self, exc): - '''Trigger timeouts of all pending requests.''' - self.timeout_session() - - def is_closing(self): - '''True if the underlying transport is closing.''' - return self.transport and self.transport.is_closing() - - def close_connection(self): - '''Close the connection.''' - if self.transport: - self.transport.close() - - def send_bytes(self, binary): - '''Send JSON text over the transport.''' - self.transport.writelines((binary, b'\n')) - - def peer_addr(self, anon=True): - '''Return the peer address and port.''' - peer_info = self.peer_info() - if not peer_info: - return 'unknown' - if anon: - return 'xx.xx.xx.xx:xx' - if ':' in peer_info[0]: - return '[{}]:{}'.format(peer_info[0], peer_info[1]) - else: - return '{}:{}'.format(peer_info[0], peer_info[1]) diff --git a/lib/server_base.py b/lib/server_base.py index e59df48..f6e1de2 100644 --- a/lib/server_base.py +++ b/lib/server_base.py @@ -6,16 +6,15 @@ # and warranty status of this software. import asyncio +import logging import os import signal import sys import time from functools import partial -import lib.util as util - -class ServerBase(util.LoggedClass): +class ServerBase(object): '''Base class server implementation. Derived classes are expected to: @@ -37,7 +36,7 @@ class ServerBase(util.LoggedClass): '''Save the environment, perform basic sanity checks, and set the event loop policy. ''' - super().__init__() + self.logger = logging.getLogger(self.__class__.__name__) self.env = env # Sanity checks diff --git a/lib/util.py b/lib/util.py index f94a1ed..88d06d3 100644 --- a/lib/util.py +++ b/lib/util.py @@ -37,30 +37,11 @@ from collections import Container, Mapping from struct import pack, Struct -class LoggedClass(object): - - def __init__(self): - self.logger = logging.getLogger(self.__class__.__name__) - self.logger.setLevel(logging.INFO) - self.log_prefix = '' - self.throttled = 0 - - def log_info(self, msg, throttle=False): - # Prevent annoying log messages by throttling them if there - # are too many in a short period - if throttle: - self.throttled += 1 - if self.throttled > 3: - return - if self.throttled == 3: - msg += ' (throttling later logs)' - self.logger.info(self.log_prefix + msg) - - def log_warning(self, msg): - self.logger.warning(self.log_prefix + msg) - - def log_error(self, msg): - self.logger.error(self.log_prefix + msg) +class ConnectionLogger(logging.LoggerAdapter): + '''Prepends a connection identifier to a logging message.''' + def process(self, msg, kwargs): + conn_id = self.extra.get('conn_id', 'unknown') + return f'[{conn_id}] {msg}', kwargs # Method decorator. To be used for calculations that will always diff --git a/server/block_processor.py b/server/block_processor.py index ea4265e..cc31b64 100644 --- a/server/block_processor.py +++ b/server/block_processor.py @@ -11,6 +11,7 @@ import array import asyncio +import logging from struct import pack, unpack import time from collections import defaultdict @@ -19,15 +20,15 @@ from functools import partial from server.daemon import DaemonError from server.version import VERSION from lib.hash import hash_to_str -from lib.util import chunks, formatted_time, LoggedClass +from lib.util import chunks, formatted_time import server.db -class Prefetcher(LoggedClass): +class Prefetcher(object): '''Prefetches blocks (in the forward direction only).''' def __init__(self, bp): - super().__init__() + self.logger = logging.getLogger(self.__class__.__name__) self.bp = bp self.caught_up = False # Access to fetched_height should be protected by the semaphore diff --git a/server/controller.py b/server/controller.py index 46b29dd..3d9d278 100644 --- a/server/controller.py +++ b/server/controller.py @@ -6,6 +6,7 @@ # and warranty status of this software. import asyncio +import itertools import json import os import ssl @@ -18,7 +19,7 @@ from functools import partial import pylru -from lib.jsonrpc import JSONSessionBase, RPCError +from aiorpcx import RPCError from lib.hash import double_sha256, hash_to_str, hex_str_to_hash from lib.peer import Peer from lib.server_base import ServerBase @@ -26,7 +27,15 @@ import lib.util as util from server.daemon import DaemonError from server.mempool import MemPool from server.peers import PeerManager -from server.session import LocalRPC +from server.session import LocalRPC, BAD_REQUEST, DAEMON_ERROR + + +class SessionGroup(object): + + def __init__(self, gid): + self.gid = gid + # Concurrency per group + self.semaphore = asyncio.Semaphore(20) class Controller(ServerBase): @@ -36,7 +45,6 @@ class Controller(ServerBase): up with the daemon. ''' - BANDS = 5 CATCHING_UP, LISTENING, PAUSED, SHUTTING_DOWN = range(4) def __init__(self, env): @@ -45,9 +53,9 @@ class Controller(ServerBase): self.coin = env.coin self.servers = {} - # Map of session to the key of its list in self.groups - self.sessions = {} - self.groups = defaultdict(list) + self.sessions = set() + self.groups = set() + self.cur_group = self._new_group(0) self.txs_sent = 0 self.next_log_sessions = 0 self.state = self.CATCHING_UP @@ -62,7 +70,6 @@ class Controller(ServerBase): self.header_cache = pylru.lrucache(8) self.cache_height = 0 env.max_send = max(350000, env.max_send) - self.setup_bands() # Set up the RPC request handlers cmds = ('add_peer daemon_url disconnect getinfo groups log peers reorg ' 'sessions stop'.split()) @@ -128,29 +135,6 @@ class Controller(ServerBase): '''Call when a TX is sent.''' self.txs_sent += 1 - def setup_bands(self): - bands = [] - limit = self.env.bandwidth_limit - for n in range(self.BANDS): - bands.append(limit) - limit //= 4 - limit = self.env.bandwidth_limit - for n in range(self.BANDS): - limit += limit // 2 - bands.append(limit) - self.bands = sorted(bands) - - def session_priority(self, session): - if isinstance(session, LocalRPC): - return 0 - gid = self.sessions[session] - group_bw = sum(session.bw_used for session in self.groups[gid]) - return 1 + (bisect_left(self.bands, session.bw_used) - + bisect_left(self.bands, group_bw)) // 2 - - def is_deprioritized(self, session): - return self.session_priority(session) > self.BANDS - async def run_in_executor(self, func, *args): '''Wait whilst running func in the executor.''' return await self.loop.run_in_executor(None, func, *args) @@ -177,7 +161,7 @@ class Controller(ServerBase): except asyncio.CancelledError: pass except Exception: - self.log_error(traceback.format_exc()) + self.logger.error(traceback.format_exc()) async def housekeeping(self): '''Regular housekeeping checks.''' @@ -185,7 +169,6 @@ class Controller(ServerBase): while True: n += 1 await asyncio.sleep(15) - JSONSessionBase.timeout_check() if n % 10 == 0: self.clear_stale_sessions() @@ -253,7 +236,6 @@ class Controller(ServerBase): .format(self.max_subs)) self.logger.info('max subscriptions per session: {:,d}' .format(self.env.max_session_subs)) - self.logger.info('bands: {}'.format(self.bands)) if self.env.drop_client is not None: self.logger.info('drop clients matching: {}' .format(self.env.drop_client.pattern)) @@ -304,7 +286,7 @@ class Controller(ServerBase): '''Return the binary header at the given height.''' header, n = self.bp.read_headers(height, 1) if n != 1: - raise RPCError('height {:,d} out of range'.format(height)) + raise RPCError(BAD_REQUEST, f'height {height:,d} out of range') return header def electrum_header(self, height): @@ -315,74 +297,54 @@ class Controller(ServerBase): height) return self.header_cache[height] - def session_delay(self, session): - priority = self.session_priority(session) - excess = max(0, priority - self.BANDS) - if excess != session.last_delay: - session.last_delay = excess - if excess: - session.log_info('high bandwidth use, deprioritizing by ' - 'delaying responses {:d}s'.format(excess)) - else: - session.log_info('stopped delaying responses') - return max(int(session.pause), excess) - - async def process_items(self, session): - '''Waits for incoming session items and processes them.''' - while True: - await session.items_event.wait() - await asyncio.sleep(self.session_delay(session)) - if not session.pause: - await session.process_pending_items() + def _new_group(self, gid): + group = SessionGroup(gid) + self.groups.add(group) + return group def add_session(self, session): - session.items_future = self.ensure_future(self.process_items(session)) - gid = int(session.start_time - self.start_time) // 900 - self.groups[gid].append(session) - self.sessions[session] = gid - session.log_info('{} {}, {:,d} total' - .format(session.kind, session.peername(), - len(self.sessions))) + self.sessions.add(session) if (len(self.sessions) >= self.max_sessions and self.state == self.LISTENING): self.state = self.PAUSED - session.log_info('maximum sessions {:,d} reached, stopping new ' - 'connections until count drops to {:,d}' - .format(self.max_sessions, self.low_watermark)) + session.logger.info('maximum sessions {:,d} reached, stopping new ' + 'connections until count drops to {:,d}' + .format(self.max_sessions, self.low_watermark)) self.close_servers(['TCP', 'SSL']) + gid = int(session.start_time - self.start_time) // 900 + if self.cur_group.gid != gid: + self.cur_group = self._new_group(gid) + return self.cur_group def remove_session(self, session): '''Remove a session from our sessions list if there.''' - session.items_future.cancel() - if session in self.sessions: - gid = self.sessions.pop(session) - assert gid in self.groups - self.groups[gid].remove(session) + self.sessions.remove(session) def close_session(self, session): - '''Close the session's transport and cancel its future.''' - session.close_connection() + '''Close the session's transport.''' + session.close() return 'disconnected {:d}'.format(session.session_id) def toggle_logging(self, session): '''Toggle logging of the session.''' - session.log_me = not session.log_me + session.toggle_logging() return 'log {:d}: {}'.format(session.session_id, session.log_me) - def clear_stale_sessions(self, grace=15): - '''Cut off sessions that haven't done anything for 10 minutes. Force - close stubborn connections that won't close cleanly after a - short grace period. - ''' + def _group_map(self): + group_map = defaultdict(list) + for session in self.sessions: + group_map[session.group].append(session) + return group_map + + def clear_stale_sessions(self): + '''Cut off sessions that haven't done anything for 10 minutes.''' now = time.time() - shutdown_cutoff = now - grace stale_cutoff = now - self.env.session_timeout stale = [] for session in self.sessions: if session.is_closing(): - if session.close_time <= shutdown_cutoff: - session.abort() + pass elif session.last_recv < stale_cutoff: self.close_session(session) stale.append(session.session_id) @@ -390,16 +352,18 @@ class Controller(ServerBase): self.logger.info('closing stale connections {}'.format(stale)) # Consolidate small groups - gids = [gid for gid, l in self.groups.items() if len(l) <= 4 - and sum(session.bw_used for session in l) < 10000] - if len(gids) > 1: - sessions = sum([self.groups[gid] for gid in gids], []) - new_gid = max(gids) - for gid in gids: - del self.groups[gid] - for session in sessions: - self.sessions[session] = new_gid - self.groups[new_gid] = sessions + bw_limit = self.env.bandwidth_limit + group_map = self._group_map() + groups = [group for group, sessions in group_map.items() + if len(sessions) <= 5 and + sum(s.bw_charge for s in sessions) < bw_limit] + if len(groups) > 1: + new_gid = max(group.gid for group in groups) + new_group = self._new_group(new_gid) + for group in groups: + self.groups.remove(group) + for session in group_map[group]: + session.group = new_group def session_count(self): '''The number of connections that we've sent something to.''' @@ -412,10 +376,10 @@ class Controller(ServerBase): 'daemon_height': self.daemon.cached_height(), 'db_height': self.bp.db_height, 'closing': len([s for s in self.sessions if s.is_closing()]), - 'errors': sum(s.error_count for s in self.sessions), + 'errors': sum(s.rpc.errors for s in self.sessions), 'groups': len(self.groups), 'logged': len([s for s in self.sessions if s.log_me]), - 'paused': sum(s.pause for s in self.sessions), + 'paused': sum(s.paused for s in self.sessions), 'pid': os.getpid(), 'peers': self.peer_mgr.info(), 'requests': sum(s.count_pending_items() for s in self.sessions), @@ -454,11 +418,11 @@ class Controller(ServerBase): def group_data(self): '''Returned to the RPC 'groups' call.''' result = [] - for gid in sorted(self.groups.keys()): - sessions = self.groups[gid] - result.append([gid, + group_map = self._group_map() + for group, sessions in group_map.items(): + result.append([group.gid, len(sessions), - sum(s.bw_used for s in sessions), + sum(s.bw_charge for s in sessions), sum(s.count_pending_items() for s in sessions), sum(s.txs_sent for s in sessions), sum(s.sub_count() for s in sessions), @@ -531,7 +495,7 @@ class Controller(ServerBase): sessions = sorted(self.sessions, key=lambda s: s.start_time) return [(session.session_id, session.flags(), - session.peername(for_log=for_log), + session.peer_address_str(for_log=for_log), session.client, session.protocol_version, session.count_pending_items(), @@ -555,7 +519,7 @@ class Controller(ServerBase): def for_each_session(self, session_ids, operation): if not isinstance(session_ids, list): - raise RPCError('expected a list of session IDs') + raise RPCError(BAD_REQUEST, 'expected a list of session IDs') result = [] for session_id in session_ids: @@ -597,7 +561,7 @@ class Controller(ServerBase): try: self.daemon.set_urls(self.env.coin.daemon_urls(daemon_url)) except Exception as e: - raise RPCError('an error occured: {}'.format(e)) + raise RPCError(BAD_REQUEST, f'an error occured: {e}') return 'now using daemon at {}'.format(self.daemon.logged_url()) def rpc_stop(self): @@ -628,7 +592,7 @@ class Controller(ServerBase): ''' count = self.non_negative_integer(count) if not self.bp.force_chain_reorg(count): - raise RPCError('still catching up with daemon') + raise RPCError(BAD_REQUEST, 'still catching up with daemon') return 'scheduled a reorg of {:,d} blocks'.format(count) # Helpers for RPC "blockchain" command handlers @@ -638,7 +602,7 @@ class Controller(ServerBase): return self.coin.address_to_hashX(address) except Exception: pass - raise RPCError('{} is not a valid address'.format(address)) + raise RPCError(BAD_REQUEST, f'{address} is not a valid address') def scripthash_to_hashX(self, scripthash): try: @@ -647,7 +611,7 @@ class Controller(ServerBase): return bin_hash[:self.coin.HASHX_LEN] except Exception: pass - raise RPCError('{} is not a valid script hash'.format(scripthash)) + raise RPCError(BAD_REQUEST, f'{scripthash} is not a valid script hash') def assert_tx_hash(self, value): '''Raise an RPCError if the value is not a valid transaction @@ -657,7 +621,7 @@ class Controller(ServerBase): return except Exception: pass - raise RPCError('{} should be a transaction hash'.format(value)) + raise RPCError(BAD_REQUEST, f'{value} should be a transaction hash') def non_negative_integer(self, value): '''Return param value it is or can be converted to a non-negative @@ -668,21 +632,22 @@ class Controller(ServerBase): return value except ValueError: pass - raise RPCError('{} should be a non-negative integer'.format(value)) + raise RPCError(BAD_REQUEST, + f'{value} should be a non-negative integer') async def daemon_request(self, method, *args): '''Catch a DaemonError and convert it to an RPCError.''' try: return await getattr(self.daemon, method)(*args) except DaemonError as e: - raise RPCError('daemon error: {}'.format(e)) + raise RPCError(DAEMON_ERROR, f'daemon error: {e}') def new_subscription(self): if self.subs_room <= 0: self.subs_room = self.max_subs - self.sub_count() if self.subs_room <= 0: - raise RPCError('server subscription limit {:,d} reached' - .format(self.max_subs)) + raise RPCError(BAD_REQUEST, f'server subscription limit ' + f'{self.max_subs:,d} reached') self.subs_room -= 1 async def tx_merkle(self, tx_hash, height): @@ -693,8 +658,8 @@ class Controller(ServerBase): try: pos = tx_hashes.index(tx_hash) except ValueError: - raise RPCError('tx hash {} not in block {} at height {:,d}' - .format(tx_hash, hex_hashes[0], height)) + raise RPCError(BAD_REQUEST, f'tx hash {tx_hash} not in ' + f'block {hex_hashes[0]} at height {height:,d}') idx = pos hashes = [hex_str_to_hash(txh) for txh in tx_hashes] diff --git a/server/daemon.py b/server/daemon.py index bd70a3c..70ccf54 100644 --- a/server/daemon.py +++ b/server/daemon.py @@ -10,24 +10,24 @@ daemon.''' import asyncio import json +import logging import time -import traceback from calendar import timegm from struct import pack from time import strptime import aiohttp -from lib.util import LoggedClass, int_to_varint, hex_to_bytes +from lib.util import int_to_varint, hex_to_bytes from lib.hash import hex_str_to_hash -from lib.jsonrpc import JSONRPC +from aiorpcx import JSONRPC class DaemonError(Exception): '''Raised when the daemon returns an error in its results.''' -class Daemon(LoggedClass): +class Daemon(object): '''Handles connections to a daemon at the given URL.''' WARMING_UP = -28 @@ -37,7 +37,7 @@ class Daemon(LoggedClass): '''Raised when the daemon returns an error in its results.''' def __init__(self, env): - super().__init__() + self.logger = logging.getLogger(self.__class__.__name__) self.coin = env.coin self.set_urls(env.coin.daemon_urls(env.daemon_url)) self._height = None @@ -151,8 +151,8 @@ class Daemon(LoggedClass): log_error('starting up checking blocks.') except (asyncio.CancelledError, DaemonError): raise - except Exception: - self.log_error(traceback.format_exc()) + except Exception as e: + self.logger.exception(f'uncaught exception: {e}') await asyncio.sleep(secs) secs = min(max_secs, secs * 2, 1) diff --git a/server/db.py b/server/db.py index 123dbbd..32bcd49 100644 --- a/server/db.py +++ b/server/db.py @@ -11,6 +11,7 @@ import array import ast +import logging import os from struct import pack, unpack from bisect import bisect_left, bisect_right @@ -25,7 +26,7 @@ from server.version import VERSION, PROTOCOL_MIN, PROTOCOL_MAX UTXO = namedtuple("UTXO", "tx_num tx_pos tx_hash height value") -class DB(util.LoggedClass): +class DB(object): '''Simple wrapper of the backend database for querying. Performs no DB update, though the DB will be cleaned on opening if @@ -41,7 +42,7 @@ class DB(util.LoggedClass): '''Raised on general DB errors generally indicating corruption.''' def __init__(self, env): - super().__init__() + self.logger = logging.getLogger(self.__class__.__name__) self.env = env self.coin = env.coin @@ -581,9 +582,10 @@ class DB(util.LoggedClass): full_hist = b''.join(hist_list) nrows = (len(full_hist) + max_row_size - 1) // max_row_size if nrows > 4: - self.log_info('hashX {} is large: {:,d} entries across {:,d} rows' - .format(hash_to_str(hashX), len(full_hist) // 4, - nrows)) + self.logger.info('hashX {} is large: {:,d} entries across ' + '{:,d} rows' + .format(hash_to_str(hashX), len(full_hist) // 4, + nrows)) # Find what history needs to be written, and what keys need to # be deleted. Start by assuming all keys are to be deleted, @@ -652,11 +654,11 @@ class DB(util.LoggedClass): max_rows = self.comp_flush_count + 1 self._flush_compaction(cursor, write_items, keys_to_delete) - self.log_info('history compaction: wrote {:,d} rows ({:.1f} MB), ' - 'removed {:,d} rows, largest: {:,d}, {:.1f}% complete' - .format(len(write_items), write_size / 1000000, - len(keys_to_delete), max_rows, - 100 * cursor / 65536)) + self.logger.info('history compaction: wrote {:,d} rows ({:.1f} MB), ' + 'removed {:,d} rows, largest: {:,d}, {:.1f}% complete' + .format(len(write_items), write_size / 1000000, + len(keys_to_delete), max_rows, + 100 * cursor / 65536)) return write_size async def compact_history(self, loop): @@ -673,7 +675,7 @@ class DB(util.LoggedClass): while self.comp_cursor != -1: if self.semaphore.locked: - self.log_info('compact_history: waiting on semaphore...') + self.logger.info('compact_history: waiting on semaphore...') async with self.semaphore: await loop.run_in_executor(None, self._compact_history, limit) diff --git a/server/env.py b/server/env.py index cd7966d..1cc7e76 100644 --- a/server/env.py +++ b/server/env.py @@ -86,9 +86,9 @@ class Env(EnvBase): # We give the DB 250 files; allow ElectrumX 100 for itself value = max(0, min(env_value, nofile_limit - 350)) if value < env_value: - self.log_warning('lowered maximum sessions from {:,d} to {:,d} ' - 'because your open file limit is {:,d}' - .format(env_value, value, nofile_limit)) + self.logger.warning('lowered maximum sessions from {:,d} to {:,d} ' + 'because your open file limit is {:,d}' + .format(env_value, value, nofile_limit)) return value def clearnet_identity(self): diff --git a/server/mempool.py b/server/mempool.py index 4236f06..535b683 100644 --- a/server/mempool.py +++ b/server/mempool.py @@ -9,16 +9,16 @@ import asyncio import itertools +import logging import time from collections import defaultdict from lib.hash import hash_to_str, hex_str_to_hash -import lib.util as util from server.daemon import DaemonError from server.db import UTXO -class MemPool(util.LoggedClass): +class MemPool(object): '''Representation of the daemon's mempool. Updated regularly in caught-up state. Goal is to enable efficient @@ -33,7 +33,7 @@ class MemPool(util.LoggedClass): ''' def __init__(self, bp, controller): - super().__init__() + self.logger = logging.getLogger(self.__class__.__name__) self.daemon = bp.daemon self.controller = controller self.coin = bp.coin diff --git a/server/peers.py b/server/peers.py index 47f54c6..5af6866 100644 --- a/server/peers.py +++ b/server/peers.py @@ -8,6 +8,7 @@ '''Peer management.''' import asyncio +import logging import random import socket import ssl @@ -18,7 +19,7 @@ from functools import partial import aiorpcx from lib.peer import Peer -import lib.util as util +from lib.util import ConnectionLogger import server.version as version @@ -27,16 +28,17 @@ STALE_SECS = 24 * 3600 WAKEUP_SECS = 300 -class PeerSession(aiorpcx.ClientSession, util.LoggedClass): +class PeerSession(aiorpcx.ClientSession): '''An outgoing session to a peer.''' def __init__(self, peer, peer_mgr, kind, host, port, **kwargs): super().__init__(host, port, **kwargs) - util.LoggedClass.__init__(self) self.peer = peer self.peer_mgr = peer_mgr self.kind = kind self.timeout = 20 if self.peer.is_tor else 10 + context = {'conn_id': f'{host}'} + self.logger = ConnectionLogger(self.logger, context) def connection_made(self, transport): '''Handle an incoming client connection.''' @@ -53,6 +55,15 @@ class PeerSession(aiorpcx.ClientSession, util.LoggedClass): self.send_request('server.version', args, self.on_version, timeout=self.timeout) + def _header_notification(self, header): + pass + + def notification_handler(self, method): + # We subscribe so might be unlucky enough to get a notification... + if method == 'blockchain.headers.subscribe': + return self._header_notification + return None + def is_good(self, request, instance): try: result = request.result() @@ -73,12 +84,12 @@ class PeerSession(aiorpcx.ClientSession, util.LoggedClass): return False def fail(self, request, reason): - self.logger.error(f'[{self.peer.host}] {request} failed: {reason}') + self.logger.error(f'{request} failed: {reason}') self.peer_mgr.set_verification_status(self.peer, self.kind, False) self.close() def bad(self, reason): - self.logger.error(f'[{self.peer.host}] marking bad: {reason}') + self.logger.error(f'marking bad: {reason}') self.peer.mark_bad() self.peer_mgr.set_verification_status(self.peer, self.kind, False) self.close() @@ -180,8 +191,7 @@ class PeerSession(aiorpcx.ClientSession, util.LoggedClass): features = self.peer_mgr.features_to_register(self.peer, peers) if features: - self.logger.info(f'[{self.peer.host}] registering ourself with ' - '"server.add_peer"') + self.logger.info(f'registering ourself with "server.add_peer"') self.send_request('server.add_peer', [features], self.on_add_peer, timeout=self.timeout) else: @@ -201,14 +211,14 @@ class PeerSession(aiorpcx.ClientSession, util.LoggedClass): self.peer_mgr.set_verification_status(self.peer, self.kind, True) -class PeerManager(util.LoggedClass): +class PeerManager(object): '''Looks after the DB of peer network servers. Attempts to maintain a connection with up to 8 peers. Issues a 'peers.subscribe' RPC to them and tells them our data. ''' def __init__(self, env, controller): - super().__init__() + self.logger = logging.getLogger(self.__class__.__name__) # Initialise the Peer class Peer.DEFAULT_PORTS = env.coin.PEER_DEFAULT_PORTS self.env = env @@ -336,12 +346,12 @@ class PeerManager(util.LoggedClass): async def on_add_peer(self, features, source_info): '''Add a peer (but only if the peer resolves to the source).''' if not source_info: - self.log_info('ignored add_peer request: no source info') + self.logger.info('ignored add_peer request: no source info') return False source = source_info[0] peers = Peer.peers_from_features(features, source) if not peers: - self.log_info('ignored add_peer request: no peers given') + self.logger.info('ignored add_peer request: no peers given') return False # Just look at the first peer, require it @@ -362,12 +372,12 @@ class PeerManager(util.LoggedClass): reason = 'source-destination mismatch' if permit: - self.log_info('accepted add_peer request from {} for {}' - .format(source, host)) + self.logger.info('accepted add_peer request from {} for {}' + .format(source, host)) self.add_peers([peer], check_ports=True) else: - self.log_warning('rejected add_peer request from {} for {} ({})' - .format(source, host, reason)) + self.logger.warning('rejected add_peer request from {} for {} ({})' + .format(source, host, reason)) return permit @@ -438,13 +448,13 @@ class PeerManager(util.LoggedClass): ports = [9050, 9150, 1080] else: ports = [self.env.tor_proxy_port] - self.log_info(f'trying to detect proxy on "{host}" ports {ports}') + self.logger.info(f'trying to detect proxy on "{host}" ports {ports}') cls = aiorpcx.SOCKSProxy result = await cls.auto_detect_host(host, ports, None, loop=self.loop) if isinstance(result, cls): self.proxy = result - self.log_info(f'detected {self.proxy}') + self.logger.info(f'detected {self.proxy}') def proxy_peername(self): '''Return the peername of the proxy, if there is a proxy, otherwise @@ -544,10 +554,10 @@ class PeerManager(util.LoggedClass): if exception: session.close() kind, port = port_pairs.pop(0) - self.log_info('failed connecting to {} at {} port {:d} ' - 'in {:.1f}s: {}' - .format(peer, kind, port, - time.time() - peer.last_try, exception)) + self.logger.info('failed connecting to {} at {} port {:d} ' + 'in {:.1f}s: {}' + .format(peer, kind, port, + time.time() - peer.last_try, exception)) if port_pairs: self.retry_peer(peer, port_pairs) else: @@ -562,7 +572,7 @@ class PeerManager(util.LoggedClass): how = 'via {} at {}'.format(kind, peer.ip_addr) status = 'verified' if good else 'failed to verify' elapsed = now - peer.last_try - self.log_info('{} {} {} in {:.1f}s'.format(status, peer, how, elapsed)) + self.logger.info(f'{status} {peer} {how} in {elapsed:.1f}s') if good: peer.try_count = 0 diff --git a/server/session.py b/server/session.py index 18ff5aa..87567d1 100644 --- a/server/session.py +++ b/server/session.py @@ -8,17 +8,38 @@ '''Classes for local RPC server and remote client TCP/SSL servers.''' import codecs +import itertools import time from functools import partial +from aiorpcx import ServerSession, JSONRPCAutoDetect, RPCError + from lib.hash import sha256, hash_to_str -from lib.jsonrpc import JSONSession, RPCError, JSONRPCv2, JSONRPC import lib.util as util from server.daemon import DaemonError import server.version as version +BAD_REQUEST = 1 +DAEMON_ERROR = 2 + + +class Semaphores(object): + + def __init__(self, semaphores): + self.semaphores = semaphores + self.acquired = [] + + async def __aenter__(self): + for semaphore in self.semaphores: + await semaphore.acquire() + self.acquired.append(semaphore) + + async def __aexit__(self, exc_type, exc_value, traceback): + for semaphore in self.acquired: + semaphore.release() + -class SessionBase(JSONSession): +class SessionBase(ServerSession): '''Base class of ElectrumX JSON sessions. Each session runs its tasks in asynchronous parallelism with other @@ -26,11 +47,10 @@ class SessionBase(JSONSession): ''' MAX_CHUNK_SIZE = 2016 + session_counter = itertools.count() def __init__(self, controller, kind): - # Force v2 as a temporary hack for old Coinomi wallets - # Remove in April 2017 - super().__init__(version=JSONRPCv2) + super().__init__(rpc_protocol=JSONRPCAutoDetect) self.kind = kind # 'RPC', 'TCP' etc. self.controller = controller self.bp = controller.bp @@ -39,24 +59,28 @@ class SessionBase(JSONSession): self.client = 'unknown' self.client_version = (1, ) self.anon_logs = self.env.anon_logs - self.last_delay = 0 self.txs_sent = 0 - self.requests = [] - self.start_time = time.time() - self.close_time = 0 + self.log_me = False self.bw_limit = self.env.bandwidth_limit - self.bw_time = self.start_time - self.bw_interval = 3600 - self.bw_used = 0 + self._orig_mr = self.rpc.message_received - def close_connection(self): - '''Call this to close the connection.''' - self.close_time = time.time() - super().close_connection() + def peer_address_str(self, *, for_log=True): + '''Returns the peer's IP address and port as a human-readable + string, respecting anon logs if the output is for a log.''' + if for_log and self.anon_logs: + return 'xx.xx.xx.xx:xx' + return super().peer_address_str() - def peername(self, *, for_log=True): - '''Return the peer address and port.''' - return self.peer_addr(anon=for_log and self.anon_logs) + def message_received(self, message): + self.logger.info(f'processing {message}') + self._orig_mr(message) + + def toggle_logging(self): + self.log_me = not self.log_me + if self.log_me: + self.rpc.message_received = self.message_received + else: + self.rpc.message_received = self._orig_mr def flags(self): '''Status flags.''' @@ -65,38 +89,42 @@ class SessionBase(JSONSession): status += 'C' if self.log_me: status += 'L' - status += str(self.controller.session_priority(self)) + status += str(self.concurrency.max_concurrent) return status def connection_made(self, transport): '''Handle an incoming client connection.''' super().connection_made(transport) - self.controller.add_session(self) + self.session_id = next(self.session_counter) + context = {'conn_id': f'{self.session_id}'} + self.logger = util.ConnectionLogger(self.logger, context) + self.rpc.logger = self.logger + self.group = self.controller.add_session(self) + self.logger.info(f'{self.kind} {self.peer_address_str()}, ' + f'{len(self.controller.sessions):,d} total') def connection_lost(self, exc): '''Handle client disconnection.''' super().connection_lost(exc) + self.controller.remove_session(self) msg = '' - if self.pause: + if self.paused: msg += ' whilst paused' - if self.controller.is_deprioritized(self): - msg += ' whilst deprioritized' + if self.concurrency.max_concurrent != self.max_concurrent: + msg += ' whilst throttled' if self.send_size >= 1024*1024: msg += ('. Sent {:,d} bytes in {:,d} messages' .format(self.send_size, self.send_count)) if msg: msg = 'disconnected' + msg - self.log_info(msg) - self.controller.remove_session(self) + self.logger.info(msg) + self.group = None + + def count_pending_items(self): + return self.rpc.pending_requests - def using_bandwidth(self, amount): - now = time.time() - # Reduce the recorded usage in proportion to the elapsed time - elapsed = now - self.bw_time - self.bandwidth_start = now - refund = int(elapsed / self.bw_interval * self.bw_limit) - refund = min(refund, self.bw_used) - self.bw_used += amount - refund + def semaphore(self): + return Semaphores([self.concurrency.semaphore, self.group.semaphore]) def sub_count(self): return 0 @@ -111,7 +139,7 @@ class ElectrumX(SessionBase): self.subscribe_headers_raw = False self.subscribe_height = False self.notified_height = None - self.max_send = self.env.max_send + self.max_response_size = self.env.max_send self.max_subs = self.env.max_session_subs self.hashX_subs = {} self.mempool_statuses = {} @@ -148,8 +176,8 @@ class ElectrumX(SessionBase): if changed: es = '' if len(changed) == 1 else 'es' - self.log_info('notified of {:,d} address{}' - .format(len(changed), es)) + self.logger.info('notified of {:,d} address{}' + .format(len(changed), es)) def notify(self, height, touched): '''Notify the client about changes to touched addresses (from mempool @@ -185,7 +213,7 @@ class ElectrumX(SessionBase): '''Return param value it is boolean otherwise raise an RPCError.''' if value in (False, True): return value - raise RPCError('{} should be a boolean value'.format(value)) + raise RPCError(BAD_REQUEST, f'{value} should be a boolean value') def subscribe_headers_result(self, height): '''The result of a header subscription for the given height.''' @@ -209,7 +237,7 @@ class ElectrumX(SessionBase): async def add_peer(self, features): '''Add a peer (but only if the peer resolves to the source).''' peer_mgr = self.controller.peer_mgr - return await peer_mgr.on_add_peer(features, self.peer_info()) + return await peer_mgr.on_add_peer(features, self.peer_address()) def peers_subscribe(self): '''Return the server peers as a list of (ip, host, details) tuples.''' @@ -244,8 +272,8 @@ class ElectrumX(SessionBase): async def hashX_subscribe(self, hashX, alias): # First check our limit. if len(self.hashX_subs) >= self.max_subs: - raise RPCError('your address subscription limit {:,d} reached' - .format(self.max_subs)) + raise RPCError(BAD_REQUEST, 'your address subscription limit ' + f'{self.max_subs:,d} reached') # Now let the controller check its limit self.controller.new_subscription() @@ -299,8 +327,8 @@ class ElectrumX(SessionBase): peername = self.controller.peer_mgr.proxy_peername() if not peername: return False - peer_info = self.peer_info() - return peer_info and peer_info[0] == peername[0] + peer_address = self.peer_address() + return peer_address and peer_address[0] == peername[0] async def replaced_banner(self, banner): network_info = await self.controller.daemon_request('getnetworkinfo') @@ -338,8 +366,7 @@ class ElectrumX(SessionBase): with codecs.open(banner_file, 'r', 'utf-8') as f: banner = f.read() except Exception as e: - self.log_error('reading banner file {}: {}' - .format(banner_file, e)) + self.loggererror(f'reading banner file {banner_file}: {e}') else: banner = await self.replaced_banner(banner) @@ -360,8 +387,9 @@ class ElectrumX(SessionBase): if client_name: if self.env.drop_client is not None and \ self.env.drop_client.match(client_name): - raise RPCError('unsupported client: {}' - .format(client_name), JSONRPC.FATAL_ERROR) + self.close_after_send = True + raise RPCError(BAD_REQUEST, + f'unsupported client: {client_name}') self.client = str(client_name)[:17] try: self.client_version = tuple(int(part) for part @@ -376,10 +404,11 @@ class ElectrumX(SessionBase): # From protocol version 1.1, protocol_version cannot be omitted if ptuple is None or (ptuple >= (1, 1) and protocol_version is None): - self.log_info('unsupported protocol version request {}' - .format(protocol_version)) - raise RPCError('unsupported protocol version: {}' - .format(protocol_version), JSONRPC.FATAL_ERROR) + self.logger.info('unsupported protocol version request {}' + .format(protocol_version)) + self.close_after_send = True + raise RPCError(BAD_REQUEST, + f'unsupported protocol version: {protocol_version}') self.set_protocol_handlers(ptuple) @@ -397,16 +426,15 @@ class ElectrumX(SessionBase): try: tx_hash = await self.daemon.sendrawtransaction([raw_tx]) self.txs_sent += 1 - self.log_info('sent tx: {}'.format(tx_hash)) + self.logger.info('sent tx: {}'.format(tx_hash)) self.controller.sent_tx(tx_hash) return tx_hash except DaemonError as e: error, = e.args message = error['message'] - self.log_info('sendrawtransaction: {}'.format(message), - throttle=True) - raise RPCError('the transaction was rejected by network rules.' - '\n\n{}\n[{}]'.format(message, raw_tx)) + self.logger.info('sendrawtransaction: {}'.format(message)) + raise RPCError(BAD_REQUEST, 'the transaction was rejected by ' + f'network rules.\n\n{message}\n[{raw_tx}]') async def transaction_broadcast_1_0(self, raw_tx): '''Broadcast a raw transaction to the network. @@ -505,7 +533,7 @@ class LocalRPC(SessionBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.client = 'RPC' - self.max_send = 0 + self.max_response_size = 0 self.protocol_version = 'RPC' def request_handler(self, method): @@ -535,13 +563,8 @@ class DashElectrumX(ElectrumX): for masternode in self.mns: status = self.daemon.masternode_list(['status', masternode]) - payload = { - 'id': None, - 'method': 'masternode.subscribe', - 'params': [masternode], - 'result': status.get(masternode), - } - self.send_binary(self.encode_payload(payload)) + self.send_notification('masternode.subscribe', + [masternode, status.get(masternode)]) return result # Masternode command handlers @@ -553,9 +576,9 @@ class DashElectrumX(ElectrumX): except DaemonError as e: error, = e.args message = error['message'] - self.log_info('masternode_broadcast: {}'.format(message)) - raise RPCError('the masternode broadcast was rejected.' - '\n\n{}\n[{}]'.format(message, signmnb)) + self.logger.info('masternode_broadcast: {}'.format(message)) + raise RPCError(BAD_REQUEST, 'the masternode broadcast was ' + f'rejected.\n\n{message}\n[{signmnb}]') async def masternode_announce_broadcast_1_0(self, signmnb): '''Pass through the masternode announce message to be broadcast diff --git a/server/version.py b/server/version.py index 998f88a..aa201b5 100644 --- a/server/version.py +++ b/server/version.py @@ -1,5 +1,5 @@ # Server name and protocol versions -VERSION = 'ElectrumX 1.3.1' +VERSION = 'ElectrumX 1.3.1a' PROTOCOL_MIN = '0.9' PROTOCOL_MAX = '1.2' diff --git a/setup.py b/setup.py index 5aaad99..73a0e37 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( # "x11_hash" package (1.4) is required to sync DASH network. # "tribus_hash" package is required to sync Denarius network. # "blake256" package is required to sync Decred network. - install_requires=['aiorpcX >= 0.4.4', 'plyvel', 'pylru', 'aiohttp >= 1'], + install_requires=['aiorpcX >= 0.5.2', 'plyvel', 'pylru', 'aiohttp >= 1'], packages=setuptools.find_packages(exclude=['tests']), description='ElectrumX Server', author='Neil Booth', diff --git a/tests/lib/test_util.py b/tests/lib/test_util.py index 5a3e9e5..d9cb517 100644 --- a/tests/lib/test_util.py +++ b/tests/lib/test_util.py @@ -5,51 +5,6 @@ import pytest from lib import util -class LoggedClassTest(util.LoggedClass): - - def __init__(self): - super().__init__() - self.logger.info = self.note_info - self.logger.warning = self.note_warning - self.logger.error = self.note_error - - def note_info(self, msg): - self.info_msg = msg - - def note_warning(self, msg): - self.warning_msg = msg - - def note_error(self, msg): - self.error_msg = msg - - -def test_LoggedClass(): - test = LoggedClassTest() - assert test.log_prefix == '' - test.log_prefix = 'prefix' - test.log_error('an error') - assert test.error_msg == 'prefixan error' - test.log_warning('a warning') - assert test.warning_msg == 'prefixa warning' - test.log_info('some info') - assert test.info_msg == 'prefixsome info' - - assert test.throttled == 0 - test.log_info('some info', throttle=True) - assert test.throttled == 1 - assert test.info_msg == 'prefixsome info' - test.log_info('some info', throttle=True) - assert test.throttled == 2 - assert test.info_msg == 'prefixsome info' - test.log_info('some info', throttle=True) - assert test.throttled == 3 - assert test.info_msg == 'prefixsome info (throttling later logs)' - test.info_msg = '' - test.log_info('some info', throttle=True) - assert test.throttled == 4 - assert test.info_msg == '' - - def test_cachedproperty(): class Target: diff --git a/tests/server/test_api.py b/tests/server/test_api.py index 4d52629..a6ec461 100644 --- a/tests/server/test_api.py +++ b/tests/server/test_api.py @@ -1,7 +1,7 @@ import asyncio from unittest import mock -from lib.jsonrpc import RPCError +from aiorpcx import RPCError from server.env import Env from server.controller import Controller @@ -27,8 +27,8 @@ async def coro(res): return res -def raise_exception(exc, msg): - raise exc(msg) +def raise_exception(msg): + raise RPCError(1, msg) def ensure_text_exception(test, exception): @@ -82,7 +82,8 @@ def test_transaction_get(): env = set_env() sut = Controller(env) sut.daemon_request = mock.Mock() - sut.daemon_request.return_value = coro(raise_exception(RPCError, 'some unhandled error')) + sut.daemon_request.return_value = coro( + raise_exception('some unhandled error')) await sut.transaction_get('ff' * 32, True) async def test_wrong_txhash():