diff --git a/daemon/src/forward_only_ok.rs b/daemon/src/forward_only_ok.rs new file mode 100644 index 0000000..5ae8cd9 --- /dev/null +++ b/daemon/src/forward_only_ok.rs @@ -0,0 +1,58 @@ +use std::fmt; + +use xtra::prelude::MessageChannel; +use xtra::{Handler, KeepRunning}; + +/// A forwarding actor that only forwards [`Result::Ok`] values and shuts itself down upon the first +/// error. +pub struct Actor { + forward: Box>, +} + +impl Actor { + pub fn new(forward: Box>) -> Self { + Self { forward } + } +} + +pub struct Message(pub Result); + +impl xtra::Message for Message +where + TOk: Send + 'static, + TErr: Send + 'static, +{ + type Result = KeepRunning; +} + +#[async_trait::async_trait] +impl Handler> for Actor +where + TOk: xtra::Message + Send + 'static, + TErr: fmt::Display + Send + 'static, +{ + async fn handle( + &mut self, + Message(result): Message, + _: &mut xtra::Context, + ) -> KeepRunning { + let ok = match result { + Ok(ok) => ok, + Err(e) => { + tracing::error!("Received error: {}", e); + + return KeepRunning::StopSelf; + } + }; + + if let Err(xtra::Disconnected) = self.forward.send(ok).await { + tracing::info!("Target actor disappeared, stopping"); + + return KeepRunning::StopSelf; + } + + KeepRunning::Yes + } +} + +impl xtra::Actor for Actor {} diff --git a/daemon/src/lib.rs b/daemon/src/lib.rs index 7b23376..def8c6c 100644 --- a/daemon/src/lib.rs +++ b/daemon/src/lib.rs @@ -4,6 +4,7 @@ pub mod bitmex_price_feed; pub mod cfd_actors; pub mod db; pub mod fan_out; +pub mod forward_only_ok; pub mod housekeeping; pub mod keypair; pub mod logger; diff --git a/daemon/src/maker_cfd.rs b/daemon/src/maker_cfd.rs index f2dd833..2c4dfb5 100644 --- a/daemon/src/maker_cfd.rs +++ b/daemon/src/maker_cfd.rs @@ -20,7 +20,6 @@ use std::collections::HashMap; use std::time::SystemTime; use tokio::sync::watch; use xtra::prelude::*; -use xtra::KeepRunning; pub struct AcceptOrder { pub order_id: OrderId, @@ -70,9 +69,9 @@ pub struct CfdRollOverCompleted { pub dlc: Result, } -pub struct TakerStreamMessage { +pub struct FromTaker { pub taker_id: TakerId, - pub item: Result, + pub msg: wire::TakerToMaker, } pub struct Actor { @@ -971,21 +970,8 @@ impl Handler for Actor { } #[async_trait] -impl Handler for Actor { - async fn handle(&mut self, msg: TakerStreamMessage, _ctx: &mut Context) -> KeepRunning { - let TakerStreamMessage { taker_id, item } = msg; - let msg = match item { - Ok(msg) => msg, - Err(e) => { - tracing::warn!( - "Error while receiving message from taker {}: {:#}", - taker_id, - e - ); - return KeepRunning::Yes; - } - }; - +impl Handler for Actor { + async fn handle(&mut self, FromTaker { taker_id, msg }: FromTaker, _ctx: &mut Context) { match msg { wire::TakerToMaker::TakeOrder { order_id, quantity } => { log_error!(self.handle_take_order(taker_id, order_id, quantity)) @@ -1034,8 +1020,6 @@ impl Handler for Actor { log_error!(self.handle_inc_roll_over_protocol_msg(taker_id, msg)) } } - - KeepRunning::Yes } } @@ -1090,9 +1074,8 @@ impl Message for RejectRollOver { type Result = (); } -// this signature is a bit different because we use `Address::attach_stream` -impl Message for TakerStreamMessage { - type Result = KeepRunning; +impl Message for FromTaker { + type Result = (); } impl xtra::Actor for Actor {} diff --git a/daemon/src/maker_inc_connections.rs b/daemon/src/maker_inc_connections.rs index bc71321..9386413 100644 --- a/daemon/src/maker_inc_connections.rs +++ b/daemon/src/maker_inc_connections.rs @@ -1,10 +1,10 @@ -use crate::maker_cfd::{NewTakerOnline, TakerStreamMessage}; +use crate::maker_cfd::{FromTaker, NewTakerOnline}; use crate::model::cfd::{Order, OrderId}; use crate::model::{BitMexPriceEventId, TakerId}; -use crate::{log_error, maker_cfd, send_to_socket, wire}; +use crate::{forward_only_ok, log_error, maker_cfd, send_to_socket, wire}; use anyhow::{Context as AnyhowContext, Result}; use async_trait::async_trait; -use futures::StreamExt; +use futures::{StreamExt, TryStreamExt}; use std::collections::HashMap; use std::io; use std::net::SocketAddr; @@ -65,13 +65,13 @@ pub enum ListenerMessage { pub struct Actor { write_connections: HashMap>>, new_taker_channel: Box>, - taker_msg_channel: Box>, + taker_msg_channel: Box>, } impl Actor { pub fn new( new_taker_channel: &impl MessageChannel, - taker_msg_channel: &impl MessageChannel, + taker_msg_channel: &impl MessageChannel, ) -> Self { Self { write_connections: HashMap::new(), @@ -165,7 +165,7 @@ impl Actor { &mut self, stream: TcpStream, address: SocketAddr, - ctx: &mut Context, + _: &mut Context, ) { let taker_id = TakerId::default(); @@ -173,10 +173,14 @@ impl Actor { let (read, write) = stream.into_split(); let read = FramedRead::new(read, wire::JsonCodec::default()) - .map(move |item| maker_cfd::TakerStreamMessage { taker_id, item }); + .map_ok(move |msg| FromTaker { taker_id, msg }) + .map(forward_only_ok::Message); - let this = ctx.address().expect("self to be alive"); - tokio::spawn(this.attach_stream(Box::pin(read))); + let forward_to_cfd = forward_only_ok::Actor::new(self.taker_msg_channel.clone_channel()) + .create(None) + .spawn_global(); + + tokio::spawn(forward_to_cfd.attach_stream(read)); let out_msg_actor = send_to_socket::Actor::new(write) .create(None) @@ -233,14 +237,6 @@ impl Handler for Actor { } } -#[async_trait] -impl Handler for Actor { - async fn handle(&mut self, msg: TakerStreamMessage, _ctx: &mut Context) -> KeepRunning { - log_error!(self.taker_msg_channel.send(msg)); - KeepRunning::Yes - } -} - impl Message for BroadcastOrder { type Result = (); }