You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

284 lines
10 KiB

8 years ago
# See the file "LICENSE" for information about the copyright
8 years ago
# and warranty status of this software.
import asyncio
8 years ago
import json
8 years ago
import signal
import traceback
8 years ago
from functools import partial
8 years ago
import aiohttp
8 years ago
from server.db import DB
from server.protocol import ElectrumX, LocalRPC
from lib.hash import (sha256, double_sha256, hash_to_str,
Base58, hex_str_to_hash)
from lib.util import LoggedClass
8 years ago
class Controller(LoggedClass):
8 years ago
def __init__(self, env):
super().__init__()
8 years ago
self.env = env
self.db = DB(env)
self.block_cache = BlockCache(env, self.db)
self.servers = []
self.sessions = set()
self.addresses = {}
self.jobs = set()
self.peers = {}
def start(self, loop):
env = self.env
if False:
protocol = partial(LocalRPC, self)
if env.rpc_port is not None:
host = 'localhost'
rpc_server = loop.create_server(protocol, host, env.rpc_port)
self.servers.append(loop.run_until_complete(rpc_server))
self.logger.info('RPC server listening on {}:{:d}'
.format(host, env.rpc_port))
protocol = partial(ElectrumX, self, env)
if env.tcp_port is not None:
tcp_server = loop.create_server(protocol, env.host, env.tcp_port)
self.servers.append(loop.run_until_complete(tcp_server))
self.logger.info('TCP server listening on {}:{:d}'
.format(env.host, env.tcp_port))
if env.ssl_port is not None:
ssl_server = loop.create_server(protocol, env.host, env.ssl_port)
self.servers.append(loop.run_until_complete(ssl_server))
self.logger.info('SSL server listening on {}:{:d}'
.format(env.host, env.ssl_port))
coros = [
self.block_cache.prefetcher(),
self.block_cache.process_blocks(),
]
for coro in coros:
asyncio.ensure_future(coro)
8 years ago
# Signal handlers
for signame in ('SIGINT', 'SIGTERM'):
loop.add_signal_handler(getattr(signal, signame),
partial(self.on_signal, loop, signame))
def stop(self):
for server in self.servers:
server.close()
def on_signal(self, loop, signame):
self.logger.warning('received {} signal, preparing to shut down'
.format(signame))
for task in asyncio.Task.all_tasks(loop):
task.cancel()
def add_session(self, session):
self.sessions.add(session)
def remove_session(self, session):
self.sessions.remove(session)
def add_job(self, coro):
'''Queue a job for asynchronous processing.'''
self.jobs.add(asyncio.ensure_future(coro))
async def reap_jobs(self):
while True:
jobs = set()
for job in self.jobs:
if job.done():
try:
job.result()
except Exception as e:
traceback.print_exc()
else:
jobs.add(job)
self.logger.info('reaped {:d} jobs, {:d} jobs pending'
.format(len(self.jobs) - len(jobs), len(jobs)))
self.jobs = jobs
await asyncio.sleep(5)
def address_status(self, hash168):
'''Returns status as 32 bytes.'''
status = self.addresses.get(hash168)
if status is None:
status = ''.join(
'{}:{:d}:'.format(hash_to_str(tx_hash), height)
for tx_hash, height in self.db.get_history(hash168)
)
if status:
status = sha256(status.encode())
self.addresses[hash168] = status
return status
async def get_merkle(self, tx_hash, height):
'''tx_hash is a hex string.'''
daemon_send = self.block_cache.send_single
block_hash = await daemon_send('getblockhash', (height,))
block = await daemon_send('getblock', (block_hash, True))
tx_hashes = block['tx']
# This will throw if the tx_hash is bad
pos = tx_hashes.index(tx_hash)
idx = pos
hashes = [hex_str_to_hash(txh) for txh in tx_hashes]
merkle_branch = []
while len(hashes) > 1:
if len(hashes) & 1:
hashes.append(hashes[-1])
idx = idx - 1 if (idx & 1) else idx + 1
merkle_branch.append(hash_to_str(hashes[idx]))
idx //= 2
hashes = [double_sha256(hashes[n] + hashes[n + 1])
for n in range(0, len(hashes), 2)]
return {"block_height": height, "merkle": merkle_branch, "pos": pos}
def get_peers(self):
'''Returns a dictionary of IRC nick to (ip, host, ports) tuples, one
per peer.'''
return self.peers
8 years ago
class BlockCache(LoggedClass):
'''Requests and caches blocks ahead of time from the daemon. Serves
them to the blockchain processor. Coordinates backing up in case of
block chain reorganisations.
'''
8 years ago
class DaemonError:
pass
def __init__(self, env, db):
super().__init__()
8 years ago
self.db = db
self.daemon_url = env.daemon_url
# Target cache size. Has little effect on sync time.
self.target_cache_size = 10 * 1024 * 1024
8 years ago
self.daemon_height = 0
self.fetched_height = db.height
self.queue = asyncio.Queue()
self.queue_size = 0
self.recent_sizes = [0]
self.logger.info('using daemon URL {}'.format(self.daemon_url))
8 years ago
def flush_db(self):
self.db.flush(self.daemon_height, True)
8 years ago
async def process_blocks(self):
try:
while True:
blocks, total_size = await self.queue.get()
self.queue_size -= total_size
for block in blocks:
self.db.process_block(block, self.daemon_height)
# Release asynchronous block fetching
await asyncio.sleep(0)
if self.db.height == self.daemon_height:
self.logger.info('caught up to height {:d}'
.format(self.daemon_height))
self.flush_db()
finally:
self.flush_db()
async def prefetcher(self):
'''Loops forever polling for more blocks.'''
self.logger.info('prefetching blocks...')
while True:
try:
await self.maybe_prefetch()
except self.DaemonError:
pass
await asyncio.sleep(2)
8 years ago
def cache_used(self):
return sum(len(block) for block in self.blocks)
def prefill_count(self, room):
ave_size = sum(self.recent_sizes) // len(self.recent_sizes)
count = room // ave_size if ave_size else 0
8 years ago
return max(count, 10)
async def maybe_prefetch(self):
'''Prefetch blocks if there are any to prefetch.'''
while self.queue_size < self.target_cache_size:
8 years ago
# Keep going by getting a whole new cache_limit of blocks
self.daemon_height = await self.send_single('getblockcount')
8 years ago
max_count = min(self.daemon_height - self.fetched_height, 4000)
count = min(max_count, self.prefill_count(self.target_cache_size))
if not count:
break
8 years ago
first = self.fetched_height + 1
param_lists = [[height] for height in range(first, first + count)]
hashes = await self.send_vector('getblockhash', param_lists)
8 years ago
# Hashes is an array of hex strings
param_lists = [(h, False) for h in hashes]
blocks = await self.send_vector('getblock', param_lists)
8 years ago
self.fetched_height += count
# Convert hex string to bytes
8 years ago
blocks = [bytes.fromhex(block) for block in blocks]
sizes = [len(block) for block in blocks]
total_size = sum(sizes)
self.queue.put_nowait((blocks, total_size))
self.queue_size += total_size
8 years ago
# Keep 50 most recent block sizes for fetch count estimation
self.recent_sizes.extend(sizes)
excess = len(self.recent_sizes) - 50
if excess > 0:
self.recent_sizes = self.recent_sizes[excess:]
async def send_single(self, method, params=None):
8 years ago
payload = {'method': method}
if params:
payload['params'] = params
result, = await self.send((payload, ))
return result
async def send_many(self, mp_pairs):
payload = [{'method': method, 'params': params}
for method, params in mp_pairs]
return await self.send(payload)
async def send_vector(self, method, params_list):
payload = [{'method': method, 'params': params}
for params in params_list]
return await self.send(payload)
async def send(self, payload):
assert isinstance(payload, (tuple, list))
data = json.dumps(payload)
8 years ago
while True:
try:
async with aiohttp.post(self.daemon_url, data=data) as resp:
result = await resp.json()
except asyncio.CancelledError:
raise
8 years ago
except Exception as e:
msg = 'aiohttp error: {}'.format(e)
secs = 3
else:
errs = tuple(item['error'] for item in result)
if not any(errs):
return tuple(item['result'] for item in result)
if any(err.get('code') == -28 for err in errs):
8 years ago
msg = 'daemon still warming up.'
secs = 30
else:
msg = 'daemon errors: {}'.format(errs)
raise self.DaemonError(msg)
8 years ago
self.logger.error('{}. Sleeping {:d}s and trying again...'
.format(msg, secs))
await asyncio.sleep(secs)