Browse Source

Merge branch 'async_mempool' into develop

master
Neil Booth 8 years ago
parent
commit
009750bacb
  1. 97
      server/block_processor.py
  2. 42
      server/protocol.py

97
server/block_processor.py

@ -43,8 +43,8 @@ class Prefetcher(LoggedClass):
self.semaphore = asyncio.Semaphore() self.semaphore = asyncio.Semaphore()
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.queue_size = 0 self.queue_size = 0
self.caught_up = False
self.fetched_height = height self.fetched_height = height
self.mempool_hashes = []
# Target cache size. Has little effect on sync time. # Target cache size. Has little effect on sync time.
self.target_cache_size = 10 * 1024 * 1024 self.target_cache_size = 10 * 1024 * 1024
# First fetch to be 10 blocks # First fetch to be 10 blocks
@ -64,13 +64,14 @@ class Prefetcher(LoggedClass):
self.fetched_height = height self.fetched_height = height
async def get_blocks(self): async def get_blocks(self):
'''Returns a list of prefetched blocks and the mempool.''' '''Blocking function that returns prefetched blocks.
blocks, height, size = await self.queue.get()
The returned result empty just once - when the prefetcher
has caught up with the daemon.
'''
blocks, size = await self.queue.get()
self.queue_size -= size self.queue_size -= size
if height == self.daemon.cached_height(): return blocks
return blocks, self.mempool_hashes
else:
return blocks, None
async def main_loop(self): async def main_loop(self):
'''Loop forever polling for more blocks.''' '''Loop forever polling for more blocks.'''
@ -78,39 +79,19 @@ class Prefetcher(LoggedClass):
.format(await self.daemon.height())) .format(await self.daemon.height()))
while True: while True:
try: try:
if await self._caught_up(): with await self.semaphore:
await asyncio.sleep(5) await self._prefetch()
else: await asyncio.sleep(5 if self.caught_up else 0)
await asyncio.sleep(0)
except DaemonError as e: except DaemonError as e:
self.logger.info('ignoring daemon error: {}'.format(e)) self.logger.info('ignoring daemon error: {}'.format(e))
except asyncio.CancelledError: except asyncio.CancelledError:
break break
async def _caught_up(self):
'''Poll for new blocks and mempool state.
Mempool is only queried if caught up with daemon.'''
with await self.semaphore:
blocks, size = await self._prefetch()
self.fetched_height += len(blocks)
caught_up = self.fetched_height == self.daemon.cached_height()
if caught_up:
self.mempool_hashes = await self.daemon.mempool_hashes()
# Wake up block processor if we have something
if blocks or caught_up:
self.queue.put_nowait((blocks, self.fetched_height, size))
self.queue_size += size
return caught_up
async def _prefetch(self): async def _prefetch(self):
'''Prefetch blocks unless the prefetch queue is full.''' '''Prefetch blocks unless the prefetch queue is full.'''
if self.queue_size >= self.target_cache_size: if self.queue_size >= self.target_cache_size:
return [], 0 return
caught_up = self.daemon.cached_height() == self.fetched_height
daemon_height = await self.daemon.height() daemon_height = await self.daemon.height()
cache_room = self.target_cache_size // self.ave_size cache_room = self.target_cache_size // self.ave_size
@ -119,15 +100,18 @@ class Prefetcher(LoggedClass):
count = min(daemon_height - self.fetched_height, cache_room) count = min(daemon_height - self.fetched_height, cache_room)
count = min(4000, max(count, 0)) count = min(4000, max(count, 0))
if not count: if not count:
return [], 0 # Indicate when we have caught up for the first time only
if not self.caught_up:
self.caught_up = True
self.queue.put_nowait(([], 0))
return
first = self.fetched_height + 1 first = self.fetched_height + 1
hex_hashes = await self.daemon.block_hex_hashes(first, count) hex_hashes = await self.daemon.block_hex_hashes(first, count)
if caught_up: if self.caught_up:
self.logger.info('new block height {:,d} hash {}' self.logger.info('new block height {:,d} hash {}'
.format(first + count - 1, hex_hashes[-1])) .format(first + count - 1, hex_hashes[-1]))
blocks = await self.daemon.raw_blocks(hex_hashes) blocks = await self.daemon.raw_blocks(hex_hashes)
size = sum(len(block) for block in blocks) size = sum(len(block) for block in blocks)
# Update our recent average block size estimate # Update our recent average block size estimate
@ -136,7 +120,9 @@ class Prefetcher(LoggedClass):
else: else:
self.ave_size = (size + (10 - count) * self.ave_size) // 10 self.ave_size = (size + (10 - count) * self.ave_size) // 10
return blocks, size self.fetched_height += len(blocks)
self.queue.put_nowait((blocks, size))
self.queue_size += size
class ChainReorg(Exception): class ChainReorg(Exception):
@ -162,6 +148,7 @@ class BlockProcessor(server.db.DB):
self.daemon = Daemon(env.daemon_url, env.debug) self.daemon = Daemon(env.daemon_url, env.debug)
self.daemon.debug_set_height(self.height) self.daemon.debug_set_height(self.height)
self.caught_up = False
self.touched = set() self.touched = set()
self.futures = [] self.futures = []
@ -223,41 +210,51 @@ class BlockProcessor(server.db.DB):
await asyncio.sleep(0) await asyncio.sleep(0)
async def _wait_for_update(self): async def _wait_for_update(self):
'''Wait for the prefetcher to deliver blocks or a mempool update. '''Wait for the prefetcher to deliver blocks.
Blocks are only processed in the forward direction. The Blocks are only processed in the forward direction.
prefetcher only provides a non-None mempool when caught up.
''' '''
blocks, mempool_hashes = await self.prefetcher.get_blocks() blocks = await self.prefetcher.get_blocks()
if not blocks:
await self.first_caught_up()
return
'''Strip the unspendable genesis coinbase.''' '''Strip the unspendable genesis coinbase.'''
if self.height == -1: if self.height == -1:
blocks[0] = blocks[0][:self.coin.HEADER_LEN] + bytes(1) blocks[0] = blocks[0][:self.coin.HEADER_LEN] + bytes(1)
caught_up = mempool_hashes is not None
try: try:
for block in blocks: for block in blocks:
self.advance_block(block, caught_up) self.advance_block(block, self.caught_up)
if not caught_up and time.time() > self.next_cache_check:
self.check_cache_size()
self.next_cache_check = time.time() + 60
await asyncio.sleep(0) # Yield await asyncio.sleep(0) # Yield
if caught_up:
await self.caught_up(mempool_hashes)
self.touched = set()
except ChainReorg: except ChainReorg:
await self.handle_chain_reorg() await self.handle_chain_reorg()
async def caught_up(self, mempool_hashes): if self.caught_up:
# Flush everything as queries are performed on the DB and
# not in-memory.
self.flush(True)
self.notify(self.touched)
elif time.time() > self.next_cache_check:
self.check_cache_size()
self.next_cache_check = time.time() + 60
self.touched = set()
async def first_caught_up(self):
'''Called after each deamon poll if caught up.''' '''Called after each deamon poll if caught up.'''
# Caught up to daemon height. Flush everything as queries self.caught_up = True
# are performed on the DB and not in-memory.
if self.first_sync: if self.first_sync:
self.first_sync = False self.first_sync = False
self.logger.info('{} synced to height {:,d}. DB version:' self.logger.info('{} synced to height {:,d}. DB version:'
.format(VERSION, self.height, self.db_version)) .format(VERSION, self.height, self.db_version))
self.flush(True) self.flush(True)
def notify(self, touched):
'''Called with list of touched addresses by new blocks.
Only called for blocks found after first_caught_up is called.
Intended to be overridden in derived classes.'''
async def handle_chain_reorg(self): async def handle_chain_reorg(self):
# First get all state on disk # First get all state on disk
self.logger.info('chain reorg detected') self.logger.info('chain reorg detected')

42
server/protocol.py

@ -38,15 +38,16 @@ class BlockServer(BlockProcessor):
super().__init__(env) super().__init__(env)
self.server_mgr = ServerManager(self, env) self.server_mgr = ServerManager(self, env)
self.mempool = MemPool(self) self.mempool = MemPool(self)
self.caught_up_yet = False
async def caught_up(self, mempool_hashes): async def first_caught_up(self):
# Call the base class to flush before doing anything else. # Call the base class to flush and log first
await super().caught_up(mempool_hashes) await super().first_caught_up()
if not self.caught_up_yet:
await self.server_mgr.start_servers() await self.server_mgr.start_servers()
self.caught_up_yet = True self.futures.append(self.mempool.start())
self.touched.update(await self.mempool.update(mempool_hashes))
def notify(self, touched):
'''Called when addresses are touched by new blocks or mempool
updates.'''
self.server_mgr.notify(self.height, self.touched) self.server_mgr.notify(self.height, self.touched)
def on_cancel(self): def on_cancel(self):
@ -97,13 +98,29 @@ class MemPool(LoggedClass):
self.bp = bp self.bp = bp
self.count = -1 self.count = -1
async def update(self, hex_hashes): def start(self):
'''Starts the mempool synchronization mainloop. Return a future.'''
return asyncio.ensure_future(self.main_loop())
async def main_loop(self):
'''Asynchronously maintain mempool status with daemon.'''
self.logger.info('maintaining state with daemon...')
while True:
try:
await self.update()
await asyncio.sleep(5)
except DaemonError as e:
self.logger.info('ignoring daemon error: {}'.format(e))
except asyncio.CancelledError:
break
async def update(self):
'''Update state given the current mempool to the passed set of hashes. '''Update state given the current mempool to the passed set of hashes.
Remove transactions that are no longer in our mempool. Remove transactions that are no longer in our mempool.
Request new transactions we don't have then add to our mempool. Request new transactions we don't have then add to our mempool.
''' '''
hex_hashes = set(hex_hashes) hex_hashes = set(await self.bp.daemon.mempool_hashes())
touched = set() touched = set()
missing_utxos = [] missing_utxos = []
@ -210,8 +227,7 @@ class MemPool(LoggedClass):
self.logger.info('{:,d} txs touching {:,d} addresses' self.logger.info('{:,d} txs touching {:,d} addresses'
.format(len(self.txs), len(self.hash168s))) .format(len(self.txs), len(self.hash168s)))
# Might include a None self.bp.notify(touched)
return touched
def transactions(self, hash168): def transactions(self, hash168):
'''Generate (hex_hash, tx_fee, unconfirmed) tuples for mempool '''Generate (hex_hash, tx_fee, unconfirmed) tuples for mempool
@ -295,7 +311,6 @@ class ServerManager(LoggedClass):
await self.start_server('SSL', env.host, env.ssl_port, ssl=sslc) await self.start_server('SSL', env.host, env.ssl_port, ssl=sslc)
if env.irc: if env.irc:
self.logger.info('starting IRC coroutine')
self.irc_future = asyncio.ensure_future(self.irc.start()) self.irc_future = asyncio.ensure_future(self.irc.start())
else: else:
self.logger.info('IRC disabled') self.logger.info('IRC disabled')
@ -310,11 +325,12 @@ class ServerManager(LoggedClass):
def stop(self): def stop(self):
'''Close listening servers.''' '''Close listening servers.'''
self.logger.info('cleanly closing client sessions, please wait...')
for server in self.servers: for server in self.servers:
server.close() server.close()
if self.irc_future: if self.irc_future:
self.irc_future.cancel() self.irc_future.cancel()
if self.sessions:
self.logger.info('cleanly closing client sessions, please wait...')
for session in self.sessions: for session in self.sessions:
self.close_session(session) self.close_session(session)

Loading…
Cancel
Save