Browse Source

Merge branch 'mempool-tests' into devel

patch-2
Neil Booth 7 years ago
parent
commit
878976dbee
  1. 14
      electrumx/lib/coins.py
  2. 14
      electrumx/server/controller.py
  3. 153
      electrumx/server/mempool.py
  4. 25
      electrumx/server/session.py
  5. 502
      tests/server/test_mempool.py

14
electrumx/lib/coins.py

@ -211,6 +211,14 @@ class Coin(object):
''' '''
return ScriptPubKey.P2PK_script(pubkey) return ScriptPubKey.P2PK_script(pubkey)
@classmethod
def hash160_to_P2PKH_script(cls, hash160):
return ScriptPubKey.P2PKH_script(hash160)
@classmethod
def hash160_to_P2PKH_hashX(cls, hash160):
return cls.hashX_from_script(cls.hash160_to_P2PKH_script(hash160))
@classmethod @classmethod
def pay_to_address_script(cls, address): def pay_to_address_script(cls, address):
'''Return a pubkey script that pays to a pubkey hash. '''Return a pubkey script that pays to a pubkey hash.
@ -223,12 +231,12 @@ class Coin(object):
verbyte = -1 verbyte = -1
verlen = len(raw) - 20 verlen = len(raw) - 20
if verlen > 0: if verlen > 0:
verbyte, hash_bytes = raw[:verlen], raw[verlen:] verbyte, hash160 = raw[:verlen], raw[verlen:]
if verbyte == cls.P2PKH_VERBYTE: if verbyte == cls.P2PKH_VERBYTE:
return ScriptPubKey.P2PKH_script(hash_bytes) return cls.hash160_to_P2PKH_script(hash160)
if verbyte in cls.P2SH_VERBYTES: if verbyte in cls.P2SH_VERBYTES:
return ScriptPubKey.P2SH_script(hash_bytes) return ScriptPubKey.P2SH_script(hash160)
raise CoinError('invalid address: {}'.format(address)) raise CoinError('invalid address: {}'.format(address))

14
electrumx/server/controller.py

@ -14,7 +14,7 @@ from electrumx.lib.server_base import ServerBase
from electrumx.lib.util import version_string from electrumx.lib.util import version_string
from electrumx.server.chain_state import ChainState from electrumx.server.chain_state import ChainState
from electrumx.server.db import DB from electrumx.server.db import DB
from electrumx.server.mempool import MemPool from electrumx.server.mempool import MemPool, MemPoolAPI
from electrumx.server.session import SessionManager from electrumx.server.session import SessionManager
@ -97,8 +97,18 @@ class Controller(ServerBase):
db = DB(env) db = DB(env)
BlockProcessor = env.coin.BLOCK_PROCESSOR BlockProcessor = env.coin.BLOCK_PROCESSOR
bp = BlockProcessor(env, db, daemon, notifications) bp = BlockProcessor(env, db, daemon, notifications)
mempool = MemPool(env.coin, daemon, notifications, db.lookup_utxos)
chain_state = ChainState(env, db, daemon, bp) chain_state = ChainState(env, db, daemon, bp)
# Set ourselves up to implement the MemPoolAPI
self.height = daemon.height
self.cached_height = daemon.cached_height
self.mempool_hashes = daemon.mempool_hashes
self.raw_transactions = daemon.getrawtransactions
self.lookup_utxos = db.lookup_utxos
self.on_mempool = notifications.on_mempool
MemPoolAPI.register(Controller)
mempool = MemPool(env.coin, self)
session_mgr = SessionManager(env, chain_state, mempool, session_mgr = SessionManager(env, chain_state, mempool,
notifications, shutdown_event) notifications, shutdown_event)

153
electrumx/server/mempool.py

@ -7,13 +7,14 @@
'''Mempool handling.''' '''Mempool handling.'''
import asyncio
import itertools import itertools
import time import time
from abc import ABC, abstractmethod
from asyncio import Lock
from collections import defaultdict from collections import defaultdict
import attr import attr
from aiorpcx import TaskGroup, run_in_thread from aiorpcx import TaskGroup, run_in_thread, sleep
from electrumx.lib.hash import hash_to_hex_str, hex_str_to_hash from electrumx.lib.hash import hash_to_hex_str, hex_str_to_hash
from electrumx.lib.util import class_logger, chunks from electrumx.lib.util import class_logger, chunks
@ -30,9 +31,60 @@ class MemPoolTx(object):
size = attr.ib() size = attr.ib()
@attr.s(slots=True)
class MemPoolTxSummary(object):
hash = attr.ib()
fee = attr.ib()
has_unconfirmed_inputs = attr.ib()
class MemPoolAPI(ABC):
'''A concrete instance of this class is passed to the MemPool object
and used by it to query DB and blockchain state.'''
@abstractmethod
async def height(self):
'''Query bitcoind for its height.'''
@abstractmethod
def cached_height(self):
'''Return the height of bitcoind the last time it was queried,
for any reason, without actually querying it.
'''
@abstractmethod
async def mempool_hashes(self):
'''Query bitcoind for the hashes of all transactions in its
mempool, returned as a list.'''
@abstractmethod
async def raw_transactions(self, hex_hashes):
'''Query bitcoind for the serialized raw transactions with the given
hashes. Missing transactions are returned as None.
hex_hashes is an iterable of hexadecimal hash strings.'''
@abstractmethod
async def lookup_utxos(self, prevouts):
'''Return a list of (hashX, value) pairs each prevout if unspent,
otherwise return None if spent or not found.
prevouts - an iterable of (hash, index) pairs
'''
@abstractmethod
async def on_mempool(self, touched, height):
'''Called each time the mempool is synchronized. touched is a set of
hashXs touched since the previous call. height is the
daemon's height at the time the mempool was obtained.'''
class MemPool(object): class MemPool(object):
'''Representation of the daemon's mempool. '''Representation of the daemon's mempool.
coin - a coin class from coins.py
api - an object implementing MemPoolAPI
Updated regularly in caught-up state. Goal is to enable efficient Updated regularly in caught-up state. Goal is to enable efficient
response to the calls in the external interface. To that end we response to the calls in the external interface. To that end we
maintain the following maps: maintain the following maps:
@ -41,23 +93,42 @@ class MemPool(object):
hashXs: hashX -> set of all hashes of txs touching the hashX hashXs: hashX -> set of all hashes of txs touching the hashX
''' '''
def __init__(self, coin, daemon, notifications, lookup_utxos): def __init__(self, coin, api, refresh_secs=5.0, log_status_secs=120.0):
self.logger = class_logger(__name__, self.__class__.__name__) assert isinstance(api, MemPoolAPI)
self.coin = coin self.coin = coin
self.lookup_utxos = lookup_utxos self.api = api
self.daemon = daemon self.logger = class_logger(__name__, self.__class__.__name__)
self.notifications = notifications
self.txs = {} self.txs = {}
self.hashXs = defaultdict(set) # None can be a key self.hashXs = defaultdict(set) # None can be a key
self.cached_compact_histogram = [] self.cached_compact_histogram = []
self.refresh_secs = refresh_secs
self.log_status_secs = log_status_secs
# Prevents mempool refreshes during fee histogram calculation
self.lock = Lock()
async def _log_stats(self): async def _logging(self, synchronized_event):
'''Print regular logs of mempool stats.'''
self.logger.info('beginning processing of daemon mempool. '
'This can take some time...')
start = time.time()
await synchronized_event.wait()
elapsed = time.time() - start
self.logger.info(f'synced in {elapsed:.2f}s')
while True: while True:
self.logger.info(f'{len(self.txs):,d} txs ' self.logger.info(f'{len(self.txs):,d} txs '
f'touching {len(self.hashXs):,d} addresses') f'touching {len(self.hashXs):,d} addresses')
await asyncio.sleep(120) await sleep(self.log_status_secs)
await synchronized_event.wait()
async def _refresh_histogram(self, synchronized_event):
while True:
await synchronized_event.wait()
async with self.lock:
# Threaded as can be expensive
await run_in_thread(self._update_histogram, 100_000)
await sleep(self.coin.MEMPOOL_HISTOGRAM_REFRESH_SECS)
def _update_histogram(self): def _update_histogram(self, bin_size):
# Build a histogram by fee rate # Build a histogram by fee rate
histogram = defaultdict(int) histogram = defaultdict(int)
for tx in self.txs.values(): for tx in self.txs.values():
@ -74,7 +145,6 @@ class MemPool(object):
compact = [] compact = []
cum_size = 0 cum_size = 0
r = 0 # ? r = 0 # ?
bin_size = 100 * 1000
for fee_rate, size in sorted(histogram.items(), reverse=True): for fee_rate, size in sorted(histogram.items(), reverse=True):
cum_size += size cum_size += size
if cum_size + r > bin_size: if cum_size + r > bin_size:
@ -129,21 +199,18 @@ class MemPool(object):
async def _refresh_hashes(self, synchronized_event): async def _refresh_hashes(self, synchronized_event):
'''Refresh our view of the daemon's mempool.''' '''Refresh our view of the daemon's mempool.'''
sleep = 5 while True:
histogram_refresh = self.coin.MEMPOOL_HISTOGRAM_REFRESH_SECS // sleep height = self.api.cached_height()
for loop_count in itertools.count(): hex_hashes = await self.api.mempool_hashes()
height = self.daemon.cached_height() if height != await self.api.height():
hex_hashes = await self.daemon.mempool_hashes()
if height != await self.daemon.height():
continue continue
hashes = set(hex_str_to_hash(hh) for hh in hex_hashes) hashes = set(hex_str_to_hash(hh) for hh in hex_hashes)
touched = await self._process_mempool(hashes) async with self.lock:
touched = await self._process_mempool(hashes)
synchronized_event.set() synchronized_event.set()
await self.notifications.on_mempool(touched, height) synchronized_event.clear()
# Thread mempool histogram refreshes - they can be expensive await self.api.on_mempool(touched, height)
if loop_count % histogram_refresh == 0: await sleep(self.refresh_secs)
await run_in_thread(self._update_histogram)
await asyncio.sleep(sleep)
async def _process_mempool(self, all_hashes): async def _process_mempool(self, all_hashes):
# Re-sync with the new set of hashes # Re-sync with the new set of hashes
@ -176,9 +243,6 @@ class MemPool(object):
tx_map.update(deferred) tx_map.update(deferred)
utxo_map.update(unspent) utxo_map.update(unspent)
# Handle the stragglers
if len(tx_map) >= 10:
self.logger.info(f'{len(tx_map)} stragglers')
prior_count = 0 prior_count = 0
# FIXME: this is not particularly efficient # FIXME: this is not particularly efficient
while tx_map and len(tx_map) != prior_count: while tx_map and len(tx_map) != prior_count:
@ -193,7 +257,7 @@ class MemPool(object):
async def _fetch_and_accept(self, hashes, all_hashes, touched): async def _fetch_and_accept(self, hashes, all_hashes, touched):
'''Fetch a list of mempool transactions.''' '''Fetch a list of mempool transactions.'''
hex_hashes_iter = (hash_to_hex_str(hash) for hash in hashes) hex_hashes_iter = (hash_to_hex_str(hash) for hash in hashes)
raw_txs = await self.daemon.getrawtransactions(hex_hashes_iter) raw_txs = await self.api.raw_transactions(hex_hashes_iter)
def deserialize_txs(): # This function is pure def deserialize_txs(): # This function is pure
to_hashX = self.coin.hashX_from_script to_hashX = self.coin.hashX_from_script
@ -225,7 +289,7 @@ class MemPool(object):
prevouts = tuple(prevout for tx in tx_map.values() prevouts = tuple(prevout for tx in tx_map.values()
for prevout in tx.prevouts for prevout in tx.prevouts
if prevout[0] not in all_hashes) if prevout[0] not in all_hashes)
utxos = await self.lookup_utxos(prevouts) utxos = await self.api.lookup_utxos(prevouts)
utxo_map = {prevout: utxo for prevout, utxo in zip(prevouts, utxos)} utxo_map = {prevout: utxo for prevout, utxo in zip(prevouts, utxos)}
return self._accept_transactions(tx_map, utxo_map, touched) return self._accept_transactions(tx_map, utxo_map, touched)
@ -235,19 +299,11 @@ class MemPool(object):
# #
async def keep_synchronized(self, synchronized_event): async def keep_synchronized(self, synchronized_event):
'''Starts the mempool synchronizer. '''Keep the mempool synchronized with the daemon.'''
async with TaskGroup(wait=any) as group:
Waits for an initial synchronization before returning.
'''
self.logger.info('beginning processing of daemon mempool. '
'This can take some time...')
async with TaskGroup() as group:
await group.spawn(self._refresh_hashes(synchronized_event)) await group.spawn(self._refresh_hashes(synchronized_event))
start = time.time() await group.spawn(self._refresh_histogram(synchronized_event))
await synchronized_event.wait() await group.spawn(self._logging(synchronized_event))
elapsed = time.time() - start
self.logger.info(f'synced in {elapsed:.2f}s')
await group.spawn(self._log_stats())
async def balance_delta(self, hashX): async def balance_delta(self, hashX):
'''Return the unconfirmed amount in the mempool for hashX. '''Return the unconfirmed amount in the mempool for hashX.
@ -255,7 +311,6 @@ class MemPool(object):
Can be positive or negative. Can be positive or negative.
''' '''
value = 0 value = 0
# hashXs is a defaultdict
if hashX in self.hashXs: if hashX in self.hashXs:
for hash in self.hashXs[hashX]: for hash in self.hashXs[hashX]:
tx = self.txs[hash] tx = self.txs[hash]
@ -271,7 +326,8 @@ class MemPool(object):
'''Return a set of (prev_hash, prev_idx) pairs from mempool '''Return a set of (prev_hash, prev_idx) pairs from mempool
transactions that touch hashX. transactions that touch hashX.
None, some or all of these may be spends of the hashX. None, some or all of these may be spends of the hashX, but all
actual spends of it (in the DB or mempool) will be included.
''' '''
result = set() result = set()
for tx_hash in self.hashXs.get(hashX, ()): for tx_hash in self.hashXs.get(hashX, ()):
@ -280,18 +336,12 @@ class MemPool(object):
return result return result
async def transaction_summaries(self, hashX): async def transaction_summaries(self, hashX):
'''Return a list of (tx_hash, tx_fee, unconfirmed) tuples for '''Return a list of MemPoolTxSummary objects for the hashX.'''
mempool entries for the hashX.
unconfirmed is True if any txin is unconfirmed.
'''
# hashXs is a defaultdict, so use get() to query
result = [] result = []
for tx_hash in self.hashXs.get(hashX, ()): for tx_hash in self.hashXs.get(hashX, ()):
tx = self.txs[tx_hash] tx = self.txs[tx_hash]
unconfirmed = any(prev_hash in self.txs has_ui = any(hash in self.txs for hash, idx in tx.prevouts)
for prev_hash, prev_idx in tx.prevouts) result.append(MemPoolTxSummary(tx_hash, tx.fee, has_ui))
result.append((tx_hash, tx.fee, unconfirmed))
return result return result
async def unordered_UTXOs(self, hashX): async def unordered_UTXOs(self, hashX):
@ -302,7 +352,6 @@ class MemPool(object):
the outputs. the outputs.
''' '''
utxos = [] utxos = []
# hashXs is a defaultdict, so use get() to query
for tx_hash in self.hashXs.get(hashX, ()): for tx_hash in self.hashXs.get(hashX, ()):
tx = self.txs.get(tx_hash) tx = self.txs.get(tx_hash)
for pos, (hX, value) in enumerate(tx.out_pairs): for pos, (hX, value) in enumerate(tx.out_pairs):

25
electrumx/server/session.py

@ -779,15 +779,16 @@ class ElectrumX(SessionBase):
Status is a hex string, but must be None if there is no history. Status is a hex string, but must be None if there is no history.
''' '''
# Note history is ordered and mempool unordered in electrum-server # Note history is ordered and mempool unordered in electrum-server
# For mempool, height is -1 if unconfirmed txins, otherwise 0 # For mempool, height is -1 if it has unconfirmed inputs, otherwise 0
history = await self.session_mgr.limited_history(hashX) db_history = await self.session_mgr.limited_history(hashX)
mempool = await self.mempool.transaction_summaries(hashX) mempool = await self.mempool.transaction_summaries(hashX)
status = ''.join('{}:{:d}:'.format(hash_to_hex_str(tx_hash), height) status = ''.join(f'{hash_to_hex_str(tx_hash)}:'
for tx_hash, height in history) f'{height:d}:'
status += ''.join('{}:{:d}:'.format(hash_to_hex_str(hex_hash), for tx_hash, height in db_history)
-unconfirmed) status += ''.join(f'{hash_to_hex_str(tx.hash)}:'
for hex_hash, tx_fee, unconfirmed in mempool) f'{-tx.has_unconfirmed_inputs:d}:'
for tx in mempool)
if status: if status:
status = sha256(status.encode()).hex() status = sha256(status.encode()).hex()
else: else:
@ -872,11 +873,11 @@ class ElectrumX(SessionBase):
async def unconfirmed_history(self, hashX): async def unconfirmed_history(self, hashX):
# Note unconfirmed history is unordered in electrum-server # Note unconfirmed history is unordered in electrum-server
# Height is -1 if unconfirmed txins, otherwise 0 # height is -1 if it has unconfirmed inputs, otherwise 0
mempool = await self.mempool.transaction_summaries(hashX) return [{'tx_hash': hash_to_hex_str(tx.hash),
return [{'tx_hash': hash_to_hex_str(tx_hash), 'height': -unconfirmed, 'height': -tx.has_unconfirmed_inputs,
'fee': fee} 'fee': tx.fee}
for tx_hash, fee, unconfirmed in mempool] for tx in await self.mempool.transaction_summaries(hashX)]
async def confirmed_and_unconfirmed_history(self, hashX): async def confirmed_and_unconfirmed_history(self, hashX):
# Note history is ordered but unconfirmed is unordered in e-s # Note history is ordered but unconfirmed is unordered in e-s

502
tests/server/test_mempool.py

@ -0,0 +1,502 @@
import logging
import os
from collections import defaultdict
from functools import partial
from random import randrange, choice
import pytest
from aiorpcx import Event, TaskGroup, sleep, spawn, ignore_after
from electrumx.server.mempool import MemPool, MemPoolAPI
from electrumx.lib.coins import BitcoinCash
from electrumx.lib.hash import HASHX_LEN, hex_str_to_hash, hash_to_hex_str
from electrumx.lib.tx import Tx, TxInput, TxOutput
from electrumx.lib.util import make_logger
coin = BitcoinCash
tx_hash_fn = coin.DESERIALIZER.TX_HASH_FN
def random_tx(hash160s, utxos):
'''Create a random TX paying to some of the hash160s using some of the
UTXOS. Return the TX. UTXOs is updated for the effects of the TX.
'''
inputs = []
n_inputs = min(randrange(1, 4), len(utxos))
input_value = 0
# Create inputs spending random UTXOs. total the inpu
for n in range(n_inputs):
prevout = choice(list(utxos))
hashX, value = utxos.pop(prevout)
inputs.append(TxInput(prevout[0], prevout[1], b'', 4294967295))
input_value += value
fee = min(input_value, randrange(500))
input_value -= fee
outputs = []
n_outputs = randrange(1, 4)
for n in range(n_outputs):
value = randrange(input_value)
input_value -= value
pk_script = coin.hash160_to_P2PKH_script(choice(hash160s))
outputs.append(TxOutput(value, pk_script))
tx = Tx(2, inputs, outputs, 0)
tx_bytes = tx.serialize()
tx_hash = tx_hash_fn(tx_bytes)
for n, output in enumerate(tx.outputs):
utxos[(tx_hash, n)] = (coin.hashX_from_script(output.pk_script),
output.value)
return tx, tx_hash, tx_bytes
class API(MemPoolAPI):
def __init__(self):
self._height = 0
self._cached_height = self._height
# Create a pool of hash160s. Map them to their script hashes
# Create a bunch of UTXOs paying to those script hashes
# Create a bunch of TXs that spend from the UTXO set and create
# new outpus, which are added to the UTXO set for later TXs to
# spend
self.db_utxos = {}
self.on_mempool_calls = []
self.hashXs = []
# Maps of mempool txs from tx_hash to raw and Tx object forms
self.raw_txs = {}
self.txs = {}
self.ordered_adds = []
def initialize(self, addr_count=100, db_utxo_count=100, mempool_size=50):
hash160s = [os.urandom(20) for n in range(addr_count)]
self.hashXs = [coin.hash160_to_P2PKH_hashX(hash160)
for hash160 in hash160s]
prevouts = [(os.urandom(32), randrange(0, 10))
for n in range (db_utxo_count)]
random_value = partial(randrange, coin.VALUE_PER_COIN * 10)
self.db_utxos = {prevout: (choice(self.hashXs), random_value())
for prevout in prevouts}
unspent_utxos = self.db_utxos.copy()
for n in range(mempool_size):
tx, tx_hash, raw_tx = random_tx(hash160s, unspent_utxos)
self.raw_txs[tx_hash] = raw_tx
self.txs[tx_hash] = tx
self.ordered_adds.append(tx_hash)
def mempool_utxos(self):
utxos = {}
for tx_hash, tx in self.txs.items():
for n, output in enumerate(tx.outputs):
hashX = coin.hashX_from_script(output.pk_script)
utxos[(tx_hash, n)] = (hashX, output.value)
return utxos
def mempool_spends(self):
return [(input.prev_hash, input.prev_idx)
for tx in self.txs.values() for input in tx.inputs]
def balance_deltas(self):
# Return mempool balance deltas indexed by hashX
deltas = defaultdict(int)
utxos = self.mempool_utxos()
for tx_hash, tx in self.txs.items():
for n, input in enumerate(tx.inputs):
prevout = (input.prev_hash, input.prev_idx)
if prevout in utxos:
utxos.pop(prevout)
else:
hashX, value = self.db_utxos[prevout]
deltas[hashX] -= value
for hashX, value in utxos.values():
deltas[hashX] += value
return deltas
def spends(self):
# Return spends indexed by hashX
spends = defaultdict(list)
utxos = self.mempool_utxos()
for tx_hash, tx in self.txs.items():
for n, input in enumerate(tx.inputs):
prevout = (input.prev_hash, input.prev_idx)
if prevout in utxos:
hashX, value = utxos.pop(prevout)
else:
hashX, value = self.db_utxos[prevout]
spends[hashX].append(prevout)
return spends
def summaries(self):
# Return lists of (tx_hash, fee, has_unconfirmed_inputs) by hashX
summaries = defaultdict(list)
utxos = self.mempool_utxos()
for tx_hash, tx in self.txs.items():
fee = 0
hashXs = set()
has_ui = False
for n, input in enumerate(tx.inputs):
has_ui = has_ui or (input.prev_hash in self.txs)
prevout = (input.prev_hash, input.prev_idx)
if prevout in utxos:
hashX, value = utxos[prevout]
else:
hashX, value = self.db_utxos[prevout]
hashXs.add(hashX)
fee += value
for output in tx.outputs:
hashXs.add(coin.hashX_from_script(output.pk_script))
fee -= output.value
summary = (tx_hash, fee, has_ui)
for hashX in hashXs:
summaries[hashX].append(summary)
return summaries
def touched(self, tx_hashes):
touched = set()
utxos = self.mempool_utxos()
for tx_hash in tx_hashes:
tx = self.txs[tx_hash]
for n, input in enumerate(tx.inputs):
prevout = (input.prev_hash, input.prev_idx)
if prevout in utxos:
hashX, value = utxos[prevout]
else:
hashX, value = self.db_utxos[prevout]
touched.add(hashX)
for output in tx.outputs:
touched.add(coin.hashX_from_script(output.pk_script))
return touched
def UTXOs(self):
# Return lists of UTXO 5-tuples by hashX
utxos = defaultdict(list)
for tx_hash, tx in self.txs.items():
for n, output in enumerate(tx.outputs):
hashX = coin.hashX_from_script(output.pk_script)
utxos[hashX].append((-1, n, tx_hash, 0, output.value))
return utxos
async def height(self):
await sleep(0)
self._cached_height = self._height
return self._height
def cached_height(self):
return self._cached_height
async def mempool_hashes(self):
'''Query bitcoind for the hashes of all transactions in its
mempool, returned as a list.'''
await sleep(0)
return [hash_to_hex_str(hash) for hash in self.txs]
async def raw_transactions(self, hex_hashes):
'''Query bitcoind for the serialized raw transactions with the given
hashes. Missing transactions are returned as None.
hex_hashes is an iterable of hexadecimal hash strings.'''
await sleep(0)
hashes = [hex_str_to_hash(hex_hash) for hex_hash in hex_hashes]
return [self.raw_txs.get(hash) for hash in hashes]
async def lookup_utxos(self, prevouts):
'''Return a list of (hashX, value) pairs each prevout if unspent,
otherwise return None if spent or not found.
prevouts - an iterable of (hash, index) pairs
'''
await sleep(0)
return [self.db_utxos.get(prevout) for prevout in prevouts]
async def on_mempool(self, touched, height):
'''Called each time the mempool is synchronized. touched is a set of
hashXs touched since the previous call. height is the
daemon's height at the time the mempool was obtained.'''
self.on_mempool_calls.append((touched, height))
await sleep(0)
class DropAPI(API):
def __init__(self, drop_count):
super().__init__()
self.drop_count = drop_count
self.dropped = False
async def raw_transactions(self, hex_hashes):
if not self.dropped:
self.dropped = True
for hash in self.ordered_adds[-self.drop_count:]:
del self.raw_txs[hash]
del self.txs[hash]
return await super().raw_transactions(hex_hashes)
def in_caplog(caplog, message):
return any(message in record.message for record in caplog.records)
@pytest.mark.asyncio
async def test_keep_synchronized(caplog):
api = API()
mempool = MemPool(coin, api)
event = Event()
with caplog.at_level(logging.INFO):
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
await group.cancel_remaining()
assert in_caplog(caplog, 'beginning processing of daemon mempool')
assert in_caplog(caplog, 'compact fee histogram')
assert in_caplog(caplog, 'synced in ')
assert in_caplog(caplog, '0 txs touching 0 addresses')
assert not in_caplog(caplog, 'txs dropped')
@pytest.mark.asyncio
async def test_balance_delta():
api = API()
api.initialize()
mempool = MemPool(coin, api)
event = Event()
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
await group.cancel_remaining()
# Check the default dict is handled properly
prior_len = len(mempool.hashXs)
assert await mempool.balance_delta(os.urandom(HASHX_LEN)) == 0
assert prior_len == len(mempool.hashXs)
# Test all hashXs
deltas = api.balance_deltas()
for hashX in api.hashXs:
expected = deltas.get(hashX, 0)
assert await mempool.balance_delta(hashX) == expected
@pytest.mark.asyncio
async def test_compact_fee_histogram():
api = API()
api.initialize()
mempool = MemPool(coin, api)
event = Event()
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
await group.cancel_remaining()
histogram = await mempool.compact_fee_histogram()
assert histogram == []
bin_size = 1000
mempool._update_histogram(bin_size)
histogram = await mempool.compact_fee_histogram()
assert len(histogram) > 0
rates, sizes = zip(*histogram)
assert all(rates[n] < rates[n - 1] for n in range(1, len(rates)))
assert all(size > bin_size * 0.95 for size in sizes)
@pytest.mark.asyncio
async def test_potential_spends():
api = API()
api.initialize()
mempool = MemPool(coin, api)
event = Event()
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
await group.cancel_remaining()
# Check the default dict is handled properly
prior_len = len(mempool.hashXs)
assert await mempool.potential_spends(os.urandom(HASHX_LEN)) == set()
assert prior_len == len(mempool.hashXs)
# Test all hashXs
spends = api.spends()
for hashX in api.hashXs:
ps = await mempool.potential_spends(hashX)
assert all(spend in ps for spend in spends[hashX])
async def _test_summaries(mempool, api):
# Test all hashXs
summaries = api.summaries()
for hashX in api.hashXs:
mempool_result = await mempool.transaction_summaries(hashX)
mempool_result = [(item.hash, item.fee, item.has_unconfirmed_inputs)
for item in mempool_result]
our_result = summaries.get(hashX, [])
assert set(our_result) == set(mempool_result)
@pytest.mark.asyncio
async def test_transaction_summaries(caplog):
api = API()
api.initialize()
mempool = MemPool(coin, api)
event = Event()
with caplog.at_level(logging.INFO):
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
await group.cancel_remaining()
# Check the default dict is handled properly
prior_len = len(mempool.hashXs)
assert await mempool.transaction_summaries(os.urandom(HASHX_LEN)) == []
assert prior_len == len(mempool.hashXs)
await _test_summaries(mempool, api)
assert not in_caplog(caplog, 'txs dropped')
@pytest.mark.asyncio
async def test_unordered_UTXOs():
api = API()
api.initialize()
mempool = MemPool(coin, api)
event = Event()
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
await group.cancel_remaining()
# Check the default dict is handled properly
prior_len = len(mempool.hashXs)
assert await mempool.unordered_UTXOs(os.urandom(HASHX_LEN)) == []
assert prior_len == len(mempool.hashXs)
# Test all hashXs
utxos = api.UTXOs()
for hashX in api.hashXs:
mempool_result = await mempool.unordered_UTXOs(hashX)
our_result = utxos.get(hashX, [])
assert set(our_result) == set(mempool_result)
@pytest.mark.asyncio
async def test_mempool_removals():
api = API()
api.initialize()
mempool = MemPool(coin, api, refresh_secs=0.01)
event = Event()
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
# Remove half the TXs from the mempool
start = len(api.ordered_adds) // 2
for tx_hash in api.ordered_adds[start:]:
del api.txs[tx_hash]
del api.raw_txs[tx_hash]
await event.wait()
await _test_summaries(mempool, api)
# Removed hashXs should have key destroyed
assert all(mempool.hashXs.values())
# Remove the rest
api.txs.clear()
api.raw_txs.clear()
await event.wait()
await _test_summaries(mempool, api)
assert not mempool.hashXs
assert not mempool.txs
await group.cancel_remaining()
@pytest.mark.asyncio
async def test_daemon_drops_txs():
# Tests things work if the daemon drops some transactions between
# returning their hashes and the mempool requesting the raw txs
api = DropAPI(10)
api.initialize()
mempool = MemPool(coin, api, refresh_secs=0.01)
event = Event()
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
await _test_summaries(mempool, api)
await group.cancel_remaining()
@pytest.mark.asyncio
async def test_notifications():
# Tests notifications over a cycle of:
# 1) A first batch of txs come in
# 2) A second batch of txs come in
# 3) A block comes in confirming the first batch only
api = API()
api.initialize()
mempool = MemPool(coin, api, refresh_secs=0.001, log_status_secs=0)
event = Event()
n = len(api.ordered_adds) // 2
raw_txs = api.raw_txs.copy()
txs = api.txs.copy()
first_hashes = api.ordered_adds[:n]
first_touched = api.touched(first_hashes)
second_hashes = api.ordered_adds[n:]
second_touched = api.touched(second_hashes)
async with TaskGroup() as group:
# First batch enters the mempool
api.raw_txs = {hash: raw_txs[hash] for hash in first_hashes}
api.txs = {hash: txs[hash] for hash in first_hashes}
first_utxos = api.mempool_utxos()
first_spends = api.mempool_spends()
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
assert len(api.on_mempool_calls) == 1
touched, height = api.on_mempool_calls[0]
assert height == api._height == api._cached_height
assert touched == first_touched
# Second batch enters the mempool
api.raw_txs = raw_txs
api.txs = txs
await event.wait()
assert len(api.on_mempool_calls) == 2
touched, height = api.on_mempool_calls[1]
assert height == api._height == api._cached_height
# Touched is incremental
assert touched == second_touched
# Block found; first half confirm
new_height = 2
api._height = new_height
api.db_utxos.update(first_utxos)
for spend in first_spends:
del api.db_utxos[spend]
api.raw_txs = {hash: raw_txs[hash] for hash in second_hashes}
api.txs = {hash: txs[hash] for hash in second_hashes}
await event.wait()
assert len(api.on_mempool_calls) == 3
touched, height = api.on_mempool_calls[2]
assert height == api._height == api._cached_height == new_height
assert touched == first_touched
await group.cancel_remaining()
@pytest.mark.asyncio
async def test_dropped_txs(caplog):
api = API()
api.initialize()
mempool = MemPool(coin, api)
event = Event()
# Remove a single TX_HASH that is used in another mempool tx
for prev_hash, prev_idx in api.mempool_spends():
if prev_hash in api.txs:
del api.txs[prev_hash]
with caplog.at_level(logging.INFO):
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
await group.cancel_remaining()
assert in_caplog(caplog, 'txs dropped')
Loading…
Cancel
Save