Browse Source

Merge pull request #7754 from SomberNight/202204_wallet_uptodate2

wallet: "up_to_date" to wait for SPV/Verifier
patch-4
ThomasV 3 years ago
committed by GitHub
parent
commit
f0a806ccf8
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 75
      electrum/address_synchronizer.py
  2. 4
      electrum/gui/stdio.py
  3. 6
      electrum/gui/text.py
  4. 21
      electrum/synchronizer.py
  5. 8
      electrum/util.py
  6. 7
      electrum/verifier.py
  7. 12
      electrum/wallet.py

75
electrum/address_synchronizer.py

@ -85,10 +85,12 @@ class AddressSynchronizer(Logger):
self.lock = threading.RLock() self.lock = threading.RLock()
self.transaction_lock = threading.RLock() self.transaction_lock = threading.RLock()
self.future_tx = {} # type: Dict[str, int] # txid -> wanted height self.future_tx = {} # type: Dict[str, int] # txid -> wanted height
# Transactions pending verification. txid -> tx_height. Access with self.lock. # Txs the server claims are mined but still pending verification:
self.unverified_tx = defaultdict(int) self.unverified_tx = defaultdict(int) # type: Dict[str, int] # txid -> height. Access with self.lock.
# Txs the server claims are in the mempool:
self.unconfirmed_tx = defaultdict(int) # type: Dict[str, int] # txid -> height. Access with self.lock.
# true when synchronized # true when synchronized
self._up_to_date = False self._up_to_date = False # considers both Synchronizer and Verifier
# thread local storage for caching stuff # thread local storage for caching stuff
self.threadlocal_cache = threading.local() self.threadlocal_cache = threading.local()
@ -176,7 +178,7 @@ class AddressSynchronizer(Logger):
hist = self.db.get_addr_history(addr) hist = self.db.get_addr_history(addr)
for tx_hash, tx_height in hist: for tx_hash, tx_height in hist:
# add it in case it was previously unconfirmed # add it in case it was previously unconfirmed
self.add_unverified_tx(tx_hash, tx_height) self.add_unverified_or_unconfirmed_tx(tx_hash, tx_height)
def start_network(self, network: Optional['Network']) -> None: def start_network(self, network: Optional['Network']) -> None:
self.network = network self.network = network
@ -379,6 +381,7 @@ class AddressSynchronizer(Logger):
self.db.remove_tx_fee(tx_hash) self.db.remove_tx_fee(tx_hash)
self.db.remove_verified_tx(tx_hash) self.db.remove_verified_tx(tx_hash)
self.unverified_tx.pop(tx_hash, None) self.unverified_tx.pop(tx_hash, None)
self.unconfirmed_tx.pop(tx_hash, None)
if tx: if tx:
for idx, txo in enumerate(tx.outputs()): for idx, txo in enumerate(tx.outputs()):
scripthash = bitcoin.script_to_scripthash(txo.scriptpubkey.hex()) scripthash = bitcoin.script_to_scripthash(txo.scriptpubkey.hex())
@ -396,7 +399,7 @@ class AddressSynchronizer(Logger):
return children return children
def receive_tx_callback(self, tx_hash: str, tx: Transaction, tx_height: int) -> None: def receive_tx_callback(self, tx_hash: str, tx: Transaction, tx_height: int) -> None:
self.add_unverified_tx(tx_hash, tx_height) self.add_unverified_or_unconfirmed_tx(tx_hash, tx_height)
self.add_transaction(tx, allow_unrelated=True) self.add_transaction(tx, allow_unrelated=True)
def receive_history_callback(self, addr: str, hist, tx_fees: Dict[str, int]): def receive_history_callback(self, addr: str, hist, tx_fees: Dict[str, int]):
@ -406,6 +409,7 @@ class AddressSynchronizer(Logger):
if (tx_hash, height) not in hist: if (tx_hash, height) not in hist:
# make tx local # make tx local
self.unverified_tx.pop(tx_hash, None) self.unverified_tx.pop(tx_hash, None)
self.unconfirmed_tx.pop(tx_hash, None)
self.db.remove_verified_tx(tx_hash) self.db.remove_verified_tx(tx_hash)
if self.verifier: if self.verifier:
self.verifier.remove_spv_proof_for_tx(tx_hash) self.verifier.remove_spv_proof_for_tx(tx_hash)
@ -413,7 +417,7 @@ class AddressSynchronizer(Logger):
for tx_hash, tx_height in hist: for tx_hash, tx_height in hist:
# add it in case it was previously unconfirmed # add it in case it was previously unconfirmed
self.add_unverified_tx(tx_hash, tx_height) self.add_unverified_or_unconfirmed_tx(tx_hash, tx_height)
# if addr is new, we have to recompute txi and txo # if addr is new, we have to recompute txi and txo
tx = self.db.get_transaction(tx_hash) tx = self.db.get_transaction(tx_hash)
if tx is None: if tx is None:
@ -459,17 +463,26 @@ class AddressSynchronizer(Logger):
self._history_local.clear() self._history_local.clear()
self._get_addr_balance_cache = {} # invalidate cache self._get_addr_balance_cache = {} # invalidate cache
def get_txpos(self, tx_hash): def get_txpos(self, tx_hash: str) -> Tuple[int, int]:
"""Returns (height, txpos) tuple, even if the tx is unverified.""" """Returns (height, txpos) tuple, even if the tx is unverified."""
with self.lock: with self.lock:
verified_tx_mined_info = self.db.get_verified_tx(tx_hash) verified_tx_mined_info = self.db.get_verified_tx(tx_hash)
if verified_tx_mined_info: if verified_tx_mined_info:
return verified_tx_mined_info.height, verified_tx_mined_info.txpos height = verified_tx_mined_info.height
txpos = verified_tx_mined_info.txpos
assert height > 0, height
assert txpos is not None
return height, txpos
elif tx_hash in self.unverified_tx: elif tx_hash in self.unverified_tx:
height = self.unverified_tx[tx_hash] height = self.unverified_tx[tx_hash]
return (height, -1) if height > 0 else ((1e9 - height), -1) assert height > 0, height
return height, -1
elif tx_hash in self.unconfirmed_tx:
height = self.unconfirmed_tx[tx_hash]
assert height <= 0, height
return (10**9 - height), -1
else: else:
return (1e9+1, -1) return (10**9 + 1), -1
def with_local_height_cached(func): def with_local_height_cached(func):
# get local height only once, as it's relatively expensive. # get local height only once, as it's relatively expensive.
@ -558,17 +571,21 @@ class AddressSynchronizer(Logger):
assert self.is_mine(addr), "address needs to be is_mine to be watched" assert self.is_mine(addr), "address needs to be is_mine to be watched"
await self._address_history_changed_events[addr].wait() await self._address_history_changed_events[addr].wait()
def add_unverified_tx(self, tx_hash, tx_height): def add_unverified_or_unconfirmed_tx(self, tx_hash, tx_height):
if self.db.is_in_verified_tx(tx_hash): if self.db.is_in_verified_tx(tx_hash):
if tx_height in (TX_HEIGHT_UNCONFIRMED, TX_HEIGHT_UNCONF_PARENT): if tx_height <= 0:
# tx was previously SPV-verified but now in mempool (probably reorg)
with self.lock: with self.lock:
self.db.remove_verified_tx(tx_hash) self.db.remove_verified_tx(tx_hash)
self.unconfirmed_tx[tx_hash] = tx_height
if self.verifier: if self.verifier:
self.verifier.remove_spv_proof_for_tx(tx_hash) self.verifier.remove_spv_proof_for_tx(tx_hash)
else: else:
with self.lock: with self.lock:
# tx will be verified only if height > 0 if tx_height > 0:
self.unverified_tx[tx_hash] = tx_height self.unverified_tx[tx_hash] = tx_height
else:
self.unconfirmed_tx[tx_hash] = tx_height
def remove_unverified_tx(self, tx_hash, tx_height): def remove_unverified_tx(self, tx_hash, tx_height):
with self.lock: with self.lock:
@ -584,7 +601,7 @@ class AddressSynchronizer(Logger):
tx_mined_status = self.get_tx_height(tx_hash) tx_mined_status = self.get_tx_height(tx_hash)
util.trigger_callback('verified', self, tx_hash, tx_mined_status) util.trigger_callback('verified', self, tx_hash, tx_mined_status)
def get_unverified_txs(self): def get_unverified_txs(self) -> Dict[str, int]:
'''Returns a map from tx hash to transaction height''' '''Returns a map from tx hash to transaction height'''
with self.lock: with self.lock:
return dict(self.unverified_tx) # copy return dict(self.unverified_tx) # copy
@ -638,6 +655,9 @@ class AddressSynchronizer(Logger):
elif tx_hash in self.unverified_tx: elif tx_hash in self.unverified_tx:
height = self.unverified_tx[tx_hash] height = self.unverified_tx[tx_hash]
return TxMinedInfo(height=height, conf=0) return TxMinedInfo(height=height, conf=0)
elif tx_hash in self.unconfirmed_tx:
height = self.unconfirmed_tx[tx_hash]
return TxMinedInfo(height=height, conf=0)
elif tx_hash in self.future_tx: elif tx_hash in self.future_tx:
num_blocks_remainining = self.future_tx[tx_hash] - self.get_local_height() num_blocks_remainining = self.future_tx[tx_hash] - self.get_local_height()
if num_blocks_remainining > 0: if num_blocks_remainining > 0:
@ -652,8 +672,14 @@ class AddressSynchronizer(Logger):
with self.lock: with self.lock:
status_changed = self._up_to_date != up_to_date status_changed = self._up_to_date != up_to_date
self._up_to_date = up_to_date self._up_to_date = up_to_date
if self.network: # reset sync state progress indicator
self.network.notify('status') if up_to_date:
if self.synchronizer:
self.synchronizer.reset_request_counters()
if self.verifier:
self.verifier.reset_request_counters()
# fire triggers
util.trigger_callback('status')
if status_changed: if status_changed:
self.logger.info(f'set_up_to_date: {up_to_date}') self.logger.info(f'set_up_to_date: {up_to_date}')
@ -661,10 +687,16 @@ class AddressSynchronizer(Logger):
return self._up_to_date return self._up_to_date
def get_history_sync_state_details(self) -> Tuple[int, int]: def get_history_sync_state_details(self) -> Tuple[int, int]:
nsent, nans = 0, 0
if self.synchronizer: if self.synchronizer:
return self.synchronizer.num_requests_sent_and_answered() n1, n2 = self.synchronizer.num_requests_sent_and_answered()
else: nsent += n1
return 0, 0 nans += n2
if self.verifier:
n1, n2 = self.verifier.num_requests_sent_and_answered()
nsent += n1
nans += n2
return nsent, nans
@with_transaction_lock @with_transaction_lock
def get_tx_delta(self, tx_hash: str, address: str) -> int: def get_tx_delta(self, tx_hash: str, address: str) -> int:
@ -902,5 +934,6 @@ class AddressSynchronizer(Logger):
c, u, x = self.get_addr_balance(address) c, u, x = self.get_addr_balance(address)
return c+u+x == 0 return c+u+x == 0
def synchronize(self): def synchronize(self) -> int:
pass """Returns the number of new addresses we generated."""
return 0

4
electrum/gui/stdio.py

@ -2,10 +2,12 @@ from decimal import Decimal
import getpass import getpass
import datetime import datetime
import logging import logging
from typing import Optional
from electrum.gui import BaseElectrumGui from electrum.gui import BaseElectrumGui
from electrum import util from electrum import util
from electrum import WalletStorage, Wallet from electrum import WalletStorage, Wallet
from electrum.wallet import Abstract_Wallet
from electrum.wallet_db import WalletDB from electrum.wallet_db import WalletDB
from electrum.util import format_satoshis from electrum.util import format_satoshis
from electrum.bitcoin import is_address, COIN from electrum.bitcoin import is_address, COIN
@ -41,7 +43,7 @@ class ElectrumGui(BaseElectrumGui):
self.str_amount = "" self.str_amount = ""
self.str_fee = "" self.str_fee = ""
self.wallet = Wallet(db, storage, config=config) self.wallet = Wallet(db, storage, config=config) # type: Optional[Abstract_Wallet]
self.wallet.start_network(self.network) self.wallet.start_network(self.network)
self.contacts = self.wallet.contacts self.contacts = self.wallet.contacts

6
electrum/gui/text.py

@ -6,7 +6,7 @@ import locale
from decimal import Decimal from decimal import Decimal
import getpass import getpass
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
import electrum import electrum
from electrum.gui import BaseElectrumGui from electrum.gui import BaseElectrumGui
@ -14,7 +14,7 @@ from electrum import util
from electrum.util import format_satoshis from electrum.util import format_satoshis
from electrum.bitcoin import is_address, COIN from electrum.bitcoin import is_address, COIN
from electrum.transaction import PartialTxOutput from electrum.transaction import PartialTxOutput
from electrum.wallet import Wallet from electrum.wallet import Wallet, Abstract_Wallet
from electrum.wallet_db import WalletDB from electrum.wallet_db import WalletDB
from electrum.storage import WalletStorage from electrum.storage import WalletStorage
from electrum.network import NetworkParameters, TxBroadcastError, BestEffortRequestFailed from electrum.network import NetworkParameters, TxBroadcastError, BestEffortRequestFailed
@ -42,7 +42,7 @@ class ElectrumGui(BaseElectrumGui):
password = getpass.getpass('Password:', stream=None) password = getpass.getpass('Password:', stream=None)
storage.decrypt(password) storage.decrypt(password)
db = WalletDB(storage.read(), manual_upgrades=False) db = WalletDB(storage.read(), manual_upgrades=False)
self.wallet = Wallet(db, storage, config=config) self.wallet = Wallet(db, storage, config=config) # type: Optional[Abstract_Wallet]
self.wallet.start_network(self.network) self.wallet.start_network(self.network)
self.contacts = self.wallet.contacts self.contacts = self.wallet.contacts

21
electrum/synchronizer.py

@ -60,7 +60,6 @@ class SynchronizerBase(NetworkJobOnDefaultServer):
""" """
def __init__(self, network: 'Network'): def __init__(self, network: 'Network'):
self.asyncio_loop = network.asyncio_loop self.asyncio_loop = network.asyncio_loop
self._reset_request_counters()
NetworkJobOnDefaultServer.__init__(self, network) NetworkJobOnDefaultServer.__init__(self, network)
@ -69,7 +68,6 @@ class SynchronizerBase(NetworkJobOnDefaultServer):
self.requested_addrs = set() self.requested_addrs = set()
self.scripthash_to_address = {} self.scripthash_to_address = {}
self._processed_some_notifications = False # so that we don't miss them self._processed_some_notifications = False # so that we don't miss them
self._reset_request_counters()
# Queues # Queues
self.add_queue = asyncio.Queue() self.add_queue = asyncio.Queue()
self.status_queue = asyncio.Queue() self.status_queue = asyncio.Queue()
@ -85,10 +83,6 @@ class SynchronizerBase(NetworkJobOnDefaultServer):
# we are being cancelled now # we are being cancelled now
self.session.unsubscribe(self.status_queue) self.session.unsubscribe(self.status_queue)
def _reset_request_counters(self):
self._requests_sent = 0
self._requests_answered = 0
def add(self, addr): def add(self, addr):
asyncio.run_coroutine_threadsafe(self._add_address(addr), self.asyncio_loop) asyncio.run_coroutine_threadsafe(self._add_address(addr), self.asyncio_loop)
@ -129,9 +123,6 @@ class SynchronizerBase(NetworkJobOnDefaultServer):
await self.taskgroup.spawn(self._on_address_status, addr, status) await self.taskgroup.spawn(self._on_address_status, addr, status)
self._processed_some_notifications = True self._processed_some_notifications = True
def num_requests_sent_and_answered(self) -> Tuple[int, int]:
return self._requests_sent, self._requests_answered
async def main(self): async def main(self):
raise NotImplementedError() # implemented by subclasses raise NotImplementedError() # implemented by subclasses
@ -260,13 +251,17 @@ class Synchronizer(SynchronizerBase):
# main loop # main loop
while True: while True:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
await run_in_thread(self.wallet.synchronize) # note: we only generate new HD addresses if the existing ones
up_to_date = self.is_up_to_date() # have history that are mined and SPV-verified. This inherently couples
# the Sychronizer and the Verifier.
hist_done = self.is_up_to_date()
spv_done = self.wallet.verifier.is_up_to_date() if self.wallet.verifier else True
num_new_addrs = await run_in_thread(self.wallet.synchronize)
up_to_date = hist_done and spv_done and num_new_addrs == 0
# see if status changed
if (up_to_date != self.wallet.is_up_to_date() if (up_to_date != self.wallet.is_up_to_date()
or up_to_date and self._processed_some_notifications): or up_to_date and self._processed_some_notifications):
self._processed_some_notifications = False self._processed_some_notifications = False
if up_to_date:
self._reset_request_counters()
self.wallet.set_up_to_date(up_to_date) self.wallet.set_up_to_date(up_to_date)
util.trigger_callback('wallet_updated', self.wallet) util.trigger_callback('wallet_updated', self.wallet)

8
electrum/util.py

@ -1326,6 +1326,7 @@ class NetworkJobOnDefaultServer(Logger, ABC):
server connection changes. server connection changes.
""" """
self.taskgroup = OldTaskGroup() self.taskgroup = OldTaskGroup()
self.reset_request_counters()
async def _start(self, interface: 'Interface'): async def _start(self, interface: 'Interface'):
self.interface = interface self.interface = interface
@ -1357,6 +1358,13 @@ class NetworkJobOnDefaultServer(Logger, ABC):
self._reset() self._reset()
await self._start(interface) await self._start(interface)
def reset_request_counters(self):
self._requests_sent = 0
self._requests_answered = 0
def num_requests_sent_and_answered(self) -> Tuple[int, int]:
return self._requests_sent, self._requests_answered
@property @property
def session(self): def session(self):
s = self.interface.session s = self.interface.session

7
electrum/verifier.py

@ -87,6 +87,7 @@ class SPV(NetworkJobOnDefaultServer):
header = self.blockchain.read_header(tx_height) header = self.blockchain.read_header(tx_height)
if header is None: if header is None:
if tx_height < constants.net.max_checkpoint(): if tx_height < constants.net.max_checkpoint():
# FIXME these requests are not counted (self._requests_sent += 1)
await self.taskgroup.spawn(self.interface.request_chunk(tx_height, None, can_return_early=True)) await self.taskgroup.spawn(self.interface.request_chunk(tx_height, None, can_return_early=True))
continue continue
# request now # request now
@ -96,6 +97,7 @@ class SPV(NetworkJobOnDefaultServer):
async def _request_and_verify_single_proof(self, tx_hash, tx_height): async def _request_and_verify_single_proof(self, tx_hash, tx_height):
try: try:
self._requests_sent += 1
async with self._network_request_semaphore: async with self._network_request_semaphore:
merkle = await self.interface.get_merkle_for_transaction(tx_hash, tx_height) merkle = await self.interface.get_merkle_for_transaction(tx_hash, tx_height)
except aiorpcx.jsonrpc.RPCError: except aiorpcx.jsonrpc.RPCError:
@ -103,6 +105,8 @@ class SPV(NetworkJobOnDefaultServer):
self.wallet.remove_unverified_tx(tx_hash, tx_height) self.wallet.remove_unverified_tx(tx_hash, tx_height)
self.requested_merkle.discard(tx_hash) self.requested_merkle.discard(tx_hash)
return return
finally:
self._requests_answered += 1
# Verify the hash of the server-provided merkle branch to a # Verify the hash of the server-provided merkle branch to a
# transaction matches the merkle root of its block # transaction matches the merkle root of its block
if tx_height != merkle.get('block_height'): if tx_height != merkle.get('block_height'):
@ -187,7 +191,8 @@ class SPV(NetworkJobOnDefaultServer):
self.requested_merkle.discard(tx_hash) self.requested_merkle.discard(tx_hash)
def is_up_to_date(self): def is_up_to_date(self):
return not self.requested_merkle return (not self.requested_merkle
and not self.wallet.unverified_tx)
def verify_tx_is_in_block(tx_hash: str, merkle_branch: Sequence[str], def verify_tx_is_in_block(tx_hash: str, merkle_branch: Sequence[str],

12
electrum/wallet.py

@ -3043,11 +3043,13 @@ class Deterministic_Wallet(Abstract_Wallet):
self._not_old_change_addresses.append(address) self._not_old_change_addresses.append(address)
return address return address
def synchronize_sequence(self, for_change): def synchronize_sequence(self, for_change: bool) -> int:
count = 0 # num new addresses we generated
limit = self.gap_limit_for_change if for_change else self.gap_limit limit = self.gap_limit_for_change if for_change else self.gap_limit
while True: while True:
num_addr = self.db.num_change_addresses() if for_change else self.db.num_receiving_addresses() num_addr = self.db.num_change_addresses() if for_change else self.db.num_receiving_addresses()
if num_addr < limit: if num_addr < limit:
count += 1
self.create_new_address(for_change) self.create_new_address(for_change)
continue continue
if for_change: if for_change:
@ -3055,15 +3057,19 @@ class Deterministic_Wallet(Abstract_Wallet):
else: else:
last_few_addresses = self.get_receiving_addresses(slice_start=-limit) last_few_addresses = self.get_receiving_addresses(slice_start=-limit)
if any(map(self.address_is_old, last_few_addresses)): if any(map(self.address_is_old, last_few_addresses)):
count += 1
self.create_new_address(for_change) self.create_new_address(for_change)
else: else:
break break
return count
@AddressSynchronizer.with_local_height_cached @AddressSynchronizer.with_local_height_cached
def synchronize(self): def synchronize(self):
count = 0
with self.lock: with self.lock:
self.synchronize_sequence(False) count += self.synchronize_sequence(False)
self.synchronize_sequence(True) count += self.synchronize_sequence(True)
return count
def get_all_known_addresses_beyond_gap_limit(self): def get_all_known_addresses_beyond_gap_limit(self):
# note that we don't stop at first large gap # note that we don't stop at first large gap

Loading…
Cancel
Save