From 88658f9c2c68e758272d646b353ed6cd29dd5131 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Mon, 24 Feb 2020 18:26:49 +0100 Subject: [PATCH] WalletDB: add type hints, and also corresponding asserts for sanity --- electrum/wallet.py | 2 +- electrum/wallet_db.py | 144 +++++++++++++++++++++++++++++------------- 2 files changed, 101 insertions(+), 45 deletions(-) diff --git a/electrum/wallet.py b/electrum/wallet.py index 56c5f5d1a..711131c87 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -1983,7 +1983,7 @@ class Imported_Wallet(Simple_Wallet): else: raise BitcoinException(str(bad_addr[0][1])) - def delete_address(self, address): + def delete_address(self, address: str): if not self.db.has_imported_address(address): return transactions_to_remove = set() # only referred to by this address diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index 730651256..b6a9bdfb6 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -28,7 +28,7 @@ import json import copy import threading from collections import defaultdict -from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence, TYPE_CHECKING +from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence, TYPE_CHECKING, Union import binascii from . import util, bitcoin @@ -666,29 +666,39 @@ class WalletDB(JsonDB): raise WalletFileException(msg) @locked - def get_txi_addresses(self, tx_hash) -> List[str]: + def get_txi_addresses(self, tx_hash: str) -> List[str]: """Returns list of is_mine addresses that appear as inputs in tx.""" + assert isinstance(tx_hash, str) return list(self.txi.get(tx_hash, {}).keys()) @locked - def get_txo_addresses(self, tx_hash) -> List[str]: + def get_txo_addresses(self, tx_hash: str) -> List[str]: """Returns list of is_mine addresses that appear as outputs in tx.""" + assert isinstance(tx_hash, str) return list(self.txo.get(tx_hash, {}).keys()) @locked - def get_txi_addr(self, tx_hash, address) -> Iterable[Tuple[str, int]]: + def get_txi_addr(self, tx_hash: str, address: str) -> Iterable[Tuple[str, int]]: """Returns an iterable of (prev_outpoint, value).""" + assert isinstance(tx_hash, str) + assert isinstance(address, str) d = self.txi.get(tx_hash, {}).get(address, {}) return list(d.items()) @locked - def get_txo_addr(self, tx_hash, address) -> Iterable[Tuple[int, int, bool]]: + def get_txo_addr(self, tx_hash: str, address: str) -> Iterable[Tuple[int, int, bool]]: """Returns an iterable of (output_index, value, is_coinbase).""" + assert isinstance(tx_hash, str) + assert isinstance(address, str) d = self.txo.get(tx_hash, {}).get(address, {}) return [(int(n), v, cb) for (n, (v, cb)) in d.items()] @modifier - def add_txi_addr(self, tx_hash, addr, ser, v): + def add_txi_addr(self, tx_hash: str, addr: str, ser: str, v: int) -> None: + assert isinstance(tx_hash, str) + assert isinstance(addr, str) + assert isinstance(ser, str) + assert isinstance(v, int) if tx_hash not in self.txi: self.txi[tx_hash] = {} d = self.txi[tx_hash] @@ -697,7 +707,13 @@ class WalletDB(JsonDB): d[addr][ser] = v @modifier - def add_txo_addr(self, tx_hash, addr, n, v, is_coinbase): + def add_txo_addr(self, tx_hash: str, addr: str, n: Union[int, str], v: int, is_coinbase: bool) -> None: + n = str(n) + assert isinstance(tx_hash, str) + assert isinstance(addr, str) + assert isinstance(n, str) + assert isinstance(v, int) + assert isinstance(is_coinbase, bool) if tx_hash not in self.txo: self.txo[tx_hash] = {} d = self.txo[tx_hash] @@ -706,46 +722,53 @@ class WalletDB(JsonDB): d[addr][n] = (v, is_coinbase) @locked - def list_txi(self): + def list_txi(self) -> Sequence[str]: return list(self.txi.keys()) @locked - def list_txo(self): + def list_txo(self) -> Sequence[str]: return list(self.txo.keys()) @modifier - def remove_txi(self, tx_hash): + def remove_txi(self, tx_hash: str) -> None: + assert isinstance(tx_hash, str) self.txi.pop(tx_hash, None) @modifier - def remove_txo(self, tx_hash): + def remove_txo(self, tx_hash: str) -> None: + assert isinstance(tx_hash, str) self.txo.pop(tx_hash, None) @locked - def list_spent_outpoints(self): + def list_spent_outpoints(self) -> Sequence[Tuple[str, str]]: return [(h, n) for h in self.spent_outpoints.keys() for n in self.get_spent_outpoints(h) ] @locked - def get_spent_outpoints(self, prevout_hash): + def get_spent_outpoints(self, prevout_hash: str) -> Sequence[str]: + assert isinstance(prevout_hash, str) return list(self.spent_outpoints.get(prevout_hash, {}).keys()) @locked - def get_spent_outpoint(self, prevout_hash, prevout_n): + def get_spent_outpoint(self, prevout_hash: str, prevout_n: Union[int, str]) -> Optional[str]: + assert isinstance(prevout_hash, str) prevout_n = str(prevout_n) return self.spent_outpoints.get(prevout_hash, {}).get(prevout_n) @modifier - def remove_spent_outpoint(self, prevout_hash, prevout_n): + def remove_spent_outpoint(self, prevout_hash: str, prevout_n: Union[int, str]) -> None: + assert isinstance(prevout_hash, str) prevout_n = str(prevout_n) self.spent_outpoints[prevout_hash].pop(prevout_n, None) if not self.spent_outpoints[prevout_hash]: self.spent_outpoints.pop(prevout_hash) @modifier - def set_spent_outpoint(self, prevout_hash, prevout_n, tx_hash): + def set_spent_outpoint(self, prevout_hash: str, prevout_n: Union[int, str], tx_hash: str) -> None: + assert isinstance(prevout_hash, str) + assert isinstance(tx_hash, str) prevout_n = str(prevout_n) if prevout_hash not in self.spent_outpoints: self.spent_outpoints[prevout_hash] = {} @@ -753,25 +776,31 @@ class WalletDB(JsonDB): @modifier def add_prevout_by_scripthash(self, scripthash: str, *, prevout: TxOutpoint, value: int) -> None: + assert isinstance(scripthash, str) assert isinstance(prevout, TxOutpoint) + assert isinstance(value, int) if scripthash not in self._prevouts_by_scripthash: self._prevouts_by_scripthash[scripthash] = set() self._prevouts_by_scripthash[scripthash].add((prevout.to_str(), value)) @modifier def remove_prevout_by_scripthash(self, scripthash: str, *, prevout: TxOutpoint, value: int) -> None: + assert isinstance(scripthash, str) assert isinstance(prevout, TxOutpoint) + assert isinstance(value, int) self._prevouts_by_scripthash[scripthash].discard((prevout.to_str(), value)) if not self._prevouts_by_scripthash[scripthash]: self._prevouts_by_scripthash.pop(scripthash) @locked def get_prevouts_by_scripthash(self, scripthash: str) -> Set[Tuple[TxOutpoint, int]]: + assert isinstance(scripthash, str) prevouts_and_values = self._prevouts_by_scripthash.get(scripthash, set()) return {(TxOutpoint.from_str(prevout), value) for prevout, value in prevouts_and_values} @modifier def add_transaction(self, tx_hash: str, tx: Transaction) -> None: + assert isinstance(tx_hash, str) assert isinstance(tx, Transaction), tx # note that tx might be a PartialTransaction if not tx_hash: @@ -784,43 +813,50 @@ class WalletDB(JsonDB): self.transactions[tx_hash] = tx @modifier - def remove_transaction(self, tx_hash) -> Optional[Transaction]: + def remove_transaction(self, tx_hash: str) -> Optional[Transaction]: + assert isinstance(tx_hash, str) return self.transactions.pop(tx_hash, None) @locked def get_transaction(self, tx_hash: str) -> Optional[Transaction]: + assert isinstance(tx_hash, str) return self.transactions.get(tx_hash) @locked - def list_transactions(self): + def list_transactions(self) -> Sequence[str]: return list(self.transactions.keys()) @locked - def get_history(self): + def get_history(self) -> Sequence[str]: return list(self.history.keys()) - def is_addr_in_history(self, addr): + def is_addr_in_history(self, addr: str) -> bool: # does not mean history is non-empty! + assert isinstance(addr, str) return addr in self.history @locked - def get_addr_history(self, addr): + def get_addr_history(self, addr: str) -> Sequence[Tuple[str, int]]: + assert isinstance(addr, str) return self.history.get(addr, []) @modifier - def set_addr_history(self, addr, hist): + def set_addr_history(self, addr: str, hist) -> None: + assert isinstance(addr, str) self.history[addr] = hist @modifier - def remove_addr_history(self, addr): + def remove_addr_history(self, addr: str) -> None: + assert isinstance(addr, str) self.history.pop(addr, None) @locked - def list_verified_tx(self): + def list_verified_tx(self) -> Sequence[str]: return list(self.verified_tx.keys()) @locked - def get_verified_tx(self, txid): + def get_verified_tx(self, txid: str) -> Optional[TxMinedInfo]: + assert isinstance(txid, str) if txid not in self.verified_tx: return None height, timestamp, txpos, header_hash = self.verified_tx[txid] @@ -831,18 +867,23 @@ class WalletDB(JsonDB): header_hash=header_hash) @modifier - def add_verified_tx(self, txid, info): + def add_verified_tx(self, txid: str, info: TxMinedInfo): + assert isinstance(txid, str) + assert isinstance(info, TxMinedInfo) self.verified_tx[txid] = (info.height, info.timestamp, info.txpos, info.header_hash) @modifier - def remove_verified_tx(self, txid): + def remove_verified_tx(self, txid: str): + assert isinstance(txid, str) self.verified_tx.pop(txid, None) - def is_in_verified_tx(self, txid): + def is_in_verified_tx(self, txid: str) -> bool: + assert isinstance(txid, str) return txid in self.verified_tx @modifier def add_tx_fee_from_server(self, txid: str, fee_sat: Optional[int]) -> None: + assert isinstance(txid, str) # note: when called with (fee_sat is None), rm currently saved value if txid not in self.tx_fees: self.tx_fees[txid] = TxFeesValue() @@ -853,14 +894,17 @@ class WalletDB(JsonDB): @modifier def add_tx_fee_we_calculated(self, txid: str, fee_sat: Optional[int]) -> None: + assert isinstance(txid, str) if fee_sat is None: return + assert isinstance(fee_sat, int) if txid not in self.tx_fees: self.tx_fees[txid] = TxFeesValue() self.tx_fees[txid] = self.tx_fees[txid]._replace(fee=fee_sat, is_calculated_by_us=True) @locked - def get_tx_fee(self, txid: str, *, trust_server=False) -> Optional[int]: + def get_tx_fee(self, txid: str, *, trust_server: bool = False) -> Optional[int]: + assert isinstance(txid, str) """Returns tx_fee.""" tx_fees_value = self.tx_fees.get(txid) if tx_fees_value is None: @@ -871,12 +915,15 @@ class WalletDB(JsonDB): @modifier def add_num_inputs_to_tx(self, txid: str, num_inputs: int) -> None: + assert isinstance(txid, str) + assert isinstance(num_inputs, int) if txid not in self.tx_fees: self.tx_fees[txid] = TxFeesValue() self.tx_fees[txid] = self.tx_fees[txid]._replace(num_inputs=num_inputs) @locked def get_num_all_inputs_of_tx(self, txid: str) -> Optional[int]: + assert isinstance(txid, str) tx_fees_value = self.tx_fees.get(txid) if tx_fees_value is None: return None @@ -884,11 +931,13 @@ class WalletDB(JsonDB): @locked def get_num_ismine_inputs_of_tx(self, txid: str) -> int: + assert isinstance(txid, str) txins = self.txi.get(txid, {}) return sum([len(tupls) for addr, tupls in txins.items()]) @modifier - def remove_tx_fee(self, txid): + def remove_tx_fee(self, txid: str) -> None: + assert isinstance(txid, str) self.tx_fees.pop(txid, None) @locked @@ -900,47 +949,53 @@ class WalletDB(JsonDB): return self.data[name] @locked - def num_change_addresses(self): + def num_change_addresses(self) -> int: return len(self.change_addresses) @locked - def num_receiving_addresses(self): + def num_receiving_addresses(self) -> int: return len(self.receiving_addresses) @locked - def get_change_addresses(self, *, slice_start=None, slice_stop=None): + def get_change_addresses(self, *, slice_start=None, slice_stop=None) -> List[str]: # note: slicing makes a shallow copy return self.change_addresses[slice_start:slice_stop] @locked - def get_receiving_addresses(self, *, slice_start=None, slice_stop=None): + def get_receiving_addresses(self, *, slice_start=None, slice_stop=None) -> List[str]: # note: slicing makes a shallow copy return self.receiving_addresses[slice_start:slice_stop] @modifier - def add_change_address(self, addr): + def add_change_address(self, addr: str) -> None: + assert isinstance(addr, str) self._addr_to_addr_index[addr] = (1, len(self.change_addresses)) self.change_addresses.append(addr) @modifier - def add_receiving_address(self, addr): + def add_receiving_address(self, addr: str) -> None: + assert isinstance(addr, str) self._addr_to_addr_index[addr] = (0, len(self.receiving_addresses)) self.receiving_addresses.append(addr) @locked - def get_address_index(self, address) -> Optional[Sequence[int]]: + def get_address_index(self, address: str) -> Optional[Sequence[int]]: + assert isinstance(address, str) return self._addr_to_addr_index.get(address) @modifier - def add_imported_address(self, addr, d): + def add_imported_address(self, addr: str, d: dict) -> None: + assert isinstance(addr, str) self.imported_addresses[addr] = d @modifier - def remove_imported_address(self, addr): + def remove_imported_address(self, addr: str) -> None: + assert isinstance(addr, str) self.imported_addresses.pop(addr) @locked def has_imported_address(self, addr: str) -> bool: + assert isinstance(addr, str) return addr in self.imported_addresses @locked @@ -948,7 +1003,8 @@ class WalletDB(JsonDB): return list(sorted(self.imported_addresses.keys())) @locked - def get_imported_address(self, addr): + def get_imported_address(self, addr: str) -> Optional[dict]: + assert isinstance(addr, str) return self.imported_addresses.get(addr) def load_addresses(self, wallet_type): @@ -973,10 +1029,10 @@ class WalletDB(JsonDB): self.data = StoredDict(self.data, self, []) # references in self.data # TODO make all these private - # txid -> address -> set of (prev_outpoint, value) - self.txi = self.get_dict('txi') # type: Dict[str, Dict[str, Set[Tuple[str, int]]]] - # txid -> address -> set of (output_index, value, is_coinbase) - self.txo = self.get_dict('txo') # type: Dict[str, Dict[str, Set[Tuple[int, int, bool]]]] + # txid -> address -> prev_outpoint -> value + self.txi = self.get_dict('txi') # type: Dict[str, Dict[str, Dict[str, int]]] + # txid -> address -> output_index -> (value, is_coinbase) + self.txo = self.get_dict('txo') # type: Dict[str, Dict[str, Dict[str, Tuple[int, bool]]]] self.transactions = self.get_dict('transactions') # type: Dict[str, Transaction] self.spent_outpoints = self.get_dict('spent_outpoints') # txid -> output_index -> next_txid self.history = self.get_dict('addr_history') # address -> list of (txid, height)