From 90dcf87536a79f67b139d4647cfb191fdbc2190b Mon Sep 17 00:00:00 2001 From: Neil Booth Date: Sun, 5 Aug 2018 15:03:15 +0900 Subject: [PATCH] Create MemPoolAPI and use it --- electrumx/server/controller.py | 14 ++++++- electrumx/server/mempool.py | 69 ++++++++++++++++++++++++++++------ 2 files changed, 69 insertions(+), 14 deletions(-) diff --git a/electrumx/server/controller.py b/electrumx/server/controller.py index 665d39c..9061f93 100644 --- a/electrumx/server/controller.py +++ b/electrumx/server/controller.py @@ -14,7 +14,7 @@ from electrumx.lib.server_base import ServerBase from electrumx.lib.util import version_string from electrumx.server.chain_state import ChainState from electrumx.server.db import DB -from electrumx.server.mempool import MemPool +from electrumx.server.mempool import MemPool, MemPoolAPI from electrumx.server.session import SessionManager @@ -97,8 +97,18 @@ class Controller(ServerBase): db = DB(env) BlockProcessor = env.coin.BLOCK_PROCESSOR bp = BlockProcessor(env, db, daemon, notifications) - mempool = MemPool(env.coin, daemon, notifications, db.lookup_utxos) chain_state = ChainState(env, db, daemon, bp) + + # Set ourselves up to implement the MemPoolAPI + self.height = daemon.height + self.cached_height = daemon.cached_height + self.mempool_hashes = daemon.mempool_hashes + self.raw_transactions = daemon.getrawtransactions + self.lookup_utxos = db.lookup_utxos + self.on_mempool = notifications.on_mempool + MemPoolAPI.register(Controller) + mempool = MemPool(env.coin, self) + session_mgr = SessionManager(env, chain_state, mempool, notifications, shutdown_event) diff --git a/electrumx/server/mempool.py b/electrumx/server/mempool.py index d3198f5..d71a943 100644 --- a/electrumx/server/mempool.py +++ b/electrumx/server/mempool.py @@ -10,6 +10,7 @@ import asyncio import itertools import time +from abc import ABC, abstractmethod from collections import defaultdict import attr @@ -30,9 +31,53 @@ class MemPoolTx(object): size = attr.ib() +class MemPoolAPI(ABC): + '''A concrete instance of this class is passed to the MemPool object + and used by it to query DB and blockchain state.''' + + @abstractmethod + async def height(self): + '''Query bitcoind for its height.''' + + @abstractmethod + def cached_height(self): + '''Return the height of bitcoind the last time it was queried, + for any reason, without actually querying it. + ''' + + @abstractmethod + async def mempool_hashes(self): + '''Query bitcoind for the hashes of all transactions in its + mempool, returned as a list.''' + + @abstractmethod + async def raw_transactions(self, hex_hashes): + '''Query bitcoind for the serialized raw transactions with the given + hashes. Missing transactions are returned as None. + + hex_hashes is an iterable of hexadecimal hash strings.''' + + @abstractmethod + async def lookup_utxos(self, prevouts): + '''Return a list of (hashX, value) pairs each prevout if unspent, + otherwise return None if spent or not found. + + prevouts - an iterable of (hash, index) pairs + ''' + + @abstractmethod + async def on_mempool(self, touched, height): + '''Called each time the mempool is synchronized. touched is a set of + hashXs touched since the previous call. height is the + daemon's height at the time the mempool was obtained.''' + + class MemPool(object): '''Representation of the daemon's mempool. + coin - a coin class from coins.py + api - an object implementing MemPoolAPI + Updated regularly in caught-up state. Goal is to enable efficient response to the calls in the external interface. To that end we maintain the following maps: @@ -41,12 +86,11 @@ class MemPool(object): hashXs: hashX -> set of all hashes of txs touching the hashX ''' - def __init__(self, coin, daemon, notifications, lookup_utxos): - self.logger = class_logger(__name__, self.__class__.__name__) + def __init__(self, coin, api): + assert isinstance(api, MemPoolAPI) self.coin = coin - self.lookup_utxos = lookup_utxos - self.daemon = daemon - self.notifications = notifications + self.api = api + self.logger = class_logger(__name__, self.__class__.__name__) self.txs = {} self.hashXs = defaultdict(set) # None can be a key self.cached_compact_histogram = [] @@ -132,14 +176,14 @@ class MemPool(object): sleep = 5 histogram_refresh = self.coin.MEMPOOL_HISTOGRAM_REFRESH_SECS // sleep for loop_count in itertools.count(): - height = self.daemon.cached_height() - hex_hashes = await self.daemon.mempool_hashes() - if height != await self.daemon.height(): + height = self.api.cached_height() + hex_hashes = await self.api.mempool_hashes() + if height != await self.api.height(): continue hashes = set(hex_str_to_hash(hh) for hh in hex_hashes) touched = await self._process_mempool(hashes) synchronized_event.set() - await self.notifications.on_mempool(touched, height) + await self.api.on_mempool(touched, height) # Thread mempool histogram refreshes - they can be expensive if loop_count % histogram_refresh == 0: await run_in_thread(self._update_histogram) @@ -193,7 +237,7 @@ class MemPool(object): async def _fetch_and_accept(self, hashes, all_hashes, touched): '''Fetch a list of mempool transactions.''' hex_hashes_iter = (hash_to_hex_str(hash) for hash in hashes) - raw_txs = await self.daemon.getrawtransactions(hex_hashes_iter) + raw_txs = await self.api.raw_transactions(hex_hashes_iter) def deserialize_txs(): # This function is pure to_hashX = self.coin.hashX_from_script @@ -225,7 +269,7 @@ class MemPool(object): prevouts = tuple(prevout for tx in tx_map.values() for prevout in tx.prevouts if prevout[0] not in all_hashes) - utxos = await self.lookup_utxos(prevouts) + utxos = await self.api.lookup_utxos(prevouts) utxo_map = {prevout: utxo for prevout, utxo in zip(prevouts, utxos)} return self._accept_transactions(tx_map, utxo_map, touched) @@ -271,7 +315,8 @@ class MemPool(object): '''Return a set of (prev_hash, prev_idx) pairs from mempool transactions that touch hashX. - None, some or all of these may be spends of the hashX. + None, some or all of these may be spends of the hashX, but all + actual spends of it (in the DB or mempool) will be included. ''' result = set() for tx_hash in self.hashXs.get(hashX, ()):