diff --git a/server/block_processor.py b/server/block_processor.py index 157d226..4b9fcda 100644 --- a/server/block_processor.py +++ b/server/block_processor.py @@ -29,7 +29,6 @@ class Prefetcher(LoggedClass): def __init__(self, bp): super().__init__() self.bp = bp - self.caught_up = False # Access to fetched_height should be protected by the semaphore self.fetched_height = None self.semaphore = asyncio.Semaphore() @@ -84,7 +83,14 @@ class Prefetcher(LoggedClass): Repeats until the queue is full or caught up. ''' daemon = self.bp.daemon - daemon_height = await daemon.height(self.bp.caught_up_event.is_set()) + # If caught up, refresh the mempool before the current height + caught_up = self.bp.caught_up_event.is_set() + if caught_up: + mempool = await daemon.mempool_hashes() + else: + mempool = [] + + daemon_height = await daemon.height() with await self.semaphore: while self.cache_size < self.min_cache_size: # Try and catch up all blocks but limit to room in cache. @@ -94,14 +100,15 @@ class Prefetcher(LoggedClass): count = min(daemon_height - self.fetched_height, cache_room) count = min(500, max(count, 0)) if not count: - if not self.caught_up: - self.caught_up = True + if caught_up: + self.bp.set_mempool_hashes(mempool) + else: self.bp.on_prefetcher_first_caught_up() return False first = self.fetched_height + 1 hex_hashes = await daemon.block_hex_hashes(first, count) - if self.caught_up: + if caught_up: self.logger.info('new block height {:,d} hash {}' .format(first + count-1, hex_hashes[-1])) blocks = await daemon.raw_blocks(hex_hashes) @@ -121,7 +128,7 @@ class Prefetcher(LoggedClass): else: self.ave_size = (size + (10 - count) * self.ave_size) // 10 - self.bp.on_prefetched_blocks(blocks, first) + self.bp.on_prefetched_blocks(blocks, first, mempool) self.cache_size += size self.fetched_height += count @@ -188,9 +195,10 @@ class BlockProcessor(server.db.DB): '''Add the task to our task queue.''' self.task_queue.put_nowait(task) - def on_prefetched_blocks(self, blocks, first): + def on_prefetched_blocks(self, blocks, first, mempool): '''Called by the prefetcher when it has prefetched some blocks.''' - self.add_task(partial(self.check_and_advance_blocks, blocks, first)) + self.add_task(partial(self.check_and_advance_blocks, blocks, first, + mempool)) def on_prefetcher_first_caught_up(self): '''Called by the prefetcher when it first catches up.''' @@ -225,7 +233,10 @@ class BlockProcessor(server.db.DB): self.open_dbs() self.caught_up_event.set() - async def check_and_advance_blocks(self, blocks, first): + def set_mempool_hashes(self, mempool): + self.controller.mempool.set_hashes(mempool) + + async def check_and_advance_blocks(self, blocks, first, mempool): '''Process the list of blocks passed. Detects and handles reorgs.''' self.prefetcher.processing_blocks(blocks) if first != self.height + 1: @@ -251,6 +262,7 @@ class BlockProcessor(server.db.DB): self.logger.info('processed {:,d} block{} in {:.1f}s' .format(len(blocks), s, time.time() - start)) + self.set_mempool_hashes(mempool) elif hprevs[0] != chain[0]: await self.reorg_chain() else: diff --git a/server/daemon.py b/server/daemon.py index 5e9b311..23cebbf 100644 --- a/server/daemon.py +++ b/server/daemon.py @@ -38,8 +38,6 @@ class Daemon(util.LoggedClass): super().__init__() self.set_urls(urls) self._height = None - self._mempool_hashes = set() - self.mempool_refresh_event = asyncio.Event() # Limit concurrent RPC calls to this number. # See DEFAULT_HTTP_WORKQUEUE in bitcoind, which is typically 16 self.workqueue_semaphore = asyncio.Semaphore(value=10) @@ -210,7 +208,7 @@ class Daemon(util.LoggedClass): return [bytes.fromhex(block) for block in blocks] async def mempool_hashes(self): - '''Update our record of the daemon's mempool hashes.''' + '''Return a list of the daemon's mempool hashes.''' return await self._send_single('getrawmempool') async def estimatefee(self, params): @@ -245,18 +243,11 @@ class Daemon(util.LoggedClass): '''Broadcast a transaction to the network.''' return await self._send_single('sendrawtransaction', params) - async def height(self, mempool=False): + async def height(self): '''Query the daemon for its current height.''' self._height = await self._send_single('getblockcount') - if mempool: - self._mempool_hashes = set(await self.mempool_hashes()) - self.mempool_refresh_event.set() return self._height - def cached_mempool_hashes(self): - '''Return the cached mempool hashes.''' - return self._mempool_hashes - def cached_height(self): '''Return the cached daemon height. diff --git a/server/mempool.py b/server/mempool.py index 0a6c27b..075c29d 100644 --- a/server/mempool.py +++ b/server/mempool.py @@ -37,6 +37,8 @@ class MemPool(util.LoggedClass): self.controller = controller self.coin = bp.coin self.db = bp + self.hashes = set() + self.mempool_refresh_event = asyncio.Event() self.touched = bp.touched self.touched_event = asyncio.Event() self.prioritized = set() @@ -49,6 +51,11 @@ class MemPool(util.LoggedClass): initial mempool sync.''' self.prioritized.add(tx_hash) + def set_hashes(self, hashes): + '''Save the list of mempool hashes.''' + self.hashes = set(hashes) + self.mempool_refresh_event.set() + def resync_daemon_hashes(self, unprocessed, unfetched): '''Re-sync self.txs with the list of hashes in the daemon's mempool. @@ -59,8 +66,7 @@ class MemPool(util.LoggedClass): hashXs = self.hashXs touched = self.touched - hashes = self.daemon.cached_mempool_hashes() - gone = set(txs).difference(hashes) + gone = set(txs).difference(self.hashes) for hex_hash in gone: unfetched.discard(hex_hash) unprocessed.pop(hex_hash, None) @@ -75,7 +81,7 @@ class MemPool(util.LoggedClass): del hashXs[hashX] touched.update(tx_hashXs) - new = hashes.difference(txs) + new = self.hashes.difference(txs) unfetched.update(new) for hex_hash in new: txs[hex_hash] = None @@ -92,15 +98,14 @@ class MemPool(util.LoggedClass): fetch_size = 800 process_some = self.async_process_some(unfetched, fetch_size // 2) - await self.daemon.mempool_refresh_event.wait() + await self.mempool_refresh_event.wait() self.logger.info('beginning processing of daemon mempool. ' 'This can take some time...') next_log = 0 loops = -1 # Zero during initial catchup while True: - # Avoid double notifications if processing a block - if self.touched and not self.processing_new_block(): + if self.touched: self.touched_event.set() # Log progress / state @@ -120,10 +125,10 @@ class MemPool(util.LoggedClass): try: if not todo: self.prioritized.clear() - await self.daemon.mempool_refresh_event.wait() + await self.mempool_refresh_event.wait() self.resync_daemon_hashes(unprocessed, unfetched) - self.daemon.mempool_refresh_event.clear() + self.mempool_refresh_event.clear() if unfetched: count = min(len(unfetched), fetch_size) @@ -177,10 +182,6 @@ class MemPool(util.LoggedClass): return process - def processing_new_block(self): - '''Return True if we're processing a new block.''' - return self.daemon.cached_height() > self.db.db_height - async def fetch_raw_txs(self, hex_hashes): '''Fetch a list of mempool transactions.''' raw_txs = await self.daemon.getrawtransactions(hex_hashes)