Browse Source

Split out server and session management

master
Neil Booth 8 years ago
parent
commit
7523735f99
  1. 329
      server/protocol.py

329
server/protocol.py

@ -27,28 +27,43 @@ from server.version import VERSION
class BlockServer(BlockProcessor): class BlockServer(BlockProcessor):
'''Like BlockProcessor but also starts servers when caught up.''' '''Like BlockProcessor but also has a server manager and starts
servers when caught up.'''
def __init__(self, env): def __init__(self, env):
super().__init__(env) super().__init__(env)
self.servers = [] self.server_mgr = ServerManager(self, env)
self.irc = IRC(env)
async def caught_up(self, mempool_hashes): async def caught_up(self, mempool_hashes):
await super().caught_up(mempool_hashes) await super().caught_up(mempool_hashes)
if not self.servers: self.server_mgr.notify(self.height, self.touched)
await self.start_servers()
if self.env.irc: def stop(self):
self.logger.info('starting IRC coroutine') '''Close the listening servers.'''
asyncio.ensure_future(self.irc.start()) self.server_mgr.stop()
else:
self.logger.info('IRC disabled')
ElectrumX.notify(self.height, self.touched) class ServerManager(LoggedClass):
'''Manages the servers.'''
AsyncTask = namedtuple('AsyncTask', 'session job')
def __init__(self, bp, env):
super().__init__()
self.bp = bp
self.env = env
self.servers = []
self.irc = IRC(env)
self.sessions = set()
self.tasks = asyncio.Queue()
self.current_task = None
async def start_server(self, class_name, kind, host, port, *, ssl=None): async def start_server(self, kind, *args, **kw_args):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
protocol = partial(class_name, self.env, kind) protocol_class = LocalRPC if kind == 'RPC' else ElectrumX
server = loop.create_server(protocol, host, port, ssl=ssl) protocol = partial(protocol_class, self, self.bp, self.env, kind)
server = loop.create_server(protocol, *args, **kw_args)
try: try:
self.servers.append(await server) self.servers.append(await server)
except asyncio.CancelledError: except asyncio.CancelledError:
@ -61,44 +76,49 @@ class BlockServer(BlockProcessor):
.format(kind, host, port)) .format(kind, host, port))
async def start_servers(self): async def start_servers(self):
'''Start listening on RPC, TCP and SSL ports. '''Connect to IRC and start listening for incoming connections.
Does not start a server if the port wasn't specified. Only connect to IRC if enabled. Start listening on RCP, TCP
and SSL ports only if the port wasn pecified.
''' '''
env = self.env env = self.env
Session.init(self, self.daemon, self.coin)
if env.rpc_port is not None: if env.rpc_port is not None:
await self.start_server(LocalRPC, 'RPC', 'localhost', env.rpc_port) await self.start_server('RPC', 'localhost', env.rpc_port)
if env.tcp_port is not None: if env.tcp_port is not None:
await self.start_server(ElectrumX, 'TCP', env.host, env.tcp_port) await self.start_server('TCP', env.host, env.tcp_port)
if env.ssl_port is not None: if env.ssl_port is not None:
# FIXME: update if we want to require Python >= 3.5.3 # FIXME: update if we want to require Python >= 3.5.3
sslc = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) sslc = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
sslc.load_cert_chain(env.ssl_certfile, keyfile=env.ssl_keyfile) sslc.load_cert_chain(env.ssl_certfile, keyfile=env.ssl_keyfile)
await self.start_server(ElectrumX, 'SSL', env.host, await self.start_server('SSL', env.host, env.ssl_port, ssl=sslc)
env.ssl_port, ssl=sslc)
def stop(self): asyncio.ensure_future(self.run_tasks())
'''Close the listening servers.'''
for server in self.servers:
server.close()
def irc_peers(self): if env.irc:
return self.irc.peers self.logger.info('starting IRC coroutine')
asyncio.ensure_future(self.irc.start())
else:
self.logger.info('IRC disabled')
async def notify(self, height, touched):
'''Notify electrum clients about height changes and touched addresses.
AsyncTask = namedtuple('AsyncTask', 'session job') Start listening if not yet listening.
'''
if not self.servers:
await self.start_servers()
class SessionManager(LoggedClass): sessions = [session for session in self.sessions
if isinstance(session, ElectrumX)]
self.ElectrumX.notify(sessions, height, touched)
def __init__(self): def stop(self):
super().__init__() '''Close the listening servers.'''
self.sessions = set() for server in self.servers:
self.tasks = asyncio.Queue() server.close()
self.current_task = None
asyncio.ensure_future(self.run_tasks())
def add_session(self, session): def add_session(self, session):
assert session not in self.sessions assert session not in self.sessions
@ -113,7 +133,7 @@ class SessionManager(LoggedClass):
def add_task(self, session, job): def add_task(self, session, job):
assert session in self.sessions assert session in self.sessions
task = asyncio.ensure_future(job) task = asyncio.ensure_future(job)
self.tasks.put_nowait(AsyncTask(session, task)) self.tasks.put_nowait(self.AsyncTask(session, task))
async def run_tasks(self): async def run_tasks(self):
'''Asynchronously run through the task queue.''' '''Asynchronously run through the task queue.'''
@ -133,22 +153,55 @@ class SessionManager(LoggedClass):
finally: finally:
self.current_task = None self.current_task = None
def irc_peers(self):
return self.irc.peers
def session_count(self):
return len(self.manager.sessions)
def info(self):
'''Returned in the RPC 'getinfo' call.'''
address_count = sum(len(session.hash168s)
for session in self.sessions
if isinstance(session, ElectrumX))
return {
'blocks': self.bp.height,
'peers': len(self.irc_peers()),
'sessions': self.session_count(),
'watched': address_count,
'cached': 0,
}
def sessions_info(self):
'''Returned to the RPC 'sessions' call.'''
now = time.time()
return [(session.kind,
session.peername(),
len(session.hash168s),
'RPC' if isinstance(session, LocalRPC) else session.client,
now - session.start)
for session in self.sessions]
class Session(JSONRPC): class Session(JSONRPC):
'''Base class of ElectrumX JSON session protocols.''' '''Base class of ElectrumX JSON session protocols.'''
def __init__(self, env, kind): def __init__(self, manager, bp, env, kind):
super().__init__() super().__init__()
self.hash168s = set() self.manager = manager
self.client = 'unknown' self.bp = bp
self.env = env self.env = env
self.daemon = bp.daemon
self.coin = bp.coin
self.kind = kind self.kind = kind
self.hash168s = set()
self.client = 'unknown'
def connection_made(self, transport): def connection_made(self, transport):
'''Handle an incoming client connection.''' '''Handle an incoming client connection.'''
super().connection_made(transport) super().connection_made(transport)
self.logger.info('connection from {}'.format(self.peername())) self.logger.info('connection from {}'.format(self.peername()))
self.SESSION_MGR.add_session(self) self.manager.add_session(self)
def connection_lost(self, exc): def connection_lost(self, exc):
'''Handle client disconnection.''' '''Handle client disconnection.'''
@ -158,7 +211,7 @@ class Session(JSONRPC):
'Sent {:,d} bytes in {:,d} messages {:,d} errors' 'Sent {:,d} bytes in {:,d} messages {:,d} errors'
.format(self.peername(), self.send_size, .format(self.peername(), self.send_size,
self.send_count, self.error_count)) self.send_count, self.error_count))
self.SESSION_MGR.remove_session(self) self.maanger.remove_session(self)
def method_handler(self, method): def method_handler(self, method):
'''Return the handler that will handle the RPC method.''' '''Return the handler that will handle the RPC method.'''
@ -166,14 +219,13 @@ class Session(JSONRPC):
def on_json_request(self, request): def on_json_request(self, request):
'''Queue the request for asynchronous handling.''' '''Queue the request for asynchronous handling.'''
self.SESSION_MGR.add_task(self, self.handle_json_request(request)) self.manager.add_task(self, self.handle_json_request(request))
def peername(self): def peername(self):
info = self.peer_info() info = self.peer_info()
return 'unknown' if not info else '{}:{}'.format(info[0], info[1]) return 'unknown' if not info else '{}:{}'.format(info[0], info[1])
@classmethod def tx_hash_from_param(self, param):
def tx_hash_from_param(cls, param):
'''Raise an RPCError if the parameter is not a valid transaction '''Raise an RPCError if the parameter is not a valid transaction
hash.''' hash.'''
if isinstance(param, str) and len(param) == 64: if isinstance(param, str) and len(param) == 64:
@ -185,17 +237,15 @@ class Session(JSONRPC):
raise RPCError('parameter should be a transaction hash: {}' raise RPCError('parameter should be a transaction hash: {}'
.format(param)) .format(param))
@classmethod def hash168_from_param(self, param):
def hash168_from_param(cls, param):
if isinstance(param, str): if isinstance(param, str):
try: try:
return cls.COIN.address_to_hash168(param) return self.coin.address_to_hash168(param)
except: except:
pass pass
raise RPCError('parameter should be a valid address: {}'.format(param)) raise RPCError('parameter should be a valid address: {}'.format(param))
@classmethod def non_negative_integer_from_param(self, param):
def non_negative_integer_from_param(cls, param):
try: try:
param = int(param) param = int(param)
except ValueError: except ValueError:
@ -207,60 +257,28 @@ class Session(JSONRPC):
raise RPCError('param should be a non-negative integer: {}' raise RPCError('param should be a non-negative integer: {}'
.format(param)) .format(param))
@classmethod def extract_hash168(self, params):
def extract_hash168(cls, params):
if len(params) == 1: if len(params) == 1:
return cls.hash168_from_param(params[0]) return self.hash168_from_param(params[0])
raise RPCError('params should contain a single address: {}' raise RPCError('params should contain a single address: {}'
.format(params)) .format(params))
@classmethod def extract_non_negative_integer(self, params):
def extract_non_negative_integer(cls, params):
if len(params) == 1: if len(params) == 1:
return cls.non_negative_integer_from_param(params[0]) return self.non_negative_integer_from_param(params[0])
raise RPCError('params should contain a non-negative integer: {}' raise RPCError('params should contain a non-negative integer: {}'
.format(params)) .format(params))
@classmethod def require_empty_params(self, params):
def require_empty_params(cls, params):
if params: if params:
raise RPCError('params should be empty: {}'.format(params)) raise RPCError('params should be empty: {}'.format(params))
@classmethod
def init(cls, block_processor, daemon, coin):
cls.BLOCK_PROCESSOR = block_processor
cls.DAEMON = daemon
cls.COIN = coin
cls.SESSION_MGR = SessionManager()
@classmethod
def irc_peers(cls):
return cls.BLOCK_PROCESSOR.irc_peers()
@classmethod
def height(cls):
'''Return the current height.'''
return cls.BLOCK_PROCESSOR.height
@classmethod
def electrum_header(cls, height=None):
'''Return the binary header at the given height.'''
if not 0 <= height <= cls.height():
raise RPCError('height {:,d} out of range'.format(height))
header = cls.BLOCK_PROCESSOR.read_headers(height, 1)
return cls.COIN.electrum_header(header, height)
@classmethod
def current_electrum_header(cls):
'''Used as response to a headers subscription request.'''
return cls.electrum_header(cls.height())
class ElectrumX(Session): class ElectrumX(Session):
'''A TCP server that handles incoming Electrum connections.''' '''A TCP server that handles incoming Electrum connections.'''
def __init__(self, env, kind): def __init__(self, *args):
super().__init__(env, kind) super().__init__(*args)
self.subscribe_headers = False self.subscribe_headers = False
self.subscribe_height = False self.subscribe_height = False
self.notified_height = None self.notified_height = None
@ -280,49 +298,57 @@ class ElectrumX(Session):
for suffix in suffixes.split()} for suffix in suffixes.split()}
@classmethod @classmethod
def watched_address_count(cls): def notify(cls, sessions, height, touched):
sessions = cls.SESSION_MGR.sessions headers_payload = height_payload = None
return sum(len(session.hash168s) for session in sessions)
@classmethod for session in sessions:
def notify(cls, height, touched): if height != session.notified_height:
'''Notify electrum clients about height changes and touched session.notified_height = height
addresses.''' if session.subscribe_headers:
if headers_payload is None:
headers_payload = json_notification_payload( headers_payload = json_notification_payload(
'blockchain.headers.subscribe', 'blockchain.headers.subscribe',
(cls.electrum_header(height), ), (session.electrum_header(height), ),
) )
session.send_json(headers_payload)
if session.subscribe_height:
if height_payload is None:
height_payload = json_notification_payload( height_payload = json_notification_payload(
'blockchain.numblocks.subscribe', 'blockchain.numblocks.subscribe',
(height, ), (height, ),
) )
hash168_to_address = cls.COIN.hash168_to_address
for session in cls.SESSION_MGR.sessions:
if not isinstance(session, ElectrumX):
continue
if height != session.notified_height:
session.notified_height = height
if session.subscribe_headers:
session.send_json(headers_payload)
if session.subscribe_height:
session.send_json(height_payload) session.send_json(height_payload)
hash168_to_address = session.coin.hash168_to_address
for hash168 in session.hash168s.intersection(touched): for hash168 in session.hash168s.intersection(touched):
address = hash168_to_address(hash168) address = hash168_to_address(hash168)
status = cls.address_status(hash168) status = session.address_status(hash168)
payload = json_notification_payload( payload = json_notification_payload(
'blockchain.address.subscribe', (address, status)) 'blockchain.address.subscribe', (address, status))
session.send_json(payload) session.send_json(payload)
@classmethod def height(self):
def address_status(cls, hash168): '''Return the block processor's current height.'''
return self.bp.height
def current_electrum_header(self):
'''Used as response to a headers subscription request.'''
return self.electrum_header(self.height())
def electrum_header(self, height):
'''Return the binary header at the given height.'''
if not 0 <= height <= self.height():
raise RPCError('height {:,d} out of range'.format(height))
header = self.bp.read_headers(height, 1)
return self.coin.electrum_header(header, height)
def address_status(self, hash168):
'''Returns status as 32 bytes.''' '''Returns status as 32 bytes.'''
# Note history is ordered and mempool unordered in electrum-server # Note history is ordered and mempool unordered in electrum-server
# For mempool, height is -1 if unconfirmed txins, otherwise 0 # For mempool, height is -1 if unconfirmed txins, otherwise 0
history = cls.BLOCK_PROCESSOR.get_history(hash168) history = self.bp.get_history(hash168)
mempool = cls.BLOCK_PROCESSOR.mempool_transactions(hash168) mempool = self.bp.mempool_transactions(hash168)
status = ''.join('{}:{:d}:'.format(hash_to_str(tx_hash), height) status = ''.join('{}:{:d}:'.format(hash_to_str(tx_hash), height)
for tx_hash, height in history) for tx_hash, height in history)
@ -332,11 +358,10 @@ class ElectrumX(Session):
return sha256(status.encode()).hex() return sha256(status.encode()).hex()
return None return None
@classmethod async def tx_merkle(self, tx_hash, height):
async def tx_merkle(cls, tx_hash, height):
'''tx_hash is a hex string.''' '''tx_hash is a hex string.'''
hex_hashes = await cls.DAEMON.block_hex_hashes(height, 1) hex_hashes = await self.daemon.block_hex_hashes(height, 1)
block = await cls.DAEMON.deserialised_block(hex_hashes[0]) block = await self.daemon.deserialised_block(hex_hashes[0])
tx_hashes = block['tx'] tx_hashes = block['tx']
# This will throw if the tx_hash is bad # This will throw if the tx_hash is bad
pos = tx_hashes.index(tx_hash) pos = tx_hashes.index(tx_hash)
@ -355,16 +380,11 @@ class ElectrumX(Session):
return {"block_height": height, "merkle": merkle_branch, "pos": pos} return {"block_height": height, "merkle": merkle_branch, "pos": pos}
@classmethod def get_history(self, hash168):
def height(cls):
return cls.BLOCK_PROCESSOR.height
@classmethod
def get_history(cls, hash168):
# Note history is ordered and mempool unordered in electrum-server # Note history is ordered and mempool unordered in electrum-server
# For mempool, height is -1 if unconfirmed txins, otherwise 0 # For mempool, height is -1 if unconfirmed txins, otherwise 0
history = cls.BLOCK_PROCESSOR.get_history(hash168, limit=None) history = self.bp.get_history(hash168, limit=None)
mempool = cls.BLOCK_PROCESSOR.mempool_transactions(hash168) mempool = self.bp.mempool_transactions(hash168)
conf = tuple({'tx_hash': hash_to_str(tx_hash), 'height': height} conf = tuple({'tx_hash': hash_to_str(tx_hash), 'height': height}
for tx_hash, height in history) for tx_hash, height in history)
@ -372,24 +392,21 @@ class ElectrumX(Session):
for tx_hash, fee, unconfirmed in mempool) for tx_hash, fee, unconfirmed in mempool)
return conf + unconf return conf + unconf
@classmethod def get_chunk(self, index):
def get_chunk(cls, index):
'''Return header chunk as hex. Index is a non-negative integer.''' '''Return header chunk as hex. Index is a non-negative integer.'''
chunk_size = cls.COIN.CHUNK_SIZE chunk_size = self.coin.CHUNK_SIZE
next_height = cls.height() + 1 next_height = self.height() + 1
start_height = min(index * chunk_size, next_height) start_height = min(index * chunk_size, next_height)
count = min(next_height - start_height, chunk_size) count = min(next_height - start_height, chunk_size)
return cls.BLOCK_PROCESSOR.read_headers(start_height, count).hex() return self.bp.read_headers(start_height, count).hex()
@classmethod def get_balance(self, hash168):
def get_balance(cls, hash168): confirmed = self.bp.get_balance(hash168)
confirmed = cls.BLOCK_PROCESSOR.get_balance(hash168) unconfirmed = self.bp.mempool_value(hash168)
unconfirmed = cls.BLOCK_PROCESSOR.mempool_value(hash168)
return {'confirmed': confirmed, 'unconfirmed': unconfirmed} return {'confirmed': confirmed, 'unconfirmed': unconfirmed}
@classmethod def list_unspent(self, hash168):
def list_unspent(cls, hash168): utxos = self.bp.get_utxos_sorted(hash168)
utxos = cls.BLOCK_PROCESSOR.get_utxos_sorted(hash168)
return tuple({'tx_hash': hash_to_str(utxo.tx_hash), return tuple({'tx_hash': hash_to_str(utxo.tx_hash),
'tx_pos': utxo.tx_pos, 'height': utxo.height, 'tx_pos': utxo.tx_pos, 'height': utxo.height,
'value': utxo.value} 'value': utxo.value}
@ -431,7 +448,7 @@ class ElectrumX(Session):
return self.electrum_header(height) return self.electrum_header(height)
async def estimatefee(self, params): async def estimatefee(self, params):
return await self.DAEMON.estimatefee(params) return await self.daemon.estimatefee(params)
async def headers_subscribe(self, params): async def headers_subscribe(self, params):
self.require_empty_params(params) self.require_empty_params(params)
@ -447,7 +464,7 @@ class ElectrumX(Session):
'''The minimum fee a low-priority tx must pay in order to be accepted '''The minimum fee a low-priority tx must pay in order to be accepted
to the daemon's memory pool.''' to the daemon's memory pool.'''
self.require_empty_params(params) self.require_empty_params(params)
return await self.DAEMON.relayfee() return await self.daemon.relayfee()
async def transaction_broadcast(self, params): async def transaction_broadcast(self, params):
'''Pass through the parameters to the daemon. '''Pass through the parameters to the daemon.
@ -458,7 +475,7 @@ class ElectrumX(Session):
user interface job here. user interface job here.
''' '''
try: try:
tx_hash = await self.DAEMON.sendrawtransaction(params) tx_hash = await self.daemon.sendrawtransaction(params)
self.logger.info('sent tx: {}'.format(tx_hash)) self.logger.info('sent tx: {}'.format(tx_hash))
return tx_hash return tx_hash
except DaemonError as e: except DaemonError as e:
@ -483,7 +500,7 @@ class ElectrumX(Session):
# in anticipation it might be dropped in the future. # in anticipation it might be dropped in the future.
if 1 <= len(params) <= 2: if 1 <= len(params) <= 2:
tx_hash = self.tx_hash_from_param(params[0]) tx_hash = self.tx_hash_from_param(params[0])
return await self.DAEMON.getrawtransaction(tx_hash) return await self.daemon.getrawtransaction(tx_hash)
raise RPCError('params wrong length: {}'.format(params)) raise RPCError('params wrong length: {}'.format(params))
@ -500,9 +517,9 @@ class ElectrumX(Session):
tx_hash = self.tx_hash_from_param(params[0]) tx_hash = self.tx_hash_from_param(params[0])
index = self.non_negative_integer_from_param(params[1]) index = self.non_negative_integer_from_param(params[1])
tx_hash = hex_str_to_hash(tx_hash) tx_hash = hex_str_to_hash(tx_hash)
hash168 = self.BLOCK_PROCESSOR.get_utxo_hash168(tx_hash, index) hash168 = self.bp.get_utxo_hash168(tx_hash, index)
if hash168: if hash168:
return self.COIN.hash168_to_address(hash168) return self.coin.hash168_to_address(hash168)
return None return None
raise RPCError('params should contain a transaction hash and index') raise RPCError('params should contain a transaction hash and index')
@ -537,7 +554,7 @@ class ElectrumX(Session):
subscription. subscription.
''' '''
self.require_empty_params(params) self.require_empty_params(params)
return list(self.irc_peers().values()) return list(self.manager.irc_peers().values())
async def version(self, params): async def version(self, params):
'''Return the server version as a string.''' '''Return the server version as a string.'''
@ -550,34 +567,22 @@ class ElectrumX(Session):
class LocalRPC(Session): class LocalRPC(Session):
'''A local TCP RPC server for querying status.''' '''A local TCP RPC server for querying status.'''
def __init__(self, env, kind): def __init__(self, *args):
super().__init__(env, kind) super().__init__(*args)
cmds = 'getinfo sessions numsessions peers numpeers'.split() cmds = 'getinfo sessions numsessions peers numpeers'.split()
self.handlers = {cmd: getattr(self, cmd) for cmd in cmds} self.handlers = {cmd: getattr(self, cmd) for cmd in cmds}
async def getinfo(self, params): async def getinfo(self, params):
return { return self.manager.info()
'blocks': self.height(),
'peers': len(self.irc_peers()),
'sessions': len(self.SESSION_MGR.sessions),
'watched': ElectrumX.watched_address_count(),
'cached': 0,
}
async def sessions(self, params): async def sessions(self, params):
now = time.time() return self.manager.sessions_info()
return [(session.kind,
'' if session == self else session.peername(),
len(session.hash168s),
'this RPC client' if session == self else session.client,
now - session.start)
for session in self.SESSION_MGR.sessions]
async def numsessions(self, params): async def numsessions(self, params):
return len(self.SESSION_MGR.sessions) return self.manager.session_count()
async def peers(self, params): async def peers(self, params):
return self.irc_peers() return self.manager.irc_peers()
async def numpeers(self, params): async def numpeers(self, params):
return len(self.irc_peers()) return len(self.manager.irc_peers())

Loading…
Cancel
Save