You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
185 lines
6.3 KiB
185 lines
6.3 KiB
# Copyright (c) 2016, Neil Booth
|
|
#
|
|
# All rights reserved.
|
|
#
|
|
# See the file "LICENCE" for information about the copyright
|
|
# and warranty status of this software.
|
|
|
|
'''Server controller.
|
|
|
|
Coordinates the parts of the server. Serves as a cache for
|
|
client-serving data such as histories.
|
|
'''
|
|
|
|
import asyncio
|
|
import signal
|
|
import ssl
|
|
import traceback
|
|
from functools import partial
|
|
|
|
from server.daemon import Daemon, DaemonError
|
|
from server.block_processor import BlockProcessor
|
|
from server.protocol import ElectrumX, LocalRPC
|
|
from lib.hash import (sha256, double_sha256, hash_to_str,
|
|
Base58, hex_str_to_hash)
|
|
from lib.util import LoggedClass
|
|
|
|
|
|
class Controller(LoggedClass):
|
|
|
|
def __init__(self, loop, env):
|
|
'''Create up the controller.
|
|
|
|
Creates DB, Daemon and BlockProcessor instances.
|
|
'''
|
|
super().__init__()
|
|
self.loop = loop
|
|
self.env = env
|
|
self.daemon = Daemon(env.daemon_url)
|
|
self.block_processor = BlockProcessor(env, self.daemon,
|
|
on_catchup=self.start_servers)
|
|
self.servers = []
|
|
self.sessions = set()
|
|
self.addresses = {}
|
|
self.jobs = asyncio.Queue()
|
|
self.peers = {}
|
|
|
|
def start(self):
|
|
'''Prime the event loop with asynchronous jobs.'''
|
|
coros = self.block_processor.coros()
|
|
coros.append(self.run_jobs())
|
|
|
|
for coro in coros:
|
|
asyncio.ensure_future(coro)
|
|
|
|
# Signal handlers
|
|
for signame in ('SIGINT', 'SIGTERM'):
|
|
self.loop.add_signal_handler(getattr(signal, signame),
|
|
partial(self.on_signal, signame))
|
|
|
|
async def start_servers(self):
|
|
'''Start listening on RPC, TCP and SSL ports.
|
|
|
|
Does not start a server if the port wasn't specified. Does
|
|
nothing if servers are already running.
|
|
'''
|
|
if self.servers:
|
|
return
|
|
|
|
env = self.env
|
|
loop = self.loop
|
|
|
|
protocol = partial(LocalRPC, self)
|
|
if env.rpc_port is not None:
|
|
host = 'localhost'
|
|
rpc_server = loop.create_server(protocol, host, env.rpc_port)
|
|
self.servers.append(await rpc_server)
|
|
self.logger.info('RPC server listening on {}:{:d}'
|
|
.format(host, env.rpc_port))
|
|
|
|
protocol = partial(ElectrumX, self, self.daemon, env)
|
|
if env.tcp_port is not None:
|
|
tcp_server = loop.create_server(protocol, env.host, env.tcp_port)
|
|
self.servers.append(await tcp_server)
|
|
self.logger.info('TCP server listening on {}:{:d}'
|
|
.format(env.host, env.tcp_port))
|
|
|
|
if env.ssl_port is not None:
|
|
# FIXME: update if we want to require Python >= 3.5.3
|
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
|
|
ssl_context.load_cert_chain(env.ssl_certfile,
|
|
keyfile=env.ssl_keyfile)
|
|
ssl_server = loop.create_server(protocol, env.host, env.ssl_port,
|
|
ssl=ssl_context)
|
|
self.servers.append(await ssl_server)
|
|
self.logger.info('SSL server listening on {}:{:d}'
|
|
.format(env.host, env.ssl_port))
|
|
|
|
def stop(self):
|
|
'''Close the listening servers.'''
|
|
for server in self.servers:
|
|
server.close()
|
|
|
|
def on_signal(self, signame):
|
|
'''Call on receipt of a signal to cleanly shutdown.'''
|
|
self.logger.warning('received {} signal, preparing to shut down'
|
|
.format(signame))
|
|
for task in asyncio.Task.all_tasks(self.loop):
|
|
task.cancel()
|
|
|
|
def add_session(self, session):
|
|
'''Add a session representing one incoming connection.'''
|
|
self.sessions.add(session)
|
|
|
|
def remove_session(self, session):
|
|
'''Remove a session.'''
|
|
self.sessions.remove(session)
|
|
|
|
def add_job(self, coro):
|
|
'''Queue a job for asynchronous processing.'''
|
|
self.jobs.put_nowait(coro)
|
|
|
|
async def run_jobs(self):
|
|
'''Asynchronously run through the job queue.'''
|
|
while True:
|
|
job = await self.jobs.get()
|
|
try:
|
|
await job
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception:
|
|
# Getting here should probably be considered a bug and fixed
|
|
traceback.print_exc()
|
|
|
|
def address_status(self, hash168):
|
|
'''Returns status as 32 bytes.'''
|
|
status = self.addresses.get(hash168)
|
|
if status is None:
|
|
history = self.block_processor.get_history(hash168)
|
|
status = ''.join('{}:{:d}:'.format(hash_to_str(tx_hash), height)
|
|
for tx_hash, height in history)
|
|
if status:
|
|
status = sha256(status.encode())
|
|
self.addresses[hash168] = status
|
|
|
|
return status
|
|
|
|
async def get_merkle(self, tx_hash, height):
|
|
'''tx_hash is a hex string.'''
|
|
block_hash = await self.daemon.send_single('getblockhash', (height,))
|
|
block = await self.daemon.send_single('getblock', (block_hash, True))
|
|
tx_hashes = block['tx']
|
|
# This will throw if the tx_hash is bad
|
|
pos = tx_hashes.index(tx_hash)
|
|
|
|
idx = pos
|
|
hashes = [hex_str_to_hash(txh) for txh in tx_hashes]
|
|
merkle_branch = []
|
|
while len(hashes) > 1:
|
|
if len(hashes) & 1:
|
|
hashes.append(hashes[-1])
|
|
idx = idx - 1 if (idx & 1) else idx + 1
|
|
merkle_branch.append(hash_to_str(hashes[idx]))
|
|
idx //= 2
|
|
hashes = [double_sha256(hashes[n] + hashes[n + 1])
|
|
for n in range(0, len(hashes), 2)]
|
|
|
|
return {"block_height": height, "merkle": merkle_branch, "pos": pos}
|
|
|
|
def get_peers(self):
|
|
'''Returns a dictionary of IRC nick to (ip, host, ports) tuples, one
|
|
per peer.'''
|
|
return self.peers
|
|
|
|
def height(self):
|
|
return self.block_processor.height
|
|
|
|
def get_current_header(self):
|
|
return self.block_processor.get_current_header()
|
|
|
|
def get_history(self, hash168):
|
|
history = self.block_processor.get_history(hash168, limit=None)
|
|
return [
|
|
{'tx_hash': hash_to_str(tx_hash), 'height': height}
|
|
for tx_hash, height in history
|
|
]
|
|
|