Browse Source

wallet: better (outgoing) invoice "paid" detection

- no more passing around "invoice" in GUIs, invoice "paid" detection is now handled by wallet logic
- a tx can now pay for multiple invoices
- an invoice can now be paid by multiple txs (through partial payments)
- new data structure in storage: prevouts_by_scripthash
  - type: scripthash -> set of (outpoint, value)
  - also, storage upgrade to build this for existing wallets
hard-fail-on-bad-server-string
SomberNight 5 years ago
parent
commit
8dbbc21aff
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 19
      electrum/address_synchronizer.py
  2. 8
      electrum/gui/kivy/main_window.py
  3. 6
      electrum/gui/kivy/uix/screens.py
  4. 3
      electrum/gui/qt/invoice_list.py
  5. 37
      electrum/gui/qt/main_window.py
  6. 17
      electrum/gui/qt/transaction_dialog.py
  7. 51
      electrum/json_db.py
  8. 77
      electrum/wallet.py

19
electrum/address_synchronizer.py

@ -26,7 +26,7 @@ import threading
import asyncio import asyncio
import itertools import itertools
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, NamedTuple, Sequence from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, NamedTuple, Sequence, List
from . import bitcoin from . import bitcoin
from .bitcoin import COINBASE_MATURITY from .bitcoin import COINBASE_MATURITY
@ -207,7 +207,7 @@ class AddressSynchronizer(Logger):
conflicting_txns -= {tx_hash} conflicting_txns -= {tx_hash}
return conflicting_txns return conflicting_txns
def add_transaction(self, tx: Transaction, allow_unrelated=False) -> bool: def add_transaction(self, tx: Transaction, *, allow_unrelated=False) -> bool:
"""Returns whether the tx was successfully added to the wallet history.""" """Returns whether the tx was successfully added to the wallet history."""
assert tx, tx assert tx, tx
assert tx.is_complete() assert tx.is_complete()
@ -283,6 +283,8 @@ class AddressSynchronizer(Logger):
for n, txo in enumerate(tx.outputs()): for n, txo in enumerate(tx.outputs()):
v = txo.value v = txo.value
ser = tx_hash + ':%d'%n ser = tx_hash + ':%d'%n
scripthash = bitcoin.script_to_scripthash(txo.scriptpubkey.hex())
self.db.add_prevout_by_scripthash(scripthash, prevout=TxOutpoint.from_str(ser), value=v)
addr = self.get_txout_address(txo) addr = self.get_txout_address(txo)
if addr and self.is_mine(addr): if addr and self.is_mine(addr):
self.db.add_txo_addr(tx_hash, addr, n, v, is_coinbase) self.db.add_txo_addr(tx_hash, addr, n, v, is_coinbase)
@ -299,7 +301,7 @@ class AddressSynchronizer(Logger):
self.db.add_num_inputs_to_tx(tx_hash, len(tx.inputs())) self.db.add_num_inputs_to_tx(tx_hash, len(tx.inputs()))
return True return True
def remove_transaction(self, tx_hash): def remove_transaction(self, tx_hash: str) -> None:
def remove_from_spent_outpoints(): def remove_from_spent_outpoints():
# undo spends in spent_outpoints # undo spends in spent_outpoints
if tx is not None: if tx is not None:
@ -317,7 +319,7 @@ class AddressSynchronizer(Logger):
if spending_txid == tx_hash: if spending_txid == tx_hash:
self.db.remove_spent_outpoint(prevout_hash, prevout_n) self.db.remove_spent_outpoint(prevout_hash, prevout_n)
with self.transaction_lock: with self.lock, self.transaction_lock:
self.logger.info(f"removing tx from history {tx_hash}") self.logger.info(f"removing tx from history {tx_hash}")
tx = self.db.remove_transaction(tx_hash) tx = self.db.remove_transaction(tx_hash)
remove_from_spent_outpoints() remove_from_spent_outpoints()
@ -327,6 +329,13 @@ class AddressSynchronizer(Logger):
self.db.remove_txi(tx_hash) self.db.remove_txi(tx_hash)
self.db.remove_txo(tx_hash) self.db.remove_txo(tx_hash)
self.db.remove_tx_fee(tx_hash) self.db.remove_tx_fee(tx_hash)
self.db.remove_verified_tx(tx_hash)
self.unverified_tx.pop(tx_hash, None)
if tx:
for idx, txo in enumerate(tx.outputs()):
scripthash = bitcoin.script_to_scripthash(txo.scriptpubkey.hex())
prevout = TxOutpoint(bfh(tx_hash), idx)
self.db.remove_prevout_by_scripthash(scripthash, prevout=prevout, value=txo.value)
def get_depending_transactions(self, tx_hash): def get_depending_transactions(self, tx_hash):
"""Returns all (grand-)children of tx_hash in this wallet.""" """Returns all (grand-)children of tx_hash in this wallet."""
@ -338,7 +347,7 @@ class AddressSynchronizer(Logger):
children |= self.get_depending_transactions(other_hash) children |= self.get_depending_transactions(other_hash)
return children return children
def receive_tx_callback(self, tx_hash, tx, tx_height): def receive_tx_callback(self, tx_hash: str, tx: Transaction, tx_height: int) -> None:
self.add_unverified_tx(tx_hash, tx_height) self.add_unverified_tx(tx_hash, tx_height)
self.add_transaction(tx, allow_unrelated=True) self.add_transaction(tx, allow_unrelated=True)

8
electrum/gui/kivy/main_window.py

@ -1028,18 +1028,12 @@ class ElectrumWindow(App):
status, msg = True, tx.txid() status, msg = True, tx.txid()
Clock.schedule_once(lambda dt: on_complete(status, msg)) Clock.schedule_once(lambda dt: on_complete(status, msg))
def broadcast(self, tx, invoice=None): def broadcast(self, tx):
def on_complete(ok, msg): def on_complete(ok, msg):
if ok: if ok:
self.show_info(_('Payment sent.')) self.show_info(_('Payment sent.'))
if self.send_screen: if self.send_screen:
self.send_screen.do_clear() self.send_screen.do_clear()
if invoice:
key = invoice['id']
txid = tx.txid()
self.wallet.set_label(txid, invoice['message'])
self.wallet.set_paid(key, txid)
self.update_tab('invoices')
else: else:
msg = msg or '' msg = msg or ''
self.show_error(msg) self.show_error(msg)

6
electrum/gui/kivy/uix/screens.py

@ -380,14 +380,14 @@ class SendScreen(CScreen):
if fee > feerate_warning * tx.estimated_size() / 1000: if fee > feerate_warning * tx.estimated_size() / 1000:
msg.append(_('Warning') + ': ' + _("The fee for this transaction seems unusually high.")) msg.append(_('Warning') + ': ' + _("The fee for this transaction seems unusually high."))
msg.append(_("Enter your PIN code to proceed")) msg.append(_("Enter your PIN code to proceed"))
self.app.protected('\n'.join(msg), self.send_tx, (tx, invoice)) self.app.protected('\n'.join(msg), self.send_tx, (tx,))
def send_tx(self, tx, invoice, password): def send_tx(self, tx, password):
if self.app.wallet.has_password() and password is None: if self.app.wallet.has_password() and password is None:
return return
def on_success(tx): def on_success(tx):
if tx.is_complete(): if tx.is_complete():
self.app.broadcast(tx, invoice) self.app.broadcast(tx)
else: else:
self.app.tx_dialog(tx) self.app.tx_dialog(tx)
def on_failure(error): def on_failure(error):

3
electrum/gui/qt/invoice_list.py

@ -96,7 +96,8 @@ class InvoiceList(MyTreeView):
_list = self.parent.wallet.get_invoices() _list = self.parent.wallet.get_invoices()
# filter out paid invoices unless we have the log # filter out paid invoices unless we have the log
lnworker_logs = self.parent.wallet.lnworker.logs if self.parent.wallet.lnworker else {} lnworker_logs = self.parent.wallet.lnworker.logs if self.parent.wallet.lnworker else {}
_list = [x for x in _list if x and x.get('status') != PR_PAID or x.get('rhash') in lnworker_logs] _list = [x for x in _list
if x and (x.get('status') != PR_PAID or x.get('rhash') in lnworker_logs)]
self.model().clear() self.model().clear()
self.update_headers(self.__class__.headers) self.update_headers(self.__class__.headers)
for idx, item in enumerate(_list): for idx, item in enumerate(_list):

37
electrum/gui/qt/main_window.py

@ -928,9 +928,9 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
d = address_dialog.AddressDialog(self, addr) d = address_dialog.AddressDialog(self, addr)
d.exec_() d.exec_()
def show_transaction(self, tx, *, invoice=None, tx_desc=None): def show_transaction(self, tx, *, tx_desc=None):
'''tx_desc is set only for txs created in the Send tab''' '''tx_desc is set only for txs created in the Send tab'''
show_transaction(tx, parent=self, invoice=invoice, desc=tx_desc) show_transaction(tx, parent=self, desc=tx_desc)
def create_receive_tab(self): def create_receive_tab(self):
# A 4-column grid layout. All the stretch is in the last column. # A 4-column grid layout. All the stretch is in the last column.
@ -1472,7 +1472,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
self.pay_lightning_invoice(invoice['invoice'], amount_sat=invoice['amount']) self.pay_lightning_invoice(invoice['invoice'], amount_sat=invoice['amount'])
elif invoice['type'] == PR_TYPE_ONCHAIN: elif invoice['type'] == PR_TYPE_ONCHAIN:
outputs = invoice['outputs'] outputs = invoice['outputs']
self.pay_onchain_dialog(self.get_coins(), outputs, invoice=invoice) self.pay_onchain_dialog(self.get_coins(), outputs)
else: else:
raise Exception('unknown invoice type') raise Exception('unknown invoice type')
@ -1492,7 +1492,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
def pay_onchain_dialog(self, inputs: Sequence[PartialTxInput], def pay_onchain_dialog(self, inputs: Sequence[PartialTxInput],
outputs: List[PartialTxOutput], *, outputs: List[PartialTxOutput], *,
invoice=None, external_keypairs=None) -> None: external_keypairs=None) -> None:
# trustedcoin requires this # trustedcoin requires this
if run_hook('abort_send', self): if run_hook('abort_send', self):
return return
@ -1508,8 +1508,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
return return
if self.config.get('advanced_preview'): if self.config.get('advanced_preview'):
self.preview_tx_dialog(make_tx=make_tx, self.preview_tx_dialog(make_tx=make_tx,
external_keypairs=external_keypairs, external_keypairs=external_keypairs)
invoice=invoice)
return return
output_value = '!' if '!' in output_values else sum(output_values) output_value = '!' if '!' in output_values else sum(output_values)
@ -1524,27 +1523,26 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
if is_send: if is_send:
def sign_done(success): def sign_done(success):
if success: if success:
self.broadcast_or_show(tx, invoice=invoice) self.broadcast_or_show(tx)
self.sign_tx_with_password(tx, callback=sign_done, password=password, self.sign_tx_with_password(tx, callback=sign_done, password=password,
external_keypairs=external_keypairs) external_keypairs=external_keypairs)
else: else:
self.preview_tx_dialog(make_tx=make_tx, self.preview_tx_dialog(make_tx=make_tx,
external_keypairs=external_keypairs, external_keypairs=external_keypairs)
invoice=invoice)
def preview_tx_dialog(self, *, make_tx, external_keypairs=None, invoice=None): def preview_tx_dialog(self, *, make_tx, external_keypairs=None):
d = PreviewTxDialog(make_tx=make_tx, external_keypairs=external_keypairs, d = PreviewTxDialog(make_tx=make_tx, external_keypairs=external_keypairs,
window=self, invoice=invoice) window=self)
d.show() d.show()
def broadcast_or_show(self, tx, *, invoice=None): def broadcast_or_show(self, tx: Transaction):
if not self.network: if not self.network:
self.show_error(_("You can't broadcast a transaction without a live network connection.")) self.show_error(_("You can't broadcast a transaction without a live network connection."))
self.show_transaction(tx, invoice=invoice) self.show_transaction(tx)
elif not tx.is_complete(): elif not tx.is_complete():
self.show_transaction(tx, invoice=invoice) self.show_transaction(tx)
else: else:
self.broadcast_transaction(tx, invoice=invoice) self.broadcast_transaction(tx)
@protected @protected
def sign_tx(self, tx, *, callback, external_keypairs, password): def sign_tx(self, tx, *, callback, external_keypairs, password):
@ -1568,7 +1566,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
msg = _('Signing transaction...') msg = _('Signing transaction...')
WaitingDialog(self, msg, task, on_success, on_failure) WaitingDialog(self, msg, task, on_success, on_failure)
def broadcast_transaction(self, tx: Transaction, *, invoice=None, tx_desc=None): def broadcast_transaction(self, tx: Transaction):
def broadcast_thread(): def broadcast_thread():
# non-GUI thread # non-GUI thread
@ -1584,11 +1582,6 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
return False, repr(e) return False, repr(e)
# success # success
txid = tx.txid() txid = tx.txid()
if tx_desc:
self.wallet.set_label(txid, tx_desc)
if invoice:
self.wallet.set_paid(invoice['id'], txid)
self.wallet.set_label(txid, invoice['message'])
if pr: if pr:
self.payment_request = None self.payment_request = None
refund_address = self.wallet.get_receiving_address() refund_address = self.wallet.get_receiving_address()
@ -2709,7 +2702,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
scriptpubkey = bfh(bitcoin.address_to_script(addr)) scriptpubkey = bfh(bitcoin.address_to_script(addr))
outputs = [PartialTxOutput(scriptpubkey=scriptpubkey, value='!')] outputs = [PartialTxOutput(scriptpubkey=scriptpubkey, value='!')]
self.warn_if_watching_only() self.warn_if_watching_only()
self.pay_onchain_dialog(coins, outputs, invoice=None, external_keypairs=keypairs) self.pay_onchain_dialog(coins, outputs, external_keypairs=keypairs)
def _do_import(self, title, header_layout, func): def _do_import(self, title, header_layout, func):
text = text_dialog(self, title, header_layout, _('Import'), allow_multi=True) text = text_dialog(self, title, header_layout, _('Import'), allow_multi=True)

17
electrum/gui/qt/transaction_dialog.py

@ -75,9 +75,9 @@ _logger = get_logger(__name__)
dialogs = [] # Otherwise python randomly garbage collects the dialogs... dialogs = [] # Otherwise python randomly garbage collects the dialogs...
def show_transaction(tx: Transaction, *, parent: 'ElectrumWindow', invoice=None, desc=None, prompt_if_unsaved=False): def show_transaction(tx: Transaction, *, parent: 'ElectrumWindow', desc=None, prompt_if_unsaved=False):
try: try:
d = TxDialog(tx, parent=parent, invoice=invoice, desc=desc, prompt_if_unsaved=prompt_if_unsaved) d = TxDialog(tx, parent=parent, desc=desc, prompt_if_unsaved=prompt_if_unsaved)
except SerializationError as e: except SerializationError as e:
_logger.exception('unable to deserialize the transaction') _logger.exception('unable to deserialize the transaction')
parent.show_critical(_("Electrum was unable to deserialize the transaction:") + "\n" + str(e)) parent.show_critical(_("Electrum was unable to deserialize the transaction:") + "\n" + str(e))
@ -88,7 +88,7 @@ def show_transaction(tx: Transaction, *, parent: 'ElectrumWindow', invoice=None,
class BaseTxDialog(QDialog, MessageBoxMixin): class BaseTxDialog(QDialog, MessageBoxMixin):
def __init__(self, *, parent: 'ElectrumWindow', invoice, desc, prompt_if_unsaved, finalized: bool, external_keypairs=None): def __init__(self, *, parent: 'ElectrumWindow', desc, prompt_if_unsaved, finalized: bool, external_keypairs=None):
'''Transactions in the wallet will show their description. '''Transactions in the wallet will show their description.
Pass desc to give a description for txs not yet in the wallet. Pass desc to give a description for txs not yet in the wallet.
''' '''
@ -103,7 +103,6 @@ class BaseTxDialog(QDialog, MessageBoxMixin):
self.prompt_if_unsaved = prompt_if_unsaved self.prompt_if_unsaved = prompt_if_unsaved
self.saved = False self.saved = False
self.desc = desc self.desc = desc
self.invoice = invoice
self.setMinimumWidth(950) self.setMinimumWidth(950)
self.set_title() self.set_title()
@ -213,7 +212,7 @@ class BaseTxDialog(QDialog, MessageBoxMixin):
def do_broadcast(self): def do_broadcast(self):
self.main_window.push_top_level_window(self) self.main_window.push_top_level_window(self)
try: try:
self.main_window.broadcast_transaction(self.tx, invoice=self.invoice, tx_desc=self.desc) self.main_window.broadcast_transaction(self.tx)
finally: finally:
self.main_window.pop_top_level_window(self) self.main_window.pop_top_level_window(self)
self.saved = True self.saved = True
@ -592,8 +591,8 @@ class TxDetailLabel(QLabel):
class TxDialog(BaseTxDialog): class TxDialog(BaseTxDialog):
def __init__(self, tx: Transaction, *, parent: 'ElectrumWindow', invoice, desc, prompt_if_unsaved): def __init__(self, tx: Transaction, *, parent: 'ElectrumWindow', desc, prompt_if_unsaved):
BaseTxDialog.__init__(self, parent=parent, invoice=invoice, desc=desc, prompt_if_unsaved=prompt_if_unsaved, finalized=True) BaseTxDialog.__init__(self, parent=parent, desc=desc, prompt_if_unsaved=prompt_if_unsaved, finalized=True)
self.set_tx(tx) self.set_tx(tx)
self.update() self.update()
@ -601,9 +600,9 @@ class TxDialog(BaseTxDialog):
class PreviewTxDialog(BaseTxDialog, TxEditor): class PreviewTxDialog(BaseTxDialog, TxEditor):
def __init__(self, *, make_tx, external_keypairs, window: 'ElectrumWindow', invoice): def __init__(self, *, make_tx, external_keypairs, window: 'ElectrumWindow'):
TxEditor.__init__(self, window=window, make_tx=make_tx, is_sweep=bool(external_keypairs)) TxEditor.__init__(self, window=window, make_tx=make_tx, is_sweep=bool(external_keypairs))
BaseTxDialog.__init__(self, parent=window, invoice=invoice, desc='', prompt_if_unsaved=False, BaseTxDialog.__init__(self, parent=window, desc='', prompt_if_unsaved=False,
finalized=False, external_keypairs=external_keypairs) finalized=False, external_keypairs=external_keypairs)
self.update_tx() self.update_tx()
self.update() self.update()

51
electrum/json_db.py

@ -31,16 +31,16 @@ from collections import defaultdict
from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence
from . import util, bitcoin from . import util, bitcoin
from .util import profiler, WalletFileException, multisig_type, TxMinedInfo from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, bfh
from .keystore import bip44_derivation from .keystore import bip44_derivation
from .transaction import Transaction from .transaction import Transaction, TxOutpoint
from .logging import Logger from .logging import Logger
# seed_version is now used for the version of the wallet file # seed_version is now used for the version of the wallet file
OLD_SEED_VERSION = 4 # electrum versions < 2.0 OLD_SEED_VERSION = 4 # electrum versions < 2.0
NEW_SEED_VERSION = 11 # electrum versions >= 2.0 NEW_SEED_VERSION = 11 # electrum versions >= 2.0
FINAL_SEED_VERSION = 21 # electrum >= 2.7 will set this to prevent FINAL_SEED_VERSION = 22 # electrum >= 2.7 will set this to prevent
# old versions from overwriting new format # old versions from overwriting new format
@ -215,6 +215,7 @@ class JsonDB(Logger):
self._convert_version_19() self._convert_version_19()
self._convert_version_20() self._convert_version_20()
self._convert_version_21() self._convert_version_21()
self._convert_version_22()
self.put('seed_version', FINAL_SEED_VERSION) # just to be sure self.put('seed_version', FINAL_SEED_VERSION) # just to be sure
self._after_upgrade_tasks() self._after_upgrade_tasks()
@ -496,6 +497,24 @@ class JsonDB(Logger):
self.put('channels', channels) self.put('channels', channels)
self.put('seed_version', 21) self.put('seed_version', 21)
def _convert_version_22(self):
# construct prevouts_by_scripthash
if not self._is_upgrade_method_needed(21, 21):
return
from .bitcoin import script_to_scripthash
transactions = self.get('transactions', {}) # txid -> raw_tx
prevouts_by_scripthash = defaultdict(list)
for txid, raw_tx in transactions.items():
tx = Transaction(raw_tx)
for idx, txout in enumerate(tx.outputs()):
outpoint = f"{txid}:{idx}"
scripthash = script_to_scripthash(txout.scriptpubkey.hex())
prevouts_by_scripthash[scripthash].append((outpoint, txout.value))
self.put('prevouts_by_scripthash', prevouts_by_scripthash)
self.put('seed_version', 22)
def _convert_imported(self): def _convert_imported(self):
if not self._is_upgrade_method_needed(0, 13): if not self._is_upgrade_method_needed(0, 13):
return return
@ -660,6 +679,25 @@ class JsonDB(Logger):
self.spent_outpoints[prevout_hash] = {} self.spent_outpoints[prevout_hash] = {}
self.spent_outpoints[prevout_hash][prevout_n] = tx_hash self.spent_outpoints[prevout_hash][prevout_n] = tx_hash
@modifier
def add_prevout_by_scripthash(self, scripthash: str, *, prevout: TxOutpoint, value: int) -> None:
assert isinstance(prevout, TxOutpoint)
if scripthash not in self._prevouts_by_scripthash:
self._prevouts_by_scripthash[scripthash] = set()
self._prevouts_by_scripthash[scripthash].add((prevout.to_str(), value))
@modifier
def remove_prevout_by_scripthash(self, scripthash: str, *, prevout: TxOutpoint, value: int) -> None:
assert isinstance(prevout, TxOutpoint)
self._prevouts_by_scripthash[scripthash].discard((prevout.to_str(), value))
if not self._prevouts_by_scripthash[scripthash]:
self._prevouts_by_scripthash.pop(scripthash)
@locked
def get_prevouts_by_scripthash(self, scripthash: str) -> Set[Tuple[TxOutpoint, int]]:
prevouts_and_values = self._prevouts_by_scripthash.get(scripthash, set())
return {(TxOutpoint.from_str(prevout), value) for prevout, value in prevouts_and_values}
@modifier @modifier
def add_transaction(self, tx_hash: str, tx: Transaction) -> None: def add_transaction(self, tx_hash: str, tx: Transaction) -> None:
assert isinstance(tx, Transaction) assert isinstance(tx, Transaction)
@ -863,14 +901,19 @@ class JsonDB(Logger):
self.history = self.get_data_ref('addr_history') # address -> list of (txid, height) self.history = self.get_data_ref('addr_history') # address -> list of (txid, height)
self.verified_tx = self.get_data_ref('verified_tx3') # txid -> (height, timestamp, txpos, header_hash) self.verified_tx = self.get_data_ref('verified_tx3') # txid -> (height, timestamp, txpos, header_hash)
self.tx_fees = self.get_data_ref('tx_fees') # type: Dict[str, TxFeesValue] self.tx_fees = self.get_data_ref('tx_fees') # type: Dict[str, TxFeesValue]
# scripthash -> set of (outpoint, value)
self._prevouts_by_scripthash = self.get_data_ref('prevouts_by_scripthash') # type: Dict[str, Set[Tuple[str, int]]]
# convert raw hex transactions to Transaction objects # convert raw hex transactions to Transaction objects
for tx_hash, raw_tx in self.transactions.items(): for tx_hash, raw_tx in self.transactions.items():
self.transactions[tx_hash] = Transaction(raw_tx) self.transactions[tx_hash] = Transaction(raw_tx)
# convert list to set # convert txi, txo: list to set
for t in self.txi, self.txo: for t in self.txi, self.txo:
for d in t.values(): for d in t.values():
for addr, lst in d.items(): for addr, lst in d.items():
d[addr] = set([tuple(x) for x in lst]) d[addr] = set([tuple(x) for x in lst])
# convert prevouts_by_scripthash: list to set, list to tuple
for scripthash, lst in self._prevouts_by_scripthash.items():
self._prevouts_by_scripthash[scripthash] = {(prevout, value) for prevout, value in lst}
# remove unreferenced tx # remove unreferenced tx
for tx_hash in list(self.transactions.keys()): for tx_hash in list(self.transactions.keys()):
if not self.get_txi_addresses(tx_hash) and not self.get_txo_addresses(tx_hash): if not self.get_txi_addresses(tx_hash) and not self.get_txo_addresses(tx_hash):

77
electrum/wallet.py

@ -36,9 +36,10 @@ import errno
import traceback import traceback
import operator import operator
from functools import partial from functools import partial
from collections import defaultdict
from numbers import Number from numbers import Number
from decimal import Decimal from decimal import Decimal
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, NamedTuple, Sequence, Dict, Any from typing import TYPE_CHECKING, List, Optional, Tuple, Union, NamedTuple, Sequence, Dict, Any, Set
from .i18n import _ from .i18n import _
from .bip32 import BIP32Node from .bip32 import BIP32Node
@ -249,6 +250,7 @@ class Abstract_Wallet(AddressSynchronizer):
if invoice.get('type') == PR_TYPE_ONCHAIN: if invoice.get('type') == PR_TYPE_ONCHAIN:
outputs = [PartialTxOutput.from_legacy_tuple(*output) for output in invoice.get('outputs')] outputs = [PartialTxOutput.from_legacy_tuple(*output) for output in invoice.get('outputs')]
invoice['outputs'] = outputs invoice['outputs'] = outputs
self._prepare_onchain_invoice_paid_detection()
self.calc_unused_change_addresses() self.calc_unused_change_addresses()
# save wallet type the first time # save wallet type the first time
if self.storage.get('wallet_type') is None: if self.storage.get('wallet_type') is None:
@ -611,7 +613,10 @@ class Abstract_Wallet(AddressSynchronizer):
elif invoice_type == PR_TYPE_ONCHAIN: elif invoice_type == PR_TYPE_ONCHAIN:
key = bh2u(sha256(repr(invoice))[0:16]) key = bh2u(sha256(repr(invoice))[0:16])
invoice['id'] = key invoice['id'] = key
invoice['txid'] = None outputs = invoice['outputs'] # type: List[PartialTxOutput]
with self.transaction_lock:
for txout in outputs:
self._invoices_from_scriptpubkey_map[txout.scriptpubkey].add(key)
else: else:
raise Exception('Unsupported invoice type') raise Exception('Unsupported invoice type')
self.invoices[key] = invoice self.invoices[key] = invoice
@ -629,27 +634,73 @@ class Abstract_Wallet(AddressSynchronizer):
out.sort(key=operator.itemgetter('time')) out.sort(key=operator.itemgetter('time'))
return out return out
def set_paid(self, key, txid):
if key not in self.invoices:
return
invoice = self.invoices[key]
assert invoice.get('type') == PR_TYPE_ONCHAIN
invoice['txid'] = txid
self.storage.put('invoices', self.invoices)
def get_invoice(self, key): def get_invoice(self, key):
if key not in self.invoices: if key not in self.invoices:
return return
item = copy.copy(self.invoices[key]) item = copy.copy(self.invoices[key])
request_type = item.get('type') request_type = item.get('type')
if request_type == PR_TYPE_ONCHAIN: if request_type == PR_TYPE_ONCHAIN:
item['status'] = PR_PAID if item.get('txid') is not None else PR_UNPAID item['status'] = PR_PAID if self._is_onchain_invoice_paid(item)[0] else PR_UNPAID
elif self.lnworker and request_type == PR_TYPE_LN: elif self.lnworker and request_type == PR_TYPE_LN:
item['status'] = self.lnworker.get_payment_status(bfh(item['rhash'])) item['status'] = self.lnworker.get_payment_status(bfh(item['rhash']))
else: else:
return return
return item return item
def _get_relevant_invoice_keys_for_tx(self, tx: Transaction) -> Set[str]:
relevant_invoice_keys = set()
for txout in tx.outputs():
for invoice_key in self._invoices_from_scriptpubkey_map.get(txout.scriptpubkey, set()):
relevant_invoice_keys.add(invoice_key)
return relevant_invoice_keys
def _prepare_onchain_invoice_paid_detection(self):
# scriptpubkey -> list(invoice_keys)
self._invoices_from_scriptpubkey_map = defaultdict(set) # type: Dict[bytes, Set[str]]
for invoice_key, invoice in self.invoices.items():
if invoice.get('type') == PR_TYPE_ONCHAIN:
outputs = invoice['outputs'] # type: List[PartialTxOutput]
for txout in outputs:
self._invoices_from_scriptpubkey_map[txout.scriptpubkey].add(invoice_key)
def _is_onchain_invoice_paid(self, invoice) -> Tuple[bool, Sequence[str]]:
"""Returns whether on-chain invoice is satisfied, and list of relevant TXIDs."""
assert invoice.get('type') == PR_TYPE_ONCHAIN
invoice_amounts = defaultdict(int) # type: Dict[bytes, int] # scriptpubkey -> value_sats
for txo in invoice['outputs']: # type: PartialTxOutput
invoice_amounts[txo.scriptpubkey] += 1 if txo.value == '!' else txo.value
relevant_txs = []
with self.transaction_lock:
for invoice_scriptpubkey, invoice_amt in invoice_amounts.items():
scripthash = bitcoin.script_to_scripthash(invoice_scriptpubkey.hex())
prevouts_and_values = self.db.get_prevouts_by_scripthash(scripthash)
relevant_txs += [prevout.txid.hex() for prevout, v in prevouts_and_values]
total_received = sum([v for prevout, v in prevouts_and_values])
if total_received < invoice_amt:
return False, []
return True, relevant_txs
def _maybe_set_tx_label_based_on_invoices(self, tx: Transaction) -> bool:
tx_hash = tx.txid()
with self.transaction_lock:
labels = []
for invoice_key in self._get_relevant_invoice_keys_for_tx(tx):
invoice = self.invoices.get(invoice_key)
if invoice is None: continue
assert invoice.get('type') == PR_TYPE_ONCHAIN
if invoice['message']:
labels.append(invoice['message'])
if labels:
self.set_label(tx_hash, "; ".join(labels))
return bool(labels)
def add_transaction(self, tx, *, allow_unrelated=False):
tx_was_added = super().add_transaction(tx, allow_unrelated=allow_unrelated)
if tx_was_added:
self._maybe_set_tx_label_based_on_invoices(tx)
return tx_was_added
@profiler @profiler
def get_full_history(self, fx=None, *, onchain_domain=None, include_lightning=True): def get_full_history(self, fx=None, *, onchain_domain=None, include_lightning=True):
transactions = OrderedDictWithIndex() transactions = OrderedDictWithIndex()
@ -1868,10 +1919,6 @@ class Imported_Wallet(Simple_Wallet):
self.db.remove_addr_history(address) self.db.remove_addr_history(address)
for tx_hash in transactions_to_remove: for tx_hash in transactions_to_remove:
self.remove_transaction(tx_hash) self.remove_transaction(tx_hash)
self.db.remove_tx_fee(tx_hash)
self.db.remove_verified_tx(tx_hash)
self.unverified_tx.pop(tx_hash, None)
self.db.remove_transaction(tx_hash)
self.set_label(address, None) self.set_label(address, None)
self.remove_payment_request(address) self.remove_payment_request(address)
self.set_frozen_state_of_addresses([address], False) self.set_frozen_state_of_addresses([address], False)

Loading…
Cancel
Save