Browse Source

wallet: use abstract base classes

hard-fail-on-bad-server-string
SomberNight 5 years ago
parent
commit
869a728317
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 6
      electrum/json_db.py
  2. 2
      electrum/keystore.py
  3. 141
      electrum/wallet.py

6
electrum/json_db.py

@ -860,11 +860,11 @@ class JsonDB(Logger):
self.imported_addresses.pop(addr) self.imported_addresses.pop(addr)
@locked @locked
def has_imported_address(self, addr): def has_imported_address(self, addr: str) -> bool:
return addr in self.imported_addresses return addr in self.imported_addresses
@locked @locked
def get_imported_addresses(self): def get_imported_addresses(self) -> Sequence[str]:
return list(sorted(self.imported_addresses.keys())) return list(sorted(self.imported_addresses.keys()))
@locked @locked
@ -874,7 +874,7 @@ class JsonDB(Logger):
def load_addresses(self, wallet_type): def load_addresses(self, wallet_type):
""" called from Abstract_Wallet.__init__ """ """ called from Abstract_Wallet.__init__ """
if wallet_type == 'imported': if wallet_type == 'imported':
self.imported_addresses = self.get_data_ref('addresses') self.imported_addresses = self.get_data_ref('addresses') # type: Dict[str, dict]
else: else:
self.get_data_ref('addresses') self.get_data_ref('addresses')
for name in ['receiving', 'change']: for name in ['receiving', 'change']:

2
electrum/keystore.py

@ -624,6 +624,8 @@ class Old_KeyStore(MasterPublicKeyMixin, Deterministic_KeyStore):
return public_key.get_public_key_hex(compressed=False) return public_key.get_public_key_hex(compressed=False)
def derive_pubkey(self, for_change, n) -> str: def derive_pubkey(self, for_change, n) -> str:
for_change = int(for_change)
assert for_change in (0, 1)
return self.get_pubkey_from_mpk(self.mpk, for_change, n) return self.get_pubkey_from_mpk(self.mpk, for_change, n)
def _get_private_key_from_stretched_exponent(self, for_change, n, secexp): def _get_private_key_from_stretched_exponent(self, for_change, n, secexp):

141
electrum/wallet.py

@ -22,9 +22,9 @@
# SOFTWARE. # SOFTWARE.
# Wallet classes: # Wallet classes:
# - Imported_Wallet: imported address, no keystore # - Imported_Wallet: imported addresses or single keys, 0 or 1 keystore
# - Standard_Wallet: one keystore, P2PKH # - Standard_Wallet: one HD keystore, P2PKH-like scripts
# - Multisig_Wallet: several keystores, P2SH # - Multisig_Wallet: several HD keystores, M-of-N OP_CHECKMULTISIG scripts
import os import os
import sys import sys
@ -40,6 +40,8 @@ from collections import defaultdict
from numbers import Number from numbers import Number
from decimal import Decimal from decimal import Decimal
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, NamedTuple, Sequence, Dict, Any, Set from typing import TYPE_CHECKING, List, Optional, Tuple, Union, NamedTuple, Sequence, Dict, Any, Set
from abc import ABC, abstractmethod
import itertools
from .i18n import _ from .i18n import _
from .bip32 import BIP32Node from .bip32 import BIP32Node
@ -210,7 +212,7 @@ class TxWalletDetails(NamedTuple):
mempool_depth_bytes: Optional[int] mempool_depth_bytes: Optional[int]
class Abstract_Wallet(AddressSynchronizer): class Abstract_Wallet(AddressSynchronizer, ABC):
""" """
Wallet classes are created to handle various address generation methods. Wallet classes are created to handle various address generation methods.
Completion states (watching-only, single account, no seed, etc) are handled inside classes. Completion states (watching-only, single account, no seed, etc) are handled inside classes.
@ -314,8 +316,9 @@ class Abstract_Wallet(AddressSynchronizer):
self.test_addresses_sanity() self.test_addresses_sanity()
super().load_and_cleanup() super().load_and_cleanup()
@abstractmethod
def load_keystore(self) -> None: def load_keystore(self) -> None:
raise NotImplementedError() # implemented by subclasses pass
def diagnostic_name(self): def diagnostic_name(self):
return self.basename() return self.basename()
@ -332,7 +335,7 @@ class Abstract_Wallet(AddressSynchronizer):
def basename(self) -> str: def basename(self) -> str:
return os.path.basename(self.storage.path) return os.path.basename(self.storage.path)
def test_addresses_sanity(self): def test_addresses_sanity(self) -> None:
addrs = self.get_receiving_addresses() addrs = self.get_receiving_addresses()
if len(addrs) > 0: if len(addrs) > 0:
addr = str(addrs[0]) addr = str(addrs[0])
@ -350,7 +353,7 @@ class Abstract_Wallet(AddressSynchronizer):
self._unused_change_addresses = [addr for addr in addrs if not self.is_used(addr)] self._unused_change_addresses = [addr for addr in addrs if not self.is_used(addr)]
return list(self._unused_change_addresses) return list(self._unused_change_addresses)
def is_deterministic(self): def is_deterministic(self) -> bool:
return self.keystore.is_deterministic() return self.keystore.is_deterministic()
def set_label(self, name, text = None): def set_label(self, name, text = None):
@ -417,26 +420,22 @@ class Abstract_Wallet(AddressSynchronizer):
return False return False
return self.get_address_index(address)[0] == 1 return self.get_address_index(address)[0] == 1
@abstractmethod
def get_address_index(self, address): def get_address_index(self, address):
raise NotImplementedError() pass
@abstractmethod
def get_redeem_script(self, address: str) -> Optional[str]: def get_redeem_script(self, address: str) -> Optional[str]:
txin_type = self.get_txin_type(address) pass
if txin_type in ('p2pkh', 'p2wpkh', 'p2pk'):
return None
if txin_type == 'p2wpkh-p2sh':
pubkey = self.get_public_key(address)
return bitcoin.p2wpkh_nested_script(pubkey)
if txin_type == 'address':
return None
raise UnknownTxinType(f'unexpected txin_type {txin_type}')
@abstractmethod
def get_witness_script(self, address: str) -> Optional[str]: def get_witness_script(self, address: str) -> Optional[str]:
return None pass
@abstractmethod
def get_txin_type(self, address: str) -> str: def get_txin_type(self, address: str) -> str:
"""Return script type of wallet address.""" """Return script type of wallet address."""
raise NotImplementedError() pass
def export_private_key(self, address, password) -> str: def export_private_key(self, address, password) -> str:
if self.is_watching_only(): if self.is_watching_only():
@ -451,17 +450,14 @@ class Abstract_Wallet(AddressSynchronizer):
serialized_privkey = bitcoin.serialize_privkey(pk, compressed, txin_type) serialized_privkey = bitcoin.serialize_privkey(pk, compressed, txin_type)
return serialized_privkey return serialized_privkey
def get_public_keys(self, address): @abstractmethod
return [self.get_public_key(address)] def get_public_keys(self, address: str) -> Sequence[str]:
pass
def get_public_keys_with_deriv_info(self, address: str) -> Dict[str, Tuple[KeyStoreWithMPK, Sequence[int]]]: def get_public_keys_with_deriv_info(self, address: str) -> Dict[str, Tuple[KeyStoreWithMPK, Sequence[int]]]:
"""Returns a map: pubkey_hex -> (keystore, derivation_suffix)""" """Returns a map: pubkey_hex -> (keystore, derivation_suffix)"""
return {} return {}
def is_found(self):
return True
#return self.history.values() != [[]] * len(self.history)
def get_tx_info(self, tx) -> TxWalletDetails: def get_tx_info(self, tx) -> TxWalletDetails:
is_relevant, is_mine, v, fee = self.get_wallet_delta(tx) is_relevant, is_mine, v, fee = self.get_wallet_delta(tx)
if fee is None and isinstance(tx, PartialTransaction): if fee is None and isinstance(tx, PartialTransaction):
@ -536,11 +532,13 @@ class Abstract_Wallet(AddressSynchronizer):
utxos = [utxo for utxo in utxos if not self.is_frozen_coin(utxo)] utxos = [utxo for utxo in utxos if not self.is_frozen_coin(utxo)]
return utxos return utxos
@abstractmethod
def get_receiving_addresses(self, *, slice_start=None, slice_stop=None) -> Sequence[str]: def get_receiving_addresses(self, *, slice_start=None, slice_stop=None) -> Sequence[str]:
raise NotImplementedError() # implemented by subclasses pass
@abstractmethod
def get_change_addresses(self, *, slice_start=None, slice_stop=None) -> Sequence[str]: def get_change_addresses(self, *, slice_start=None, slice_stop=None) -> Sequence[str]:
raise NotImplementedError() # implemented by subclasses pass
def dummy_address(self): def dummy_address(self):
# first receiving address # first receiving address
@ -1304,8 +1302,9 @@ class Abstract_Wallet(AddressSynchronizer):
locktime = get_locktime_for_new_transaction(self.network) locktime = get_locktime_for_new_transaction(self.network)
return PartialTransaction.from_io(inputs, outputs, locktime=locktime) return PartialTransaction.from_io(inputs, outputs, locktime=locktime)
@abstractmethod
def _add_input_sig_info(self, txin: PartialTxInput, address: str, *, only_der_suffix: bool = True) -> None: def _add_input_sig_info(self, txin: PartialTxInput, address: str, *, only_der_suffix: bool = True) -> None:
raise NotImplementedError() # implemented by subclasses pass
def _add_txinout_derivation_info(self, txinout: Union[PartialTxInput, PartialTxOutput], def _add_txinout_derivation_info(self, txinout: Union[PartialTxInput, PartialTxOutput],
address: str, *, only_der_suffix: bool = True) -> None: address: str, *, only_der_suffix: bool = True) -> None:
@ -1439,10 +1438,10 @@ class Abstract_Wallet(AddressSynchronizer):
tx.add_info_from_wallet(self, include_xpubs_and_full_paths=False) tx.add_info_from_wallet(self, include_xpubs_and_full_paths=False)
return tx return tx
def try_detecting_internal_addresses_corruption(self): def try_detecting_internal_addresses_corruption(self) -> None:
pass pass
def check_address(self, addr): def check_address(self, addr: str) -> None:
pass pass
def check_returned_address(func): def check_returned_address(func):
@ -1479,7 +1478,7 @@ class Abstract_Wallet(AddressSynchronizer):
choice = addr choice = addr
return choice return choice
def create_new_address(self, for_change=False): def create_new_address(self, for_change: bool = False):
raise Exception("this wallet cannot generate new addresses") raise Exception("this wallet cannot generate new addresses")
def get_payment_status(self, address, amount): def get_payment_status(self, address, amount):
@ -1650,8 +1649,9 @@ class Abstract_Wallet(AddressSynchronizer):
out.sort(key=operator.itemgetter('time')) out.sort(key=operator.itemgetter('time'))
return out return out
@abstractmethod
def get_fingerprint(self): def get_fingerprint(self):
raise NotImplementedError() pass
def can_import_privkey(self): def can_import_privkey(self):
return False return False
@ -1722,15 +1722,23 @@ class Abstract_Wallet(AddressSynchronizer):
self.storage.set_keystore_encryption(bool(new_pw) and encrypt_keystore) self.storage.set_keystore_encryption(bool(new_pw) and encrypt_keystore)
self.storage.write() self.storage.write()
@abstractmethod
def _update_password_for_keystore(self, old_pw: Optional[str], new_pw: Optional[str]) -> None:
pass
def sign_message(self, address, message, password): def sign_message(self, address, message, password):
index = self.get_address_index(address) index = self.get_address_index(address)
return self.keystore.sign_message(index, message, password) return self.keystore.sign_message(index, message, password)
def decrypt_message(self, pubkey, message, password) -> bytes: def decrypt_message(self, pubkey, message, password) -> bytes:
addr = self.pubkeys_to_address(pubkey) addr = self.pubkeys_to_address([pubkey])
index = self.get_address_index(addr) index = self.get_address_index(addr)
return self.keystore.decrypt_message(index, message, password) return self.keystore.decrypt_message(index, message, password)
@abstractmethod
def pubkeys_to_address(self, pubkeys: Sequence[str]) -> Optional[str]:
pass
def txin_value(self, txin: TxInput) -> Optional[int]: def txin_value(self, txin: TxInput) -> Optional[int]:
if isinstance(txin, PartialTxInput): if isinstance(txin, PartialTxInput):
v = txin.value_sats() v = txin.value_sats()
@ -1799,8 +1807,9 @@ class Abstract_Wallet(AddressSynchronizer):
# overridden for TrustedCoin wallets # overridden for TrustedCoin wallets
return False return False
@abstractmethod
def is_watching_only(self) -> bool: def is_watching_only(self) -> bool:
raise NotImplementedError() pass
def get_keystore(self) -> Optional[KeyStore]: def get_keystore(self) -> Optional[KeyStore]:
return self.keystore return self.keystore
@ -1808,14 +1817,17 @@ class Abstract_Wallet(AddressSynchronizer):
def get_keystores(self) -> Sequence[KeyStore]: def get_keystores(self) -> Sequence[KeyStore]:
return [self.keystore] if self.keystore else [] return [self.keystore] if self.keystore else []
@abstractmethod
def save_keystore(self): def save_keystore(self):
raise NotImplementedError() pass
@abstractmethod
def has_seed(self) -> bool: def has_seed(self) -> bool:
raise NotImplementedError() pass
@abstractmethod
def is_beyond_limit(self, address: str) -> bool: def is_beyond_limit(self, address: str) -> bool:
raise NotImplementedError() pass
class Simple_Wallet(Abstract_Wallet): class Simple_Wallet(Abstract_Wallet):
@ -1832,6 +1844,27 @@ class Simple_Wallet(Abstract_Wallet):
def save_keystore(self): def save_keystore(self):
self.storage.put('keystore', self.keystore.dump()) self.storage.put('keystore', self.keystore.dump())
@abstractmethod
def get_public_key(self, address: str) -> Optional[str]:
pass
def get_public_keys(self, address: str) -> Sequence[str]:
return [self.get_public_key(address)]
def get_redeem_script(self, address: str) -> Optional[str]:
txin_type = self.get_txin_type(address)
if txin_type in ('p2pkh', 'p2wpkh', 'p2pk'):
return None
if txin_type == 'p2wpkh-p2sh':
pubkey = self.get_public_key(address)
return bitcoin.p2wpkh_nested_script(pubkey)
if txin_type == 'address':
return None
raise UnknownTxinType(f'unexpected txin_type {txin_type}')
def get_witness_script(self, address: str) -> Optional[str]:
return None
class Imported_Wallet(Simple_Wallet): class Imported_Wallet(Simple_Wallet):
# wallet made of imported addresses # wallet made of imported addresses
@ -2005,10 +2038,13 @@ class Imported_Wallet(Simple_Wallet):
raise Exception(f'Unexpected script type: {txin.script_type}. ' raise Exception(f'Unexpected script type: {txin.script_type}. '
f'Imported wallets are not implemented to handle this.') f'Imported wallets are not implemented to handle this.')
def pubkeys_to_address(self, pubkey): def pubkeys_to_address(self, pubkeys):
pubkey = pubkeys[0]
for addr in self.db.get_imported_addresses(): for addr in self.db.get_imported_addresses():
if self.db.get_imported_address(addr)['pubkey'] == pubkey: if self.db.get_imported_address(addr)['pubkey'] == pubkey:
return addr return addr
return None
class Deterministic_Wallet(Abstract_Wallet): class Deterministic_Wallet(Abstract_Wallet):
@ -2047,7 +2083,7 @@ class Deterministic_Wallet(Abstract_Wallet):
# sample2: a few more randomly selected # sample2: a few more randomly selected
addresses_rand = addresses_all[10:] addresses_rand = addresses_all[10:]
addresses_sample2 = random.sample(addresses_rand, min(len(addresses_rand), 10)) addresses_sample2 = random.sample(addresses_rand, min(len(addresses_rand), 10))
for addr_found in addresses_sample1 + addresses_sample2: for addr_found in itertools.chain(addresses_sample1, addresses_sample2):
self.check_address(addr_found) self.check_address(addr_found)
def check_address(self, addr): def check_address(self, addr):
@ -2058,9 +2094,6 @@ class Deterministic_Wallet(Abstract_Wallet):
def get_seed(self, password): def get_seed(self, password):
return self.keystore.get_seed(password) return self.keystore.get_seed(password)
def add_seed(self, seed, pw):
self.keystore.add_seed(seed, pw)
def change_gap_limit(self, value): def change_gap_limit(self, value):
'''This method is not called in the code, it is kept for console use''' '''This method is not called in the code, it is kept for console use'''
if value >= self.min_acceptable_gap(): if value >= self.min_acceptable_gap():
@ -2093,9 +2126,14 @@ class Deterministic_Wallet(Abstract_Wallet):
nmax = max(nmax, n) nmax = max(nmax, n)
return nmax + 1 return nmax + 1
def derive_address(self, for_change, n): @abstractmethod
x = self.derive_pubkeys(for_change, n) def derive_pubkeys(self, c: int, i: int) -> Sequence[str]:
return self.pubkeys_to_address(x) pass
def derive_address(self, for_change: int, n: int) -> str:
for_change = int(for_change)
pubkeys = self.derive_pubkeys(for_change, n)
return self.pubkeys_to_address(pubkeys)
def get_public_keys_with_deriv_info(self, address: str): def get_public_keys_with_deriv_info(self, address: str):
der_suffix = self.get_address_index(address) der_suffix = self.get_address_index(address)
@ -2117,11 +2155,11 @@ class Deterministic_Wallet(Abstract_Wallet):
only_der_suffix=only_der_suffix) only_der_suffix=only_der_suffix)
txinout.bip32_paths[bfh(pubkey_hex)] = (fp_bytes, der_full) txinout.bip32_paths[bfh(pubkey_hex)] = (fp_bytes, der_full)
def create_new_address(self, for_change=False): def create_new_address(self, for_change: bool = False):
assert type(for_change) is bool assert type(for_change) is bool
with self.lock: with self.lock:
n = self.db.num_change_addresses() if for_change else self.db.num_receiving_addresses() n = self.db.num_change_addresses() if for_change else self.db.num_receiving_addresses()
address = self.derive_address(for_change, n) address = self.derive_address(int(for_change), n)
self.db.add_change_address(address) if for_change else self.db.add_receiving_address(address) self.db.add_change_address(address) if for_change else self.db.add_receiving_address(address)
self.add_address(address) self.add_address(address)
if for_change: if for_change:
@ -2197,8 +2235,8 @@ class Simple_Deterministic_Wallet(Simple_Wallet, Deterministic_Wallet):
def get_public_key(self, address): def get_public_key(self, address):
sequence = self.get_address_index(address) sequence = self.get_address_index(address)
pubkey = self.derive_pubkeys(*sequence) pubkeys = self.derive_pubkeys(*sequence)
return pubkey return pubkeys[0]
def load_keystore(self): def load_keystore(self):
self.keystore = load_keystore(self.storage, 'keystore') self.keystore = load_keystore(self.storage, 'keystore')
@ -2212,7 +2250,7 @@ class Simple_Deterministic_Wallet(Simple_Wallet, Deterministic_Wallet):
return self.keystore.get_master_public_key() return self.keystore.get_master_public_key()
def derive_pubkeys(self, c, i): def derive_pubkeys(self, c, i):
return self.keystore.derive_pubkey(c, i) return [self.keystore.derive_pubkey(c, i)]
@ -2222,7 +2260,8 @@ class Simple_Deterministic_Wallet(Simple_Wallet, Deterministic_Wallet):
class Standard_Wallet(Simple_Deterministic_Wallet): class Standard_Wallet(Simple_Deterministic_Wallet):
wallet_type = 'standard' wallet_type = 'standard'
def pubkeys_to_address(self, pubkey): def pubkeys_to_address(self, pubkeys):
pubkey = pubkeys[0]
return bitcoin.pubkey_to_address(self.txin_type, pubkey) return bitcoin.pubkey_to_address(self.txin_type, pubkey)

Loading…
Cancel
Save