Browse Source

Enable servers

master
Neil Booth 8 years ago
parent
commit
3d11afbda2
  1. 40
      server/block_processor.py
  2. 74
      server/controller.py
  3. 3
      server/env.py
  4. 38
      server/protocol.py

40
server/block_processor.py

@ -126,17 +126,18 @@ class BlockProcessor(LoggedClass):
Coordinate backing up in case of chain reorganisations.
'''
def __init__(self, env, daemon):
def __init__(self, env, daemon, on_catchup=None):
super().__init__()
self.daemon = daemon
self.on_catchup = on_catchup
# Meta
self.utxo_MB = env.utxo_MB
self.hist_MB = env.hist_MB
self.next_cache_check = 0
self.coin = env.coin
self.caught_up = False
self.have_caught_up = False
self.reorg_limit = env.reorg_limit
# Chain state (initialize to genesis in case of new DB)
@ -192,6 +193,17 @@ class BlockProcessor(LoggedClass):
else:
return [self.start(), self.prefetcher.start()]
async def caught_up(self):
'''Call when we catch up to the daemon's height.'''
# Flush everything when in caught-up state as queries
# are performed on DB and not in-memory.
self.flush(True)
if not self.have_caught_up:
self.have_caught_up = True
self.logger.info('caught up to height {:,d}'.format(self.height))
if self.on_catchup:
await self.on_catchup()
async def start(self):
'''External entry point for block processing.
@ -199,32 +211,26 @@ class BlockProcessor(LoggedClass):
shutdown.
'''
try:
await self.advance_blocks()
# If we're caught up so the start servers immediately
if self.height == await self.daemon.height():
await self.caught_up()
await self.wait_for_blocks()
finally:
self.flush(True)
async def advance_blocks(self):
async def wait_for_blocks(self):
'''Loop forever processing blocks in the forward direction.'''
while True:
blocks = await self.prefetcher.get_blocks()
for block in blocks:
if not self.advance_block(block):
await self.handle_chain_reorg()
self.caught_up = False
self.have_caught_up = False
break
await asyncio.sleep(0) # Yield
if self.height != self.daemon.cached_height():
continue
if not self.caught_up:
self.caught_up = True
self.logger.info('caught up to height {:,d}'
.format(self.height))
# Flush everything when in caught-up state as queries
# are performed on DB not in-memory
self.flush(True)
if self.height == self.daemon.cached_height():
await self.caught_up()
async def force_chain_reorg(self, to_genesis):
try:
@ -360,7 +366,7 @@ class BlockProcessor(LoggedClass):
def flush_state(self, batch):
'''Flush chain state to the batch.'''
if self.caught_up:
if self.have_caught_up:
self.first_sync = False
now = time.time()
self.wall_time += now - self.last_flush

74
server/controller.py

@ -13,6 +13,7 @@ client-serving data such as histories.
import asyncio
import signal
import ssl
import traceback
from functools import partial
@ -35,51 +36,62 @@ class Controller(LoggedClass):
self.loop = loop
self.env = env
self.daemon = Daemon(env.daemon_url)
self.block_processor = BlockProcessor(env, self.daemon)
self.block_processor = BlockProcessor(env, self.daemon,
on_catchup=self.start_servers)
self.servers = []
self.sessions = set()
self.addresses = {}
self.jobs = set()
self.jobs = asyncio.Queue()
self.peers = {}
def start(self):
'''Prime the event loop with asynchronous servers and jobs.'''
env = self.env
loop = self.loop
'''Prime the event loop with asynchronous jobs.'''
coros = self.block_processor.coros()
if False:
self.start_servers()
coros.append(self.reap_jobs())
coros.append(self.run_jobs())
for coro in coros:
asyncio.ensure_future(coro)
# Signal handlers
for signame in ('SIGINT', 'SIGTERM'):
loop.add_signal_handler(getattr(signal, signame),
partial(self.on_signal, signame))
self.loop.add_signal_handler(getattr(signal, signame),
partial(self.on_signal, signame))
async def start_servers(self):
'''Start listening on RPC, TCP and SSL ports.
Does not start a server if the port wasn't specified. Does
nothing if servers are already running.
'''
if self.servers:
return
env = self.env
loop = self.loop
def start_servers(self):
protocol = partial(LocalRPC, self)
if env.rpc_port is not None:
host = 'localhost'
rpc_server = loop.create_server(protocol, host, env.rpc_port)
self.servers.append(loop.run_until_complete(rpc_server))
self.servers.append(await rpc_server)
self.logger.info('RPC server listening on {}:{:d}'
.format(host, env.rpc_port))
protocol = partial(ElectrumX, self, self.daemon, env)
if env.tcp_port is not None:
tcp_server = loop.create_server(protocol, env.host, env.tcp_port)
self.servers.append(loop.run_until_complete(tcp_server))
self.servers.append(await tcp_server)
self.logger.info('TCP server listening on {}:{:d}'
.format(env.host, env.tcp_port))
if env.ssl_port is not None:
ssl_server = loop.create_server(protocol, env.host, env.ssl_port)
self.servers.append(loop.run_until_complete(ssl_server))
# FIXME: update if we want to require Python >= 3.5.3
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
ssl_context.load_cert_chain(env.ssl_certfile,
keyfile=env.ssl_keyfile)
ssl_server = loop.create_server(protocol, env.host, env.ssl_port,
ssl=ssl_context)
self.servers.append(await ssl_server)
self.logger.info('SSL server listening on {}:{:d}'
.format(env.host, env.ssl_port))
@ -96,30 +108,28 @@ class Controller(LoggedClass):
task.cancel()
def add_session(self, session):
'''Add a session representing one incoming connection.'''
self.sessions.add(session)
def remove_session(self, session):
'''Remove a session.'''
self.sessions.remove(session)
def add_job(self, coro):
'''Queue a job for asynchronous processing.'''
self.jobs.add(asyncio.ensure_future(coro))
self.jobs.put_nowait(coro)
async def reap_jobs(self):
async def run_jobs(self):
'''Asynchronously run through the job queue.'''
while True:
jobs = set()
for job in self.jobs:
if job.done():
try:
job.result()
except Exception as e:
traceback.print_exc()
else:
jobs.add(job)
self.logger.info('reaped {:d} jobs, {:d} jobs pending'
.format(len(self.jobs) - len(jobs), len(jobs)))
self.jobs = jobs
await asyncio.sleep(5)
job = await self.jobs.get()
try:
await job
except asyncio.CancelledError:
raise
except Exception:
# Getting here should probably be considered a bug and fixed
traceback.print_exc()
def address_status(self, hash168):
'''Returns status as 32 bytes.'''

3
server/env.py

@ -34,6 +34,9 @@ class Env(LoggedClass):
# Server stuff
self.tcp_port = self.integer('TCP_PORT', None)
self.ssl_port = self.integer('SSL_PORT', None)
if self.ssl_port:
self.ssl_certfile = self.required('SSL_CERTFILE')
self.ssl_keyfile = self.required('SSL_KEYFILE')
self.rpc_port = self.integer('RPC_PORT', 8000)
self.max_subscriptions = self.integer('MAX_SUBSCRIPTIONS', 10000)
self.banner_file = self.default('BANNER_FILE', None)

38
server/protocol.py

@ -24,6 +24,14 @@ class Error(Exception):
class JSONRPC(asyncio.Protocol, LoggedClass):
'''Base class that manages a JSONRPC connection.
When a request comes in for an RPC method M, then a member
function handle_M is called with the request params array, except
that periods in M are replaced with underscores. So a RPC call
for method 'blockchain.estimatefee' will be passed to
handle_blockchain_estimatefee.
'''
def __init__(self, controller):
super().__init__()
@ -31,39 +39,41 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
self.parts = []
def connection_made(self, transport):
'''Handle an incoming client connection.'''
self.transport = transport
peername = transport.get_extra_info('peername')
self.logger.info('connection from {}'.format(peername))
self.peername = transport.get_extra_info('peername')
self.logger.info('connection from {}'.format(self.peername))
self.controller.add_session(self)
def connection_lost(self, exc):
self.logger.info('disconnected')
'''Handle client disconnection.'''
self.logger.info('disconnected: {}'.format(self.peername))
self.controller.remove_session(self)
def data_received(self, data):
'''Handle incoming data (synchronously).
Requests end in newline characters. Pass complete requests to
decode_message for handling.
'''
while True:
npos = data.find(ord('\n'))
if npos == -1:
self.parts.append(data)
break
tail, data = data[:npos], data[npos + 1:]
parts = self.parts
self.parts = []
parts, self.parts = self.parts, []
parts.append(tail)
self.decode_message(b''.join(parts))
if data:
self.parts.append(data)
def decode_message(self, message):
'''Message is a binary message.'''
'''Decode a binary message and queue it for asynchronous handling.'''
try:
message = json.loads(message.decode())
except Exception as e:
self.logger.info('caught exception decoding message'.format(e))
return
job = self.request_handler(message)
self.controller.add_job(job)
self.logger.info('error decoding JSON message'.format(e))
else:
self.controller.add_job(self.request_handler(message))
async def request_handler(self, request):
'''Called asynchronously.'''

Loading…
Cancel
Save