From 319e0ba8dd5164981bea2abecd590fedb54ec0d3 Mon Sep 17 00:00:00 2001 From: Mariusz Klochowicz Date: Wed, 13 Oct 2021 13:58:58 +1030 Subject: [PATCH] Prevent rounding errors when calculating fees Converting from f64 to u64 can make us lose precision. Create a new Fee type that encapsulates calculating fees, includding splitting it between two parties. --- cfd_protocol/src/protocol/transactions.rs | 120 ++++++++++++++++++++-- 1 file changed, 109 insertions(+), 11 deletions(-) diff --git a/cfd_protocol/src/protocol/transactions.rs b/cfd_protocol/src/protocol/transactions.rs index 7dd9c98..3cf065d 100644 --- a/cfd_protocol/src/protocol/transactions.rs +++ b/cfd_protocol/src/protocol/transactions.rs @@ -87,7 +87,7 @@ pub(crate) struct CommitTransaction { amount: Amount, sighash: SigHash, lock_descriptor: Descriptor, - fee: u64, + fee: Fee, } impl CommitTransaction { @@ -131,9 +131,9 @@ impl CommitTransaction { input: vec![lock_input], output: vec![output], }; - let fee = (Self::SIGNED_VBYTES * SATS_PER_VBYTE as f64) as u64; + let fee = Fee::new(Self::SIGNED_VBYTES); - let commit_tx_amount = lock_amount - fee as u64; + let commit_tx_amount = lock_amount - fee.as_u64(); inner.output[0].value = commit_tx_amount; let sighash = SigHashCache::new(&inner).signature_hash( @@ -185,7 +185,7 @@ impl CommitTransaction { } fn fee(&self) -> u64 { - self.fee + self.fee.as_u64() } } @@ -223,11 +223,10 @@ impl ContractExecutionTransaction { ..Default::default() }; - let mut fee = Self::SIGNED_VBYTES * SATS_PER_VBYTE; - fee += commit_tx.fee() as f64; + let fee = Fee::new(Self::SIGNED_VBYTES).add(commit_tx.fee() as f64); let output = payout .with_updated_fee( - Amount::from_sat(fee as u64), + Amount::from_sat(fee.as_u64()), maker_address.script_pubkey().dust_value(), taker_address.script_pubkey().dust_value(), )? @@ -372,15 +371,14 @@ pub fn close_transaction( // TODO: The fee could take into account the network state in this // case, since this transaction is to be broadcast immediately // after building and signing it - let fee = SIGNED_VBYTES * SATS_PER_VBYTE as f64; - let fee_per_party = (fee / 2.0) as u64; + let (maker_fee, taker_fee) = Fee::new(SIGNED_VBYTES).split(); let maker_output = TxOut { - value: maker_amount.as_sat() - fee_per_party, + value: maker_amount.as_sat() - maker_fee, script_pubkey: maker_address.script_pubkey(), }; let taker_output = TxOut { - value: taker_amount.as_sat() - fee_per_party, + value: taker_amount.as_sat() - taker_fee, script_pubkey: taker_address.script_pubkey(), }; @@ -496,3 +494,103 @@ pub fn punish_transaction( Ok(punish_tx) } + +#[derive(Clone, Debug)] +struct Fee { + fee: f64, +} + +impl Fee { + fn new(signed_vbytes: f64) -> Self { + let fee = signed_vbytes * SATS_PER_VBYTE; + Self { fee } + } + + #[must_use] + fn add(self, number: f64) -> Fee { + Fee { + fee: self.fee + number, + } + } + + fn as_u64(&self) -> u64 { + // Ceil to prevent going lower than the min relay fee + self.fee.ceil() as u64 + } + + fn split(&self) -> (u64, u64) { + let half = self.as_u64() / 2; + (half as u64, self.as_u64() - half) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + + proptest! { + #[test] + fn test_fee_always_above_min_relay_fee(signed_vbytes in 1.0f64..100_000_000.0f64) { + let fee = Fee::new(signed_vbytes); + let (maker_fee, taker_fee) = fee.split(); + + prop_assert!(signed_vbytes <= fee.as_u64() as f64); + prop_assert!(signed_vbytes <= (maker_fee + taker_fee) as f64); + } + } + + // A bunch of tests illustrating how fees are split + + #[test] + fn test_splitting_fee_1_0() { + const SIGNED_VBYTES_TEST: f64 = 1.0; + + let fee = Fee::new(SIGNED_VBYTES_TEST); + let (maker_fee, taker_fee) = fee.split(); + + assert_eq!(fee.as_u64(), 1); + assert_eq!(maker_fee, 0); + assert_eq!(taker_fee, 1); + assert!((maker_fee + taker_fee) as f64 >= SIGNED_VBYTES_TEST); + } + + #[test] + fn test_splitting_fee_2_0() { + const SIGNED_VBYTES_TEST: f64 = 2.0; + + let fee = Fee::new(SIGNED_VBYTES_TEST); + let (maker_fee, taker_fee) = fee.split(); + + assert_eq!(fee.as_u64(), 2); + assert_eq!(maker_fee, 1); + assert_eq!(taker_fee, 1); + assert!((maker_fee + taker_fee) as f64 >= SIGNED_VBYTES_TEST); + } + + #[test] + fn test_splitting_fee_2_1() { + const SIGNED_VBYTES_TEST: f64 = 2.1; + + let fee = Fee::new(SIGNED_VBYTES_TEST); + let (maker_fee, taker_fee) = fee.split(); + + assert_eq!(fee.as_u64(), 3); + assert_eq!(maker_fee, 1); + assert_eq!(taker_fee, 2); + assert!((maker_fee + taker_fee) as f64 >= SIGNED_VBYTES_TEST); + } + + #[test] + fn test_splitting_fee_2_6() { + const SIGNED_VBYTES_TEST: f64 = 2.6; + + let fee = Fee::new(SIGNED_VBYTES_TEST); + let (maker_fee, taker_fee) = fee.split(); + + assert_eq!(fee.as_u64(), 3); + assert_eq!(maker_fee, 1); + assert_eq!(taker_fee, 2); + assert!((maker_fee + taker_fee) as f64 >= SIGNED_VBYTES_TEST); + } +}