Browse Source

Completely overhaul mempool sync logic

- highly concurrent and a lot more efficient than previously
- initial mempool sync should be much faster (feedback please)
- mempool processing no longer blocks client session handling
- uses less memory to store the mempool
- fixes an obscure bug where sometimes txs were dropped
- more robust, clean and easy to understand

Fixes #433
patch-2
Neil Booth 7 years ago
parent
commit
0963ce5230
  1. 319
      electrumx/server/mempool.py
  2. 6
      electrumx/server/session.py

319
electrumx/server/mempool.py

@ -15,13 +15,12 @@ from collections import defaultdict
import attr import attr
from electrumx.lib.hash import hash_to_hex_str, hex_str_to_hash from electrumx.lib.hash import hash_to_hex_str, hex_str_to_hash
from electrumx.lib.util import class_logger from electrumx.lib.util import class_logger, chunks
from electrumx.server.db import UTXO from electrumx.server.db import UTXO
@attr.s(slots=True) @attr.s(slots=True)
class MemPoolTx(object): class MemPoolTx(object):
hash = attr.ib()
in_pairs = attr.ib() in_pairs = attr.ib()
out_pairs = attr.ib() out_pairs = attr.ib()
fee = attr.ib() fee = attr.ib()
@ -36,10 +35,10 @@ class MemPool(object):
To that end we maintain the following maps: To that end we maintain the following maps:
tx_hash -> MemPoolTx tx_hash -> MemPoolTx (in_paris
hashX -> set of all tx hashes in which the hashX appears hashX -> set of all tx hashes in which the hashX appears
A pair is a (hashX, value) tuple. tx hashes are hex strings. A pair is a (hashX, value) tuple. tx hashes are binary not strings.
''' '''
def __init__(self, coin, tasks, daemon, notifications, lookup_utxos): def __init__(self, coin, tasks, daemon, notifications, lookup_utxos):
@ -61,77 +60,81 @@ class MemPool(object):
f'touching {len(self.hashXs):,d} addresses') f'touching {len(self.hashXs):,d} addresses')
await asyncio.sleep(120) await asyncio.sleep(120)
async def _synchronize_forever(self): def _accept_transactions(self, tx_map, utxo_map, touched):
while True: '''Accept transactions in tx_map to the mempool if all their inputs
await asyncio.sleep(5) can be found in the existing mempool or a utxo_map from the
await self._synchronize(False) DB.
async def _refresh_hashes(self):
'''Return a (hash set, height) pair when we're sure which height they
are for.'''
while True:
height = self.daemon.cached_height()
hashes = await self.daemon.mempool_hashes()
if height == await self.daemon.height():
return set(hashes), height
async def _synchronize(self, first_time):
'''Asynchronously maintain mempool status with daemon.
Processes the mempool each time the mempool refresh event is Returns an (unprocessed tx_map, unspent utxo_map) pair.
signalled.
''' '''
unprocessed = {} hashXs = self.hashXs
unfetched = set()
touched = set()
txs = self.txs txs = self.txs
next_refresh = 0 fee_hist = self.fee_histogram
fetch_size = 800 init_count = len(utxo_map)
process_some = self._async_process_some(fetch_size // 2)
while True: deferred = {}
now = time.time() unspent = set(utxo_map)
# If processing a large mempool, a block being found might # Try to find all previns so we can accept the TX
# shrink our work considerably, so refresh our view every 20s for hash, tx in tx_map.items():
if now > next_refresh: in_pairs = []
hashes, height = await self._refresh_hashes() try:
self._resync_hashes(hashes, unprocessed, unfetched, touched) for previn in tx.in_pairs:
next_refresh = time.time() + 20 utxo = utxo_map.get(previn)
if not utxo:
# Log progress of initial sync prev_hash, prev_index = previn
todo = len(unfetched) + len(unprocessed) # Raises KeyError if prev_hash is not in txs
if first_time: utxo = txs[prev_hash].out_pairs[prev_index]
pct = (len(txs) - todo) * 100 // len(txs) if txs else 0 in_pairs.append(utxo)
self.logger.info(f'catchup {pct:d}% complete ' except KeyError:
f'({todo:,d} txs left)') deferred[hash] = tx
if not todo: continue
break
# FIXME: parallelize
if unfetched:
count = min(len(unfetched), fetch_size)
hex_hashes = [unfetched.pop() for n in range(count)]
unprocessed.update(await self._fetch_raw_txs(hex_hashes))
if unprocessed:
await process_some(unprocessed, touched)
await self.notifications.on_mempool(touched, height) # Spend the previns
unspent.difference_update(tx.in_pairs)
# Convert in_pairs and add the TX to
tx.in_pairs = in_pairs
# Compute fee
tx_fee = (sum(v for hashX, v in tx.in_pairs) -
sum(v for hashX, v in tx.out_pairs))
fee_rate = tx.fee // tx.size
fee_hist[fee_rate] += tx.size
txs[hash] = tx
for hashX, value in itertools.chain(tx.in_pairs, tx.out_pairs):
touched.add(hashX)
hashXs[hashX].add(hash)
def _resync_hashes(self, hashes, unprocessed, unfetched, touched): return deferred, {previn: utxo_map[previn] for previn in unspent}
'''Re-sync self.txs with the list of hashes in the daemon's mempool.
Additionally, remove gone hashes from unprocessed and async def _refresh_hashes(self, single_pass):
unfetched. Add new ones to unprocessed. '''Return a (hash set, height) pair when we're sure which height they
''' are for.'''
refresh_event = asyncio.Event()
loop = self.tasks.loop
while True:
height = self.daemon.cached_height()
hex_hashes = await self.daemon.mempool_hashes()
if height != await self.daemon.height():
continue
loop.call_later(5, refresh_event.set)
hashes = set(hex_str_to_hash(hh) for hh in hex_hashes)
touched = await self._process_mempool(hashes)
await self.notifications.on_mempool(touched, height)
if single_pass:
return
await refresh_event.wait()
refresh_event.clear()
async def _process_mempool(self, all_hashes):
# Re-sync with the new set of hashes
txs = self.txs txs = self.txs
hashXs = self.hashXs hashXs = self.hashXs
touched = set()
fee_hist = self.fee_histogram fee_hist = self.fee_histogram
gone = set(txs).difference(hashes)
for hex_hash in gone: # First handle txs that have disappeared
unfetched.discard(hex_hash) for tx_hash in set(txs).difference(all_hashes):
unprocessed.pop(hex_hash, None) tx = txs.pop(tx_hash)
tx = txs.pop(hex_hash)
if tx:
fee_rate = tx.fee // tx.size fee_rate = tx.fee // tx.size
fee_hist[fee_rate] -= tx.size fee_hist[fee_rate] -= tx.size
if fee_hist[fee_rate] == 0: if fee_hist[fee_rate] == 0:
@ -139,123 +142,82 @@ class MemPool(object):
tx_hashXs = set(hashX for hashX, value in tx.in_pairs) tx_hashXs = set(hashX for hashX, value in tx.in_pairs)
tx_hashXs.update(hashX for hashX, value in tx.out_pairs) tx_hashXs.update(hashX for hashX, value in tx.out_pairs)
for hashX in tx_hashXs: for hashX in tx_hashXs:
hashXs[hashX].remove(hex_hash) hashXs[hashX].remove(tx_hash)
if not hashXs[hashX]: if not hashXs[hashX]:
del hashXs[hashX] del hashXs[hashX]
touched.update(tx_hashXs) touched.update(tx_hashXs)
new = hashes.difference(txs) # Process new transactions
unfetched.update(new) new_hashes = list(all_hashes.difference(txs))
for hex_hash in new: jobs = [self.tasks.create_task(self._fetch_and_accept
txs[hex_hash] = None (hashes, all_hashes, touched))
for hashes in chunks(new_hashes, 2000)]
def _async_process_some(self, limit): if jobs:
pending = [] await asyncio.wait(jobs)
txs = self.txs tx_map = {}
utxo_map = {}
async def process(unprocessed, touched): for job in jobs:
nonlocal pending deferred, unspent = job.result()
tx_map.update(deferred)
raw_txs = {} utxo_map.update(unspent)
while unprocessed and len(raw_txs) < limit: # Handle the stragglers
hex_hash, raw_tx = unprocessed.popitem() if len(tx_map) >= 10:
raw_txs[hex_hash] = raw_tx self.logger.info(f'{len(tx_map)} stragglers')
prior_count = 0
if unprocessed: # FIXME: this is not particularly efficient
deferred = [] while tx_map and len(tx_map) != prior_count:
else: prior_count = len(tx_map)
deferred = pending tx_map, utxo_map = self._accept_transactions(tx_map, utxo_map,
pending = [] touched)
if tx_map:
deferred = await self._process_raw_txs(raw_txs, deferred, touched) self.logger.info(f'{len(tx_map)} txs dropped')
pending.extend(deferred)
return touched
return process
async def _fetch_and_accept(self, hashes, all_hashes, touched):
async def _fetch_raw_txs(self, hex_hashes):
'''Fetch a list of mempool transactions.''' '''Fetch a list of mempool transactions.'''
hex_hashes = [hash_to_hex_str(hash) for hash in hashes]
raw_txs = await self.daemon.getrawtransactions(hex_hashes) raw_txs = await self.daemon.getrawtransactions(hex_hashes)
count = len([raw_tx for raw_tx in raw_txs if raw_tx])
# Skip hashes the daemon has dropped. Either they were
# evicted or they got in a block.
return {hh: raw for hh, raw in zip(hex_hashes, raw_txs) if raw}
async def _process_raw_txs(self, raw_tx_map, pending, touched):
'''Process the dictionary of raw transactions and return a dictionary
of updates to apply to self.txs.
'''
def deserialize_txs(): def deserialize_txs():
# This function is pure
script_hashX = self.coin.hashX_from_script script_hashX = self.coin.hashX_from_script
deserializer = self.coin.DESERIALIZER deserializer = self.coin.DESERIALIZER
# Deserialize each tx and put it in a pending list txs = {}
for tx_hash, raw_tx in raw_tx_map.items(): for hash, raw_tx in zip(hashes, raw_txs):
# The daemon may have evicted the tx from its
# mempool or it may have gotten in a block
if not raw_tx:
continue
tx, tx_size = deserializer(raw_tx).read_tx_and_vsize() tx, tx_size = deserializer(raw_tx).read_tx_and_vsize()
# Convert the tx outputs into (hashX, value) pairs # Convert the tx outputs into (hashX, value) pairs
txout_pairs = [(script_hashX(txout.pk_script), txout.value) txout_pairs = [(script_hashX(txout.pk_script), txout.value)
for txout in tx.outputs] for txout in tx.outputs]
# Convert the tx inputs to ([prev_hex_hash, prev_idx) pairs # Convert the tx inputs to (prev_hash, prev_idx) pairs
txin_pairs = [(hash_to_hex_str(txin.prev_hash), txin.prev_idx) txin_pairs = [(txin.prev_hash, txin.prev_idx)
for txin in tx.inputs] for txin in tx.inputs]
pending.append(MemPoolTx(tx_hash, txin_pairs, txout_pairs, txs[hash] = MemPoolTx(txin_pairs, txout_pairs, 0, tx_size)
0, tx_size)) return txs
# Do this potentially slow operation in a thread so as not to # Thread this potentially slow operation so as not to block
# block tx_map = await self.tasks.run_in_thread(deserialize_txs)
await self.tasks.run_in_thread(deserialize_txs)
# The transaction inputs can be from other mempool
# transactions (which may or may not be processed yet) or are
# otherwise presumably in the DB.
txs = self.txs
db_prevouts = [(hex_str_to_hash(prev_hash), prev_idx)
for tx in pending
for (prev_hash, prev_idx) in tx.in_pairs
if prev_hash not in txs]
# If a lookup fails, it returns a None entry
db_utxos = await self.lookup_utxos(db_prevouts)
db_utxo_map = {(hash_to_hex_str(prev_hash), prev_idx): db_utxo
for (prev_hash, prev_idx), db_utxo
in zip(db_prevouts, db_utxos)}
deferred = []
hashXs = self.hashXs
fee_hist = self.fee_histogram
for tx in pending: # Determine all prevouts not in the mempool, and fetch the
if tx.hash not in txs: # UTXO information from the database. Failed prevout lookups
continue # return None - concurrent database updates happen
prevouts = [tx_in for tx in tx_map.values()for tx_in in tx.in_pairs
in_pairs = [] if tx_in[0] not in all_hashes]
try: utxos = await self.lookup_utxos(prevouts)
for previn in tx.in_pairs: utxo_map = {prevout: utxo for prevout, utxo in zip(prevouts, utxos)}
utxo = db_utxo_map.get(previn)
if not utxo:
prev_hash, prev_index = previn
# This can raise a KeyError or TypeError
utxo = txs[prev_hash][1][prev_index]
in_pairs.append(utxo)
except (KeyError, TypeError):
deferred.append(tx)
continue
tx.in_pairs = in_pairs
# Compute fee
tx_fee = (sum(v for hashX, v in tx.in_pairs) -
sum(v for hashX, v in tx.out_pairs))
fee_rate = tx.fee // tx.size
fee_hist[fee_rate] += tx.size
txs[tx.hash] = tx
for hashX, value in itertools.chain(tx.in_pairs, tx.out_pairs):
touched.add(hashX)
hashXs[hashX].add(tx.hash)
return deferred # Attempt to complete processing of txs
return self._accept_transactions(tx_map, utxo_map, touched)
async def _raw_transactions(self, hashX): async def _raw_transactions(self, hashX):
'''Returns an iterable of (hex_hash, raw_tx) pairs for all '''Returns an iterable of (hex_hash, raw_tx) pairs for all
@ -267,9 +229,10 @@ class MemPool(object):
if hashX not in self.hashXs: if hashX not in self.hashXs:
return [] return []
hex_hashes = self.hashXs[hashX] hashes = self.hashXs[hashX]
hex_hashes = [hash_to_hex_str(hash) for hash in hashes]
raw_txs = await self.daemon.getrawtransactions(hex_hashes) raw_txs = await self.daemon.getrawtransactions(hex_hashes)
return zip(hex_hashes, raw_txs) return zip(hashes, raw_txs)
def _calc_compact_histogram(self): def _calc_compact_histogram(self):
# For efficiency, get_fees returns a compact histogram with # For efficiency, get_fees returns a compact histogram with
@ -300,9 +263,12 @@ class MemPool(object):
''' '''
self.logger.info('beginning processing of daemon mempool. ' self.logger.info('beginning processing of daemon mempool. '
'This can take some time...') 'This can take some time...')
await self._synchronize(True) start = time.time()
await self._refresh_hashes(True)
elapsed = time.time() - start
self.logger.info(f'synced in {elapsed:.2f}s')
self.tasks.create_task(self._log_stats()) self.tasks.create_task(self._log_stats())
self.tasks.create_task(self._synchronize_forever()) self.tasks.create_task(self._refresh_hashes(False))
async def balance_delta(self, hashX): async def balance_delta(self, hashX):
'''Return the unconfirmed amount in the mempool for hashX. '''Return the unconfirmed amount in the mempool for hashX.
@ -312,8 +278,8 @@ class MemPool(object):
value = 0 value = 0
# hashXs is a defaultdict # hashXs is a defaultdict
if hashX in self.hashXs: if hashX in self.hashXs:
for hex_hash in self.hashXs[hashX]: for hash in self.hashXs[hashX]:
tx = self.txs[hex_hash] tx = self.txs[hash]
value -= sum(v for h168, v in tx.in_pairs if h168 == hashX) value -= sum(v for h168, v in tx.in_pairs if h168 == hashX)
value += sum(v for h168, v in tx.out_pairs if h168 == hashX) value += sum(v for h168, v in tx.out_pairs if h168 == hashX)
return value return value
@ -335,7 +301,7 @@ class MemPool(object):
deserializer = self.coin.DESERIALIZER deserializer = self.coin.DESERIALIZER
pairs = await self._raw_transactions(hashX) pairs = await self._raw_transactions(hashX)
result = set() result = set()
for hex_hash, raw_tx in pairs: for hash, raw_tx in pairs:
if not raw_tx: if not raw_tx:
continue continue
tx = deserializer(raw_tx).read_tx() tx = deserializer(raw_tx).read_tx()
@ -344,7 +310,7 @@ class MemPool(object):
return result return result
async def transaction_summaries(self, hashX): async def transaction_summaries(self, hashX):
'''Return a list of (tx_hex_hash, tx_fee, unconfirmed) tuples for '''Return a list of (tx_hash, tx_fee, unconfirmed) tuples for
mempool entries for the hashX. mempool entries for the hashX.
unconfirmed is True if any txin is unconfirmed. unconfirmed is True if any txin is unconfirmed.
@ -352,14 +318,15 @@ class MemPool(object):
deserializer = self.coin.DESERIALIZER deserializer = self.coin.DESERIALIZER
pairs = await self._raw_transactions(hashX) pairs = await self._raw_transactions(hashX)
result = [] result = []
for hex_hash, raw_tx in pairs: for tx_hash, raw_tx in pairs:
mempool_tx = self.txs.get(hex_hash) mempool_tx = self.txs.get(tx_hash)
if not mempool_tx or not raw_tx: if not mempool_tx or not raw_tx:
continue continue
tx = deserializer(raw_tx).read_tx() tx = deserializer(raw_tx).read_tx()
unconfirmed = any(hash_to_hex_str(txin.prev_hash) in self.txs # FIXME: use all_hashes not self.txs
unconfirmed = any(txin.prev_hash in self.txs
for txin in tx.inputs) for txin in tx.inputs)
result.append((hex_hash, mempool_tx.fee, unconfirmed)) result.append((tx_hash, mempool_tx.fee, unconfirmed))
return result return result
async def unordered_UTXOs(self, hashX): async def unordered_UTXOs(self, hashX):
@ -371,13 +338,11 @@ class MemPool(object):
''' '''
utxos = [] utxos = []
# hashXs is a defaultdict, so use get() to query # hashXs is a defaultdict, so use get() to query
for hex_hash in self.hashXs.get(hashX, []): for tx_hash in self.hashXs.get(hashX, []):
tx = self.txs.get(hex_hash) tx = self.txs.get(tx_hash)
if not tx: if not tx:
continue continue
for pos, (hX, value) in enumerate(tx.out_pairs): for pos, (hX, value) in enumerate(tx.out_pairs):
if hX == hashX: if hX == hashX:
# Unfortunately UTXO holds a binary hash utxos.append(UTXO(-1, pos, tx_hash, 0, value))
utxos.append(UTXO(-1, pos, hex_str_to_hash(hex_hash),
0, value))
return utxos return utxos

6
electrumx/server/session.py

@ -733,7 +733,8 @@ class ElectrumX(SessionBase):
status = ''.join('{}:{:d}:'.format(hash_to_hex_str(tx_hash), height) status = ''.join('{}:{:d}:'.format(hash_to_hex_str(tx_hash), height)
for tx_hash, height in history) for tx_hash, height in history)
status += ''.join('{}:{:d}:'.format(hex_hash, -unconfirmed) status += ''.join('{}:{:d}:'.format(hash_to_hex_str(hex_hash),
-unconfirmed)
for hex_hash, tx_fee, unconfirmed in mempool) for hex_hash, tx_fee, unconfirmed in mempool)
if status: if status:
status = sha256(status.encode()).hex() status = sha256(status.encode()).hex()
@ -821,7 +822,8 @@ class ElectrumX(SessionBase):
# Note unconfirmed history is unordered in electrum-server # Note unconfirmed history is unordered in electrum-server
# Height is -1 if unconfirmed txins, otherwise 0 # Height is -1 if unconfirmed txins, otherwise 0
mempool = await self.mempool.transaction_summaries(hashX) mempool = await self.mempool.transaction_summaries(hashX)
return [{'tx_hash': tx_hash, 'height': -unconfirmed, 'fee': fee} return [{'tx_hash': hash_to_hex_str(tx_hash), 'height': -unconfirmed,
'fee': fee}
for tx_hash, fee, unconfirmed in mempool] for tx_hash, fee, unconfirmed in mempool]
async def confirmed_and_unconfirmed_history(self, hashX): async def confirmed_and_unconfirmed_history(self, hashX):

Loading…
Cancel
Save