diff --git a/Cargo.lock b/Cargo.lock index 7fb0d80..e76b98a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -555,6 +555,7 @@ dependencies = [ "uuid", "vergen", "xtra", + "xtra_productivity", ] [[package]] diff --git a/daemon/Cargo.toml b/daemon/Cargo.toml index 0122e99..082b479 100644 --- a/daemon/Cargo.toml +++ b/daemon/Cargo.toml @@ -41,6 +41,7 @@ tracing = { version = "0.1" } tracing-subscriber = { version = "0.2", default-features = false, features = ["fmt", "ansi", "env-filter", "chrono", "tracing-log", "json"] } uuid = { version = "0.8", features = ["serde", "v4"] } xtra = { version = "0.6", features = ["with-tokio-1"] } +xtra_productivity = { path = "../xtra_productivity" } [[bin]] name = "taker" diff --git a/daemon/src/maker_cfd.rs b/daemon/src/maker_cfd.rs index 04b4fb0..f5dc89d 100644 --- a/daemon/src/maker_cfd.rs +++ b/daemon/src/maker_cfd.rs @@ -583,7 +583,7 @@ where taker_id, command: TakerCommand::NotifyOrderAccepted { id: order_id }, }) - .await?; + .await??; Ok(()) } }); @@ -770,7 +770,7 @@ where oracle_event_id, }, }) - .await?; + .await??; Ok(()) } }); diff --git a/daemon/src/maker_inc_connections.rs b/daemon/src/maker_inc_connections.rs index 4e2ef44..6bed66d 100644 --- a/daemon/src/maker_inc_connections.rs +++ b/daemon/src/maker_inc_connections.rs @@ -1,7 +1,7 @@ use crate::maker_cfd::{FromTaker, NewTakerOnline}; use crate::model::cfd::{Order, OrderId}; use crate::model::{BitMexPriceEventId, TakerId}; -use crate::{forward_only_ok, log_error, maker_cfd, send_to_socket, wire}; +use crate::{forward_only_ok, maker_cfd, send_to_socket, wire}; use anyhow::{Context as AnyhowContext, Result}; use async_trait::async_trait; use futures::{StreamExt, TryStreamExt}; @@ -13,6 +13,7 @@ use tokio_util::codec::FramedRead; use xtra::prelude::*; use xtra::spawn::TokioGlobalSpawnExt; use xtra::{Actor as _, KeepRunning}; +use xtra_productivity::xtra_productivity; pub struct BroadcastOrder(pub Option); @@ -92,6 +93,52 @@ impl Actor { Ok(()) } + async fn handle_new_connection_impl( + &mut self, + stream: TcpStream, + address: SocketAddr, + _: &mut Context, + ) { + let taker_id = TakerId::default(); + + tracing::info!("New taker {} connected on {}", taker_id, address); + + let (read, write) = stream.into_split(); + let read = FramedRead::new(read, wire::JsonCodec::default()) + .map_ok(move |msg| FromTaker { taker_id, msg }) + .map(forward_only_ok::Message); + + let (out_msg_actor_address, mut out_msg_actor_context) = xtra::Context::new(None); + + let forward_to_cfd = forward_only_ok::Actor::new(self.taker_msg_channel.clone_channel()) + .create(None) + .spawn_global(); + + // only allow outgoing messages while we are successfully reading incoming ones + tokio::spawn(async move { + let mut actor = send_to_socket::Actor::new(write); + + out_msg_actor_context + .handle_while(&mut actor, forward_to_cfd.attach_stream(read)) + .await; + + tracing::error!("Closing connection to taker {}", taker_id); + + actor.shutdown().await; + }); + + self.write_connections + .insert(taker_id, out_msg_actor_address); + + let _ = self + .new_taker_channel + .send(maker_cfd::NewTakerOnline { id: taker_id }) + .await; + } +} + +#[xtra_productivity] +impl Actor { async fn handle_broadcast_order(&mut self, msg: BroadcastOrder) -> Result<()> { let order = msg.0; @@ -161,78 +208,10 @@ impl Actor { Ok(()) } - async fn handle_new_connection( - &mut self, - stream: TcpStream, - address: SocketAddr, - _: &mut Context, - ) { - let taker_id = TakerId::default(); - - tracing::info!("New taker {} connected on {}", taker_id, address); - - let (read, write) = stream.into_split(); - let read = FramedRead::new(read, wire::JsonCodec::default()) - .map_ok(move |msg| FromTaker { taker_id, msg }) - .map(forward_only_ok::Message); - - let (out_msg_actor_address, mut out_msg_actor_context) = xtra::Context::new(None); - - let forward_to_cfd = forward_only_ok::Actor::new(self.taker_msg_channel.clone_channel()) - .create(None) - .spawn_global(); - - // only allow outgoing messages while we are successfully reading incoming ones - tokio::spawn(async move { - let mut actor = send_to_socket::Actor::new(write); - - out_msg_actor_context - .handle_while(&mut actor, forward_to_cfd.attach_stream(read)) - .await; - - tracing::error!("Closing connection to taker {}", taker_id); - - actor.shutdown().await; - }); - - self.write_connections - .insert(taker_id, out_msg_actor_address); - - let _ = self - .new_taker_channel - .send(maker_cfd::NewTakerOnline { id: taker_id }) - .await; - } -} - -macro_rules! log_error { - ($future:expr) => { - if let Err(e) = $future.await { - tracing::error!(%e); - } - }; -} - -#[async_trait] -impl Handler for Actor { - async fn handle(&mut self, msg: BroadcastOrder, _ctx: &mut Context) { - log_error!(self.handle_broadcast_order(msg)); - } -} - -#[async_trait] -impl Handler for Actor { - async fn handle(&mut self, msg: TakerMessage, _ctx: &mut Context) { - log_error!(self.handle_taker_message(msg)); - } -} - -#[async_trait] -impl Handler for Actor { async fn handle(&mut self, msg: ListenerMessage, ctx: &mut Context) -> KeepRunning { match msg { ListenerMessage::NewConnection { stream, address } => { - self.handle_new_connection(stream, address, ctx).await; + self.handle_new_connection_impl(stream, address, ctx).await; KeepRunning::Yes } @@ -247,16 +226,4 @@ impl Handler for Actor { } } -impl Message for BroadcastOrder { - type Result = (); -} - -impl Message for TakerMessage { - type Result = (); -} - -impl Message for ListenerMessage { - type Result = KeepRunning; -} - impl xtra::Actor for Actor {}