diff --git a/gui/qt/transaction_dialog.py b/gui/qt/transaction_dialog.py index a7f754a82..5097af28a 100644 --- a/gui/qt/transaction_dialog.py +++ b/gui/qt/transaction_dialog.py @@ -179,10 +179,13 @@ class TxDialog(QDialog, MessageBoxMixin): self.main_window.sign_tx(self.tx, sign_done) def save(self): - self.wallet.add_transaction(self.tx.txid(), self.tx) + if not self.wallet.add_transaction(self.tx.txid(), self.tx): + self.show_error(_("Transaction could not be saved. It conflicts with current history.")) + return self.wallet.save_transactions(write=True) - self.main_window.history_list.update() + # need to update at least: history_list, utxo_list, address_list + self.main_window.need_update.set() self.save_button.setDisabled(True) self.show_message(_("Transaction saved successfully")) diff --git a/lib/commands.py b/lib/commands.py index 49760268e..c0bfb48ed 100644 --- a/lib/commands.py +++ b/lib/commands.py @@ -627,7 +627,8 @@ class Commands: def addtransaction(self, tx): """ Add a transaction to the wallet history """ tx = Transaction(tx) - self.wallet.add_transaction(tx.txid(), tx) + if not self.wallet.add_transaction(tx.txid(), tx): + return False self.wallet.save_transactions() return tx.txid() diff --git a/lib/transaction.py b/lib/transaction.py index 5688dd911..b23cf9cf2 100644 --- a/lib/transaction.py +++ b/lib/transaction.py @@ -797,6 +797,14 @@ class Transaction: def serialize_outpoint(self, txin): return bh2u(bfh(txin['prevout_hash'])[::-1]) + int_to_hex(txin['prevout_n'], 4) + @classmethod + def get_outpoint_from_txin(cls, txin): + if txin['type'] == 'coinbase': + return None + prevout_hash = txin['prevout_hash'] + prevout_n = txin['prevout_n'] + return prevout_hash + ':%d' % prevout_n + @classmethod def serialize_input(self, txin, script): # Prev hash and index diff --git a/lib/wallet.py b/lib/wallet.py index 72fd982f7..6a682debf 100644 --- a/lib/wallet.py +++ b/lib/wallet.py @@ -188,7 +188,7 @@ class Abstract_Wallet(PrintError): self.load_keystore() self.load_addresses() self.load_transactions() - self.build_reverse_history() + self.build_spent_outpoints() # load requests self.receive_requests = self.storage.get('payment_requests', {}) @@ -204,8 +204,10 @@ class Abstract_Wallet(PrintError): # interface.is_up_to_date() returns true when all requests have been answered and processed # wallet.up_to_date is true when the wallet is synchronized (stronger requirement) self.up_to_date = False + + # locks: if you need to take multiple ones, acquire them in the order they are defined here! self.lock = threading.Lock() - self.transaction_lock = threading.Lock() + self.transaction_lock = threading.RLock() self.check_history() @@ -238,7 +240,8 @@ class Abstract_Wallet(PrintError): for tx_hash, raw in tx_list.items(): tx = Transaction(raw) self.transactions[tx_hash] = tx - if self.txi.get(tx_hash) is None and self.txo.get(tx_hash) is None and (tx_hash not in self.pruned_txo.values()): + if self.txi.get(tx_hash) is None and self.txo.get(tx_hash) is None \ + and (tx_hash not in self.pruned_txo.values()): self.print_error("removing unreferenced tx", tx_hash) self.transactions.pop(tx_hash) @@ -258,24 +261,25 @@ class Abstract_Wallet(PrintError): self.storage.write() def clear_history(self): - with self.transaction_lock: - self.txi = {} - self.txo = {} - self.tx_fees = {} - self.pruned_txo = {} - self.save_transactions() with self.lock: - self.history = {} - self.tx_addr_hist = {} + with self.transaction_lock: + self.txi = {} + self.txo = {} + self.tx_fees = {} + self.pruned_txo = {} + self.spent_outpoints = {} + self.history = {} + self.save_transactions() @profiler - def build_reverse_history(self): - self.tx_addr_hist = {} - for addr, hist in self.history.items(): - for tx_hash, h in hist: - s = self.tx_addr_hist.get(tx_hash, set()) - s.add(addr) - self.tx_addr_hist[tx_hash] = s + def build_spent_outpoints(self): + self.spent_outpoints = {} + for txid, tx in self.transactions.items(): + for txi in tx.inputs(): + ser = Transaction.get_outpoint_from_txin(txi) + if ser is None: + continue + self.spent_outpoints[ser] = txid @profiler def check_history(self): @@ -415,7 +419,7 @@ class Abstract_Wallet(PrintError): return self.network.get_local_height() if self.network else self.storage.get('stored_height', 0) def get_tx_height(self, tx_hash): - """ return the height and timestamp of a transaction. """ + """ Given a transaction, returns (height, conf, timestamp) """ with self.lock: if tx_hash in self.verified_tx: height, timestamp, pos = self.verified_tx[tx_hash] @@ -682,10 +686,69 @@ class Abstract_Wallet(PrintError): self.print_error("found pay-to-pubkey address:", addr) return addr + def get_conflicting_transactions(self, tx): + """Returns a set of transaction hashes from the wallet history that are + directly conflicting with tx, i.e. they have common outpoints being + spent with tx. If the tx is already in wallet history, that will not be + reported as a conflict. + """ + conflicting_txns = set() + with self.transaction_lock: + for txi in tx.inputs(): + ser = Transaction.get_outpoint_from_txin(txi) + if ser is None: + continue + spending_tx_hash = self.spent_outpoints.get(ser, None) + if spending_tx_hash is None: + continue + # this outpoint (ser) has already been spent, by spending_tx + if spending_tx_hash not in self.transactions: + # can't find this txn: delete and ignore it + self.spent_outpoints.pop(ser) + continue + conflicting_txns |= {spending_tx_hash} + txid = tx.txid() + if txid in conflicting_txns: + # this tx is already in history, so it conflicts with itself + if len(conflicting_txns) > 1: + raise Exception('Found conflicting transactions already in wallet history.') + conflicting_txns -= {txid} + return conflicting_txns + def add_transaction(self, tx_hash, tx): is_coinbase = tx.inputs()[0]['type'] == 'coinbase' related = False with self.transaction_lock: + # Find all conflicting transactions. + # In case of a conflict, + # 1. confirmed > mempool > local + # 2. this new txn has priority over existing ones + # When this method exits, there must NOT be any conflict, so + # either keep this txn and remove all conflicting (along with dependencies) + # or drop this txn + conflicting_txns = self.get_conflicting_transactions(tx) + if conflicting_txns: + tx_height = self.get_tx_height(tx_hash)[0] + existing_mempool_txn = any( + self.get_tx_height(tx_hash2)[0] in (TX_HEIGHT_UNCONFIRMED, TX_HEIGHT_UNCONF_PARENT) + for tx_hash2 in conflicting_txns) + existing_confirmed_txn = any( + self.get_tx_height(tx_hash2)[0] > 0 + for tx_hash2 in conflicting_txns) + if existing_confirmed_txn and tx_height <= 0: + # this is a non-confirmed tx that conflicts with confirmed txns; drop. + return False + if existing_mempool_txn and tx_height == TX_HEIGHT_LOCAL: + # this is a local tx that conflicts with non-local txns; drop. + return False + # keep this txn and remove all conflicting + to_remove = set() + to_remove |= conflicting_txns + for conflicting_tx_hash in conflicting_txns: + to_remove |= self.get_depending_transactions(conflicting_tx_hash) + for tx_hash2 in to_remove: + self.remove_transaction(tx_hash2) + # add inputs self.txi[tx_hash] = d = {} for txi in tx.inputs(): @@ -694,6 +757,7 @@ class Abstract_Wallet(PrintError): prevout_hash = txi['prevout_hash'] prevout_n = txi['prevout_n'] ser = prevout_hash + ':%d'%prevout_n + self.spent_outpoints[ser] = tx_hash if addr == "(pubkey)": addr = self.find_pay_to_pubkey_address(prevout_hash, prevout_n) # find value from prev output @@ -739,14 +803,27 @@ class Abstract_Wallet(PrintError): # save self.transactions[tx_hash] = tx + return True def remove_transaction(self, tx_hash): + def undo_spend(outpoint_to_txid_map): + if tx: + # if we have the tx, this should often be faster + for txi in tx.inputs(): + ser = Transaction.get_outpoint_from_txin(txi) + outpoint_to_txid_map.pop(ser, None) + else: + for ser, hh in list(outpoint_to_txid_map.items()): + if hh == tx_hash: + outpoint_to_txid_map.pop(ser) + with self.transaction_lock: self.print_error("removing tx from history", tx_hash) #tx = self.transactions.pop(tx_hash) - for ser, hh in list(self.pruned_txo.items()): - if hh == tx_hash: - self.pruned_txo.pop(ser) + tx = self.transactions.get(tx_hash, None) + undo_spend(self.pruned_txo) + undo_spend(self.spent_outpoints) + # add tx to pruned_txo, and undo the txi addition for next_tx, dd in self.txi.items(): for addr, l in list(dd.items()): @@ -768,8 +845,8 @@ class Abstract_Wallet(PrintError): self.print_error("tx was not in history", tx_hash) def receive_tx_callback(self, tx_hash, tx, tx_height): - self.add_transaction(tx_hash, tx) self.add_unverified_tx(tx_hash, tx_height) + self.add_transaction(tx_hash, tx) def receive_history_callback(self, addr, hist, tx_fees): with self.lock: @@ -785,10 +862,6 @@ class Abstract_Wallet(PrintError): for tx_hash, tx_height in hist: # add it in case it was previously unconfirmed self.add_unverified_tx(tx_hash, tx_height) - # add reference in tx_addr_hist - s = self.tx_addr_hist.get(tx_hash, set()) - s.add(addr) - self.tx_addr_hist[tx_hash] = s # if addr is new, we have to recompute txi and txo tx = self.transactions.get(tx_hash) if tx is not None and self.txi.get(tx_hash, {}).get(addr) is None and self.txo.get(tx_hash, {}).get(addr) is None: