diff --git a/electrum/coinchooser.py b/electrum/coinchooser.py index e911a56df..c8604afb2 100644 --- a/electrum/coinchooser.py +++ b/electrum/coinchooser.py @@ -24,7 +24,7 @@ # SOFTWARE. from collections import defaultdict from math import floor, log10 -from typing import NamedTuple, List +from typing import NamedTuple, List, Callable from .bitcoin import sha256, COIN, TYPE_ADDRESS, is_address from .transaction import Transaction, TxOutput @@ -79,6 +79,12 @@ class Bucket(NamedTuple): witness: bool # whether any coin uses segwit +class ScoredCandidate(NamedTuple): + penalty: float + tx: Transaction + buckets: List[Bucket] + + def strip_unneeded(bkts, sufficient_funds): '''Remove buckets that are unnecessary in achieving the spend amount''' if sufficient_funds([], bucket_value_sum=0): @@ -121,12 +127,10 @@ class CoinChooserBase(Logger): return list(map(make_Bucket, buckets.keys(), buckets.values())) - def penalty_func(self, tx, *, fee_for_buckets): - def penalty(candidate): - return 0 - return penalty + def penalty_func(self, base_tx, *, tx_from_buckets) -> Callable[[List[Bucket]], ScoredCandidate]: + raise NotImplementedError - def change_amounts(self, tx, count, fee_estimator, dust_threshold): + def _change_amounts(self, tx, count, fee_estimator): # Break change up if bigger than max_change output_amounts = [o.value for o in tx.outputs()] # Don't split change of less than 0.02 BTC @@ -180,22 +184,60 @@ class CoinChooserBase(Logger): return amounts - def change_outputs(self, tx, change_addrs, fee_estimator, dust_threshold): - amounts = self.change_amounts(tx, len(change_addrs), fee_estimator, - dust_threshold) + def _change_outputs(self, tx, change_addrs, fee_estimator, dust_threshold): + amounts = self._change_amounts(tx, len(change_addrs), fee_estimator) assert min(amounts) >= 0 assert len(change_addrs) >= len(amounts) # If change is above dust threshold after accounting for the # size of the change output, add it to the transaction. - dust = sum(amount for amount in amounts if amount < dust_threshold) amounts = [amount for amount in amounts if amount >= dust_threshold] change = [TxOutput(TYPE_ADDRESS, addr, amount) for addr, amount in zip(change_addrs, amounts)] - self.logger.info(f'change: {change}') - if dust: - self.logger.info(f'not keeping dust {dust}') return change + def _construct_tx_from_selected_buckets(self, *, buckets, base_tx, change_addrs, + fee_estimator_w, dust_threshold, base_weight): + # make a copy of base_tx so it won't get mutated + tx = Transaction.from_io(base_tx.inputs()[:], base_tx.outputs()[:]) + + tx.add_inputs([coin for b in buckets for coin in b.coins]) + tx_weight = self._get_tx_weight(buckets, base_weight=base_weight) + + # change is sent back to sending address unless specified + if not change_addrs: + change_addrs = [tx.inputs()[0]['address']] + # note: this is not necessarily the final "first input address" + # because the inputs had not been sorted at this point + assert is_address(change_addrs[0]) + + # This takes a count of change outputs and returns a tx fee + output_weight = 4 * Transaction.estimated_output_size(change_addrs[0]) + fee = lambda count: fee_estimator_w(tx_weight + count * output_weight) + change = self._change_outputs(tx, change_addrs, fee, dust_threshold) + tx.add_outputs(change) + + return tx, change + + def _get_tx_weight(self, buckets, *, base_weight) -> int: + """Given a collection of buckets, return the total weight of the + resulting transaction. + base_weight is the weight of the tx that includes the fixed (non-change) + outputs and potentially some fixed inputs. Note that the change outputs + at this point are not yet known so they are NOT accounted for. + """ + total_weight = base_weight + sum(bucket.weight for bucket in buckets) + is_segwit_tx = any(bucket.witness for bucket in buckets) + if is_segwit_tx: + total_weight += 2 # marker and flag + # non-segwit inputs were previously assumed to have + # a witness of '' instead of '00' (hex) + # note that mixed legacy/segwit buckets are already ok + num_legacy_inputs = sum((not bucket.witness) * len(bucket.coins) + for bucket in buckets) + total_weight += num_legacy_inputs + + return total_weight + def make_tx(self, coins, inputs, outputs, change_addrs, fee_estimator, dust_threshold): """Select unspent coins to spend to pay outputs. If the change is @@ -211,34 +253,20 @@ class CoinChooserBase(Logger): self.p = PRNG(''.join(sorted(utxos))) # Copy the outputs so when adding change we don't modify "outputs" - tx = Transaction.from_io(inputs[:], outputs[:]) - input_value = tx.input_value() + base_tx = Transaction.from_io(inputs[:], outputs[:]) + input_value = base_tx.input_value() # Weight of the transaction with no inputs and no change # Note: this will use legacy tx serialization as the need for "segwit" # would be detected from inputs. The only side effect should be that the # marker and flag are excluded, which is compensated in get_tx_weight() # FIXME calculation will be off by this (2 wu) in case of RBF batching - base_weight = tx.estimated_weight() - spent_amount = tx.output_value() + base_weight = base_tx.estimated_weight() + spent_amount = base_tx.output_value() def fee_estimator_w(weight): return fee_estimator(Transaction.virtual_size_from_weight(weight)) - def get_tx_weight(buckets): - total_weight = base_weight + sum(bucket.weight for bucket in buckets) - is_segwit_tx = any(bucket.witness for bucket in buckets) - if is_segwit_tx: - total_weight += 2 # marker and flag - # non-segwit inputs were previously assumed to have - # a witness of '' instead of '00' (hex) - # note that mixed legacy/segwit buckets are already ok - num_legacy_inputs = sum((not bucket.witness) * len(bucket.coins) - for bucket in buckets) - total_weight += num_legacy_inputs - - return total_weight - def sufficient_funds(buckets, *, bucket_value_sum): '''Given a list of buckets, return True if it has enough value to pay for the transaction''' @@ -248,45 +276,30 @@ class CoinChooserBase(Logger): return False # note re performance: so far this was constant time # what follows is linear in len(buckets) - total_weight = get_tx_weight(buckets) + total_weight = self._get_tx_weight(buckets, base_weight=base_weight) return total_input >= spent_amount + fee_estimator_w(total_weight) - def fee_for_buckets(buckets) -> int: - """Given a list of buckets, return the total fee paid by the - transaction, in satoshis. - Note that the change output(s) are not yet known here, - so fees for those are excluded and hence this is a lower bound. - """ - total_weight = get_tx_weight(buckets) - return fee_estimator_w(total_weight) + def tx_from_buckets(buckets): + return self._construct_tx_from_selected_buckets(buckets=buckets, + base_tx=base_tx, + change_addrs=change_addrs, + fee_estimator_w=fee_estimator_w, + dust_threshold=dust_threshold, + base_weight=base_weight) # Collect the coins into buckets, choose a subset of the buckets - buckets = self.bucketize_coins(coins) - buckets = self.choose_buckets(buckets, sufficient_funds, - self.penalty_func(tx, fee_for_buckets=fee_for_buckets)) - - tx.add_inputs([coin for b in buckets for coin in b.coins]) - tx_weight = get_tx_weight(buckets) - - # change is sent back to sending address unless specified - if not change_addrs: - change_addrs = [tx.inputs()[0]['address']] - # note: this is not necessarily the final "first input address" - # because the inputs had not been sorted at this point - assert is_address(change_addrs[0]) - - # This takes a count of change outputs and returns a tx fee - output_weight = 4 * Transaction.estimated_output_size(change_addrs[0]) - fee = lambda count: fee_estimator_w(tx_weight + count * output_weight) - change = self.change_outputs(tx, change_addrs, fee, dust_threshold) - tx.add_outputs(change) + all_buckets = self.bucketize_coins(coins) + scored_candidate = self.choose_buckets(all_buckets, sufficient_funds, + self.penalty_func(base_tx, tx_from_buckets=tx_from_buckets)) + tx = scored_candidate.tx self.logger.info(f"using {len(tx.inputs())} inputs") - self.logger.info(f"using buckets: {[bucket.desc for bucket in buckets]}") + self.logger.info(f"using buckets: {[bucket.desc for bucket in scored_candidate.buckets]}") return tx - def choose_buckets(self, buckets, sufficient_funds, penalty_func): + def choose_buckets(self, buckets, sufficient_funds, + penalty_func: Callable[[List[Bucket]], ScoredCandidate]) -> ScoredCandidate: raise NotImplemented('To be subclassed') @@ -368,12 +381,14 @@ class CoinChooserRandom(CoinChooserBase): def choose_buckets(self, buckets, sufficient_funds, penalty_func): candidates = self.bucket_candidates_prefer_confirmed(buckets, sufficient_funds) - penalties = [penalty_func(cand) for cand in candidates] - winner = candidates[penalties.index(min(penalties))] - self.logger.info(f"Bucket sets: {len(buckets)}") - self.logger.info(f"Winning penalty: {min(penalties)}") + scored_candidates = [penalty_func(cand) for cand in candidates] + winner = min(scored_candidates, key=lambda x: x.penalty) + self.logger.info(f"Total number of buckets: {len(buckets)}") + self.logger.info(f"Num candidates considered: {len(candidates)}. " + f"Winning penalty: {winner.penalty}") return winner + class CoinChooserPrivacy(CoinChooserRandom): """Attempts to better preserve user privacy. First, if any coin is spent from a user address, all coins are. @@ -388,18 +403,15 @@ class CoinChooserPrivacy(CoinChooserRandom): def keys(self, coins): return [coin['address'] for coin in coins] - def penalty_func(self, tx, *, fee_for_buckets): - min_change = min(o.value for o in tx.outputs()) * 0.75 - max_change = max(o.value for o in tx.outputs()) * 1.33 - spent_amount = sum(o.value for o in tx.outputs()) + def penalty_func(self, base_tx, *, tx_from_buckets): + min_change = min(o.value for o in base_tx.outputs()) * 0.75 + max_change = max(o.value for o in base_tx.outputs()) * 1.33 - def penalty(buckets): + def penalty(buckets) -> ScoredCandidate: + # Penalize using many buckets (~inputs) badness = len(buckets) - 1 - total_input = sum(bucket.value for bucket in buckets) - # FIXME fee_for_buckets does not include fees needed to cover the change output(s) - # so fee here is a lower bound - fee = fee_for_buckets(buckets) - change = float(total_input - spent_amount - fee) + tx, change_outputs = tx_from_buckets(buckets) + change = sum(o.value for o in change_outputs) # Penalize change not roughly in output range if change < min_change: badness += (min_change - change) / (min_change + 10000) @@ -407,7 +419,7 @@ class CoinChooserPrivacy(CoinChooserRandom): badness += (change - max_change) / (max_change + 10000) # Penalize large change; 5 BTC excess ~= using 1 more input badness += change / (COIN * 5) - return badness + return ScoredCandidate(badness, tx, buckets) return penalty