Browse Source

Fully integrate aiorpcX

patch-2
Neil Booth 7 years ago
parent
commit
bc6093a8fe
  1. 7
      lib/env_base.py
  2. 806
      lib/jsonrpc.py
  3. 7
      lib/server_base.py
  4. 29
      lib/util.py
  5. 7
      server/block_processor.py
  6. 181
      server/controller.py
  7. 14
      server/daemon.py
  8. 24
      server/db.py
  9. 6
      server/env.py
  10. 6
      server/mempool.py
  11. 54
      server/peers.py
  12. 157
      server/session.py
  13. 2
      server/version.py
  14. 2
      setup.py
  15. 45
      tests/lib/test_util.py
  16. 9
      tests/server/test_api.py

7
lib/env_base.py

@ -8,19 +8,18 @@
'''Class for server environment configuration and defaults.'''
import logging
from os import environ
import lib.util as lib_util
class EnvBase(lib_util.LoggedClass):
class EnvBase(object):
'''Wraps environment configuration.'''
class Error(Exception):
pass
def __init__(self):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.allow_root = self.boolean('ALLOW_ROOT', False)
self.host = self.default('HOST', 'localhost')
self.rpc_host = self.default('RPC_HOST', 'localhost')

806
lib/jsonrpc.py

@ -1,806 +0,0 @@
# Copyright (c) 2016-2017, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''Classes for acting as a peer over a transport and speaking the JSON
RPC versions 1.0 and 2.0.
JSONSessionBase can use an arbitrary transport.
JSONSession integrates asyncio.Protocol to provide the transport.
'''
import asyncio
import collections
import inspect
import json
import numbers
import time
import traceback
import lib.util as util
class RPCError(Exception):
'''RPC handlers raise this error.'''
def __init__(self, msg, code=-1, **kw_args):
super().__init__(**kw_args)
self.msg = msg
self.code = code
class JSONRPC(object):
'''Base class of JSON RPC versions.'''
# See http://www.jsonrpc.org/specification
PARSE_ERROR = -32700
INVALID_REQUEST = -32600
METHOD_NOT_FOUND = -32601
INVALID_ARGS = -32602
INTERNAL_ERROR = -32603
# Codes for this library
INVALID_RESPONSE = -100
ERROR_CODE_UNAVAILABLE = -101
REQUEST_TIMEOUT = -102
FATAL_ERROR = -103
ID_TYPES = (type(None), str, numbers.Number)
HAS_BATCHES = False
@classmethod
def canonical_error(cls, error):
'''Convert an error to a JSON RPC 2.0 error.
Handlers then only have a single form of error to deal with.
'''
if isinstance(error, int):
error = {'code': error}
elif isinstance(error, str):
error = {'message': error}
elif not isinstance(error, dict):
error = {'data': error}
error['code'] = error.get('code', JSONRPC.ERROR_CODE_UNAVAILABLE)
error['message'] = error.get('message', 'error message unavailable')
return error
@classmethod
def timeout_error(cls):
return {'message': 'request timed out',
'code': JSONRPC.REQUEST_TIMEOUT}
class JSONRPCv1(JSONRPC):
'''JSON RPC version 1.0.'''
@classmethod
def request_payload(cls, id_, method, params=None):
'''JSON v1 request payload. Params is mandatory.'''
return {'method': method, 'params': params or [], 'id': id_}
@classmethod
def notification_payload(cls, method, params=None):
'''JSON v1 notification payload. Params and id are mandatory.'''
return {'method': method, 'params': params or [], 'id': None}
@classmethod
def response_payload(cls, result, id_):
'''JSON v1 response payload. error is present and None.'''
return {'id': id_, 'result': result, 'error': None}
@classmethod
def error_payload(cls, message, code, id_):
'''JSON v1 error payload. result is present and None.'''
return {'id': id_, 'result': None,
'error': {'message': message, 'code': code}}
@classmethod
def handle_response(cls, handler, payload):
'''JSON v1 response handler. Both 'error' and 'result'
should exist with exactly one being None.
Unfortunately many 1.0 clients behave like 2.0, and just send
one or the other.
'''
error = payload.get('error')
if error is None:
handler(payload.get('result'), None)
else:
handler(None, cls.canonical_error(error))
@classmethod
def is_request(cls, payload):
'''Returns True if the payload (which has a method) is a request.
False means it is a notification.'''
return payload.get('id') is not None
class JSONRPCv2(JSONRPC):
'''JSON RPC version 2.0.'''
HAS_BATCHES = True
@classmethod
def request_payload(cls, id_, method, params=None):
'''JSON v2 request payload. Params is optional.'''
payload = {'jsonrpc': '2.0', 'method': method, 'id': id_}
if params:
payload['params'] = params
return payload
@classmethod
def notification_payload(cls, method, params=None):
'''JSON v2 notification payload. There must be no id.'''
payload = {'jsonrpc': '2.0', 'method': method}
if params:
payload['params'] = params
return payload
@classmethod
def response_payload(cls, result, id_):
'''JSON v2 response payload. error is not present.'''
return {'jsonrpc': '2.0', 'id': id_, 'result': result}
@classmethod
def error_payload(cls, message, code, id_):
'''JSON v2 error payload. result is not present.'''
return {'jsonrpc': '2.0', 'id': id_,
'error': {'message': message, 'code': code}}
@classmethod
def handle_response(cls, handler, payload):
'''JSON v2 response handler. Exactly one of 'error' and 'result'
must exist. Errors must have 'code' and 'message' members.
'''
if 'error' in payload:
handler(None, cls.canonical_error(payload['error']))
elif 'result' in payload:
handler(payload['result'], None)
else:
error = {'message': 'no error or result returned',
'code': JSONRPC.INVALID_RESPONSE}
handler(None, cls.canonical_error(error))
@classmethod
def batch_size(cls, parts):
'''Return the size of a JSON batch from its parts.'''
return sum(len(part) for part in parts) + 2 * len(parts)
@classmethod
def batch_bytes(cls, parts):
'''Return the bytes of a JSON batch from its parts.'''
if parts:
return b'[' + b', '.join(parts) + b']'
return b''
@classmethod
def is_request(cls, payload):
'''Returns True if the payload (which has a method) is a request.
False means it is a notification.'''
return 'id' in payload
class JSONRPCCompat(JSONRPC):
'''Intended to be used until receiving a response from the peer, at
which point detect_version should be used to choose which version
to use.
Sends requests compatible with v1 and v2. Errors cannot be
compatible so v2 errors are sent.
Does not send responses or notifications, nor handle responses.
'''
@classmethod
def request_payload(cls, id_, method, params=None):
'''JSON v2 request payload but with params present.'''
return {'jsonrpc': '2.0', 'id': id_,
'method': method, 'params': params or []}
@classmethod
def error_payload(cls, message, code, id_):
'''JSON v2 error payload. result is not present.'''
return {'jsonrpc': '2.0', 'id': id_,
'error': {'message': message, 'code': code}}
@classmethod
def detect_version(cls, payload):
'''Return a best guess at a version compatible with the received
payload.
Return None if one cannot be determined.
'''
def item_version(item):
if isinstance(item, dict):
version = item.get('jsonrpc')
if version is None:
return JSONRPCv1
if version == '2.0':
return JSONRPCv2
return None
if isinstance(payload, list) and payload:
version = item_version(payload[0])
# If a batch return at least JSONRPCv2
if version in (JSONRPCv1, None):
version = JSONRPCv2
else:
version = item_version(payload)
return version
class JSONSessionBase(util.LoggedClass):
'''Acts as the application layer session, communicating via JSON RPC
over an underlying transport.
Processes incoming and sends outgoing requests, notifications and
responses. Incoming messages are queued. When the queue goes
from empty
'''
_next_session_id = 0
_pending_reqs = {} # Outgoing requests waiting for a response
@classmethod
def next_session_id(cls):
'''Return the next unique session ID.'''
session_id = cls._next_session_id
cls._next_session_id += 1
return session_id
def _pending_request_keys(self):
'''Return a generator of pending request keys for this session.'''
return [key for key in self._pending_reqs if key[0] is self]
def has_pending_requests(self):
'''Return True if this session has pending requests.'''
return bool(self._pending_request_keys())
def pop_response_handler(self, msg_id):
'''Return the response handler for the given message ID.'''
return self._pending_reqs.pop((self, msg_id), (None, None))[0]
def timeout_session(self):
'''Trigger timeouts for all of the session's pending requests.'''
self._timeout_requests(self._pending_request_keys())
@classmethod
def timeout_check(cls):
'''Trigger timeouts where necessary for all pending requests.'''
now = time.time()
keys = [key for key, value in cls._pending_reqs.items()
if value[1] < now]
cls._timeout_requests(keys)
@classmethod
def _timeout_requests(cls, keys):
'''Trigger timeouts for the given lookup keys.'''
values = [cls._pending_reqs.pop(key) for key in keys]
handlers = [handler for handler, timeout in values]
timeout_error = JSONRPC.timeout_error()
for handler in handlers:
handler(None, timeout_error)
def __init__(self, version=JSONRPCCompat):
super().__init__()
# Parts of an incomplete JSON line. We buffer them until
# getting a newline.
self.parts = []
self.version = version
self.log_me = False
self.session_id = None
# Count of incoming complete JSON requests and the time of the
# last one. A batch counts as just one here.
self.last_recv = time.time()
self.send_count = 0
self.send_size = 0
self.recv_size = 0
self.recv_count = 0
self.error_count = 0
self.pause = False
# Handling of incoming items
self.items = collections.deque()
self.items_event = asyncio.Event()
self.batch_results = []
# Handling of outgoing requests
self.next_request_id = 0
# If buffered incoming data exceeds this the connection is closed
self.max_buffer_size = 1000000
self.max_send = 50000
self.close_after_send = False
def pause_writing(self):
'''Transport calls when the send buffer is full.'''
self.log_info('pausing processing whilst socket drains')
self.pause = True
def resume_writing(self):
'''Transport calls when the send buffer has room.'''
self.log_info('resuming processing')
self.pause = False
def is_oversized(self, length, id_):
'''Return an error payload if the given outgoing message size is too
large, or False if not.
'''
if self.max_send and length > max(1000, self.max_send):
msg = 'response too large (at least {:d} bytes)'.format(length)
return self.error_bytes(msg, JSONRPC.INVALID_REQUEST, id_)
return False
def send_binary(self, binary):
'''Pass the bytes through to the transport.
Close the connection if close_after_send is set.
'''
if self.is_closing():
return
self.using_bandwidth(len(binary))
self.send_count += 1
self.send_size += len(binary)
self.send_bytes(binary)
if self.close_after_send:
self.close_connection()
def payload_id(self, payload):
'''Extract and return the ID from the payload.
Returns None if it is missing or invalid.'''
try:
return self.check_payload_id(payload)
except RPCError:
return None
def check_payload_id(self, payload):
'''Extract and return the ID from the payload.
Raises an RPCError if it is missing or invalid.'''
if 'id' not in payload:
raise RPCError('missing id', JSONRPC.INVALID_REQUEST)
id_ = payload['id']
if not isinstance(id_, self.version.ID_TYPES):
raise RPCError('invalid id type {}'.format(type(id_)),
JSONRPC.INVALID_REQUEST)
return id_
def request_bytes(self, id_, method, params=None):
'''Return the bytes of a JSON request.'''
payload = self.version.request_payload(id_, method, params)
return self.encode_payload(payload)
def notification_bytes(self, method, params=None):
payload = self.version.notification_payload(method, params)
return self.encode_payload(payload)
def response_bytes(self, result, id_):
'''Return the bytes of a JSON response.'''
return self.encode_payload(self.version.response_payload(result, id_))
def error_bytes(self, message, code, id_=None):
'''Return the bytes of a JSON error.
Flag the connection to close on a fatal error or too many errors.'''
version = self.version
self.error_count += 1
if not self.close_after_send:
fatal_log = None
if code in (version.PARSE_ERROR, version.INVALID_REQUEST,
version.FATAL_ERROR):
fatal_log = message
elif self.error_count >= 10:
fatal_log = 'too many errors, last: {}'.format(message)
if fatal_log:
self.log_info(fatal_log)
self.close_after_send = True
return self.encode_payload(self.version.error_payload
(message, code, id_))
def encode_payload(self, payload):
'''Encode a Python object as binary bytes.'''
assert isinstance(payload, dict)
id_ = payload.get('id')
try:
binary = json.dumps(payload).encode()
except TypeError:
msg = 'JSON encoding failure: {}'.format(payload)
self.log_error(msg)
binary = self.error_bytes(msg, JSONRPC.INTERNAL_ERROR, id_)
error_bytes = self.is_oversized(len(binary), id_)
return error_bytes or binary
def decode_message(self, payload):
'''Decode a binary message and pass it on to process_single_item or
process_batch as appropriate.
Messages that cannot be decoded are logged and dropped.
'''
try:
payload = payload.decode()
except UnicodeDecodeError as e:
msg = 'cannot decode message: {}'.format(e)
self.send_error(msg, JSONRPC.PARSE_ERROR)
return
try:
payload = json.loads(payload)
except json.JSONDecodeError as e:
msg = 'cannot decode JSON: {}'.format(e)
self.send_error(msg, JSONRPC.PARSE_ERROR)
return
if self.version is JSONRPCCompat:
# Attempt to detect peer's JSON RPC version
version = self.version.detect_version(payload)
if not version:
version = JSONRPCv2
self.log_info('unable to detect JSON RPC version, using 2.0')
self.version = version
# Batches must have at least one object.
if payload == [] and self.version.HAS_BATCHES:
self.send_error('empty batch', JSONRPC.INVALID_REQUEST)
return
self.items.append(payload)
self.items_event.set()
async def process_batch(self, batch, count):
'''Processes count items from the batch according to the JSON 2.0
spec.
If any remain, puts what is left of the batch back in the deque
and returns None. Otherwise returns the binary batch result.'''
results = self.batch_results
self.batch_results = []
for n in range(count):
item = batch.pop()
result = await self.process_single_item(item)
if result:
results.append(result)
if not batch:
return self.version.batch_bytes(results)
error_bytes = self.is_oversized(self.batch_size(results), None)
if error_bytes:
return error_bytes
self.items.appendleft(item)
self.batch_results = results
return None
async def process_single_item(self, payload):
'''Handle a single JSON request, notification or response.
If it is a request, return the binary response, oterhwise None.'''
if self.log_me:
self.log_info('processing {}'.format(payload))
if not isinstance(payload, dict):
return self.error_bytes('request must be a dictionary',
JSONRPC.INVALID_REQUEST)
try:
# Requests and notifications must have a method.
if 'method' in payload:
if self.version.is_request(payload):
return await self.process_single_request(payload)
else:
await self.process_single_notification(payload)
else:
self.process_single_response(payload)
return None
except asyncio.CancelledError:
raise
except Exception:
self.log_error(traceback.format_exc())
return self.error_bytes('internal error processing request',
JSONRPC.INTERNAL_ERROR,
self.payload_id(payload))
async def process_single_request(self, payload):
'''Handle a single JSON request and return the binary response.'''
try:
result = await self.handle_payload(payload, self.request_handler)
return self.response_bytes(result, payload['id'])
except RPCError as e:
return self.error_bytes(e.msg, e.code, self.payload_id(payload))
except asyncio.CancelledError:
raise
except Exception:
self.log_error(traceback.format_exc())
return self.error_bytes('internal error processing request',
JSONRPC.INTERNAL_ERROR,
self.payload_id(payload))
async def process_single_notification(self, payload):
'''Handle a single JSON notification.'''
try:
await self.handle_payload(payload, self.notification_handler)
except RPCError:
pass
except Exception:
self.log_error(traceback.format_exc())
def process_single_response(self, payload):
'''Handle a single JSON response.'''
try:
id_ = self.check_payload_id(payload)
handler = self.pop_response_handler(id_)
if handler:
self.version.handle_response(handler, payload)
else:
self.log_info('response for unsent id {}'.format(id_),
throttle=True)
except RPCError:
pass
except Exception:
self.log_error(traceback.format_exc())
async def handle_payload(self, payload, get_handler):
'''Handle a request or notification payload given the handlers.'''
# An argument is the value passed to a function parameter...
args = payload.get('params', [])
method = payload.get('method')
if not isinstance(method, str):
raise RPCError("invalid method type {}".format(type(method)),
JSONRPC.INVALID_REQUEST)
handler = get_handler(method)
if not handler:
raise RPCError("unknown method: '{}'".format(method),
JSONRPC.METHOD_NOT_FOUND)
if not isinstance(args, (list, dict)):
raise RPCError('arguments should be an array or dictionary',
JSONRPC.INVALID_REQUEST)
params = inspect.signature(handler).parameters
names = list(params)
min_args = sum(p.default is p.empty for p in params.values())
if len(args) < min_args:
raise RPCError('too few arguments to {}: expected {:d} got {:d}'
.format(method, min_args, len(args)),
JSONRPC.INVALID_ARGS)
if len(args) > len(params):
raise RPCError('too many arguments to {}: expected {:d} got {:d}'
.format(method, len(params), len(args)),
JSONRPC.INVALID_ARGS)
if isinstance(args, list):
kw_args = {name: arg for name, arg in zip(names, args)}
else:
kw_args = args
bad_names = ['<{}>'.format(name) for name in args
if name not in names]
if bad_names:
raise RPCError('invalid parameter names: {}'
.format(', '.join(bad_names)))
if inspect.iscoroutinefunction(handler):
return await handler(**kw_args)
else:
return handler(**kw_args)
# ---- External Interface ----
async def process_pending_items(self, limit=8):
'''Processes up to LIMIT pending items asynchronously.'''
while limit > 0 and self.items:
item = self.items.popleft()
if isinstance(item, list) and self.version.HAS_BATCHES:
count = min(limit, len(item))
binary = await self.process_batch(item, count)
limit -= count
else:
binary = await self.process_single_item(item)
limit -= 1
if binary:
self.send_binary(binary)
if not self.items:
self.items_event.clear()
def count_pending_items(self):
'''Counts the number of pending items.'''
return sum(len(item) if isinstance(item, list) else 1
for item in self.items)
def connection_made(self):
'''Call when an incoming client connection is established.'''
self.session_id = self.next_session_id()
self.log_prefix = '[{:d}] '.format(self.session_id)
def data_received(self, data):
'''Underlying transport calls this when new data comes in.
Look for newline separators terminating full requests.
'''
if self.is_closing():
return
self.using_bandwidth(len(data))
self.recv_size += len(data)
# Close abusive connections where buffered data exceeds limit
buffer_size = len(data) + sum(len(part) for part in self.parts)
if buffer_size > self.max_buffer_size:
self.log_error('read buffer of {:,d} bytes over {:,d} byte limit'
.format(buffer_size, self.max_buffer_size))
self.close_connection()
return
while True:
npos = data.find(ord('\n'))
if npos == -1:
self.parts.append(data)
break
tail, data = data[:npos], data[npos + 1:]
parts, self.parts = self.parts, []
parts.append(tail)
self.recv_count += 1
self.last_recv = time.time()
self.decode_message(b''.join(parts))
def send_error(self, message, code, id_=None):
'''Send a JSON error.'''
self.send_binary(self.error_bytes(message, code, id_))
def send_request(self, handler, method, params=None, timeout=30):
'''Sends a request and arranges for handler to be called with the
response when it comes in.
A call to request_timeout_check() will result in pending
responses that have been waiting more than timeout seconds to
call the handler with a REQUEST_TIMEOUT error.
'''
id_ = self.next_request_id
self.next_request_id += 1
self.send_binary(self.request_bytes(id_, method, params))
self._pending_reqs[(self, id_)] = (handler, time.time() + timeout)
def send_notification(self, method, params=None):
'''Send a notification.'''
self.send_binary(self.notification_bytes(method, params))
def send_notifications(self, mp_iterable):
'''Send an iterable of (method, params) notification pairs.
A 1-tuple is also valid in which case there are no params.'''
if False and self.version.HAS_BATCHES:
parts = [self.notification_bytes(*pair) for pair in mp_iterable]
self.send_binary(self.version.batch_bytes(parts))
else:
for pair in mp_iterable:
self.send_notification(*pair)
# -- derived classes are intended to override these functions
# Transport layer
def is_closing(self):
'''Return True if the underlying transport is closing.'''
raise NotImplementedError
def close_connection(self):
'''Close the connection.'''
raise NotImplementedError
def send_bytes(self, binary):
'''Pass the bytes through to the underlying transport.'''
raise NotImplementedError
# App layer
def using_bandwidth(self, amount):
'''Called as bandwidth is consumed.
Override to implement bandwidth management.
'''
pass
def notification_handler(self, method):
'''Return the handler for the given notification.
The handler can be synchronous or asynchronous.'''
return None
def request_handler(self, method):
'''Return the handler for the given request method.
The handler can be synchronous or asynchronous.'''
return None
class JSONSession(JSONSessionBase, asyncio.Protocol):
'''A JSONSessionBase instance specialized for use with
asyncio.protocol to implement the transport layer.
The app should await on items_event, which is set when unprocessed
incoming items remain and cleared when the queue is empty, and
then arrange to call process_pending_items asynchronously.
Derived classes may want to override the request and notification
handlers.
'''
def __init__(self, version=JSONRPCCompat):
super().__init__(version=version)
self.transport = None
self.write_buffer_high = 500000
def peer_info(self):
'''Returns information about the peer.'''
try:
# get_extra_info can throw even if self.transport is not None
return self.transport.get_extra_info('peername')
except Exception:
return None
def abort(self):
'''Cut the connection abruptly.'''
self.transport.abort()
def connection_made(self, transport):
'''Handle an incoming client connection.'''
transport.set_write_buffer_limits(high=self.write_buffer_high)
self.transport = transport
super().connection_made()
def connection_lost(self, exc):
'''Trigger timeouts of all pending requests.'''
self.timeout_session()
def is_closing(self):
'''True if the underlying transport is closing.'''
return self.transport and self.transport.is_closing()
def close_connection(self):
'''Close the connection.'''
if self.transport:
self.transport.close()
def send_bytes(self, binary):
'''Send JSON text over the transport.'''
self.transport.writelines((binary, b'\n'))
def peer_addr(self, anon=True):
'''Return the peer address and port.'''
peer_info = self.peer_info()
if not peer_info:
return 'unknown'
if anon:
return 'xx.xx.xx.xx:xx'
if ':' in peer_info[0]:
return '[{}]:{}'.format(peer_info[0], peer_info[1])
else:
return '{}:{}'.format(peer_info[0], peer_info[1])

7
lib/server_base.py

@ -6,16 +6,15 @@
# and warranty status of this software.
import asyncio
import logging
import os
import signal
import sys
import time
from functools import partial
import lib.util as util
class ServerBase(util.LoggedClass):
class ServerBase(object):
'''Base class server implementation.
Derived classes are expected to:
@ -37,7 +36,7 @@ class ServerBase(util.LoggedClass):
'''Save the environment, perform basic sanity checks, and set the
event loop policy.
'''
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.env = env
# Sanity checks

29
lib/util.py

@ -37,30 +37,11 @@ from collections import Container, Mapping
from struct import pack, Struct
class LoggedClass(object):
def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.setLevel(logging.INFO)
self.log_prefix = ''
self.throttled = 0
def log_info(self, msg, throttle=False):
# Prevent annoying log messages by throttling them if there
# are too many in a short period
if throttle:
self.throttled += 1
if self.throttled > 3:
return
if self.throttled == 3:
msg += ' (throttling later logs)'
self.logger.info(self.log_prefix + msg)
def log_warning(self, msg):
self.logger.warning(self.log_prefix + msg)
def log_error(self, msg):
self.logger.error(self.log_prefix + msg)
class ConnectionLogger(logging.LoggerAdapter):
'''Prepends a connection identifier to a logging message.'''
def process(self, msg, kwargs):
conn_id = self.extra.get('conn_id', 'unknown')
return f'[{conn_id}] {msg}', kwargs
# Method decorator. To be used for calculations that will always

7
server/block_processor.py

@ -11,6 +11,7 @@
import array
import asyncio
import logging
from struct import pack, unpack
import time
from collections import defaultdict
@ -19,15 +20,15 @@ from functools import partial
from server.daemon import DaemonError
from server.version import VERSION
from lib.hash import hash_to_str
from lib.util import chunks, formatted_time, LoggedClass
from lib.util import chunks, formatted_time
import server.db
class Prefetcher(LoggedClass):
class Prefetcher(object):
'''Prefetches blocks (in the forward direction only).'''
def __init__(self, bp):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.bp = bp
self.caught_up = False
# Access to fetched_height should be protected by the semaphore

181
server/controller.py

@ -6,6 +6,7 @@
# and warranty status of this software.
import asyncio
import itertools
import json
import os
import ssl
@ -18,7 +19,7 @@ from functools import partial
import pylru
from lib.jsonrpc import JSONSessionBase, RPCError
from aiorpcx import RPCError
from lib.hash import double_sha256, hash_to_str, hex_str_to_hash
from lib.peer import Peer
from lib.server_base import ServerBase
@ -26,7 +27,15 @@ import lib.util as util
from server.daemon import DaemonError
from server.mempool import MemPool
from server.peers import PeerManager
from server.session import LocalRPC
from server.session import LocalRPC, BAD_REQUEST, DAEMON_ERROR
class SessionGroup(object):
def __init__(self, gid):
self.gid = gid
# Concurrency per group
self.semaphore = asyncio.Semaphore(20)
class Controller(ServerBase):
@ -36,7 +45,6 @@ class Controller(ServerBase):
up with the daemon.
'''
BANDS = 5
CATCHING_UP, LISTENING, PAUSED, SHUTTING_DOWN = range(4)
def __init__(self, env):
@ -45,9 +53,9 @@ class Controller(ServerBase):
self.coin = env.coin
self.servers = {}
# Map of session to the key of its list in self.groups
self.sessions = {}
self.groups = defaultdict(list)
self.sessions = set()
self.groups = set()
self.cur_group = self._new_group(0)
self.txs_sent = 0
self.next_log_sessions = 0
self.state = self.CATCHING_UP
@ -62,7 +70,6 @@ class Controller(ServerBase):
self.header_cache = pylru.lrucache(8)
self.cache_height = 0
env.max_send = max(350000, env.max_send)
self.setup_bands()
# Set up the RPC request handlers
cmds = ('add_peer daemon_url disconnect getinfo groups log peers reorg '
'sessions stop'.split())
@ -128,29 +135,6 @@ class Controller(ServerBase):
'''Call when a TX is sent.'''
self.txs_sent += 1
def setup_bands(self):
bands = []
limit = self.env.bandwidth_limit
for n in range(self.BANDS):
bands.append(limit)
limit //= 4
limit = self.env.bandwidth_limit
for n in range(self.BANDS):
limit += limit // 2
bands.append(limit)
self.bands = sorted(bands)
def session_priority(self, session):
if isinstance(session, LocalRPC):
return 0
gid = self.sessions[session]
group_bw = sum(session.bw_used for session in self.groups[gid])
return 1 + (bisect_left(self.bands, session.bw_used)
+ bisect_left(self.bands, group_bw)) // 2
def is_deprioritized(self, session):
return self.session_priority(session) > self.BANDS
async def run_in_executor(self, func, *args):
'''Wait whilst running func in the executor.'''
return await self.loop.run_in_executor(None, func, *args)
@ -177,7 +161,7 @@ class Controller(ServerBase):
except asyncio.CancelledError:
pass
except Exception:
self.log_error(traceback.format_exc())
self.logger.error(traceback.format_exc())
async def housekeeping(self):
'''Regular housekeeping checks.'''
@ -185,7 +169,6 @@ class Controller(ServerBase):
while True:
n += 1
await asyncio.sleep(15)
JSONSessionBase.timeout_check()
if n % 10 == 0:
self.clear_stale_sessions()
@ -253,7 +236,6 @@ class Controller(ServerBase):
.format(self.max_subs))
self.logger.info('max subscriptions per session: {:,d}'
.format(self.env.max_session_subs))
self.logger.info('bands: {}'.format(self.bands))
if self.env.drop_client is not None:
self.logger.info('drop clients matching: {}'
.format(self.env.drop_client.pattern))
@ -304,7 +286,7 @@ class Controller(ServerBase):
'''Return the binary header at the given height.'''
header, n = self.bp.read_headers(height, 1)
if n != 1:
raise RPCError('height {:,d} out of range'.format(height))
raise RPCError(BAD_REQUEST, f'height {height:,d} out of range')
return header
def electrum_header(self, height):
@ -315,74 +297,54 @@ class Controller(ServerBase):
height)
return self.header_cache[height]
def session_delay(self, session):
priority = self.session_priority(session)
excess = max(0, priority - self.BANDS)
if excess != session.last_delay:
session.last_delay = excess
if excess:
session.log_info('high bandwidth use, deprioritizing by '
'delaying responses {:d}s'.format(excess))
else:
session.log_info('stopped delaying responses')
return max(int(session.pause), excess)
async def process_items(self, session):
'''Waits for incoming session items and processes them.'''
while True:
await session.items_event.wait()
await asyncio.sleep(self.session_delay(session))
if not session.pause:
await session.process_pending_items()
def _new_group(self, gid):
group = SessionGroup(gid)
self.groups.add(group)
return group
def add_session(self, session):
session.items_future = self.ensure_future(self.process_items(session))
gid = int(session.start_time - self.start_time) // 900
self.groups[gid].append(session)
self.sessions[session] = gid
session.log_info('{} {}, {:,d} total'
.format(session.kind, session.peername(),
len(self.sessions)))
self.sessions.add(session)
if (len(self.sessions) >= self.max_sessions
and self.state == self.LISTENING):
self.state = self.PAUSED
session.log_info('maximum sessions {:,d} reached, stopping new '
'connections until count drops to {:,d}'
.format(self.max_sessions, self.low_watermark))
session.logger.info('maximum sessions {:,d} reached, stopping new '
'connections until count drops to {:,d}'
.format(self.max_sessions, self.low_watermark))
self.close_servers(['TCP', 'SSL'])
gid = int(session.start_time - self.start_time) // 900
if self.cur_group.gid != gid:
self.cur_group = self._new_group(gid)
return self.cur_group
def remove_session(self, session):
'''Remove a session from our sessions list if there.'''
session.items_future.cancel()
if session in self.sessions:
gid = self.sessions.pop(session)
assert gid in self.groups
self.groups[gid].remove(session)
self.sessions.remove(session)
def close_session(self, session):
'''Close the session's transport and cancel its future.'''
session.close_connection()
'''Close the session's transport.'''
session.close()
return 'disconnected {:d}'.format(session.session_id)
def toggle_logging(self, session):
'''Toggle logging of the session.'''
session.log_me = not session.log_me
session.toggle_logging()
return 'log {:d}: {}'.format(session.session_id, session.log_me)
def clear_stale_sessions(self, grace=15):
'''Cut off sessions that haven't done anything for 10 minutes. Force
close stubborn connections that won't close cleanly after a
short grace period.
'''
def _group_map(self):
group_map = defaultdict(list)
for session in self.sessions:
group_map[session.group].append(session)
return group_map
def clear_stale_sessions(self):
'''Cut off sessions that haven't done anything for 10 minutes.'''
now = time.time()
shutdown_cutoff = now - grace
stale_cutoff = now - self.env.session_timeout
stale = []
for session in self.sessions:
if session.is_closing():
if session.close_time <= shutdown_cutoff:
session.abort()
pass
elif session.last_recv < stale_cutoff:
self.close_session(session)
stale.append(session.session_id)
@ -390,16 +352,18 @@ class Controller(ServerBase):
self.logger.info('closing stale connections {}'.format(stale))
# Consolidate small groups
gids = [gid for gid, l in self.groups.items() if len(l) <= 4
and sum(session.bw_used for session in l) < 10000]
if len(gids) > 1:
sessions = sum([self.groups[gid] for gid in gids], [])
new_gid = max(gids)
for gid in gids:
del self.groups[gid]
for session in sessions:
self.sessions[session] = new_gid
self.groups[new_gid] = sessions
bw_limit = self.env.bandwidth_limit
group_map = self._group_map()
groups = [group for group, sessions in group_map.items()
if len(sessions) <= 5 and
sum(s.bw_charge for s in sessions) < bw_limit]
if len(groups) > 1:
new_gid = max(group.gid for group in groups)
new_group = self._new_group(new_gid)
for group in groups:
self.groups.remove(group)
for session in group_map[group]:
session.group = new_group
def session_count(self):
'''The number of connections that we've sent something to.'''
@ -412,10 +376,10 @@ class Controller(ServerBase):
'daemon_height': self.daemon.cached_height(),
'db_height': self.bp.db_height,
'closing': len([s for s in self.sessions if s.is_closing()]),
'errors': sum(s.error_count for s in self.sessions),
'errors': sum(s.rpc.errors for s in self.sessions),
'groups': len(self.groups),
'logged': len([s for s in self.sessions if s.log_me]),
'paused': sum(s.pause for s in self.sessions),
'paused': sum(s.paused for s in self.sessions),
'pid': os.getpid(),
'peers': self.peer_mgr.info(),
'requests': sum(s.count_pending_items() for s in self.sessions),
@ -454,11 +418,11 @@ class Controller(ServerBase):
def group_data(self):
'''Returned to the RPC 'groups' call.'''
result = []
for gid in sorted(self.groups.keys()):
sessions = self.groups[gid]
result.append([gid,
group_map = self._group_map()
for group, sessions in group_map.items():
result.append([group.gid,
len(sessions),
sum(s.bw_used for s in sessions),
sum(s.bw_charge for s in sessions),
sum(s.count_pending_items() for s in sessions),
sum(s.txs_sent for s in sessions),
sum(s.sub_count() for s in sessions),
@ -531,7 +495,7 @@ class Controller(ServerBase):
sessions = sorted(self.sessions, key=lambda s: s.start_time)
return [(session.session_id,
session.flags(),
session.peername(for_log=for_log),
session.peer_address_str(for_log=for_log),
session.client,
session.protocol_version,
session.count_pending_items(),
@ -555,7 +519,7 @@ class Controller(ServerBase):
def for_each_session(self, session_ids, operation):
if not isinstance(session_ids, list):
raise RPCError('expected a list of session IDs')
raise RPCError(BAD_REQUEST, 'expected a list of session IDs')
result = []
for session_id in session_ids:
@ -597,7 +561,7 @@ class Controller(ServerBase):
try:
self.daemon.set_urls(self.env.coin.daemon_urls(daemon_url))
except Exception as e:
raise RPCError('an error occured: {}'.format(e))
raise RPCError(BAD_REQUEST, f'an error occured: {e}')
return 'now using daemon at {}'.format(self.daemon.logged_url())
def rpc_stop(self):
@ -628,7 +592,7 @@ class Controller(ServerBase):
'''
count = self.non_negative_integer(count)
if not self.bp.force_chain_reorg(count):
raise RPCError('still catching up with daemon')
raise RPCError(BAD_REQUEST, 'still catching up with daemon')
return 'scheduled a reorg of {:,d} blocks'.format(count)
# Helpers for RPC "blockchain" command handlers
@ -638,7 +602,7 @@ class Controller(ServerBase):
return self.coin.address_to_hashX(address)
except Exception:
pass
raise RPCError('{} is not a valid address'.format(address))
raise RPCError(BAD_REQUEST, f'{address} is not a valid address')
def scripthash_to_hashX(self, scripthash):
try:
@ -647,7 +611,7 @@ class Controller(ServerBase):
return bin_hash[:self.coin.HASHX_LEN]
except Exception:
pass
raise RPCError('{} is not a valid script hash'.format(scripthash))
raise RPCError(BAD_REQUEST, f'{scripthash} is not a valid script hash')
def assert_tx_hash(self, value):
'''Raise an RPCError if the value is not a valid transaction
@ -657,7 +621,7 @@ class Controller(ServerBase):
return
except Exception:
pass
raise RPCError('{} should be a transaction hash'.format(value))
raise RPCError(BAD_REQUEST, f'{value} should be a transaction hash')
def non_negative_integer(self, value):
'''Return param value it is or can be converted to a non-negative
@ -668,21 +632,22 @@ class Controller(ServerBase):
return value
except ValueError:
pass
raise RPCError('{} should be a non-negative integer'.format(value))
raise RPCError(BAD_REQUEST,
f'{value} should be a non-negative integer')
async def daemon_request(self, method, *args):
'''Catch a DaemonError and convert it to an RPCError.'''
try:
return await getattr(self.daemon, method)(*args)
except DaemonError as e:
raise RPCError('daemon error: {}'.format(e))
raise RPCError(DAEMON_ERROR, f'daemon error: {e}')
def new_subscription(self):
if self.subs_room <= 0:
self.subs_room = self.max_subs - self.sub_count()
if self.subs_room <= 0:
raise RPCError('server subscription limit {:,d} reached'
.format(self.max_subs))
raise RPCError(BAD_REQUEST, f'server subscription limit '
f'{self.max_subs:,d} reached')
self.subs_room -= 1
async def tx_merkle(self, tx_hash, height):
@ -693,8 +658,8 @@ class Controller(ServerBase):
try:
pos = tx_hashes.index(tx_hash)
except ValueError:
raise RPCError('tx hash {} not in block {} at height {:,d}'
.format(tx_hash, hex_hashes[0], height))
raise RPCError(BAD_REQUEST, f'tx hash {tx_hash} not in '
f'block {hex_hashes[0]} at height {height:,d}')
idx = pos
hashes = [hex_str_to_hash(txh) for txh in tx_hashes]

14
server/daemon.py

@ -10,24 +10,24 @@ daemon.'''
import asyncio
import json
import logging
import time
import traceback
from calendar import timegm
from struct import pack
from time import strptime
import aiohttp
from lib.util import LoggedClass, int_to_varint, hex_to_bytes
from lib.util import int_to_varint, hex_to_bytes
from lib.hash import hex_str_to_hash
from lib.jsonrpc import JSONRPC
from aiorpcx import JSONRPC
class DaemonError(Exception):
'''Raised when the daemon returns an error in its results.'''
class Daemon(LoggedClass):
class Daemon(object):
'''Handles connections to a daemon at the given URL.'''
WARMING_UP = -28
@ -37,7 +37,7 @@ class Daemon(LoggedClass):
'''Raised when the daemon returns an error in its results.'''
def __init__(self, env):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.coin = env.coin
self.set_urls(env.coin.daemon_urls(env.daemon_url))
self._height = None
@ -151,8 +151,8 @@ class Daemon(LoggedClass):
log_error('starting up checking blocks.')
except (asyncio.CancelledError, DaemonError):
raise
except Exception:
self.log_error(traceback.format_exc())
except Exception as e:
self.logger.exception(f'uncaught exception: {e}')
await asyncio.sleep(secs)
secs = min(max_secs, secs * 2, 1)

24
server/db.py

@ -11,6 +11,7 @@
import array
import ast
import logging
import os
from struct import pack, unpack
from bisect import bisect_left, bisect_right
@ -25,7 +26,7 @@ from server.version import VERSION, PROTOCOL_MIN, PROTOCOL_MAX
UTXO = namedtuple("UTXO", "tx_num tx_pos tx_hash height value")
class DB(util.LoggedClass):
class DB(object):
'''Simple wrapper of the backend database for querying.
Performs no DB update, though the DB will be cleaned on opening if
@ -41,7 +42,7 @@ class DB(util.LoggedClass):
'''Raised on general DB errors generally indicating corruption.'''
def __init__(self, env):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.env = env
self.coin = env.coin
@ -581,9 +582,10 @@ class DB(util.LoggedClass):
full_hist = b''.join(hist_list)
nrows = (len(full_hist) + max_row_size - 1) // max_row_size
if nrows > 4:
self.log_info('hashX {} is large: {:,d} entries across {:,d} rows'
.format(hash_to_str(hashX), len(full_hist) // 4,
nrows))
self.logger.info('hashX {} is large: {:,d} entries across '
'{:,d} rows'
.format(hash_to_str(hashX), len(full_hist) // 4,
nrows))
# Find what history needs to be written, and what keys need to
# be deleted. Start by assuming all keys are to be deleted,
@ -652,11 +654,11 @@ class DB(util.LoggedClass):
max_rows = self.comp_flush_count + 1
self._flush_compaction(cursor, write_items, keys_to_delete)
self.log_info('history compaction: wrote {:,d} rows ({:.1f} MB), '
'removed {:,d} rows, largest: {:,d}, {:.1f}% complete'
.format(len(write_items), write_size / 1000000,
len(keys_to_delete), max_rows,
100 * cursor / 65536))
self.logger.info('history compaction: wrote {:,d} rows ({:.1f} MB), '
'removed {:,d} rows, largest: {:,d}, {:.1f}% complete'
.format(len(write_items), write_size / 1000000,
len(keys_to_delete), max_rows,
100 * cursor / 65536))
return write_size
async def compact_history(self, loop):
@ -673,7 +675,7 @@ class DB(util.LoggedClass):
while self.comp_cursor != -1:
if self.semaphore.locked:
self.log_info('compact_history: waiting on semaphore...')
self.logger.info('compact_history: waiting on semaphore...')
async with self.semaphore:
await loop.run_in_executor(None, self._compact_history, limit)

6
server/env.py

@ -86,9 +86,9 @@ class Env(EnvBase):
# We give the DB 250 files; allow ElectrumX 100 for itself
value = max(0, min(env_value, nofile_limit - 350))
if value < env_value:
self.log_warning('lowered maximum sessions from {:,d} to {:,d} '
'because your open file limit is {:,d}'
.format(env_value, value, nofile_limit))
self.logger.warning('lowered maximum sessions from {:,d} to {:,d} '
'because your open file limit is {:,d}'
.format(env_value, value, nofile_limit))
return value
def clearnet_identity(self):

6
server/mempool.py

@ -9,16 +9,16 @@
import asyncio
import itertools
import logging
import time
from collections import defaultdict
from lib.hash import hash_to_str, hex_str_to_hash
import lib.util as util
from server.daemon import DaemonError
from server.db import UTXO
class MemPool(util.LoggedClass):
class MemPool(object):
'''Representation of the daemon's mempool.
Updated regularly in caught-up state. Goal is to enable efficient
@ -33,7 +33,7 @@ class MemPool(util.LoggedClass):
'''
def __init__(self, bp, controller):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.daemon = bp.daemon
self.controller = controller
self.coin = bp.coin

54
server/peers.py

@ -8,6 +8,7 @@
'''Peer management.'''
import asyncio
import logging
import random
import socket
import ssl
@ -18,7 +19,7 @@ from functools import partial
import aiorpcx
from lib.peer import Peer
import lib.util as util
from lib.util import ConnectionLogger
import server.version as version
@ -27,16 +28,17 @@ STALE_SECS = 24 * 3600
WAKEUP_SECS = 300
class PeerSession(aiorpcx.ClientSession, util.LoggedClass):
class PeerSession(aiorpcx.ClientSession):
'''An outgoing session to a peer.'''
def __init__(self, peer, peer_mgr, kind, host, port, **kwargs):
super().__init__(host, port, **kwargs)
util.LoggedClass.__init__(self)
self.peer = peer
self.peer_mgr = peer_mgr
self.kind = kind
self.timeout = 20 if self.peer.is_tor else 10
context = {'conn_id': f'{host}'}
self.logger = ConnectionLogger(self.logger, context)
def connection_made(self, transport):
'''Handle an incoming client connection.'''
@ -53,6 +55,15 @@ class PeerSession(aiorpcx.ClientSession, util.LoggedClass):
self.send_request('server.version', args, self.on_version,
timeout=self.timeout)
def _header_notification(self, header):
pass
def notification_handler(self, method):
# We subscribe so might be unlucky enough to get a notification...
if method == 'blockchain.headers.subscribe':
return self._header_notification
return None
def is_good(self, request, instance):
try:
result = request.result()
@ -73,12 +84,12 @@ class PeerSession(aiorpcx.ClientSession, util.LoggedClass):
return False
def fail(self, request, reason):
self.logger.error(f'[{self.peer.host}] {request} failed: {reason}')
self.logger.error(f'{request} failed: {reason}')
self.peer_mgr.set_verification_status(self.peer, self.kind, False)
self.close()
def bad(self, reason):
self.logger.error(f'[{self.peer.host}] marking bad: {reason}')
self.logger.error(f'marking bad: {reason}')
self.peer.mark_bad()
self.peer_mgr.set_verification_status(self.peer, self.kind, False)
self.close()
@ -180,8 +191,7 @@ class PeerSession(aiorpcx.ClientSession, util.LoggedClass):
features = self.peer_mgr.features_to_register(self.peer, peers)
if features:
self.logger.info(f'[{self.peer.host}] registering ourself with '
'"server.add_peer"')
self.logger.info(f'registering ourself with "server.add_peer"')
self.send_request('server.add_peer', [features],
self.on_add_peer, timeout=self.timeout)
else:
@ -201,14 +211,14 @@ class PeerSession(aiorpcx.ClientSession, util.LoggedClass):
self.peer_mgr.set_verification_status(self.peer, self.kind, True)
class PeerManager(util.LoggedClass):
class PeerManager(object):
'''Looks after the DB of peer network servers.
Attempts to maintain a connection with up to 8 peers.
Issues a 'peers.subscribe' RPC to them and tells them our data.
'''
def __init__(self, env, controller):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
# Initialise the Peer class
Peer.DEFAULT_PORTS = env.coin.PEER_DEFAULT_PORTS
self.env = env
@ -336,12 +346,12 @@ class PeerManager(util.LoggedClass):
async def on_add_peer(self, features, source_info):
'''Add a peer (but only if the peer resolves to the source).'''
if not source_info:
self.log_info('ignored add_peer request: no source info')
self.logger.info('ignored add_peer request: no source info')
return False
source = source_info[0]
peers = Peer.peers_from_features(features, source)
if not peers:
self.log_info('ignored add_peer request: no peers given')
self.logger.info('ignored add_peer request: no peers given')
return False
# Just look at the first peer, require it
@ -362,12 +372,12 @@ class PeerManager(util.LoggedClass):
reason = 'source-destination mismatch'
if permit:
self.log_info('accepted add_peer request from {} for {}'
.format(source, host))
self.logger.info('accepted add_peer request from {} for {}'
.format(source, host))
self.add_peers([peer], check_ports=True)
else:
self.log_warning('rejected add_peer request from {} for {} ({})'
.format(source, host, reason))
self.logger.warning('rejected add_peer request from {} for {} ({})'
.format(source, host, reason))
return permit
@ -438,13 +448,13 @@ class PeerManager(util.LoggedClass):
ports = [9050, 9150, 1080]
else:
ports = [self.env.tor_proxy_port]
self.log_info(f'trying to detect proxy on "{host}" ports {ports}')
self.logger.info(f'trying to detect proxy on "{host}" ports {ports}')
cls = aiorpcx.SOCKSProxy
result = await cls.auto_detect_host(host, ports, None, loop=self.loop)
if isinstance(result, cls):
self.proxy = result
self.log_info(f'detected {self.proxy}')
self.logger.info(f'detected {self.proxy}')
def proxy_peername(self):
'''Return the peername of the proxy, if there is a proxy, otherwise
@ -544,10 +554,10 @@ class PeerManager(util.LoggedClass):
if exception:
session.close()
kind, port = port_pairs.pop(0)
self.log_info('failed connecting to {} at {} port {:d} '
'in {:.1f}s: {}'
.format(peer, kind, port,
time.time() - peer.last_try, exception))
self.logger.info('failed connecting to {} at {} port {:d} '
'in {:.1f}s: {}'
.format(peer, kind, port,
time.time() - peer.last_try, exception))
if port_pairs:
self.retry_peer(peer, port_pairs)
else:
@ -562,7 +572,7 @@ class PeerManager(util.LoggedClass):
how = 'via {} at {}'.format(kind, peer.ip_addr)
status = 'verified' if good else 'failed to verify'
elapsed = now - peer.last_try
self.log_info('{} {} {} in {:.1f}s'.format(status, peer, how, elapsed))
self.logger.info(f'{status} {peer} {how} in {elapsed:.1f}s')
if good:
peer.try_count = 0

157
server/session.py

@ -8,17 +8,38 @@
'''Classes for local RPC server and remote client TCP/SSL servers.'''
import codecs
import itertools
import time
from functools import partial
from aiorpcx import ServerSession, JSONRPCAutoDetect, RPCError
from lib.hash import sha256, hash_to_str
from lib.jsonrpc import JSONSession, RPCError, JSONRPCv2, JSONRPC
import lib.util as util
from server.daemon import DaemonError
import server.version as version
BAD_REQUEST = 1
DAEMON_ERROR = 2
class Semaphores(object):
def __init__(self, semaphores):
self.semaphores = semaphores
self.acquired = []
async def __aenter__(self):
for semaphore in self.semaphores:
await semaphore.acquire()
self.acquired.append(semaphore)
async def __aexit__(self, exc_type, exc_value, traceback):
for semaphore in self.acquired:
semaphore.release()
class SessionBase(JSONSession):
class SessionBase(ServerSession):
'''Base class of ElectrumX JSON sessions.
Each session runs its tasks in asynchronous parallelism with other
@ -26,11 +47,10 @@ class SessionBase(JSONSession):
'''
MAX_CHUNK_SIZE = 2016
session_counter = itertools.count()
def __init__(self, controller, kind):
# Force v2 as a temporary hack for old Coinomi wallets
# Remove in April 2017
super().__init__(version=JSONRPCv2)
super().__init__(rpc_protocol=JSONRPCAutoDetect)
self.kind = kind # 'RPC', 'TCP' etc.
self.controller = controller
self.bp = controller.bp
@ -39,24 +59,28 @@ class SessionBase(JSONSession):
self.client = 'unknown'
self.client_version = (1, )
self.anon_logs = self.env.anon_logs
self.last_delay = 0
self.txs_sent = 0
self.requests = []
self.start_time = time.time()
self.close_time = 0
self.log_me = False
self.bw_limit = self.env.bandwidth_limit
self.bw_time = self.start_time
self.bw_interval = 3600
self.bw_used = 0
self._orig_mr = self.rpc.message_received
def close_connection(self):
'''Call this to close the connection.'''
self.close_time = time.time()
super().close_connection()
def peer_address_str(self, *, for_log=True):
'''Returns the peer's IP address and port as a human-readable
string, respecting anon logs if the output is for a log.'''
if for_log and self.anon_logs:
return 'xx.xx.xx.xx:xx'
return super().peer_address_str()
def peername(self, *, for_log=True):
'''Return the peer address and port.'''
return self.peer_addr(anon=for_log and self.anon_logs)
def message_received(self, message):
self.logger.info(f'processing {message}')
self._orig_mr(message)
def toggle_logging(self):
self.log_me = not self.log_me
if self.log_me:
self.rpc.message_received = self.message_received
else:
self.rpc.message_received = self._orig_mr
def flags(self):
'''Status flags.'''
@ -65,38 +89,42 @@ class SessionBase(JSONSession):
status += 'C'
if self.log_me:
status += 'L'
status += str(self.controller.session_priority(self))
status += str(self.concurrency.max_concurrent)
return status
def connection_made(self, transport):
'''Handle an incoming client connection.'''
super().connection_made(transport)
self.controller.add_session(self)
self.session_id = next(self.session_counter)
context = {'conn_id': f'{self.session_id}'}
self.logger = util.ConnectionLogger(self.logger, context)
self.rpc.logger = self.logger
self.group = self.controller.add_session(self)
self.logger.info(f'{self.kind} {self.peer_address_str()}, '
f'{len(self.controller.sessions):,d} total')
def connection_lost(self, exc):
'''Handle client disconnection.'''
super().connection_lost(exc)
self.controller.remove_session(self)
msg = ''
if self.pause:
if self.paused:
msg += ' whilst paused'
if self.controller.is_deprioritized(self):
msg += ' whilst deprioritized'
if self.concurrency.max_concurrent != self.max_concurrent:
msg += ' whilst throttled'
if self.send_size >= 1024*1024:
msg += ('. Sent {:,d} bytes in {:,d} messages'
.format(self.send_size, self.send_count))
if msg:
msg = 'disconnected' + msg
self.log_info(msg)
self.controller.remove_session(self)
self.logger.info(msg)
self.group = None
def count_pending_items(self):
return self.rpc.pending_requests
def using_bandwidth(self, amount):
now = time.time()
# Reduce the recorded usage in proportion to the elapsed time
elapsed = now - self.bw_time
self.bandwidth_start = now
refund = int(elapsed / self.bw_interval * self.bw_limit)
refund = min(refund, self.bw_used)
self.bw_used += amount - refund
def semaphore(self):
return Semaphores([self.concurrency.semaphore, self.group.semaphore])
def sub_count(self):
return 0
@ -111,7 +139,7 @@ class ElectrumX(SessionBase):
self.subscribe_headers_raw = False
self.subscribe_height = False
self.notified_height = None
self.max_send = self.env.max_send
self.max_response_size = self.env.max_send
self.max_subs = self.env.max_session_subs
self.hashX_subs = {}
self.mempool_statuses = {}
@ -148,8 +176,8 @@ class ElectrumX(SessionBase):
if changed:
es = '' if len(changed) == 1 else 'es'
self.log_info('notified of {:,d} address{}'
.format(len(changed), es))
self.logger.info('notified of {:,d} address{}'
.format(len(changed), es))
def notify(self, height, touched):
'''Notify the client about changes to touched addresses (from mempool
@ -185,7 +213,7 @@ class ElectrumX(SessionBase):
'''Return param value it is boolean otherwise raise an RPCError.'''
if value in (False, True):
return value
raise RPCError('{} should be a boolean value'.format(value))
raise RPCError(BAD_REQUEST, f'{value} should be a boolean value')
def subscribe_headers_result(self, height):
'''The result of a header subscription for the given height.'''
@ -209,7 +237,7 @@ class ElectrumX(SessionBase):
async def add_peer(self, features):
'''Add a peer (but only if the peer resolves to the source).'''
peer_mgr = self.controller.peer_mgr
return await peer_mgr.on_add_peer(features, self.peer_info())
return await peer_mgr.on_add_peer(features, self.peer_address())
def peers_subscribe(self):
'''Return the server peers as a list of (ip, host, details) tuples.'''
@ -244,8 +272,8 @@ class ElectrumX(SessionBase):
async def hashX_subscribe(self, hashX, alias):
# First check our limit.
if len(self.hashX_subs) >= self.max_subs:
raise RPCError('your address subscription limit {:,d} reached'
.format(self.max_subs))
raise RPCError(BAD_REQUEST, 'your address subscription limit '
f'{self.max_subs:,d} reached')
# Now let the controller check its limit
self.controller.new_subscription()
@ -299,8 +327,8 @@ class ElectrumX(SessionBase):
peername = self.controller.peer_mgr.proxy_peername()
if not peername:
return False
peer_info = self.peer_info()
return peer_info and peer_info[0] == peername[0]
peer_address = self.peer_address()
return peer_address and peer_address[0] == peername[0]
async def replaced_banner(self, banner):
network_info = await self.controller.daemon_request('getnetworkinfo')
@ -338,8 +366,7 @@ class ElectrumX(SessionBase):
with codecs.open(banner_file, 'r', 'utf-8') as f:
banner = f.read()
except Exception as e:
self.log_error('reading banner file {}: {}'
.format(banner_file, e))
self.loggererror(f'reading banner file {banner_file}: {e}')
else:
banner = await self.replaced_banner(banner)
@ -360,8 +387,9 @@ class ElectrumX(SessionBase):
if client_name:
if self.env.drop_client is not None and \
self.env.drop_client.match(client_name):
raise RPCError('unsupported client: {}'
.format(client_name), JSONRPC.FATAL_ERROR)
self.close_after_send = True
raise RPCError(BAD_REQUEST,
f'unsupported client: {client_name}')
self.client = str(client_name)[:17]
try:
self.client_version = tuple(int(part) for part
@ -376,10 +404,11 @@ class ElectrumX(SessionBase):
# From protocol version 1.1, protocol_version cannot be omitted
if ptuple is None or (ptuple >= (1, 1) and protocol_version is None):
self.log_info('unsupported protocol version request {}'
.format(protocol_version))
raise RPCError('unsupported protocol version: {}'
.format(protocol_version), JSONRPC.FATAL_ERROR)
self.logger.info('unsupported protocol version request {}'
.format(protocol_version))
self.close_after_send = True
raise RPCError(BAD_REQUEST,
f'unsupported protocol version: {protocol_version}')
self.set_protocol_handlers(ptuple)
@ -397,16 +426,15 @@ class ElectrumX(SessionBase):
try:
tx_hash = await self.daemon.sendrawtransaction([raw_tx])
self.txs_sent += 1
self.log_info('sent tx: {}'.format(tx_hash))
self.logger.info('sent tx: {}'.format(tx_hash))
self.controller.sent_tx(tx_hash)
return tx_hash
except DaemonError as e:
error, = e.args
message = error['message']
self.log_info('sendrawtransaction: {}'.format(message),
throttle=True)
raise RPCError('the transaction was rejected by network rules.'
'\n\n{}\n[{}]'.format(message, raw_tx))
self.logger.info('sendrawtransaction: {}'.format(message))
raise RPCError(BAD_REQUEST, 'the transaction was rejected by '
f'network rules.\n\n{message}\n[{raw_tx}]')
async def transaction_broadcast_1_0(self, raw_tx):
'''Broadcast a raw transaction to the network.
@ -505,7 +533,7 @@ class LocalRPC(SessionBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = 'RPC'
self.max_send = 0
self.max_response_size = 0
self.protocol_version = 'RPC'
def request_handler(self, method):
@ -535,13 +563,8 @@ class DashElectrumX(ElectrumX):
for masternode in self.mns:
status = self.daemon.masternode_list(['status', masternode])
payload = {
'id': None,
'method': 'masternode.subscribe',
'params': [masternode],
'result': status.get(masternode),
}
self.send_binary(self.encode_payload(payload))
self.send_notification('masternode.subscribe',
[masternode, status.get(masternode)])
return result
# Masternode command handlers
@ -553,9 +576,9 @@ class DashElectrumX(ElectrumX):
except DaemonError as e:
error, = e.args
message = error['message']
self.log_info('masternode_broadcast: {}'.format(message))
raise RPCError('the masternode broadcast was rejected.'
'\n\n{}\n[{}]'.format(message, signmnb))
self.logger.info('masternode_broadcast: {}'.format(message))
raise RPCError(BAD_REQUEST, 'the masternode broadcast was '
f'rejected.\n\n{message}\n[{signmnb}]')
async def masternode_announce_broadcast_1_0(self, signmnb):
'''Pass through the masternode announce message to be broadcast

2
server/version.py

@ -1,5 +1,5 @@
# Server name and protocol versions
VERSION = 'ElectrumX 1.3.1'
VERSION = 'ElectrumX 1.3.1a'
PROTOCOL_MIN = '0.9'
PROTOCOL_MAX = '1.2'

2
setup.py

@ -11,7 +11,7 @@ setuptools.setup(
# "x11_hash" package (1.4) is required to sync DASH network.
# "tribus_hash" package is required to sync Denarius network.
# "blake256" package is required to sync Decred network.
install_requires=['aiorpcX >= 0.4.4', 'plyvel', 'pylru', 'aiohttp >= 1'],
install_requires=['aiorpcX >= 0.5.2', 'plyvel', 'pylru', 'aiohttp >= 1'],
packages=setuptools.find_packages(exclude=['tests']),
description='ElectrumX Server',
author='Neil Booth',

45
tests/lib/test_util.py

@ -5,51 +5,6 @@ import pytest
from lib import util
class LoggedClassTest(util.LoggedClass):
def __init__(self):
super().__init__()
self.logger.info = self.note_info
self.logger.warning = self.note_warning
self.logger.error = self.note_error
def note_info(self, msg):
self.info_msg = msg
def note_warning(self, msg):
self.warning_msg = msg
def note_error(self, msg):
self.error_msg = msg
def test_LoggedClass():
test = LoggedClassTest()
assert test.log_prefix == ''
test.log_prefix = 'prefix'
test.log_error('an error')
assert test.error_msg == 'prefixan error'
test.log_warning('a warning')
assert test.warning_msg == 'prefixa warning'
test.log_info('some info')
assert test.info_msg == 'prefixsome info'
assert test.throttled == 0
test.log_info('some info', throttle=True)
assert test.throttled == 1
assert test.info_msg == 'prefixsome info'
test.log_info('some info', throttle=True)
assert test.throttled == 2
assert test.info_msg == 'prefixsome info'
test.log_info('some info', throttle=True)
assert test.throttled == 3
assert test.info_msg == 'prefixsome info (throttling later logs)'
test.info_msg = ''
test.log_info('some info', throttle=True)
assert test.throttled == 4
assert test.info_msg == ''
def test_cachedproperty():
class Target:

9
tests/server/test_api.py

@ -1,7 +1,7 @@
import asyncio
from unittest import mock
from lib.jsonrpc import RPCError
from aiorpcx import RPCError
from server.env import Env
from server.controller import Controller
@ -27,8 +27,8 @@ async def coro(res):
return res
def raise_exception(exc, msg):
raise exc(msg)
def raise_exception(msg):
raise RPCError(1, msg)
def ensure_text_exception(test, exception):
@ -82,7 +82,8 @@ def test_transaction_get():
env = set_env()
sut = Controller(env)
sut.daemon_request = mock.Mock()
sut.daemon_request.return_value = coro(raise_exception(RPCError, 'some unhandled error'))
sut.daemon_request.return_value = coro(
raise_exception('some unhandled error'))
await sut.transaction_get('ff' * 32, True)
async def test_wrong_txhash():

Loading…
Cancel
Save