diff --git a/daemon/src/routes_taker.rs b/daemon/src/routes_taker.rs index 994322b..7e3e511 100644 --- a/daemon/src/routes_taker.rs +++ b/daemon/src/routes_taker.rs @@ -5,7 +5,7 @@ use daemon::model::{Leverage, Price, Usd, WalletInfo}; use daemon::projection::Feeds; use daemon::routes::EmbeddedFileExt; use daemon::to_sse_event::{CfdAction, CfdsWithAuxData, ToSseEvent}; -use daemon::{bitmex_price_feed, taker_cfd, wallet}; +use daemon::{bitmex_price_feed, monitor, oracle, taker_cfd, wallet}; use http_api_problem::{HttpApiProblem, StatusCode}; use rocket::http::{ContentType, Status}; use rocket::response::stream::EventStream; @@ -20,6 +20,8 @@ use tokio::select; use tokio::sync::watch; use xtra::prelude::*; +type Taker = xtra::Address>; + #[rocket::get("/feed")] pub async fn feed( rx: &State, @@ -114,9 +116,9 @@ pub struct CfdOrderRequest { #[rocket::post("/cfd/order", data = "")] pub async fn post_order_request( cfd_order_request: Json, - take_offer_channel: &State>>, + cfd_actor: &State, ) -> Result, HttpApiProblem> { - take_offer_channel + cfd_actor .send(taker_cfd::TakeOffer { order_id: cfd_order_request.order_id, quantity: cfd_order_request.quantity, @@ -136,10 +138,9 @@ pub async fn post_order_request( pub async fn post_cfd_action( id: OrderId, action: CfdAction, - cfd_action_channel: &State>>, + cfd_actor: &State, feeds: &State, ) -> Result, HttpApiProblem> { - use taker_cfd::CfdAction::*; let result = match action { CfdAction::AcceptOrder | CfdAction::RejectOrder @@ -150,20 +151,25 @@ pub async fn post_cfd_action( return Err(HttpApiProblem::new(StatusCode::BAD_REQUEST) .detail(format!("taker cannot invoke action {}", action))); } - CfdAction::Commit => cfd_action_channel.send(Commit { order_id: id }), + CfdAction::Commit => cfd_actor.send(taker_cfd::Commit { order_id: id }).await, CfdAction::Settle => { let quote: bitmex_price_feed::Quote = feeds.quote.borrow().clone().into(); let current_price = quote.for_taker(); - cfd_action_channel.send(ProposeSettlement { - order_id: id, - current_price, - }) + cfd_actor + .send(taker_cfd::ProposeSettlement { + order_id: id, + current_price, + }) + .await + } + CfdAction::RollOver => { + cfd_actor + .send(taker_cfd::ProposeRollOver { order_id: id }) + .await } - CfdAction::RollOver => cfd_action_channel.send(ProposeRollOver { order_id: id }), }; result - .await .unwrap_or_else(|e| anyhow::bail!(e.to_string())) .map_err(|e| { HttpApiProblem::new(StatusCode::INTERNAL_SERVER_ERROR) diff --git a/daemon/src/taker.rs b/daemon/src/taker.rs index a2db57f..0e57871 100644 --- a/daemon/src/taker.rs +++ b/daemon/src/taker.rs @@ -9,8 +9,8 @@ use daemon::model::WalletInfo; use daemon::seed::Seed; use daemon::tokio_ext::FutureExt; use daemon::{ - bitmex_price_feed, db, housekeeping, logger, monitor, oracle, projection, taker_cfd, wallet, - wallet_sync, TakerActorSystem, Tasks, HEARTBEAT_INTERVAL, N_PAYOUTS, SETTLEMENT_INTERVAL, + bitmex_price_feed, db, housekeeping, logger, monitor, oracle, projection, wallet, wallet_sync, + TakerActorSystem, Tasks, HEARTBEAT_INTERVAL, N_PAYOUTS, SETTLEMENT_INTERVAL, }; use sqlx::sqlite::SqliteConnectOptions; use sqlx::SqlitePool; @@ -19,7 +19,6 @@ use std::path::PathBuf; use std::str::FromStr; use tokio::sync::watch; use tracing_subscriber::filter::LevelFilter; -use xtra::prelude::MessageChannel; use xtra::Actor; mod routes_taker; @@ -269,13 +268,10 @@ async fn main() -> Result<()> { )); tasks.add(wallet_sync::new(wallet.clone(), wallet_feed_sender)); - let take_offer_channel = MessageChannel::::clone_channel(&cfd_actor_addr); - let cfd_action_channel = MessageChannel::::clone_channel(&cfd_actor_addr); let rocket = rocket::custom(figment) .manage(projection_feeds) - .manage(take_offer_channel) - .manage(cfd_action_channel) + .manage(cfd_actor_addr) .manage(wallet_feed_receiver) .manage(bitcoin_network) .manage(wallet) diff --git a/daemon/src/taker_cfd.rs b/daemon/src/taker_cfd.rs index 8a21dee..39756ea 100644 --- a/daemon/src/taker_cfd.rs +++ b/daemon/src/taker_cfd.rs @@ -23,23 +23,24 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use xtra::prelude::*; use xtra::Actor as _; +use xtra_productivity::xtra_productivity; pub struct TakeOffer { pub order_id: OrderId, pub quantity: Usd, } -pub enum CfdAction { - ProposeSettlement { - order_id: OrderId, - current_price: Price, - }, - ProposeRollOver { - order_id: OrderId, - }, - Commit { - order_id: OrderId, - }, +pub struct ProposeSettlement { + pub order_id: OrderId, + pub current_price: Price, +} + +pub struct ProposeRollOver { + pub order_id: OrderId, +} + +pub struct Commit { + pub order_id: OrderId, } pub struct CfdRollOverCompleted { @@ -135,24 +136,59 @@ impl Actor { } } +#[xtra_productivity] impl Actor where - W: xtra::Handler - + xtra::Handler - + xtra::Handler, + W: xtra::Handler, { - async fn handle_commit(&mut self, order_id: OrderId) -> Result<()> { + async fn handle_commit(&mut self, msg: Commit) -> Result<()> { + let Commit { order_id } = msg; + let mut conn = self.db.acquire().await?; cfd_actors::handle_commit(order_id, &mut conn, &self.wallet, &self.projection_actor) .await?; Ok(()) } - async fn handle_propose_settlement( - &mut self, - order_id: OrderId, - current_price: Price, - ) -> Result<()> { + async fn handle_propose_roll_over(&mut self, msg: ProposeRollOver) -> Result<()> { + let ProposeRollOver { order_id } = msg; + + if self.current_pending_proposals.contains_key(&order_id) { + anyhow::bail!("An update for order id {} is already in progress", order_id) + } + + let proposal = RollOverProposal { + order_id, + timestamp: Timestamp::now(), + }; + + self.current_pending_proposals.insert( + proposal.order_id, + UpdateCfdProposal::RollOverProposal { + proposal: proposal.clone(), + direction: SettlementKind::Outgoing, + }, + ); + self.send_pending_update_proposals().await?; + + self.conn_actor + .send(wire::TakerToMaker::ProposeRollOver { + order_id: proposal.order_id, + timestamp: proposal.timestamp, + }) + .await?; + Ok(()) + } +} + +#[xtra_productivity] +impl Actor { + async fn handle_propose_settlement(&mut self, msg: ProposeSettlement) -> Result<()> { + let ProposeSettlement { + order_id, + current_price, + } = msg; + let mut conn = self.db.acquire().await?; let cfd = load_cfd_by_order_id(order_id, &mut conn).await?; @@ -196,7 +232,14 @@ where .await?; Ok(()) } +} +impl Actor +where + W: xtra::Handler + + xtra::Handler + + xtra::Handler, +{ async fn handle_settlement_rejected(&mut self, order_id: OrderId) -> Result<()> { tracing::info!(%order_id, "Settlement proposal got rejected"); @@ -318,34 +361,6 @@ where .await?; Ok(()) } - - async fn handle_propose_roll_over(&mut self, order_id: OrderId) -> Result<()> { - if self.current_pending_proposals.contains_key(&order_id) { - anyhow::bail!("An update for order id {} is already in progress", order_id) - } - - let proposal = RollOverProposal { - order_id, - timestamp: Timestamp::now(), - }; - - self.current_pending_proposals.insert( - proposal.order_id, - UpdateCfdProposal::RollOverProposal { - proposal: proposal.clone(), - direction: SettlementKind::Outgoing, - }, - ); - self.send_pending_update_proposals().await?; - - self.conn_actor - .send(wire::TakerToMaker::ProposeRollOver { - order_id: proposal.order_id, - timestamp: proposal.timestamp, - }) - .await?; - Ok(()) - } } impl Actor @@ -662,34 +677,6 @@ where } } -#[async_trait] -impl Handler for Actor -where - W: xtra::Handler - + xtra::Handler - + xtra::Handler, -{ - async fn handle(&mut self, msg: CfdAction, _ctx: &mut Context) -> Result<()> { - use CfdAction::*; - - if let Err(e) = match msg { - Commit { order_id } => self.handle_commit(order_id).await, - ProposeSettlement { - order_id, - current_price, - } => { - self.handle_propose_settlement(order_id, current_price) - .await - } - ProposeRollOver { order_id } => self.handle_propose_roll_over(order_id).await, - } { - tracing::error!("Message handler failed: {:#}", e); - anyhow::bail!(e) - } - Ok(()) - } -} - #[async_trait] impl Handler for Actor where @@ -790,10 +777,6 @@ impl Message for TakeOffer { type Result = Result<()>; } -impl Message for CfdAction { - type Result = Result<()>; -} - impl Message for CfdRollOverCompleted { type Result = (); }