Browse Source

coinchooser: refactor so that penalty_func has access to change outputs

SomberNight 6 years ago
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 164


@ -24,7 +24,7 @@
from collections import defaultdict
from math import floor, log10
from typing import NamedTuple, List
from typing import NamedTuple, List, Callable
from .bitcoin import sha256, COIN, TYPE_ADDRESS, is_address
from .transaction import Transaction, TxOutput
@ -79,6 +79,12 @@ class Bucket(NamedTuple):
witness: bool # whether any coin uses segwit
class ScoredCandidate(NamedTuple):
penalty: float
tx: Transaction
buckets: List[Bucket]
def strip_unneeded(bkts, sufficient_funds):
'''Remove buckets that are unnecessary in achieving the spend amount'''
if sufficient_funds([], bucket_value_sum=0):
@ -121,12 +127,10 @@ class CoinChooserBase(Logger):
return list(map(make_Bucket, buckets.keys(), buckets.values()))
def penalty_func(self, tx, *, fee_for_buckets):
def penalty(candidate):
return 0
return penalty
def penalty_func(self, base_tx, *, tx_from_buckets) -> Callable[[List[Bucket]], ScoredCandidate]:
raise NotImplementedError
def change_amounts(self, tx, count, fee_estimator, dust_threshold):
def _change_amounts(self, tx, count, fee_estimator):
# Break change up if bigger than max_change
output_amounts = [o.value for o in tx.outputs()]
# Don't split change of less than 0.02 BTC
@ -180,22 +184,60 @@ class CoinChooserBase(Logger):
return amounts
def change_outputs(self, tx, change_addrs, fee_estimator, dust_threshold):
amounts = self.change_amounts(tx, len(change_addrs), fee_estimator,
def _change_outputs(self, tx, change_addrs, fee_estimator, dust_threshold):
amounts = self._change_amounts(tx, len(change_addrs), fee_estimator)
assert min(amounts) >= 0
assert len(change_addrs) >= len(amounts)
# If change is above dust threshold after accounting for the
# size of the change output, add it to the transaction.
dust = sum(amount for amount in amounts if amount < dust_threshold)
amounts = [amount for amount in amounts if amount >= dust_threshold]
change = [TxOutput(TYPE_ADDRESS, addr, amount)
for addr, amount in zip(change_addrs, amounts)]'change: {change}')
if dust:'not keeping dust {dust}')
return change
def _construct_tx_from_selected_buckets(self, *, buckets, base_tx, change_addrs,
fee_estimator_w, dust_threshold, base_weight):
# make a copy of base_tx so it won't get mutated
tx = Transaction.from_io(base_tx.inputs()[:], base_tx.outputs()[:])
tx.add_inputs([coin for b in buckets for coin in b.coins])
tx_weight = self._get_tx_weight(buckets, base_weight=base_weight)
# change is sent back to sending address unless specified
if not change_addrs:
change_addrs = [tx.inputs()[0]['address']]
# note: this is not necessarily the final "first input address"
# because the inputs had not been sorted at this point
assert is_address(change_addrs[0])
# This takes a count of change outputs and returns a tx fee
output_weight = 4 * Transaction.estimated_output_size(change_addrs[0])
fee = lambda count: fee_estimator_w(tx_weight + count * output_weight)
change = self._change_outputs(tx, change_addrs, fee, dust_threshold)
return tx, change
def _get_tx_weight(self, buckets, *, base_weight) -> int:
"""Given a collection of buckets, return the total weight of the
resulting transaction.
base_weight is the weight of the tx that includes the fixed (non-change)
outputs and potentially some fixed inputs. Note that the change outputs
at this point are not yet known so they are NOT accounted for.
total_weight = base_weight + sum(bucket.weight for bucket in buckets)
is_segwit_tx = any(bucket.witness for bucket in buckets)
if is_segwit_tx:
total_weight += 2 # marker and flag
# non-segwit inputs were previously assumed to have
# a witness of '' instead of '00' (hex)
# note that mixed legacy/segwit buckets are already ok
num_legacy_inputs = sum((not bucket.witness) * len(bucket.coins)
for bucket in buckets)
total_weight += num_legacy_inputs
return total_weight
def make_tx(self, coins, inputs, outputs, change_addrs, fee_estimator,
"""Select unspent coins to spend to pay outputs. If the change is
@ -211,34 +253,20 @@ class CoinChooserBase(Logger):
self.p = PRNG(''.join(sorted(utxos)))
# Copy the outputs so when adding change we don't modify "outputs"
tx = Transaction.from_io(inputs[:], outputs[:])
input_value = tx.input_value()
base_tx = Transaction.from_io(inputs[:], outputs[:])
input_value = base_tx.input_value()
# Weight of the transaction with no inputs and no change
# Note: this will use legacy tx serialization as the need for "segwit"
# would be detected from inputs. The only side effect should be that the
# marker and flag are excluded, which is compensated in get_tx_weight()
# FIXME calculation will be off by this (2 wu) in case of RBF batching
base_weight = tx.estimated_weight()
spent_amount = tx.output_value()
base_weight = base_tx.estimated_weight()
spent_amount = base_tx.output_value()
def fee_estimator_w(weight):
return fee_estimator(Transaction.virtual_size_from_weight(weight))
def get_tx_weight(buckets):
total_weight = base_weight + sum(bucket.weight for bucket in buckets)
is_segwit_tx = any(bucket.witness for bucket in buckets)
if is_segwit_tx:
total_weight += 2 # marker and flag
# non-segwit inputs were previously assumed to have
# a witness of '' instead of '00' (hex)
# note that mixed legacy/segwit buckets are already ok
num_legacy_inputs = sum((not bucket.witness) * len(bucket.coins)
for bucket in buckets)
total_weight += num_legacy_inputs
return total_weight
def sufficient_funds(buckets, *, bucket_value_sum):
'''Given a list of buckets, return True if it has enough
value to pay for the transaction'''
@ -248,45 +276,30 @@ class CoinChooserBase(Logger):
return False
# note re performance: so far this was constant time
# what follows is linear in len(buckets)
total_weight = get_tx_weight(buckets)
total_weight = self._get_tx_weight(buckets, base_weight=base_weight)
return total_input >= spent_amount + fee_estimator_w(total_weight)
def fee_for_buckets(buckets) -> int:
"""Given a list of buckets, return the total fee paid by the
transaction, in satoshis.
Note that the change output(s) are not yet known here,
so fees for those are excluded and hence this is a lower bound.
total_weight = get_tx_weight(buckets)
return fee_estimator_w(total_weight)
def tx_from_buckets(buckets):
return self._construct_tx_from_selected_buckets(buckets=buckets,
# Collect the coins into buckets, choose a subset of the buckets
buckets = self.bucketize_coins(coins)
buckets = self.choose_buckets(buckets, sufficient_funds,
self.penalty_func(tx, fee_for_buckets=fee_for_buckets))
tx.add_inputs([coin for b in buckets for coin in b.coins])
tx_weight = get_tx_weight(buckets)
# change is sent back to sending address unless specified
if not change_addrs:
change_addrs = [tx.inputs()[0]['address']]
# note: this is not necessarily the final "first input address"
# because the inputs had not been sorted at this point
assert is_address(change_addrs[0])
# This takes a count of change outputs and returns a tx fee
output_weight = 4 * Transaction.estimated_output_size(change_addrs[0])
fee = lambda count: fee_estimator_w(tx_weight + count * output_weight)
change = self.change_outputs(tx, change_addrs, fee, dust_threshold)
all_buckets = self.bucketize_coins(coins)
scored_candidate = self.choose_buckets(all_buckets, sufficient_funds,
self.penalty_func(base_tx, tx_from_buckets=tx_from_buckets))
tx = scored_candidate.tx"using {len(tx.inputs())} inputs")"using buckets: {[bucket.desc for bucket in buckets]}")"using buckets: {[bucket.desc for bucket in scored_candidate.buckets]}")
return tx
def choose_buckets(self, buckets, sufficient_funds, penalty_func):
def choose_buckets(self, buckets, sufficient_funds,
penalty_func: Callable[[List[Bucket]], ScoredCandidate]) -> ScoredCandidate:
raise NotImplemented('To be subclassed')
@ -368,12 +381,14 @@ class CoinChooserRandom(CoinChooserBase):
def choose_buckets(self, buckets, sufficient_funds, penalty_func):
candidates = self.bucket_candidates_prefer_confirmed(buckets, sufficient_funds)
penalties = [penalty_func(cand) for cand in candidates]
winner = candidates[penalties.index(min(penalties))]"Bucket sets: {len(buckets)}")"Winning penalty: {min(penalties)}")
scored_candidates = [penalty_func(cand) for cand in candidates]
winner = min(scored_candidates, key=lambda x: x.penalty)"Total number of buckets: {len(buckets)}")"Num candidates considered: {len(candidates)}. "
f"Winning penalty: {winner.penalty}")
return winner
class CoinChooserPrivacy(CoinChooserRandom):
"""Attempts to better preserve user privacy.
First, if any coin is spent from a user address, all coins are.
@ -388,18 +403,15 @@ class CoinChooserPrivacy(CoinChooserRandom):
def keys(self, coins):
return [coin['address'] for coin in coins]
def penalty_func(self, tx, *, fee_for_buckets):
min_change = min(o.value for o in tx.outputs()) * 0.75
max_change = max(o.value for o in tx.outputs()) * 1.33
spent_amount = sum(o.value for o in tx.outputs())
def penalty_func(self, base_tx, *, tx_from_buckets):
min_change = min(o.value for o in base_tx.outputs()) * 0.75
max_change = max(o.value for o in base_tx.outputs()) * 1.33
def penalty(buckets):
def penalty(buckets) -> ScoredCandidate:
# Penalize using many buckets (~inputs)
badness = len(buckets) - 1
total_input = sum(bucket.value for bucket in buckets)
# FIXME fee_for_buckets does not include fees needed to cover the change output(s)
# so fee here is a lower bound
fee = fee_for_buckets(buckets)
change = float(total_input - spent_amount - fee)
tx, change_outputs = tx_from_buckets(buckets)
change = sum(o.value for o in change_outputs)
# Penalize change not roughly in output range
if change < min_change:
badness += (min_change - change) / (min_change + 10000)
@ -407,7 +419,7 @@ class CoinChooserPrivacy(CoinChooserRandom):
badness += (change - max_change) / (max_change + 10000)
# Penalize large change; 5 BTC excess ~= using 1 more input
badness += change / (COIN * 5)
return badness
return ScoredCandidate(badness, tx, buckets)
return penalty
