Browse Source

Merge branch 'mempool-tests' into devel

patch-2
Neil Booth 7 years ago
parent
commit
878976dbee
  1. 14
      electrumx/lib/coins.py
  2. 14
      electrumx/server/controller.py
  3. 153
      electrumx/server/mempool.py
  4. 25
      electrumx/server/session.py
  5. 502
      tests/server/test_mempool.py

14
electrumx/lib/coins.py

@ -211,6 +211,14 @@ class Coin(object):
'''
return ScriptPubKey.P2PK_script(pubkey)
@classmethod
def hash160_to_P2PKH_script(cls, hash160):
return ScriptPubKey.P2PKH_script(hash160)
@classmethod
def hash160_to_P2PKH_hashX(cls, hash160):
return cls.hashX_from_script(cls.hash160_to_P2PKH_script(hash160))
@classmethod
def pay_to_address_script(cls, address):
'''Return a pubkey script that pays to a pubkey hash.
@ -223,12 +231,12 @@ class Coin(object):
verbyte = -1
verlen = len(raw) - 20
if verlen > 0:
verbyte, hash_bytes = raw[:verlen], raw[verlen:]
verbyte, hash160 = raw[:verlen], raw[verlen:]
if verbyte == cls.P2PKH_VERBYTE:
return ScriptPubKey.P2PKH_script(hash_bytes)
return cls.hash160_to_P2PKH_script(hash160)
if verbyte in cls.P2SH_VERBYTES:
return ScriptPubKey.P2SH_script(hash_bytes)
return ScriptPubKey.P2SH_script(hash160)
raise CoinError('invalid address: {}'.format(address))

14
electrumx/server/controller.py

@ -14,7 +14,7 @@ from electrumx.lib.server_base import ServerBase
from electrumx.lib.util import version_string
from electrumx.server.chain_state import ChainState
from electrumx.server.db import DB
from electrumx.server.mempool import MemPool
from electrumx.server.mempool import MemPool, MemPoolAPI
from electrumx.server.session import SessionManager
@ -97,8 +97,18 @@ class Controller(ServerBase):
db = DB(env)
BlockProcessor = env.coin.BLOCK_PROCESSOR
bp = BlockProcessor(env, db, daemon, notifications)
mempool = MemPool(env.coin, daemon, notifications, db.lookup_utxos)
chain_state = ChainState(env, db, daemon, bp)
# Set ourselves up to implement the MemPoolAPI
self.height = daemon.height
self.cached_height = daemon.cached_height
self.mempool_hashes = daemon.mempool_hashes
self.raw_transactions = daemon.getrawtransactions
self.lookup_utxos = db.lookup_utxos
self.on_mempool = notifications.on_mempool
MemPoolAPI.register(Controller)
mempool = MemPool(env.coin, self)
session_mgr = SessionManager(env, chain_state, mempool,
notifications, shutdown_event)

153
electrumx/server/mempool.py

@ -7,13 +7,14 @@
'''Mempool handling.'''
import asyncio
import itertools
import time
from abc import ABC, abstractmethod
from asyncio import Lock
from collections import defaultdict
import attr
from aiorpcx import TaskGroup, run_in_thread
from aiorpcx import TaskGroup, run_in_thread, sleep
from electrumx.lib.hash import hash_to_hex_str, hex_str_to_hash
from electrumx.lib.util import class_logger, chunks
@ -30,9 +31,60 @@ class MemPoolTx(object):
size = attr.ib()
@attr.s(slots=True)
class MemPoolTxSummary(object):
hash = attr.ib()
fee = attr.ib()
has_unconfirmed_inputs = attr.ib()
class MemPoolAPI(ABC):
'''A concrete instance of this class is passed to the MemPool object
and used by it to query DB and blockchain state.'''
@abstractmethod
async def height(self):
'''Query bitcoind for its height.'''
@abstractmethod
def cached_height(self):
'''Return the height of bitcoind the last time it was queried,
for any reason, without actually querying it.
'''
@abstractmethod
async def mempool_hashes(self):
'''Query bitcoind for the hashes of all transactions in its
mempool, returned as a list.'''
@abstractmethod
async def raw_transactions(self, hex_hashes):
'''Query bitcoind for the serialized raw transactions with the given
hashes. Missing transactions are returned as None.
hex_hashes is an iterable of hexadecimal hash strings.'''
@abstractmethod
async def lookup_utxos(self, prevouts):
'''Return a list of (hashX, value) pairs each prevout if unspent,
otherwise return None if spent or not found.
prevouts - an iterable of (hash, index) pairs
'''
@abstractmethod
async def on_mempool(self, touched, height):
'''Called each time the mempool is synchronized. touched is a set of
hashXs touched since the previous call. height is the
daemon's height at the time the mempool was obtained.'''
class MemPool(object):
'''Representation of the daemon's mempool.
coin - a coin class from coins.py
api - an object implementing MemPoolAPI
Updated regularly in caught-up state. Goal is to enable efficient
response to the calls in the external interface. To that end we
maintain the following maps:
@ -41,23 +93,42 @@ class MemPool(object):
hashXs: hashX -> set of all hashes of txs touching the hashX
'''
def __init__(self, coin, daemon, notifications, lookup_utxos):
self.logger = class_logger(__name__, self.__class__.__name__)
def __init__(self, coin, api, refresh_secs=5.0, log_status_secs=120.0):
assert isinstance(api, MemPoolAPI)
self.coin = coin
self.lookup_utxos = lookup_utxos
self.daemon = daemon
self.notifications = notifications
self.api = api
self.logger = class_logger(__name__, self.__class__.__name__)
self.txs = {}
self.hashXs = defaultdict(set) # None can be a key
self.cached_compact_histogram = []
self.refresh_secs = refresh_secs
self.log_status_secs = log_status_secs
# Prevents mempool refreshes during fee histogram calculation
self.lock = Lock()
async def _log_stats(self):
async def _logging(self, synchronized_event):
'''Print regular logs of mempool stats.'''
self.logger.info('beginning processing of daemon mempool. '
'This can take some time...')
start = time.time()
await synchronized_event.wait()
elapsed = time.time() - start
self.logger.info(f'synced in {elapsed:.2f}s')
while True:
self.logger.info(f'{len(self.txs):,d} txs '
f'touching {len(self.hashXs):,d} addresses')
await asyncio.sleep(120)
await sleep(self.log_status_secs)
await synchronized_event.wait()
async def _refresh_histogram(self, synchronized_event):
while True:
await synchronized_event.wait()
async with self.lock:
# Threaded as can be expensive
await run_in_thread(self._update_histogram, 100_000)
await sleep(self.coin.MEMPOOL_HISTOGRAM_REFRESH_SECS)
def _update_histogram(self):
def _update_histogram(self, bin_size):
# Build a histogram by fee rate
histogram = defaultdict(int)
for tx in self.txs.values():
@ -74,7 +145,6 @@ class MemPool(object):
compact = []
cum_size = 0
r = 0 # ?
bin_size = 100 * 1000
for fee_rate, size in sorted(histogram.items(), reverse=True):
cum_size += size
if cum_size + r > bin_size:
@ -129,21 +199,18 @@ class MemPool(object):
async def _refresh_hashes(self, synchronized_event):
'''Refresh our view of the daemon's mempool.'''
sleep = 5
histogram_refresh = self.coin.MEMPOOL_HISTOGRAM_REFRESH_SECS // sleep
for loop_count in itertools.count():
height = self.daemon.cached_height()
hex_hashes = await self.daemon.mempool_hashes()
if height != await self.daemon.height():
while True:
height = self.api.cached_height()
hex_hashes = await self.api.mempool_hashes()
if height != await self.api.height():
continue
hashes = set(hex_str_to_hash(hh) for hh in hex_hashes)
touched = await self._process_mempool(hashes)
async with self.lock:
touched = await self._process_mempool(hashes)
synchronized_event.set()
await self.notifications.on_mempool(touched, height)
# Thread mempool histogram refreshes - they can be expensive
if loop_count % histogram_refresh == 0:
await run_in_thread(self._update_histogram)
await asyncio.sleep(sleep)
synchronized_event.clear()
await self.api.on_mempool(touched, height)
await sleep(self.refresh_secs)
async def _process_mempool(self, all_hashes):
# Re-sync with the new set of hashes
@ -176,9 +243,6 @@ class MemPool(object):
tx_map.update(deferred)
utxo_map.update(unspent)
# Handle the stragglers
if len(tx_map) >= 10:
self.logger.info(f'{len(tx_map)} stragglers')
prior_count = 0
# FIXME: this is not particularly efficient
while tx_map and len(tx_map) != prior_count:
@ -193,7 +257,7 @@ class MemPool(object):
async def _fetch_and_accept(self, hashes, all_hashes, touched):
'''Fetch a list of mempool transactions.'''
hex_hashes_iter = (hash_to_hex_str(hash) for hash in hashes)
raw_txs = await self.daemon.getrawtransactions(hex_hashes_iter)
raw_txs = await self.api.raw_transactions(hex_hashes_iter)
def deserialize_txs(): # This function is pure
to_hashX = self.coin.hashX_from_script
@ -225,7 +289,7 @@ class MemPool(object):
prevouts = tuple(prevout for tx in tx_map.values()
for prevout in tx.prevouts
if prevout[0] not in all_hashes)
utxos = await self.lookup_utxos(prevouts)
utxos = await self.api.lookup_utxos(prevouts)
utxo_map = {prevout: utxo for prevout, utxo in zip(prevouts, utxos)}
return self._accept_transactions(tx_map, utxo_map, touched)
@ -235,19 +299,11 @@ class MemPool(object):
#
async def keep_synchronized(self, synchronized_event):
'''Starts the mempool synchronizer.
Waits for an initial synchronization before returning.
'''
self.logger.info('beginning processing of daemon mempool. '
'This can take some time...')
async with TaskGroup() as group:
'''Keep the mempool synchronized with the daemon.'''
async with TaskGroup(wait=any) as group:
await group.spawn(self._refresh_hashes(synchronized_event))
start = time.time()
await synchronized_event.wait()
elapsed = time.time() - start
self.logger.info(f'synced in {elapsed:.2f}s')
await group.spawn(self._log_stats())
await group.spawn(self._refresh_histogram(synchronized_event))
await group.spawn(self._logging(synchronized_event))
async def balance_delta(self, hashX):
'''Return the unconfirmed amount in the mempool for hashX.
@ -255,7 +311,6 @@ class MemPool(object):
Can be positive or negative.
'''
value = 0
# hashXs is a defaultdict
if hashX in self.hashXs:
for hash in self.hashXs[hashX]:
tx = self.txs[hash]
@ -271,7 +326,8 @@ class MemPool(object):
'''Return a set of (prev_hash, prev_idx) pairs from mempool
transactions that touch hashX.
None, some or all of these may be spends of the hashX.
None, some or all of these may be spends of the hashX, but all
actual spends of it (in the DB or mempool) will be included.
'''
result = set()
for tx_hash in self.hashXs.get(hashX, ()):
@ -280,18 +336,12 @@ class MemPool(object):
return result
async def transaction_summaries(self, hashX):
'''Return a list of (tx_hash, tx_fee, unconfirmed) tuples for
mempool entries for the hashX.
unconfirmed is True if any txin is unconfirmed.
'''
# hashXs is a defaultdict, so use get() to query
'''Return a list of MemPoolTxSummary objects for the hashX.'''
result = []
for tx_hash in self.hashXs.get(hashX, ()):
tx = self.txs[tx_hash]
unconfirmed = any(prev_hash in self.txs
for prev_hash, prev_idx in tx.prevouts)
result.append((tx_hash, tx.fee, unconfirmed))
has_ui = any(hash in self.txs for hash, idx in tx.prevouts)
result.append(MemPoolTxSummary(tx_hash, tx.fee, has_ui))
return result
async def unordered_UTXOs(self, hashX):
@ -302,7 +352,6 @@ class MemPool(object):
the outputs.
'''
utxos = []
# hashXs is a defaultdict, so use get() to query
for tx_hash in self.hashXs.get(hashX, ()):
tx = self.txs.get(tx_hash)
for pos, (hX, value) in enumerate(tx.out_pairs):

25
electrumx/server/session.py

@ -779,15 +779,16 @@ class ElectrumX(SessionBase):
Status is a hex string, but must be None if there is no history.
'''
# Note history is ordered and mempool unordered in electrum-server
# For mempool, height is -1 if unconfirmed txins, otherwise 0
history = await self.session_mgr.limited_history(hashX)
# For mempool, height is -1 if it has unconfirmed inputs, otherwise 0
db_history = await self.session_mgr.limited_history(hashX)
mempool = await self.mempool.transaction_summaries(hashX)
status = ''.join('{}:{:d}:'.format(hash_to_hex_str(tx_hash), height)
for tx_hash, height in history)
status += ''.join('{}:{:d}:'.format(hash_to_hex_str(hex_hash),
-unconfirmed)
for hex_hash, tx_fee, unconfirmed in mempool)
status = ''.join(f'{hash_to_hex_str(tx_hash)}:'
f'{height:d}:'
for tx_hash, height in db_history)
status += ''.join(f'{hash_to_hex_str(tx.hash)}:'
f'{-tx.has_unconfirmed_inputs:d}:'
for tx in mempool)
if status:
status = sha256(status.encode()).hex()
else:
@ -872,11 +873,11 @@ class ElectrumX(SessionBase):
async def unconfirmed_history(self, hashX):
# Note unconfirmed history is unordered in electrum-server
# Height is -1 if unconfirmed txins, otherwise 0
mempool = await self.mempool.transaction_summaries(hashX)
return [{'tx_hash': hash_to_hex_str(tx_hash), 'height': -unconfirmed,
'fee': fee}
for tx_hash, fee, unconfirmed in mempool]
# height is -1 if it has unconfirmed inputs, otherwise 0
return [{'tx_hash': hash_to_hex_str(tx.hash),
'height': -tx.has_unconfirmed_inputs,
'fee': tx.fee}
for tx in await self.mempool.transaction_summaries(hashX)]
async def confirmed_and_unconfirmed_history(self, hashX):
# Note history is ordered but unconfirmed is unordered in e-s

502
tests/server/test_mempool.py

@ -0,0 +1,502 @@
import logging
import os
from collections import defaultdict
from functools import partial
from random import randrange, choice
import pytest
from aiorpcx import Event, TaskGroup, sleep, spawn, ignore_after
from electrumx.server.mempool import MemPool, MemPoolAPI
from electrumx.lib.coins import BitcoinCash
from electrumx.lib.hash import HASHX_LEN, hex_str_to_hash, hash_to_hex_str
from electrumx.lib.tx import Tx, TxInput, TxOutput
from electrumx.lib.util import make_logger
coin = BitcoinCash
tx_hash_fn = coin.DESERIALIZER.TX_HASH_FN
def random_tx(hash160s, utxos):
'''Create a random TX paying to some of the hash160s using some of the
UTXOS. Return the TX. UTXOs is updated for the effects of the TX.
'''
inputs = []
n_inputs = min(randrange(1, 4), len(utxos))
input_value = 0
# Create inputs spending random UTXOs. total the inpu
for n in range(n_inputs):
prevout = choice(list(utxos))
hashX, value = utxos.pop(prevout)
inputs.append(TxInput(prevout[0], prevout[1], b'', 4294967295))
input_value += value
fee = min(input_value, randrange(500))
input_value -= fee
outputs = []
n_outputs = randrange(1, 4)
for n in range(n_outputs):
value = randrange(input_value)
input_value -= value
pk_script = coin.hash160_to_P2PKH_script(choice(hash160s))
outputs.append(TxOutput(value, pk_script))
tx = Tx(2, inputs, outputs, 0)
tx_bytes = tx.serialize()
tx_hash = tx_hash_fn(tx_bytes)
for n, output in enumerate(tx.outputs):
utxos[(tx_hash, n)] = (coin.hashX_from_script(output.pk_script),
output.value)
return tx, tx_hash, tx_bytes
class API(MemPoolAPI):
def __init__(self):
self._height = 0
self._cached_height = self._height
# Create a pool of hash160s. Map them to their script hashes
# Create a bunch of UTXOs paying to those script hashes
# Create a bunch of TXs that spend from the UTXO set and create
# new outpus, which are added to the UTXO set for later TXs to
# spend
self.db_utxos = {}
self.on_mempool_calls = []
self.hashXs = []
# Maps of mempool txs from tx_hash to raw and Tx object forms
self.raw_txs = {}
self.txs = {}
self.ordered_adds = []
def initialize(self, addr_count=100, db_utxo_count=100, mempool_size=50):
hash160s = [os.urandom(20) for n in range(addr_count)]
self.hashXs = [coin.hash160_to_P2PKH_hashX(hash160)
for hash160 in hash160s]
prevouts = [(os.urandom(32), randrange(0, 10))
for n in range (db_utxo_count)]
random_value = partial(randrange, coin.VALUE_PER_COIN * 10)
self.db_utxos = {prevout: (choice(self.hashXs), random_value())
for prevout in prevouts}
unspent_utxos = self.db_utxos.copy()
for n in range(mempool_size):
tx, tx_hash, raw_tx = random_tx(hash160s, unspent_utxos)
self.raw_txs[tx_hash] = raw_tx
self.txs[tx_hash] = tx
self.ordered_adds.append(tx_hash)
def mempool_utxos(self):
utxos = {}
for tx_hash, tx in self.txs.items():
for n, output in enumerate(tx.outputs):
hashX = coin.hashX_from_script(output.pk_script)
utxos[(tx_hash, n)] = (hashX, output.value)
return utxos
def mempool_spends(self):
return [(input.prev_hash, input.prev_idx)
for tx in self.txs.values() for input in tx.inputs]
def balance_deltas(self):
# Return mempool balance deltas indexed by hashX
deltas = defaultdict(int)
utxos = self.mempool_utxos()
for tx_hash, tx in self.txs.items():
for n, input in enumerate(tx.inputs):
prevout = (input.prev_hash, input.prev_idx)
if prevout in utxos:
utxos.pop(prevout)
else:
hashX, value = self.db_utxos[prevout]
deltas[hashX] -= value
for hashX, value in utxos.values():
deltas[hashX] += value
return deltas
def spends(self):
# Return spends indexed by hashX
spends = defaultdict(list)
utxos = self.mempool_utxos()
for tx_hash, tx in self.txs.items():
for n, input in enumerate(tx.inputs):
prevout = (input.prev_hash, input.prev_idx)
if prevout in utxos:
hashX, value = utxos.pop(prevout)
else:
hashX, value = self.db_utxos[prevout]
spends[hashX].append(prevout)
return spends
def summaries(self):
# Return lists of (tx_hash, fee, has_unconfirmed_inputs) by hashX
summaries = defaultdict(list)
utxos = self.mempool_utxos()
for tx_hash, tx in self.txs.items():
fee = 0
hashXs = set()
has_ui = False
for n, input in enumerate(tx.inputs):
has_ui = has_ui or (input.prev_hash in self.txs)
prevout = (input.prev_hash, input.prev_idx)
if prevout in utxos:
hashX, value = utxos[prevout]
else:
hashX, value = self.db_utxos[prevout]
hashXs.add(hashX)
fee += value
for output in tx.outputs:
hashXs.add(coin.hashX_from_script(output.pk_script))
fee -= output.value
summary = (tx_hash, fee, has_ui)
for hashX in hashXs:
summaries[hashX].append(summary)
return summaries
def touched(self, tx_hashes):
touched = set()
utxos = self.mempool_utxos()
for tx_hash in tx_hashes:
tx = self.txs[tx_hash]
for n, input in enumerate(tx.inputs):
prevout = (input.prev_hash, input.prev_idx)
if prevout in utxos:
hashX, value = utxos[prevout]
else:
hashX, value = self.db_utxos[prevout]
touched.add(hashX)
for output in tx.outputs:
touched.add(coin.hashX_from_script(output.pk_script))
return touched
def UTXOs(self):
# Return lists of UTXO 5-tuples by hashX
utxos = defaultdict(list)
for tx_hash, tx in self.txs.items():
for n, output in enumerate(tx.outputs):
hashX = coin.hashX_from_script(output.pk_script)
utxos[hashX].append((-1, n, tx_hash, 0, output.value))
return utxos
async def height(self):
await sleep(0)
self._cached_height = self._height
return self._height
def cached_height(self):
return self._cached_height
async def mempool_hashes(self):
'''Query bitcoind for the hashes of all transactions in its
mempool, returned as a list.'''
await sleep(0)
return [hash_to_hex_str(hash) for hash in self.txs]
async def raw_transactions(self, hex_hashes):
'''Query bitcoind for the serialized raw transactions with the given
hashes. Missing transactions are returned as None.
hex_hashes is an iterable of hexadecimal hash strings.'''
await sleep(0)
hashes = [hex_str_to_hash(hex_hash) for hex_hash in hex_hashes]
return [self.raw_txs.get(hash) for hash in hashes]
async def lookup_utxos(self, prevouts):
'''Return a list of (hashX, value) pairs each prevout if unspent,
otherwise return None if spent or not found.
prevouts - an iterable of (hash, index) pairs
'''
await sleep(0)
return [self.db_utxos.get(prevout) for prevout in prevouts]
async def on_mempool(self, touched, height):
'''Called each time the mempool is synchronized. touched is a set of
hashXs touched since the previous call. height is the
daemon's height at the time the mempool was obtained.'''
self.on_mempool_calls.append((touched, height))
await sleep(0)
class DropAPI(API):
def __init__(self, drop_count):
super().__init__()
self.drop_count = drop_count
self.dropped = False
async def raw_transactions(self, hex_hashes):
if not self.dropped:
self.dropped = True
for hash in self.ordered_adds[-self.drop_count:]:
del self.raw_txs[hash]
del self.txs[hash]
return await super().raw_transactions(hex_hashes)
def in_caplog(caplog, message):
return any(message in record.message for record in caplog.records)
@pytest.mark.asyncio
async def test_keep_synchronized(caplog):
api = API()
mempool = MemPool(coin, api)
event = Event()
with caplog.at_level(logging.INFO):
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
await group.cancel_remaining()
assert in_caplog(caplog, 'beginning processing of daemon mempool')
assert in_caplog(caplog, 'compact fee histogram')
assert in_caplog(caplog, 'synced in ')
assert in_caplog(caplog, '0 txs touching 0 addresses')
assert not in_caplog(caplog, 'txs dropped')
@pytest.mark.asyncio
async def test_balance_delta():
api = API()
api.initialize()
mempool = MemPool(coin, api)
event = Event()
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
await group.cancel_remaining()
# Check the default dict is handled properly
prior_len = len(mempool.hashXs)
assert await mempool.balance_delta(os.urandom(HASHX_LEN)) == 0
assert prior_len == len(mempool.hashXs)
# Test all hashXs
deltas = api.balance_deltas()
for hashX in api.hashXs:
expected = deltas.get(hashX, 0)
assert await mempool.balance_delta(hashX) == expected
@pytest.mark.asyncio
async def test_compact_fee_histogram():
api = API()
api.initialize()
mempool = MemPool(coin, api)
event = Event()
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
await group.cancel_remaining()
histogram = await mempool.compact_fee_histogram()
assert histogram == []
bin_size = 1000
mempool._update_histogram(bin_size)
histogram = await mempool.compact_fee_histogram()
assert len(histogram) > 0
rates, sizes = zip(*histogram)
assert all(rates[n] < rates[n - 1] for n in range(1, len(rates)))
assert all(size > bin_size * 0.95 for size in sizes)
@pytest.mark.asyncio
async def test_potential_spends():
api = API()
api.initialize()
mempool = MemPool(coin, api)
event = Event()
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
await group.cancel_remaining()
# Check the default dict is handled properly
prior_len = len(mempool.hashXs)
assert await mempool.potential_spends(os.urandom(HASHX_LEN)) == set()
assert prior_len == len(mempool.hashXs)
# Test all hashXs
spends = api.spends()
for hashX in api.hashXs:
ps = await mempool.potential_spends(hashX)
assert all(spend in ps for spend in spends[hashX])
async def _test_summaries(mempool, api):
# Test all hashXs
summaries = api.summaries()
for hashX in api.hashXs:
mempool_result = await mempool.transaction_summaries(hashX)
mempool_result = [(item.hash, item.fee, item.has_unconfirmed_inputs)
for item in mempool_result]
our_result = summaries.get(hashX, [])
assert set(our_result) == set(mempool_result)
@pytest.mark.asyncio
async def test_transaction_summaries(caplog):
api = API()
api.initialize()
mempool = MemPool(coin, api)
event = Event()
with caplog.at_level(logging.INFO):
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
await group.cancel_remaining()
# Check the default dict is handled properly
prior_len = len(mempool.hashXs)
assert await mempool.transaction_summaries(os.urandom(HASHX_LEN)) == []
assert prior_len == len(mempool.hashXs)
await _test_summaries(mempool, api)
assert not in_caplog(caplog, 'txs dropped')
@pytest.mark.asyncio
async def test_unordered_UTXOs():
api = API()
api.initialize()
mempool = MemPool(coin, api)
event = Event()
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
await group.cancel_remaining()
# Check the default dict is handled properly
prior_len = len(mempool.hashXs)
assert await mempool.unordered_UTXOs(os.urandom(HASHX_LEN)) == []
assert prior_len == len(mempool.hashXs)
# Test all hashXs
utxos = api.UTXOs()
for hashX in api.hashXs:
mempool_result = await mempool.unordered_UTXOs(hashX)
our_result = utxos.get(hashX, [])
assert set(our_result) == set(mempool_result)
@pytest.mark.asyncio
async def test_mempool_removals():
api = API()
api.initialize()
mempool = MemPool(coin, api, refresh_secs=0.01)
event = Event()
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
# Remove half the TXs from the mempool
start = len(api.ordered_adds) // 2
for tx_hash in api.ordered_adds[start:]:
del api.txs[tx_hash]
del api.raw_txs[tx_hash]
await event.wait()
await _test_summaries(mempool, api)
# Removed hashXs should have key destroyed
assert all(mempool.hashXs.values())
# Remove the rest
api.txs.clear()
api.raw_txs.clear()
await event.wait()
await _test_summaries(mempool, api)
assert not mempool.hashXs
assert not mempool.txs
await group.cancel_remaining()
@pytest.mark.asyncio
async def test_daemon_drops_txs():
# Tests things work if the daemon drops some transactions between
# returning their hashes and the mempool requesting the raw txs
api = DropAPI(10)
api.initialize()
mempool = MemPool(coin, api, refresh_secs=0.01)
event = Event()
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
await _test_summaries(mempool, api)
await group.cancel_remaining()
@pytest.mark.asyncio
async def test_notifications():
# Tests notifications over a cycle of:
# 1) A first batch of txs come in
# 2) A second batch of txs come in
# 3) A block comes in confirming the first batch only
api = API()
api.initialize()
mempool = MemPool(coin, api, refresh_secs=0.001, log_status_secs=0)
event = Event()
n = len(api.ordered_adds) // 2
raw_txs = api.raw_txs.copy()
txs = api.txs.copy()
first_hashes = api.ordered_adds[:n]
first_touched = api.touched(first_hashes)
second_hashes = api.ordered_adds[n:]
second_touched = api.touched(second_hashes)
async with TaskGroup() as group:
# First batch enters the mempool
api.raw_txs = {hash: raw_txs[hash] for hash in first_hashes}
api.txs = {hash: txs[hash] for hash in first_hashes}
first_utxos = api.mempool_utxos()
first_spends = api.mempool_spends()
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
assert len(api.on_mempool_calls) == 1
touched, height = api.on_mempool_calls[0]
assert height == api._height == api._cached_height
assert touched == first_touched
# Second batch enters the mempool
api.raw_txs = raw_txs
api.txs = txs
await event.wait()
assert len(api.on_mempool_calls) == 2
touched, height = api.on_mempool_calls[1]
assert height == api._height == api._cached_height
# Touched is incremental
assert touched == second_touched
# Block found; first half confirm
new_height = 2
api._height = new_height
api.db_utxos.update(first_utxos)
for spend in first_spends:
del api.db_utxos[spend]
api.raw_txs = {hash: raw_txs[hash] for hash in second_hashes}
api.txs = {hash: txs[hash] for hash in second_hashes}
await event.wait()
assert len(api.on_mempool_calls) == 3
touched, height = api.on_mempool_calls[2]
assert height == api._height == api._cached_height == new_height
assert touched == first_touched
await group.cancel_remaining()
@pytest.mark.asyncio
async def test_dropped_txs(caplog):
api = API()
api.initialize()
mempool = MemPool(coin, api)
event = Event()
# Remove a single TX_HASH that is used in another mempool tx
for prev_hash, prev_idx in api.mempool_spends():
if prev_hash in api.txs:
del api.txs[prev_hash]
with caplog.at_level(logging.INFO):
async with TaskGroup() as group:
await group.spawn(mempool.keep_synchronized, event)
await event.wait()
await group.cancel_remaining()
assert in_caplog(caplog, 'txs dropped')
Loading…
Cancel
Save