diff --git a/Cargo.lock b/Cargo.lock index 60b6907..4145bcd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -541,6 +541,7 @@ dependencies = [ "serde", "serde_json", "serde_plain", + "serde_test", "serde_with", "sha2", "sqlx", @@ -2245,7 +2246,6 @@ dependencies = [ "arrayvec", "num-traits", "serde", - "serde_json", ] [[package]] @@ -2478,6 +2478,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_test" +version = "1.0.130" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d82178225dbdeae2d5d190e8649287db6a3a32c6d24da22ae3146325aa353e4c" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.0" diff --git a/daemon/Cargo.toml b/daemon/Cargo.toml index a979456..0122e99 100644 --- a/daemon/Cargo.toml +++ b/daemon/Cargo.toml @@ -24,7 +24,7 @@ reqwest = { version = "0.11", default-features = false, features = ["json", "rus rocket = { version = "0.5.0-rc.1", features = ["json"] } rocket-basicauth = { version = "2", default-features = false } rust-embed = "6.2" -rust_decimal = { version = "1.16", features = ["serde-float", "serde-arbitrary-precision"] } +rust_decimal = "1.16" rust_decimal_macros = "1.16" serde = { version = "1", features = ["derive"] } serde_json = "1" @@ -52,6 +52,7 @@ path = "src/maker.rs" [dev-dependencies] pretty_assertions = "1" +serde_test = "1" time = { version = "0.3", features = ["std"] } [build-dependencies] diff --git a/daemon/src/bitmex_price_feed.rs b/daemon/src/bitmex_price_feed.rs index 5c5c307..0c0a2f2 100644 --- a/daemon/src/bitmex_price_feed.rs +++ b/daemon/src/bitmex_price_feed.rs @@ -1,8 +1,7 @@ use crate::model::Usd; -use anyhow::{Context, Result}; +use anyhow::Result; use futures::{StreamExt, TryStreamExt}; use rust_decimal::Decimal; -use rust_decimal_macros::dec; use std::convert::TryFrom; use std::future::Future; use std::time::SystemTime; @@ -81,10 +80,10 @@ impl Quote { } fn mid_range(&self) -> Result { - Ok(Usd((self.bid.checked_add(self.ask))? - .0 - .checked_div(dec!(2)) - .context("division error")?)) + let sum = self.bid.checked_add(self.ask)?; + let half = sum.half(); + + Ok(half) } } @@ -120,7 +119,7 @@ mod tests { let quote = Quote::from_message(message).unwrap().unwrap(); - assert_eq!(quote.bid, Usd(dec!(42640.5))); - assert_eq!(quote.ask, Usd(dec!(42641))); + assert_eq!(quote.bid, Usd::new(dec!(42640.5))); + assert_eq!(quote.ask, Usd::new(dec!(42641))); } } diff --git a/daemon/src/db.rs b/daemon/src/db.rs index 7a3597c..c1cf5ce 100644 --- a/daemon/src/db.rs +++ b/daemon/src/db.rs @@ -101,11 +101,11 @@ pub async fn load_order_by_id( id: row.uuid, trading_pair: row.trading_pair, position: row.position, - price: Usd(Decimal::from_str(&row.initial_price)?), - min_quantity: Usd(Decimal::from_str(&row.min_quantity)?), - max_quantity: Usd(Decimal::from_str(&row.max_quantity)?), + price: Usd::new(Decimal::from_str(&row.initial_price)?), + min_quantity: Usd::new(Decimal::from_str(&row.min_quantity)?), + max_quantity: Usd::new(Decimal::from_str(&row.max_quantity)?), leverage: row.leverage, - liquidation_price: Usd(Decimal::from_str(&row.liquidation_price)?), + liquidation_price: Usd::new(Decimal::from_str(&row.liquidation_price)?), creation_timestamp: convert_to_system_time(row.ts_secs, row.ts_nanos)?, term: Duration::new(row.term_secs, row.term_nanos), origin: row.origin, @@ -317,11 +317,11 @@ pub async fn load_cfd_by_order_id( id: row.uuid, trading_pair: row.trading_pair, position: row.position, - price: Usd(Decimal::from_str(&row.initial_price)?), - min_quantity: Usd(Decimal::from_str(&row.min_quantity)?), - max_quantity: Usd(Decimal::from_str(&row.max_quantity)?), + price: Usd::new(Decimal::from_str(&row.initial_price)?), + min_quantity: Usd::new(Decimal::from_str(&row.min_quantity)?), + max_quantity: Usd::new(Decimal::from_str(&row.max_quantity)?), leverage: row.leverage, - liquidation_price: Usd(Decimal::from_str(&row.liquidation_price)?), + liquidation_price: Usd::new(Decimal::from_str(&row.liquidation_price)?), creation_timestamp: convert_to_system_time(row.ts_secs, row.ts_nanos)?, term: Duration::new(row.term_secs, row.term_nanos), origin: row.origin, @@ -333,7 +333,7 @@ pub async fn load_cfd_by_order_id( // via https://github.com/comit-network/hermes/issues/290 Ok(Cfd { order, - quantity_usd: Usd(Decimal::from_str(&row.quantity_usd)?), + quantity_usd: Usd::new(Decimal::from_str(&row.quantity_usd)?), state: serde_json::from_str(row.state.as_str())?, }) } @@ -419,11 +419,11 @@ pub async fn load_all_cfds(conn: &mut PoolConnection) -> anyhow::Result< id: row.uuid, trading_pair: row.trading_pair, position: row.position, - price: Usd(Decimal::from_str(&row.initial_price)?), - min_quantity: Usd(Decimal::from_str(&row.min_quantity)?), - max_quantity: Usd(Decimal::from_str(&row.max_quantity)?), + price: Usd::new(Decimal::from_str(&row.initial_price)?), + min_quantity: Usd::new(Decimal::from_str(&row.min_quantity)?), + max_quantity: Usd::new(Decimal::from_str(&row.max_quantity)?), leverage: row.leverage, - liquidation_price: Usd(Decimal::from_str(&row.liquidation_price)?), + liquidation_price: Usd::new(Decimal::from_str(&row.liquidation_price)?), creation_timestamp: convert_to_system_time(row.ts_secs, row.ts_nanos)?, term: Duration::new(row.term_secs, row.term_nanos), origin: row.origin, @@ -432,7 +432,7 @@ pub async fn load_all_cfds(conn: &mut PoolConnection) -> anyhow::Result< Ok(Cfd { order, - quantity_usd: Usd(Decimal::from_str(&row.quantity_usd)?), + quantity_usd: Usd::new(Decimal::from_str(&row.quantity_usd)?), state: serde_json::from_str(row.state.as_str())?, }) }) @@ -529,11 +529,11 @@ pub async fn load_cfds_by_oracle_event_id( id: row.uuid, trading_pair: row.trading_pair, position: row.position, - price: Usd(Decimal::from_str(&row.initial_price)?), - min_quantity: Usd(Decimal::from_str(&row.min_quantity)?), - max_quantity: Usd(Decimal::from_str(&row.max_quantity)?), + price: Usd::new(Decimal::from_str(&row.initial_price)?), + min_quantity: Usd::new(Decimal::from_str(&row.min_quantity)?), + max_quantity: Usd::new(Decimal::from_str(&row.max_quantity)?), leverage: row.leverage, - liquidation_price: Usd(Decimal::from_str(&row.liquidation_price)?), + liquidation_price: Usd::new(Decimal::from_str(&row.liquidation_price)?), creation_timestamp: convert_to_system_time(row.ts_secs, row.ts_nanos)?, term: Duration::new(row.term_secs, row.term_nanos), origin: row.origin, @@ -542,7 +542,7 @@ pub async fn load_cfds_by_oracle_event_id( Ok(Cfd { order, - quantity_usd: Usd(Decimal::from_str(&row.quantity_usd)?), + quantity_usd: Usd::new(Decimal::from_str(&row.quantity_usd)?), state: serde_json::from_str(row.state.as_str())?, }) }) @@ -787,7 +787,7 @@ mod tests { fn dummy() -> Self { Cfd::new( Order::dummy(), - Usd(dec!(1000)), + Usd::new(dec!(1000)), CfdState::outgoing_order_request(), ) } @@ -809,9 +809,9 @@ mod tests { impl Order { fn dummy() -> Self { Order::new( - Usd(dec!(1000)), - Usd(dec!(100)), - Usd(dec!(1000)), + Usd::new(dec!(1000)), + Usd::new(dec!(100)), + Usd::new(dec!(1000)), Origin::Theirs, BitMexPriceEventId::with_20_digits(OffsetDateTime::now_utc()), ) diff --git a/daemon/src/model.rs b/daemon/src/model.rs index bb65dc6..4d2e71a 100644 --- a/daemon/src/model.rs +++ b/daemon/src/model.rs @@ -4,6 +4,7 @@ use bdk::bitcoin::{Address, Amount}; use reqwest::Url; use rust_decimal::prelude::ToPrimitive; use rust_decimal::Decimal; +use rust_decimal_macros::dec; use serde::{Deserialize, Serialize}; use serde_with::{DeserializeFromStr, SerializeDisplay}; use std::time::SystemTime; @@ -13,10 +14,14 @@ use uuid::Uuid; pub mod cfd; -#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)] -pub struct Usd(pub Decimal); +#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)] +pub struct Usd(Decimal); impl Usd { + pub fn new(dec: Decimal) -> Self { + Self(dec) + } + pub fn checked_add(&self, other: Usd) -> Result { let result = self.0.checked_add(other.0).context("addition error")?; Ok(Usd(result)) @@ -27,6 +32,7 @@ impl Usd { Ok(Usd(result)) } + // TODO: Usd * Usd = Usd^2 not Usd !!! pub fn checked_mul(&self, other: Usd) -> Result { let result = self .0 @@ -43,6 +49,35 @@ impl Usd { pub fn try_into_u64(&self) -> Result { self.0.to_u64().context("could not fit decimal into u64") } + + pub fn half(&self) -> Usd { + let half = self + .0 + .checked_div(dec!(2)) + .expect("can always divide by two"); + + Usd(half) + } +} + +impl Serialize for Usd { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + ::serialize(&self.0.round_dp(2), serializer) + } +} + +impl<'de> Deserialize<'de> for Usd { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let dec = ::deserialize(deserializer)?.round_dp(2); + + Ok(Usd(dec)) + } } impl fmt::Display for Usd { @@ -57,8 +92,8 @@ impl From for Usd { } } -#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq)] -pub struct Percent(pub Decimal); +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct Percent(Decimal); impl fmt::Display for Percent { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -66,6 +101,26 @@ impl fmt::Display for Percent { } } +impl Serialize for Percent { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + ::serialize(&self.0.round_dp(2), serializer) + } +} + +impl<'de> Deserialize<'de> for Percent { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let dec = ::deserialize(deserializer)?.round_dp(2); + + Ok(Percent(dec)) + } +} + impl From for Percent { fn from(decimal: Decimal) -> Self { Percent(decimal) @@ -186,10 +241,25 @@ impl str::FromStr for BitMexPriceEventId { #[cfg(test)] mod tests { + use serde_test::{assert_de_tokens, assert_ser_tokens, Token}; use time::macros::datetime; use super::*; + #[test] + fn usd_serializes_with_only_cents() { + let usd = Usd::new(dec!(1000.12345)); + + assert_ser_tokens(&usd, &[Token::Str("1000.12")]); + } + + #[test] + fn usd_deserializes_trims_precision() { + let usd = Usd::new(dec!(1000.12)); + + assert_de_tokens(&usd, &[Token::Str("1000.12345")]); + } + #[test] fn to_olivia_url() { let url = BitMexPriceEventId::with_20_digits(datetime!(2021-09-23 10:00:00).assume_utc()) diff --git a/daemon/src/payout_curve.rs b/daemon/src/payout_curve.rs index 7b3bb5b..83cb5b1 100644 --- a/daemon/src/payout_curve.rs +++ b/daemon/src/payout_curve.rs @@ -503,9 +503,12 @@ mod tests { #[test] fn calculate_snapshot() { - let actual_payouts = - calculate_payout_parameters(Usd(dec!(54000.00)), Usd(dec!(3500.00)), Leverage(5)) - .unwrap(); + let actual_payouts = calculate_payout_parameters( + Usd::new(dec!(54000.00)), + Usd::new(dec!(3500.00)), + Leverage(5), + ) + .unwrap(); let expected_payouts = vec![ payout(0..=45000, 7777777, 0), @@ -716,9 +719,12 @@ mod tests { #[test] fn verfiy_tails() { - let actual_payouts = - calculate_payout_parameters(Usd(dec!(54000.00)), Usd(dec!(3500.00)), Leverage(5)) - .unwrap(); + let actual_payouts = calculate_payout_parameters( + Usd::new(dec!(54000.00)), + Usd::new(dec!(3500.00)), + Leverage(5), + ) + .unwrap(); let lower_tail = payout(0..=45000, 7777777, 0); let upper_tail = payout(107765..=108000, 3240740, 4537037);