diff --git a/Cargo.lock b/Cargo.lock index 8c8fdd6..e76b98a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -555,6 +555,7 @@ dependencies = [ "uuid", "vergen", "xtra", + "xtra_productivity", ] [[package]] @@ -3094,6 +3095,20 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" +[[package]] +name = "trybuild" +version = "1.0.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbaccfa9796293406a02ec790614628c88d0b3246249a620ac1ee7076274716b" +dependencies = [ + "glob", + "lazy_static", + "serde", + "serde_json", + "termcolor", + "toml", +] + [[package]] name = "tungstenite" version = "0.14.0" @@ -3502,6 +3517,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "xtra_productivity" +version = "0.1.0" +dependencies = [ + "async-trait", + "quote", + "syn", + "tokio", + "trybuild", + "xtra", +] + [[package]] name = "yansi" version = "0.5.0" diff --git a/Cargo.toml b/Cargo.toml index 761a767..1ba6d1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["cfd_protocol", "daemon"] +members = ["cfd_protocol", "daemon", "xtra_productivity"] resolver = "2" [patch.crates-io] diff --git a/daemon/Cargo.toml b/daemon/Cargo.toml index 0122e99..082b479 100644 --- a/daemon/Cargo.toml +++ b/daemon/Cargo.toml @@ -41,6 +41,7 @@ tracing = { version = "0.1" } tracing-subscriber = { version = "0.2", default-features = false, features = ["fmt", "ansi", "env-filter", "chrono", "tracing-log", "json"] } uuid = { version = "0.8", features = ["serde", "v4"] } xtra = { version = "0.6", features = ["with-tokio-1"] } +xtra_productivity = { path = "../xtra_productivity" } [[bin]] name = "taker" diff --git a/daemon/src/maker_cfd.rs b/daemon/src/maker_cfd.rs index 39fddea..f5dc89d 100644 --- a/daemon/src/maker_cfd.rs +++ b/daemon/src/maker_cfd.rs @@ -8,6 +8,7 @@ use crate::model::cfd::{ }; use crate::model::{TakerId, Usd}; use crate::monitor::MonitorParams; +use crate::tokio_ext::spawn_fallible; use crate::wallet::Wallet; use crate::{log_error, maker_inc_connections, monitor, oracle, setup_contract, wire}; use anyhow::{Context as _, Result}; @@ -574,12 +575,18 @@ where // Use `.send` here to ensure we only continue once the message has been sent // Nothing done after this call should be able to fail, otherwise we notified the taker, but // might not transition to `Active` ourselves! - self.takers - .send(maker_inc_connections::TakerMessage { - taker_id, - command: TakerCommand::NotifyOrderAccepted { id: order_id }, - }) - .await?; + spawn_fallible::<_, anyhow::Error>({ + let takers = self.takers.clone(); + async move { + takers + .send(maker_inc_connections::TakerMessage { + taker_id, + command: TakerCommand::NotifyOrderAccepted { id: order_id }, + }) + .await??; + Ok(()) + } + }); // 5. Spawn away the contract setup let (sender, receiver) = mpsc::unbounded(); @@ -751,15 +758,22 @@ where .await? .with_context(|| format!("Announcement {} not found", oracle_event_id))?; - self.takers - .send(maker_inc_connections::TakerMessage { - taker_id, - command: TakerCommand::NotifyRollOverAccepted { - id: proposal.order_id, - oracle_event_id, - }, - }) - .await?; + spawn_fallible::<_, anyhow::Error>({ + let takers = self.takers.clone(); + let order_id = proposal.order_id; + async move { + takers + .send(maker_inc_connections::TakerMessage { + taker_id, + command: TakerCommand::NotifyRollOverAccepted { + id: order_id, + oracle_event_id, + }, + }) + .await??; + Ok(()) + } + }); self.oracle_actor .do_send_async(oracle::MonitorAttestation { diff --git a/daemon/src/maker_inc_connections.rs b/daemon/src/maker_inc_connections.rs index 4e2ef44..6bed66d 100644 --- a/daemon/src/maker_inc_connections.rs +++ b/daemon/src/maker_inc_connections.rs @@ -1,7 +1,7 @@ use crate::maker_cfd::{FromTaker, NewTakerOnline}; use crate::model::cfd::{Order, OrderId}; use crate::model::{BitMexPriceEventId, TakerId}; -use crate::{forward_only_ok, log_error, maker_cfd, send_to_socket, wire}; +use crate::{forward_only_ok, maker_cfd, send_to_socket, wire}; use anyhow::{Context as AnyhowContext, Result}; use async_trait::async_trait; use futures::{StreamExt, TryStreamExt}; @@ -13,6 +13,7 @@ use tokio_util::codec::FramedRead; use xtra::prelude::*; use xtra::spawn::TokioGlobalSpawnExt; use xtra::{Actor as _, KeepRunning}; +use xtra_productivity::xtra_productivity; pub struct BroadcastOrder(pub Option); @@ -92,6 +93,52 @@ impl Actor { Ok(()) } + async fn handle_new_connection_impl( + &mut self, + stream: TcpStream, + address: SocketAddr, + _: &mut Context, + ) { + let taker_id = TakerId::default(); + + tracing::info!("New taker {} connected on {}", taker_id, address); + + let (read, write) = stream.into_split(); + let read = FramedRead::new(read, wire::JsonCodec::default()) + .map_ok(move |msg| FromTaker { taker_id, msg }) + .map(forward_only_ok::Message); + + let (out_msg_actor_address, mut out_msg_actor_context) = xtra::Context::new(None); + + let forward_to_cfd = forward_only_ok::Actor::new(self.taker_msg_channel.clone_channel()) + .create(None) + .spawn_global(); + + // only allow outgoing messages while we are successfully reading incoming ones + tokio::spawn(async move { + let mut actor = send_to_socket::Actor::new(write); + + out_msg_actor_context + .handle_while(&mut actor, forward_to_cfd.attach_stream(read)) + .await; + + tracing::error!("Closing connection to taker {}", taker_id); + + actor.shutdown().await; + }); + + self.write_connections + .insert(taker_id, out_msg_actor_address); + + let _ = self + .new_taker_channel + .send(maker_cfd::NewTakerOnline { id: taker_id }) + .await; + } +} + +#[xtra_productivity] +impl Actor { async fn handle_broadcast_order(&mut self, msg: BroadcastOrder) -> Result<()> { let order = msg.0; @@ -161,78 +208,10 @@ impl Actor { Ok(()) } - async fn handle_new_connection( - &mut self, - stream: TcpStream, - address: SocketAddr, - _: &mut Context, - ) { - let taker_id = TakerId::default(); - - tracing::info!("New taker {} connected on {}", taker_id, address); - - let (read, write) = stream.into_split(); - let read = FramedRead::new(read, wire::JsonCodec::default()) - .map_ok(move |msg| FromTaker { taker_id, msg }) - .map(forward_only_ok::Message); - - let (out_msg_actor_address, mut out_msg_actor_context) = xtra::Context::new(None); - - let forward_to_cfd = forward_only_ok::Actor::new(self.taker_msg_channel.clone_channel()) - .create(None) - .spawn_global(); - - // only allow outgoing messages while we are successfully reading incoming ones - tokio::spawn(async move { - let mut actor = send_to_socket::Actor::new(write); - - out_msg_actor_context - .handle_while(&mut actor, forward_to_cfd.attach_stream(read)) - .await; - - tracing::error!("Closing connection to taker {}", taker_id); - - actor.shutdown().await; - }); - - self.write_connections - .insert(taker_id, out_msg_actor_address); - - let _ = self - .new_taker_channel - .send(maker_cfd::NewTakerOnline { id: taker_id }) - .await; - } -} - -macro_rules! log_error { - ($future:expr) => { - if let Err(e) = $future.await { - tracing::error!(%e); - } - }; -} - -#[async_trait] -impl Handler for Actor { - async fn handle(&mut self, msg: BroadcastOrder, _ctx: &mut Context) { - log_error!(self.handle_broadcast_order(msg)); - } -} - -#[async_trait] -impl Handler for Actor { - async fn handle(&mut self, msg: TakerMessage, _ctx: &mut Context) { - log_error!(self.handle_taker_message(msg)); - } -} - -#[async_trait] -impl Handler for Actor { async fn handle(&mut self, msg: ListenerMessage, ctx: &mut Context) -> KeepRunning { match msg { ListenerMessage::NewConnection { stream, address } => { - self.handle_new_connection(stream, address, ctx).await; + self.handle_new_connection_impl(stream, address, ctx).await; KeepRunning::Yes } @@ -247,16 +226,4 @@ impl Handler for Actor { } } -impl Message for BroadcastOrder { - type Result = (); -} - -impl Message for TakerMessage { - type Result = (); -} - -impl Message for ListenerMessage { - type Result = KeepRunning; -} - impl xtra::Actor for Actor {} diff --git a/xtra_productivity/Cargo.toml b/xtra_productivity/Cargo.toml new file mode 100644 index 0000000..1839147 --- /dev/null +++ b/xtra_productivity/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "xtra_productivity" +version = "0.1.0" +edition = "2018" + +[lib] +proc-macro = true + +[dependencies] +quote = "1" +syn = { version = "1", features = ["full"] } + +[dev-dependencies] +async-trait = "0.1.51" +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +trybuild = "1" +xtra = { version = "0.6", features = ["with-tokio-1"] } diff --git a/xtra_productivity/src/lib.rs b/xtra_productivity/src/lib.rs new file mode 100644 index 0000000..1a608cb --- /dev/null +++ b/xtra_productivity/src/lib.rs @@ -0,0 +1,55 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{FnArg, ImplItem, ItemImpl, ReturnType}; + +#[proc_macro_attribute] +pub fn xtra_productivity(_attribute: TokenStream, item: TokenStream) -> TokenStream { + let block = syn::parse::(item).unwrap(); + + let actor = block.self_ty; + + let code = block + .items + .into_iter() + .filter_map(|block_item| match block_item { + ImplItem::Method(method) => Some(method), + _ => None, + }) + .map(|method| { + let message_arg = method.sig.inputs[1].clone(); + + let message_type = match message_arg { + // receiver represents self + FnArg::Receiver(_) => unreachable!("cannot have receiver on second position"), + FnArg::Typed(ref typed) => typed.ty.clone() + }; + + let method_return = method.sig.output; + let method_block = method.block; + + let context_arg = method.sig.inputs.iter().nth(2).map(|fn_arg| quote! { #fn_arg }).unwrap_or_else(|| quote! { + _ctx: &mut xtra::Context + }); + + let result_type = match method_return { + ReturnType::Default => quote! { () }, + ReturnType::Type(_, ref t) => quote! { #t } + }; + + quote! { + impl xtra::Message for #message_type { + type Result = #result_type; + } + + #[async_trait] + impl xtra::Handler<#message_type> for #actor { + async fn handle(&mut self, #message_arg, #context_arg) #method_return #method_block + } + } + }).collect::>(); + + (quote! { + #(#code)* + }) + .into() +} diff --git a/xtra_productivity/tests/pass/can_handle_message.rs b/xtra_productivity/tests/pass/can_handle_message.rs new file mode 100644 index 0000000..dabba5f --- /dev/null +++ b/xtra_productivity/tests/pass/can_handle_message.rs @@ -0,0 +1,43 @@ +use async_trait::async_trait; +use xtra::spawn::TokioGlobalSpawnExt; +use xtra::Actor; +use xtra_productivity::xtra_productivity; + +struct DummyActor; + +impl xtra::Actor for DummyActor {} + +#[derive(Clone)] +struct DummyMessage; + +struct DummyMessageWithContext; + +// Dummy actor, xtra::Handler and xtra::Message impls generated by xtra_productivity +#[xtra_productivity] +impl DummyActor { + pub fn handle_dummy_message(&mut self, message: DummyMessage) -> i32 { + let _ = message.clone(); + 0 + } + + pub fn handle_dummy_message_with_context( + &mut self, + _message: DummyMessageWithContext, + context: &mut xtra::Context, + ) { + let _ = context.address(); + } +} + +fn is_i32(_: i32) {} + +#[tokio::main] +async fn main() { + // Create dummy actor + let dummy_actor = DummyActor.create(None).spawn_global(); + + // Send message to dummy actor + let i32 = dummy_actor.send(DummyMessage).await.unwrap(); + is_i32(i32); + dummy_actor.send(DummyMessageWithContext).await.unwrap(); +} diff --git a/xtra_productivity/tests/ui.rs b/xtra_productivity/tests/ui.rs new file mode 100644 index 0000000..c13cfdf --- /dev/null +++ b/xtra_productivity/tests/ui.rs @@ -0,0 +1,6 @@ +#[test] +fn ui() { + let t = trybuild::TestCases::new(); + t.compile_fail("tests/ui/*.rs"); + t.pass("tests/pass/*.rs"); +}