Browse Source

WalletDB: add type hints, and also corresponding asserts for sanity

hard-fail-on-bad-server-string
SomberNight 5 years ago
parent
commit
88658f9c2c
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 2
      electrum/wallet.py
  2. 144
      electrum/wallet_db.py

2
electrum/wallet.py

@ -1983,7 +1983,7 @@ class Imported_Wallet(Simple_Wallet):
else: else:
raise BitcoinException(str(bad_addr[0][1])) raise BitcoinException(str(bad_addr[0][1]))
def delete_address(self, address): def delete_address(self, address: str):
if not self.db.has_imported_address(address): if not self.db.has_imported_address(address):
return return
transactions_to_remove = set() # only referred to by this address transactions_to_remove = set() # only referred to by this address

144
electrum/wallet_db.py

@ -28,7 +28,7 @@ import json
import copy import copy
import threading import threading
from collections import defaultdict from collections import defaultdict
from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence, TYPE_CHECKING from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence, TYPE_CHECKING, Union
import binascii import binascii
from . import util, bitcoin from . import util, bitcoin
@ -666,29 +666,39 @@ class WalletDB(JsonDB):
raise WalletFileException(msg) raise WalletFileException(msg)
@locked @locked
def get_txi_addresses(self, tx_hash) -> List[str]: def get_txi_addresses(self, tx_hash: str) -> List[str]:
"""Returns list of is_mine addresses that appear as inputs in tx.""" """Returns list of is_mine addresses that appear as inputs in tx."""
assert isinstance(tx_hash, str)
return list(self.txi.get(tx_hash, {}).keys()) return list(self.txi.get(tx_hash, {}).keys())
@locked @locked
def get_txo_addresses(self, tx_hash) -> List[str]: def get_txo_addresses(self, tx_hash: str) -> List[str]:
"""Returns list of is_mine addresses that appear as outputs in tx.""" """Returns list of is_mine addresses that appear as outputs in tx."""
assert isinstance(tx_hash, str)
return list(self.txo.get(tx_hash, {}).keys()) return list(self.txo.get(tx_hash, {}).keys())
@locked @locked
def get_txi_addr(self, tx_hash, address) -> Iterable[Tuple[str, int]]: def get_txi_addr(self, tx_hash: str, address: str) -> Iterable[Tuple[str, int]]:
"""Returns an iterable of (prev_outpoint, value).""" """Returns an iterable of (prev_outpoint, value)."""
assert isinstance(tx_hash, str)
assert isinstance(address, str)
d = self.txi.get(tx_hash, {}).get(address, {}) d = self.txi.get(tx_hash, {}).get(address, {})
return list(d.items()) return list(d.items())
@locked @locked
def get_txo_addr(self, tx_hash, address) -> Iterable[Tuple[int, int, bool]]: def get_txo_addr(self, tx_hash: str, address: str) -> Iterable[Tuple[int, int, bool]]:
"""Returns an iterable of (output_index, value, is_coinbase).""" """Returns an iterable of (output_index, value, is_coinbase)."""
assert isinstance(tx_hash, str)
assert isinstance(address, str)
d = self.txo.get(tx_hash, {}).get(address, {}) d = self.txo.get(tx_hash, {}).get(address, {})
return [(int(n), v, cb) for (n, (v, cb)) in d.items()] return [(int(n), v, cb) for (n, (v, cb)) in d.items()]
@modifier @modifier
def add_txi_addr(self, tx_hash, addr, ser, v): def add_txi_addr(self, tx_hash: str, addr: str, ser: str, v: int) -> None:
assert isinstance(tx_hash, str)
assert isinstance(addr, str)
assert isinstance(ser, str)
assert isinstance(v, int)
if tx_hash not in self.txi: if tx_hash not in self.txi:
self.txi[tx_hash] = {} self.txi[tx_hash] = {}
d = self.txi[tx_hash] d = self.txi[tx_hash]
@ -697,7 +707,13 @@ class WalletDB(JsonDB):
d[addr][ser] = v d[addr][ser] = v
@modifier @modifier
def add_txo_addr(self, tx_hash, addr, n, v, is_coinbase): def add_txo_addr(self, tx_hash: str, addr: str, n: Union[int, str], v: int, is_coinbase: bool) -> None:
n = str(n)
assert isinstance(tx_hash, str)
assert isinstance(addr, str)
assert isinstance(n, str)
assert isinstance(v, int)
assert isinstance(is_coinbase, bool)
if tx_hash not in self.txo: if tx_hash not in self.txo:
self.txo[tx_hash] = {} self.txo[tx_hash] = {}
d = self.txo[tx_hash] d = self.txo[tx_hash]
@ -706,46 +722,53 @@ class WalletDB(JsonDB):
d[addr][n] = (v, is_coinbase) d[addr][n] = (v, is_coinbase)
@locked @locked
def list_txi(self): def list_txi(self) -> Sequence[str]:
return list(self.txi.keys()) return list(self.txi.keys())
@locked @locked
def list_txo(self): def list_txo(self) -> Sequence[str]:
return list(self.txo.keys()) return list(self.txo.keys())
@modifier @modifier
def remove_txi(self, tx_hash): def remove_txi(self, tx_hash: str) -> None:
assert isinstance(tx_hash, str)
self.txi.pop(tx_hash, None) self.txi.pop(tx_hash, None)
@modifier @modifier
def remove_txo(self, tx_hash): def remove_txo(self, tx_hash: str) -> None:
assert isinstance(tx_hash, str)
self.txo.pop(tx_hash, None) self.txo.pop(tx_hash, None)
@locked @locked
def list_spent_outpoints(self): def list_spent_outpoints(self) -> Sequence[Tuple[str, str]]:
return [(h, n) return [(h, n)
for h in self.spent_outpoints.keys() for h in self.spent_outpoints.keys()
for n in self.get_spent_outpoints(h) for n in self.get_spent_outpoints(h)
] ]
@locked @locked
def get_spent_outpoints(self, prevout_hash): def get_spent_outpoints(self, prevout_hash: str) -> Sequence[str]:
assert isinstance(prevout_hash, str)
return list(self.spent_outpoints.get(prevout_hash, {}).keys()) return list(self.spent_outpoints.get(prevout_hash, {}).keys())
@locked @locked
def get_spent_outpoint(self, prevout_hash, prevout_n): def get_spent_outpoint(self, prevout_hash: str, prevout_n: Union[int, str]) -> Optional[str]:
assert isinstance(prevout_hash, str)
prevout_n = str(prevout_n) prevout_n = str(prevout_n)
return self.spent_outpoints.get(prevout_hash, {}).get(prevout_n) return self.spent_outpoints.get(prevout_hash, {}).get(prevout_n)
@modifier @modifier
def remove_spent_outpoint(self, prevout_hash, prevout_n): def remove_spent_outpoint(self, prevout_hash: str, prevout_n: Union[int, str]) -> None:
assert isinstance(prevout_hash, str)
prevout_n = str(prevout_n) prevout_n = str(prevout_n)
self.spent_outpoints[prevout_hash].pop(prevout_n, None) self.spent_outpoints[prevout_hash].pop(prevout_n, None)
if not self.spent_outpoints[prevout_hash]: if not self.spent_outpoints[prevout_hash]:
self.spent_outpoints.pop(prevout_hash) self.spent_outpoints.pop(prevout_hash)
@modifier @modifier
def set_spent_outpoint(self, prevout_hash, prevout_n, tx_hash): def set_spent_outpoint(self, prevout_hash: str, prevout_n: Union[int, str], tx_hash: str) -> None:
assert isinstance(prevout_hash, str)
assert isinstance(tx_hash, str)
prevout_n = str(prevout_n) prevout_n = str(prevout_n)
if prevout_hash not in self.spent_outpoints: if prevout_hash not in self.spent_outpoints:
self.spent_outpoints[prevout_hash] = {} self.spent_outpoints[prevout_hash] = {}
@ -753,25 +776,31 @@ class WalletDB(JsonDB):
@modifier @modifier
def add_prevout_by_scripthash(self, scripthash: str, *, prevout: TxOutpoint, value: int) -> None: def add_prevout_by_scripthash(self, scripthash: str, *, prevout: TxOutpoint, value: int) -> None:
assert isinstance(scripthash, str)
assert isinstance(prevout, TxOutpoint) assert isinstance(prevout, TxOutpoint)
assert isinstance(value, int)
if scripthash not in self._prevouts_by_scripthash: if scripthash not in self._prevouts_by_scripthash:
self._prevouts_by_scripthash[scripthash] = set() self._prevouts_by_scripthash[scripthash] = set()
self._prevouts_by_scripthash[scripthash].add((prevout.to_str(), value)) self._prevouts_by_scripthash[scripthash].add((prevout.to_str(), value))
@modifier @modifier
def remove_prevout_by_scripthash(self, scripthash: str, *, prevout: TxOutpoint, value: int) -> None: def remove_prevout_by_scripthash(self, scripthash: str, *, prevout: TxOutpoint, value: int) -> None:
assert isinstance(scripthash, str)
assert isinstance(prevout, TxOutpoint) assert isinstance(prevout, TxOutpoint)
assert isinstance(value, int)
self._prevouts_by_scripthash[scripthash].discard((prevout.to_str(), value)) self._prevouts_by_scripthash[scripthash].discard((prevout.to_str(), value))
if not self._prevouts_by_scripthash[scripthash]: if not self._prevouts_by_scripthash[scripthash]:
self._prevouts_by_scripthash.pop(scripthash) self._prevouts_by_scripthash.pop(scripthash)
@locked @locked
def get_prevouts_by_scripthash(self, scripthash: str) -> Set[Tuple[TxOutpoint, int]]: def get_prevouts_by_scripthash(self, scripthash: str) -> Set[Tuple[TxOutpoint, int]]:
assert isinstance(scripthash, str)
prevouts_and_values = self._prevouts_by_scripthash.get(scripthash, set()) prevouts_and_values = self._prevouts_by_scripthash.get(scripthash, set())
return {(TxOutpoint.from_str(prevout), value) for prevout, value in prevouts_and_values} return {(TxOutpoint.from_str(prevout), value) for prevout, value in prevouts_and_values}
@modifier @modifier
def add_transaction(self, tx_hash: str, tx: Transaction) -> None: def add_transaction(self, tx_hash: str, tx: Transaction) -> None:
assert isinstance(tx_hash, str)
assert isinstance(tx, Transaction), tx assert isinstance(tx, Transaction), tx
# note that tx might be a PartialTransaction # note that tx might be a PartialTransaction
if not tx_hash: if not tx_hash:
@ -784,43 +813,50 @@ class WalletDB(JsonDB):
self.transactions[tx_hash] = tx self.transactions[tx_hash] = tx
@modifier @modifier
def remove_transaction(self, tx_hash) -> Optional[Transaction]: def remove_transaction(self, tx_hash: str) -> Optional[Transaction]:
assert isinstance(tx_hash, str)
return self.transactions.pop(tx_hash, None) return self.transactions.pop(tx_hash, None)
@locked @locked
def get_transaction(self, tx_hash: str) -> Optional[Transaction]: def get_transaction(self, tx_hash: str) -> Optional[Transaction]:
assert isinstance(tx_hash, str)
return self.transactions.get(tx_hash) return self.transactions.get(tx_hash)
@locked @locked
def list_transactions(self): def list_transactions(self) -> Sequence[str]:
return list(self.transactions.keys()) return list(self.transactions.keys())
@locked @locked
def get_history(self): def get_history(self) -> Sequence[str]:
return list(self.history.keys()) return list(self.history.keys())
def is_addr_in_history(self, addr): def is_addr_in_history(self, addr: str) -> bool:
# does not mean history is non-empty! # does not mean history is non-empty!
assert isinstance(addr, str)
return addr in self.history return addr in self.history
@locked @locked
def get_addr_history(self, addr): def get_addr_history(self, addr: str) -> Sequence[Tuple[str, int]]:
assert isinstance(addr, str)
return self.history.get(addr, []) return self.history.get(addr, [])
@modifier @modifier
def set_addr_history(self, addr, hist): def set_addr_history(self, addr: str, hist) -> None:
assert isinstance(addr, str)
self.history[addr] = hist self.history[addr] = hist
@modifier @modifier
def remove_addr_history(self, addr): def remove_addr_history(self, addr: str) -> None:
assert isinstance(addr, str)
self.history.pop(addr, None) self.history.pop(addr, None)
@locked @locked
def list_verified_tx(self): def list_verified_tx(self) -> Sequence[str]:
return list(self.verified_tx.keys()) return list(self.verified_tx.keys())
@locked @locked
def get_verified_tx(self, txid): def get_verified_tx(self, txid: str) -> Optional[TxMinedInfo]:
assert isinstance(txid, str)
if txid not in self.verified_tx: if txid not in self.verified_tx:
return None return None
height, timestamp, txpos, header_hash = self.verified_tx[txid] height, timestamp, txpos, header_hash = self.verified_tx[txid]
@ -831,18 +867,23 @@ class WalletDB(JsonDB):
header_hash=header_hash) header_hash=header_hash)
@modifier @modifier
def add_verified_tx(self, txid, info): def add_verified_tx(self, txid: str, info: TxMinedInfo):
assert isinstance(txid, str)
assert isinstance(info, TxMinedInfo)
self.verified_tx[txid] = (info.height, info.timestamp, info.txpos, info.header_hash) self.verified_tx[txid] = (info.height, info.timestamp, info.txpos, info.header_hash)
@modifier @modifier
def remove_verified_tx(self, txid): def remove_verified_tx(self, txid: str):
assert isinstance(txid, str)
self.verified_tx.pop(txid, None) self.verified_tx.pop(txid, None)
def is_in_verified_tx(self, txid): def is_in_verified_tx(self, txid: str) -> bool:
assert isinstance(txid, str)
return txid in self.verified_tx return txid in self.verified_tx
@modifier @modifier
def add_tx_fee_from_server(self, txid: str, fee_sat: Optional[int]) -> None: def add_tx_fee_from_server(self, txid: str, fee_sat: Optional[int]) -> None:
assert isinstance(txid, str)
# note: when called with (fee_sat is None), rm currently saved value # note: when called with (fee_sat is None), rm currently saved value
if txid not in self.tx_fees: if txid not in self.tx_fees:
self.tx_fees[txid] = TxFeesValue() self.tx_fees[txid] = TxFeesValue()
@ -853,14 +894,17 @@ class WalletDB(JsonDB):
@modifier @modifier
def add_tx_fee_we_calculated(self, txid: str, fee_sat: Optional[int]) -> None: def add_tx_fee_we_calculated(self, txid: str, fee_sat: Optional[int]) -> None:
assert isinstance(txid, str)
if fee_sat is None: if fee_sat is None:
return return
assert isinstance(fee_sat, int)
if txid not in self.tx_fees: if txid not in self.tx_fees:
self.tx_fees[txid] = TxFeesValue() self.tx_fees[txid] = TxFeesValue()
self.tx_fees[txid] = self.tx_fees[txid]._replace(fee=fee_sat, is_calculated_by_us=True) self.tx_fees[txid] = self.tx_fees[txid]._replace(fee=fee_sat, is_calculated_by_us=True)
@locked @locked
def get_tx_fee(self, txid: str, *, trust_server=False) -> Optional[int]: def get_tx_fee(self, txid: str, *, trust_server: bool = False) -> Optional[int]:
assert isinstance(txid, str)
"""Returns tx_fee.""" """Returns tx_fee."""
tx_fees_value = self.tx_fees.get(txid) tx_fees_value = self.tx_fees.get(txid)
if tx_fees_value is None: if tx_fees_value is None:
@ -871,12 +915,15 @@ class WalletDB(JsonDB):
@modifier @modifier
def add_num_inputs_to_tx(self, txid: str, num_inputs: int) -> None: def add_num_inputs_to_tx(self, txid: str, num_inputs: int) -> None:
assert isinstance(txid, str)
assert isinstance(num_inputs, int)
if txid not in self.tx_fees: if txid not in self.tx_fees:
self.tx_fees[txid] = TxFeesValue() self.tx_fees[txid] = TxFeesValue()
self.tx_fees[txid] = self.tx_fees[txid]._replace(num_inputs=num_inputs) self.tx_fees[txid] = self.tx_fees[txid]._replace(num_inputs=num_inputs)
@locked @locked
def get_num_all_inputs_of_tx(self, txid: str) -> Optional[int]: def get_num_all_inputs_of_tx(self, txid: str) -> Optional[int]:
assert isinstance(txid, str)
tx_fees_value = self.tx_fees.get(txid) tx_fees_value = self.tx_fees.get(txid)
if tx_fees_value is None: if tx_fees_value is None:
return None return None
@ -884,11 +931,13 @@ class WalletDB(JsonDB):
@locked @locked
def get_num_ismine_inputs_of_tx(self, txid: str) -> int: def get_num_ismine_inputs_of_tx(self, txid: str) -> int:
assert isinstance(txid, str)
txins = self.txi.get(txid, {}) txins = self.txi.get(txid, {})
return sum([len(tupls) for addr, tupls in txins.items()]) return sum([len(tupls) for addr, tupls in txins.items()])
@modifier @modifier
def remove_tx_fee(self, txid): def remove_tx_fee(self, txid: str) -> None:
assert isinstance(txid, str)
self.tx_fees.pop(txid, None) self.tx_fees.pop(txid, None)
@locked @locked
@ -900,47 +949,53 @@ class WalletDB(JsonDB):
return self.data[name] return self.data[name]
@locked @locked
def num_change_addresses(self): def num_change_addresses(self) -> int:
return len(self.change_addresses) return len(self.change_addresses)
@locked @locked
def num_receiving_addresses(self): def num_receiving_addresses(self) -> int:
return len(self.receiving_addresses) return len(self.receiving_addresses)
@locked @locked
def get_change_addresses(self, *, slice_start=None, slice_stop=None): def get_change_addresses(self, *, slice_start=None, slice_stop=None) -> List[str]:
# note: slicing makes a shallow copy # note: slicing makes a shallow copy
return self.change_addresses[slice_start:slice_stop] return self.change_addresses[slice_start:slice_stop]
@locked @locked
def get_receiving_addresses(self, *, slice_start=None, slice_stop=None): def get_receiving_addresses(self, *, slice_start=None, slice_stop=None) -> List[str]:
# note: slicing makes a shallow copy # note: slicing makes a shallow copy
return self.receiving_addresses[slice_start:slice_stop] return self.receiving_addresses[slice_start:slice_stop]
@modifier @modifier
def add_change_address(self, addr): def add_change_address(self, addr: str) -> None:
assert isinstance(addr, str)
self._addr_to_addr_index[addr] = (1, len(self.change_addresses)) self._addr_to_addr_index[addr] = (1, len(self.change_addresses))
self.change_addresses.append(addr) self.change_addresses.append(addr)
@modifier @modifier
def add_receiving_address(self, addr): def add_receiving_address(self, addr: str) -> None:
assert isinstance(addr, str)
self._addr_to_addr_index[addr] = (0, len(self.receiving_addresses)) self._addr_to_addr_index[addr] = (0, len(self.receiving_addresses))
self.receiving_addresses.append(addr) self.receiving_addresses.append(addr)
@locked @locked
def get_address_index(self, address) -> Optional[Sequence[int]]: def get_address_index(self, address: str) -> Optional[Sequence[int]]:
assert isinstance(address, str)
return self._addr_to_addr_index.get(address) return self._addr_to_addr_index.get(address)
@modifier @modifier
def add_imported_address(self, addr, d): def add_imported_address(self, addr: str, d: dict) -> None:
assert isinstance(addr, str)
self.imported_addresses[addr] = d self.imported_addresses[addr] = d
@modifier @modifier
def remove_imported_address(self, addr): def remove_imported_address(self, addr: str) -> None:
assert isinstance(addr, str)
self.imported_addresses.pop(addr) self.imported_addresses.pop(addr)
@locked @locked
def has_imported_address(self, addr: str) -> bool: def has_imported_address(self, addr: str) -> bool:
assert isinstance(addr, str)
return addr in self.imported_addresses return addr in self.imported_addresses
@locked @locked
@ -948,7 +1003,8 @@ class WalletDB(JsonDB):
return list(sorted(self.imported_addresses.keys())) return list(sorted(self.imported_addresses.keys()))
@locked @locked
def get_imported_address(self, addr): def get_imported_address(self, addr: str) -> Optional[dict]:
assert isinstance(addr, str)
return self.imported_addresses.get(addr) return self.imported_addresses.get(addr)
def load_addresses(self, wallet_type): def load_addresses(self, wallet_type):
@ -973,10 +1029,10 @@ class WalletDB(JsonDB):
self.data = StoredDict(self.data, self, []) self.data = StoredDict(self.data, self, [])
# references in self.data # references in self.data
# TODO make all these private # TODO make all these private
# txid -> address -> set of (prev_outpoint, value) # txid -> address -> prev_outpoint -> value
self.txi = self.get_dict('txi') # type: Dict[str, Dict[str, Set[Tuple[str, int]]]] self.txi = self.get_dict('txi') # type: Dict[str, Dict[str, Dict[str, int]]]
# txid -> address -> set of (output_index, value, is_coinbase) # txid -> address -> output_index -> (value, is_coinbase)
self.txo = self.get_dict('txo') # type: Dict[str, Dict[str, Set[Tuple[int, int, bool]]]] self.txo = self.get_dict('txo') # type: Dict[str, Dict[str, Dict[str, Tuple[int, bool]]]]
self.transactions = self.get_dict('transactions') # type: Dict[str, Transaction] self.transactions = self.get_dict('transactions') # type: Dict[str, Transaction]
self.spent_outpoints = self.get_dict('spent_outpoints') # txid -> output_index -> next_txid self.spent_outpoints = self.get_dict('spent_outpoints') # txid -> output_index -> next_txid
self.history = self.get_dict('addr_history') # address -> list of (txid, height) self.history = self.get_dict('addr_history') # address -> list of (txid, height)

Loading…
Cancel
Save