diff --git a/daemon/tests/happy_path.rs b/daemon/tests/happy_path.rs index a68002c..a989fc3 100644 --- a/daemon/tests/happy_path.rs +++ b/daemon/tests/happy_path.rs @@ -1,26 +1,34 @@ -use anyhow::Result; +use anyhow::{Context, Result}; use bdk::bitcoin::util::psbt::PartiallySignedTransaction; use bdk::bitcoin::{ecdsa, Txid}; use cfd_protocol::secp256k1_zkp::{schnorrsig, Secp256k1}; use cfd_protocol::PartyParams; -use daemon::model::cfd::Order; +use daemon::maker_cfd::CfdAction; +use daemon::model::cfd::{Cfd, CfdState, Order}; use daemon::model::{Price, Usd, WalletInfo}; -use daemon::{connection, db, logger, maker_cfd, maker_inc_connections, monitor, oracle, wallet}; +use daemon::tokio_ext::FutureExt; +use daemon::{ + connection, db, maker_cfd, maker_inc_connections, monitor, oracle, taker_cfd, wallet, +}; use rand::thread_rng; use rust_decimal_macros::dec; use sqlx::SqlitePool; use std::net::SocketAddr; use std::str::FromStr; use std::task::Poll; -use std::time::SystemTime; +use std::time::{Duration, SystemTime}; use tokio::sync::watch; +use tracing::subscriber::DefaultGuard; use tracing_subscriber::filter::LevelFilter; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::EnvFilter; use xtra::spawn::TokioGlobalSpawnExt; use xtra::Actor; use xtra_productivity::xtra_productivity; #[tokio::test] async fn taker_receives_order_from_maker_on_publication() { + let _guard = init_tracing(); let (mut maker, mut taker) = start_both().await; assert!(is_next_none(&mut taker.order_feed).await); @@ -36,6 +44,42 @@ async fn taker_receives_order_from_maker_on_publication() { assert_eq!(published.id, received.id); } +#[tokio::test] +async fn taker_takes_order_and_maker_rejects() { + let _guard = init_tracing(); + let (mut maker, mut taker) = start_both().await; + + // TODO: Why is this needed? For the cfd stream it is not needed + is_next_none(&mut taker.order_feed).await; + + maker.publish_order(new_dummy_order()); + + let (_, received) = next_order(&mut maker.order_feed, &mut taker.order_feed).await; + + taker.take_order(received.clone(), Usd::new(dec!(10))); + + let (taker_cfd, maker_cfd) = next_cfd(&mut taker.cfd_feed, &mut maker.cfd_feed).await; + assert_eq!(taker_cfd.order.id, received.id); + assert_eq!(maker_cfd.order.id, received.id); + assert!(matches!( + taker_cfd.state, + CfdState::OutgoingOrderRequest { .. } + )); + assert!(matches!( + maker_cfd.state, + CfdState::IncomingOrderRequest { .. } + )); + + maker.reject_take_request(received.clone()); + + let (taker_cfd, maker_cfd) = next_cfd(&mut taker.cfd_feed, &mut maker.cfd_feed).await; + // TODO: More elaborate Cfd assertions + assert_eq!(taker_cfd.order.id, received.id); + assert_eq!(maker_cfd.order.id, received.id); + assert!(matches!(taker_cfd.state, CfdState::Rejected { .. })); + assert!(matches!(maker_cfd.state, CfdState::Rejected { .. })); +} + fn new_dummy_order() -> maker_cfd::NewOrder { maker_cfd::NewOrder { price: Price::new(dec!(50_000)).expect("unexpected failure"), @@ -44,6 +88,30 @@ fn new_dummy_order() -> maker_cfd::NewOrder { } } +/// Returns the first `Cfd` from both channels +/// +/// Ensures that there is only one `Cfd` present in both channels. +async fn next_cfd( + rx_a: &mut watch::Receiver>, + rx_b: &mut watch::Receiver>, +) -> (Cfd, Cfd) { + let (a, b) = tokio::join!(next(rx_a), next(rx_b)); + + assert_eq!(a.len(), 1); + assert_eq!(b.len(), 1); + + (a.first().unwrap().clone(), b.first().unwrap().clone()) +} + +async fn next_order( + rx_a: &mut watch::Receiver>, + rx_b: &mut watch::Receiver>, +) -> (Order, Order) { + let (a, b) = tokio::join!(next_some(rx_a), next_some(rx_b)); + + (a, b) +} + /// Returns the value if the next Option received on the stream is Some /// /// Panics if None is received on the stream. @@ -73,13 +141,38 @@ async fn next(rx: &mut watch::Receiver) -> T where T: Clone, { - rx.changed().await.unwrap(); + rx.changed() + .timeout(Duration::from_secs(5)) + .await + .context("Waiting for next element in channel is taking too long, aborting") + .unwrap() + .unwrap(); rx.borrow().clone() } -fn init_tracing() { - logger::init(LevelFilter::DEBUG, false).unwrap(); +fn init_tracing() -> DefaultGuard { + let filter = EnvFilter::from_default_env() + // apply warning level globally + .add_directive(format!("{}", LevelFilter::WARN).parse().unwrap()) + // log traces from test itself + .add_directive( + format!("happy_path={}", LevelFilter::DEBUG) + .parse() + .unwrap(), + ) + .add_directive(format!("taker={}", LevelFilter::DEBUG).parse().unwrap()) + .add_directive(format!("maker={}", LevelFilter::DEBUG).parse().unwrap()) + .add_directive(format!("daemon={}", LevelFilter::DEBUG).parse().unwrap()) + .add_directive(format!("rocket={}", LevelFilter::WARN).parse().unwrap()); + + let guard = tracing_subscriber::fmt() + .with_env_filter(filter) + .with_test_writer() + .set_default(); + tracing::info!("Running version: {}", env!("VERGEN_GIT_SEMVER_LIGHTWEIGHT")); + + guard } /// Test Stub simulating the Oracle actor @@ -152,10 +245,12 @@ impl Wallet { } /// Maker Test Setup +#[derive(Clone)] struct Maker { cfd_actor_addr: xtra::Address>, order_feed: watch::Receiver>, + cfd_feed: watch::Receiver>, #[allow(dead_code)] // we need to keep the xtra::Address for refcounting inc_conn_actor_addr: xtra::Address, listen_addr: SocketAddr, @@ -202,6 +297,7 @@ impl Maker { Self { cfd_actor_addr: maker.cfd_actor_addr, order_feed: maker.order_feed_receiver, + cfd_feed: maker.cfd_feed_receiver, inc_conn_actor_addr: maker.inc_conn_addr, listen_addr: address, } @@ -210,11 +306,20 @@ impl Maker { fn publish_order(&mut self, new_order_params: maker_cfd::NewOrder) { self.cfd_actor_addr.do_send(new_order_params).unwrap(); } + + fn reject_take_request(&self, order: Order) { + self.cfd_actor_addr + .do_send(CfdAction::RejectOrder { order_id: order.id }) + .unwrap(); + } } /// Taker Test Setup +#[derive(Clone)] struct Taker { order_feed: watch::Receiver>, + cfd_feed: watch::Receiver>, + cfd_actor_addr: xtra::Address>, } impl Taker { @@ -242,12 +347,22 @@ impl Taker { Self { order_feed: taker.order_feed_receiver, + cfd_feed: taker.cfd_feed_receiver, + cfd_actor_addr: taker.cfd_actor_addr, } } + + fn take_order(&self, order: Order, quantity: Usd) { + self.cfd_actor_addr + .do_send(taker_cfd::TakeOffer { + order_id: order.id, + quantity, + }) + .unwrap(); + } } async fn start_both() -> (Maker, Taker) { - init_tracing(); let oracle_pk: schnorrsig::PublicKey = schnorrsig::PublicKey::from_str( "ddd4636845a90185991826be5a494cde9f4a6947b1727217afedc6292fa4caf7", )