Browse Source

Split out daemon handler into separate file.

master
Neil Booth 9 years ago
parent
commit
8452d0c016
  1. 85
      server/controller.py
  2. 70
      server/daemon.py
  3. 14
      server/protocol.py

85
server/controller.py

@ -2,13 +2,11 @@
# and warranty status of this software. # and warranty status of this software.
import asyncio import asyncio
import json
import signal import signal
import traceback import traceback
from functools import partial from functools import partial
import aiohttp from server.daemon import Daemon, DaemonError
from server.db import DB from server.db import DB
from server.protocol import ElectrumX, LocalRPC from server.protocol import ElectrumX, LocalRPC
from lib.hash import (sha256, double_sha256, hash_to_str, from lib.hash import (sha256, double_sha256, hash_to_str,
@ -19,10 +17,15 @@ from lib.util import LoggedClass
class Controller(LoggedClass): class Controller(LoggedClass):
def __init__(self, env): def __init__(self, env):
'''Create up the controller.
Creates DB, Daemon and BlockCache instances.
'''
super().__init__() super().__init__()
self.env = env self.env = env
self.db = DB(env) self.db = DB(env)
self.block_cache = BlockCache(env, self.db) self.daemon = Daemon(env.daemon_url)
self.block_cache = BlockCache(self.db, self.daemon)
self.servers = [] self.servers = []
self.sessions = set() self.sessions = set()
self.addresses = {} self.addresses = {}
@ -30,6 +33,7 @@ class Controller(LoggedClass):
self.peers = {} self.peers = {}
def start(self, loop): def start(self, loop):
'''Prime the event loop with asynchronous servers and jobs.'''
env = self.env env = self.env
if False: if False:
@ -41,7 +45,7 @@ class Controller(LoggedClass):
self.logger.info('RPC server listening on {}:{:d}' self.logger.info('RPC server listening on {}:{:d}'
.format(host, env.rpc_port)) .format(host, env.rpc_port))
protocol = partial(ElectrumX, self, env) protocol = partial(ElectrumX, self, self.db, self.daemon, env)
if env.tcp_port is not None: if env.tcp_port is not None:
tcp_server = loop.create_server(protocol, env.host, env.tcp_port) tcp_server = loop.create_server(protocol, env.host, env.tcp_port)
self.servers.append(loop.run_until_complete(tcp_server)) self.servers.append(loop.run_until_complete(tcp_server))
@ -68,10 +72,12 @@ class Controller(LoggedClass):
partial(self.on_signal, loop, signame)) partial(self.on_signal, loop, signame))
def stop(self): def stop(self):
'''Close the listening servers.'''
for server in self.servers: for server in self.servers:
server.close() server.close()
def on_signal(self, loop, signame): def on_signal(self, loop, signame):
'''Call on receipt of a signal to cleanly shutdown.'''
self.logger.warning('received {} signal, preparing to shut down' self.logger.warning('received {} signal, preparing to shut down'
.format(signame)) .format(signame))
for task in asyncio.Task.all_tasks(loop): for task in asyncio.Task.all_tasks(loop):
@ -119,9 +125,8 @@ class Controller(LoggedClass):
async def get_merkle(self, tx_hash, height): async def get_merkle(self, tx_hash, height):
'''tx_hash is a hex string.''' '''tx_hash is a hex string.'''
daemon_send = self.block_cache.send_single block_hash = await self.daemon.send_single('getblockhash', (height,))
block_hash = await daemon_send('getblockhash', (height,)) block = await self.daemon.send_single('getblock', (block_hash, True))
block = await daemon_send('getblock', (block_hash, True))
tx_hashes = block['tx'] tx_hashes = block['tx']
# This will throw if the tx_hash is bad # This will throw if the tx_hash is bad
pos = tx_hashes.index(tx_hash) pos = tx_hashes.index(tx_hash)
@ -151,13 +156,10 @@ class BlockCache(LoggedClass):
block chain reorganisations. block chain reorganisations.
''' '''
class DaemonError: def __init__(self, db, daemon):
pass
def __init__(self, env, db):
super().__init__() super().__init__()
self.db = db self.db = db
self.daemon_url = env.daemon_url self.daemon = daemon
# Target cache size. Has little effect on sync time. # Target cache size. Has little effect on sync time.
self.target_cache_size = 10 * 1024 * 1024 self.target_cache_size = 10 * 1024 * 1024
self.daemon_height = 0 self.daemon_height = 0
@ -166,8 +168,6 @@ class BlockCache(LoggedClass):
self.queue_size = 0 self.queue_size = 0
self.recent_sizes = [0] self.recent_sizes = [0]
self.logger.info('using daemon URL {}'.format(self.daemon_url))
def flush_db(self): def flush_db(self):
self.db.flush(self.daemon_height, True) self.db.flush(self.daemon_height, True)
@ -194,8 +194,8 @@ class BlockCache(LoggedClass):
while True: while True:
try: try:
await self.maybe_prefetch() await self.maybe_prefetch()
except self.DaemonError: except DaemonError as e:
pass self.logger.info('ignoring daemon errors: {}'.format(e))
await asyncio.sleep(2) await asyncio.sleep(2)
def cache_used(self): def cache_used(self):
@ -208,9 +208,10 @@ class BlockCache(LoggedClass):
async def maybe_prefetch(self): async def maybe_prefetch(self):
'''Prefetch blocks if there are any to prefetch.''' '''Prefetch blocks if there are any to prefetch.'''
daemon = self.daemon
while self.queue_size < self.target_cache_size: while self.queue_size < self.target_cache_size:
# Keep going by getting a whole new cache_limit of blocks # Keep going by getting a whole new cache_limit of blocks
self.daemon_height = await self.send_single('getblockcount') self.daemon_height = await daemon.send_single('getblockcount')
max_count = min(self.daemon_height - self.fetched_height, 4000) max_count = min(self.daemon_height - self.fetched_height, 4000)
count = min(max_count, self.prefill_count(self.target_cache_size)) count = min(max_count, self.prefill_count(self.target_cache_size))
if not count: if not count:
@ -218,11 +219,11 @@ class BlockCache(LoggedClass):
first = self.fetched_height + 1 first = self.fetched_height + 1
param_lists = [[height] for height in range(first, first + count)] param_lists = [[height] for height in range(first, first + count)]
hashes = await self.send_vector('getblockhash', param_lists) hashes = await daemon.send_vector('getblockhash', param_lists)
# Hashes is an array of hex strings # Hashes is an array of hex strings
param_lists = [(h, False) for h in hashes] param_lists = [(h, False) for h in hashes]
blocks = await self.send_vector('getblock', param_lists) blocks = await daemon.send_vector('getblock', param_lists)
self.fetched_height += count self.fetched_height += count
# Convert hex string to bytes # Convert hex string to bytes
@ -237,47 +238,3 @@ class BlockCache(LoggedClass):
excess = len(self.recent_sizes) - 50 excess = len(self.recent_sizes) - 50
if excess > 0: if excess > 0:
self.recent_sizes = self.recent_sizes[excess:] self.recent_sizes = self.recent_sizes[excess:]
async def send_single(self, method, params=None):
payload = {'method': method}
if params:
payload['params'] = params
result, = await self.send((payload, ))
return result
async def send_many(self, mp_pairs):
payload = [{'method': method, 'params': params}
for method, params in mp_pairs]
return await self.send(payload)
async def send_vector(self, method, params_list):
payload = [{'method': method, 'params': params}
for params in params_list]
return await self.send(payload)
async def send(self, payload):
assert isinstance(payload, (tuple, list))
data = json.dumps(payload)
while True:
try:
async with aiohttp.post(self.daemon_url, data=data) as resp:
result = await resp.json()
except asyncio.CancelledError:
raise
except Exception as e:
msg = 'aiohttp error: {}'.format(e)
secs = 3
else:
errs = tuple(item['error'] for item in result)
if not any(errs):
return tuple(item['result'] for item in result)
if any(err.get('code') == -28 for err in errs):
msg = 'daemon still warming up.'
secs = 30
else:
msg = 'daemon errors: {}'.format(errs)
raise self.DaemonError(msg)
self.logger.error('{}. Sleeping {:d}s and trying again...'
.format(msg, secs))
await asyncio.sleep(secs)

70
server/daemon.py

@ -0,0 +1,70 @@
# See the file "LICENSE" for information about the copyright
# and warranty status of this software.
'''Classes for handling asynchronous connections to a blockchain
daemon.'''
import asyncio
import json
import aiohttp
from lib.util import LoggedClass
class DaemonError(Exception):
'''Raised when the daemon returns an error in its results that
cannot be remedied by retrying.'''
class Daemon(LoggedClass):
'''Handles connections to a daemon at the given URL.'''
def __init__(self, url):
super().__init__()
self.url = url
self.logger.info('connecting to daemon at URL {}'.format(url))
async def send_single(self, method, params=None):
payload = {'method': method}
if params:
payload['params'] = params
result, = await self.send((payload, ))
return result
async def send_many(self, mp_pairs):
payload = [{'method': method, 'params': params}
for method, params in mp_pairs]
return await self.send(payload)
async def send_vector(self, method, params_list):
payload = [{'method': method, 'params': params}
for params in params_list]
return await self.send(payload)
async def send(self, payload):
assert isinstance(payload, (tuple, list))
data = json.dumps(payload)
while True:
try:
async with aiohttp.post(self.url, data=data) as resp:
result = await resp.json()
except asyncio.CancelledError:
raise
except Exception as e:
msg = 'aiohttp error: {}'.format(e)
secs = 3
else:
errs = tuple(item['error'] for item in result)
if not any(errs):
return tuple(item['result'] for item in result)
if any(err.get('code') == -28 for err in errs):
msg = 'daemon still warming up.'
secs = 30
else:
msg = '{}'.format(errs)
raise DaemonError(msg)
self.logger.error('{}. Sleeping {:d}s and trying again...'
.format(msg, secs))
await asyncio.sleep(secs)

14
server/protocol.py

@ -100,11 +100,12 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
class ElectrumX(JSONRPC): class ElectrumX(JSONRPC):
'''A TCP server that handles incoming Electrum connections.'''
def __init__(self, controller, env): def __init__(self, controller, db, daemon, env):
super().__init__(controller) super().__init__(controller)
self.BC = controller.block_cache self.db = db
self.db = controller.db self.daemon = daemon
self.env = env self.env = env
self.addresses = set() self.addresses = set()
self.subscribe_headers = False self.subscribe_headers = False
@ -134,7 +135,7 @@ class ElectrumX(JSONRPC):
return status.hex() if status else None return status.hex() if status else None
async def handle_blockchain_estimatefee(self, params): async def handle_blockchain_estimatefee(self, params):
result = await self.BC.send_single('estimatefee', params) result = await self.daemon.send_single('estimatefee', params)
return result return result
async def handle_blockchain_headers_subscribe(self, params): async def handle_blockchain_headers_subscribe(self, params):
@ -145,7 +146,7 @@ class ElectrumX(JSONRPC):
'''The minimum fee a low-priority tx must pay in order to be accepted '''The minimum fee a low-priority tx must pay in order to be accepted
to this daemon's memory pool. to this daemon's memory pool.
''' '''
net_info = await self.BC.send_single('getnetworkinfo') net_info = await self.daemon.send_single('getnetworkinfo')
return net_info['relayfee'] return net_info['relayfee']
async def handle_blockchain_transaction_get(self, params): async def handle_blockchain_transaction_get(self, params):
@ -153,7 +154,7 @@ class ElectrumX(JSONRPC):
raise Error(Error.BAD_REQUEST, raise Error(Error.BAD_REQUEST,
'params should contain a transaction hash') 'params should contain a transaction hash')
tx_hash = params[0] tx_hash = params[0]
return await self.BC.send_single('getrawtransaction', (tx_hash, 0)) return await self.daemon.send_single('getrawtransaction', (tx_hash, 0))
async def handle_blockchain_transaction_get_merkle(self, params): async def handle_blockchain_transaction_get_merkle(self, params):
if len(params) != 2: if len(params) != 2:
@ -196,6 +197,7 @@ class ElectrumX(JSONRPC):
class LocalRPC(JSONRPC): class LocalRPC(JSONRPC):
'''A local TCP RPC server for querying status.'''
async def handle_getinfo(self, params): async def handle_getinfo(self, params):
return { return {

Loading…
Cancel
Save