diff --git a/daemon/src/maker.rs b/daemon/src/maker.rs index 0ba0ab5..65c0c92 100644 --- a/daemon/src/maker.rs +++ b/daemon/src/maker.rs @@ -29,7 +29,7 @@ mod model; mod routes; mod routes_maker; mod seed; -mod send_wire_message_actor; +mod send_to_socket; mod setup_contract_actor; mod to_sse_event; mod wallet; @@ -194,16 +194,17 @@ async fn main() -> Result<()> { cfd_maker_actor_inbox.clone(), taker_id, ); - let (out_msg_actor, out_msg_actor_inbox) = - send_wire_message_actor::new::(write); + + let out_msg_actor = send_to_socket::Actor::new(write) + .create(None) + .spawn_global(); tokio::spawn(in_taker_actor); - tokio::spawn(out_msg_actor); maker_inc_connections_address .do_send_async(maker_inc_connections::NewTakerOnline { taker_id, - out_msg_actor_inbox, + out_msg_actor, }) .await .unwrap(); @@ -238,3 +239,7 @@ async fn main() -> Result<()> { Ok(()) } + +impl xtra::Message for wire::MakerToTaker { + type Result = (); +} diff --git a/daemon/src/maker_inc_connections.rs b/daemon/src/maker_inc_connections.rs index 22d7f73..2c08182 100644 --- a/daemon/src/maker_inc_connections.rs +++ b/daemon/src/maker_inc_connections.rs @@ -2,18 +2,15 @@ use crate::actors::log_error; use crate::model::cfd::{Order, OrderId}; use crate::model::TakerId; use crate::wire::SetupMsg; -use crate::{maker_cfd, wire}; +use crate::{maker_cfd, send_to_socket, wire}; use anyhow::{Context as AnyhowContext, Result}; use async_trait::async_trait; use futures::{Future, StreamExt}; use std::collections::HashMap; use tokio::net::tcp::OwnedReadHalf; -use tokio::sync::mpsc; use tokio_util::codec::{FramedRead, LengthDelimitedCodec}; use xtra::prelude::*; -type MakerToTakerSender = mpsc::UnboundedSender; - pub struct BroadcastOrder(pub Option); impl Message for BroadcastOrder { @@ -40,7 +37,7 @@ impl Message for TakerMessage { pub struct NewTakerOnline { pub taker_id: TakerId, - pub out_msg_actor_inbox: MakerToTakerSender, + pub out_msg_actor: Address, } impl Message for NewTakerOnline { @@ -48,7 +45,7 @@ impl Message for NewTakerOnline { } pub struct Actor { - write_connections: HashMap, + write_connections: HashMap>, cfd_actor: Address, } @@ -57,44 +54,53 @@ impl xtra::Actor for Actor {} impl Actor { pub fn new(cfd_actor: Address) -> Self { Self { - write_connections: HashMap::::new(), + write_connections: HashMap::new(), cfd_actor, } } - fn send_to_taker(&self, taker_id: TakerId, msg: wire::MakerToTaker) -> Result<()> { + async fn send_to_taker(&self, taker_id: TakerId, msg: wire::MakerToTaker) -> Result<()> { let conn = self .write_connections .get(&taker_id) .context("no connection to taker_id")?; - conn.send(msg)?; + conn.do_send_async(msg).await?; + Ok(()) } async fn handle_broadcast_order(&mut self, msg: BroadcastOrder) -> Result<()> { let order = msg.0; - self.write_connections - .values() - .try_for_each(|conn| conn.send(wire::MakerToTaker::CurrentOrder(order.clone())))?; + + for conn in self.write_connections.values() { + conn.do_send_async(wire::MakerToTaker::CurrentOrder(order.clone())) + .await?; + } + Ok(()) } async fn handle_taker_message(&mut self, msg: TakerMessage) -> Result<()> { match msg.command { TakerCommand::SendOrder { order } => { - self.send_to_taker(msg.taker_id, wire::MakerToTaker::CurrentOrder(order))?; + self.send_to_taker(msg.taker_id, wire::MakerToTaker::CurrentOrder(order)) + .await?; } TakerCommand::NotifyInvalidOrderId { id } => { - self.send_to_taker(msg.taker_id, wire::MakerToTaker::InvalidOrderId(id))?; + self.send_to_taker(msg.taker_id, wire::MakerToTaker::InvalidOrderId(id)) + .await?; } TakerCommand::NotifyOrderAccepted { id } => { - self.send_to_taker(msg.taker_id, wire::MakerToTaker::ConfirmOrder(id))?; + self.send_to_taker(msg.taker_id, wire::MakerToTaker::ConfirmOrder(id)) + .await?; } TakerCommand::NotifyOrderRejected { id } => { - self.send_to_taker(msg.taker_id, wire::MakerToTaker::RejectOrder(id))?; + self.send_to_taker(msg.taker_id, wire::MakerToTaker::RejectOrder(id)) + .await?; } TakerCommand::OutProtocolMsg { setup_msg } => { - self.send_to_taker(msg.taker_id, wire::MakerToTaker::Protocol(setup_msg))?; + self.send_to_taker(msg.taker_id, wire::MakerToTaker::Protocol(setup_msg)) + .await?; } } Ok(()) @@ -106,7 +112,7 @@ impl Actor { .await?; self.write_connections - .insert(msg.taker_id, msg.out_msg_actor_inbox); + .insert(msg.taker_id, msg.out_msg_actor); Ok(()) } } diff --git a/daemon/src/send_to_socket.rs b/daemon/src/send_to_socket.rs new file mode 100644 index 0000000..1ea6146 --- /dev/null +++ b/daemon/src/send_to_socket.rs @@ -0,0 +1,35 @@ +use futures::SinkExt; +use serde::Serialize; +use std::fmt; +use tokio::net::tcp::OwnedWriteHalf; +use tokio_util::codec::{FramedWrite, LengthDelimitedCodec}; +use xtra::{Handler, Message}; + +pub struct Actor { + write: FramedWrite, +} + +impl Actor { + pub fn new(write: OwnedWriteHalf) -> Self { + Self { + write: FramedWrite::new(write, LengthDelimitedCodec::new()), + } + } +} + +#[async_trait::async_trait] +impl Handler for Actor +where + T: Message + Serialize + fmt::Display, +{ + async fn handle(&mut self, message: T, ctx: &mut xtra::Context) { + let bytes = serde_json::to_vec(&message).expect("serialization should never fail"); + + if let Err(e) = self.write.send(bytes.into()).await { + tracing::error!("Failed to write message {} to socket: {}", message, e); + ctx.stop(); + } + } +} + +impl xtra::Actor for Actor {} diff --git a/daemon/src/send_wire_message_actor.rs b/daemon/src/send_wire_message_actor.rs deleted file mode 100644 index b127c32..0000000 --- a/daemon/src/send_wire_message_actor.rs +++ /dev/null @@ -1,31 +0,0 @@ -use futures::{Future, SinkExt}; -use serde::Serialize; -use tokio::net::tcp::OwnedWriteHalf; -use tokio::sync::mpsc; -use tokio_util::codec::{FramedWrite, LengthDelimitedCodec}; - -pub fn new(write: OwnedWriteHalf) -> (impl Future, mpsc::UnboundedSender) -where - T: Serialize, -{ - let (sender, mut receiver) = mpsc::unbounded_channel::(); - - let actor = async move { - let mut framed_write = FramedWrite::new(write, LengthDelimitedCodec::new()); - - while let Some(message) = receiver.recv().await { - match framed_write - .send(serde_json::to_vec(&message).unwrap().into()) - .await - { - Ok(_) => {} - Err(_) => { - tracing::error!("TCP connection error"); - break; - } - } - } - }; - - (actor, sender) -} diff --git a/daemon/src/taker.rs b/daemon/src/taker.rs index 4c851a5..79adfd5 100644 --- a/daemon/src/taker.rs +++ b/daemon/src/taker.rs @@ -14,6 +14,7 @@ use std::thread::sleep; use std::time::Duration; use tokio::sync::watch; use tracing_subscriber::filter::LevelFilter; +use wire::TakerToMaker; use xtra::spawn::TokioGlobalSpawnExt; use xtra::Actor; @@ -26,7 +27,7 @@ mod model; mod routes; mod routes_taker; mod seed; -mod send_wire_message_actor; +mod send_to_socket; mod setup_contract_actor; mod taker_cfd; mod taker_inc_message_actor; @@ -157,8 +158,9 @@ async fn main() -> Result<()> { None => return Err(rocket), }; - let (out_maker_messages_actor, out_maker_actor_inbox) = - send_wire_message_actor::new(write); + let send_to_maker = send_to_socket::Actor::new(write) + .create(None) + .spawn_global(); let cfd_actor_inbox = taker_cfd::Actor::new( db, @@ -166,7 +168,7 @@ async fn main() -> Result<()> { schnorrsig::PublicKey::from_keypair(SECP256K1, &oracle), cfd_feed_sender, order_feed_sender, - out_maker_actor_inbox, + send_to_maker, ) .await .unwrap() @@ -178,7 +180,6 @@ async fn main() -> Result<()> { tokio::spawn(wallet_sync::new(wallet, wallet_feed_sender)); tokio::spawn(inc_maker_messages_actor); - tokio::spawn(out_maker_messages_actor); Ok(rocket.manage(cfd_actor_inbox)) }, @@ -201,3 +202,7 @@ async fn main() -> Result<()> { Ok(()) } + +impl xtra::Message for TakerToMaker { + type Result = (); +} diff --git a/daemon/src/taker_cfd.rs b/daemon/src/taker_cfd.rs index e8c1ef2..ffae7f9 100644 --- a/daemon/src/taker_cfd.rs +++ b/daemon/src/taker_cfd.rs @@ -8,7 +8,7 @@ use crate::model::cfd::{Cfd, CfdState, CfdStateCommon, Dlc, Order, OrderId}; use crate::model::Usd; use crate::wallet::Wallet; use crate::wire::SetupMsg; -use crate::{setup_contract_actor, wire}; +use crate::{send_to_socket, setup_contract_actor, wire}; use anyhow::Result; use async_trait::async_trait; use bdk::bitcoin::secp256k1::schnorrsig; @@ -38,7 +38,7 @@ pub struct Actor { oracle_pk: schnorrsig::PublicKey, cfd_feed_actor_inbox: watch::Sender>, order_feed_actor_inbox: watch::Sender>, - out_msg_maker_inbox: mpsc::UnboundedSender, + send_to_maker: Address, current_contract_setup: Option>, // TODO: Move the contract setup into a dedicated actor and send messages to that actor that // manages the state instead of this ugly buffer @@ -52,7 +52,7 @@ impl Actor { oracle_pk: schnorrsig::PublicKey, cfd_feed_actor_inbox: watch::Sender>, order_feed_actor_inbox: watch::Sender>, - out_msg_maker_inbox: mpsc::UnboundedSender, + send_to_maker: Address, ) -> Result { let mut conn = db.acquire().await?; cfd_feed_actor_inbox.send(load_all_cfds(&mut conn).await?)?; @@ -63,7 +63,7 @@ impl Actor { oracle_pk, cfd_feed_actor_inbox, order_feed_actor_inbox, - out_msg_maker_inbox, + send_to_maker, current_contract_setup: None, contract_setup_message_buffer: vec![], }) @@ -90,8 +90,9 @@ impl Actor { self.cfd_feed_actor_inbox .send(load_all_cfds(&mut conn).await?)?; - self.out_msg_maker_inbox - .send(wire::TakerToMaker::TakeOrder { order_id, quantity })?; + self.send_to_maker + .do_send_async(wire::TakerToMaker::TakeOrder { order_id, quantity }) + .await?; Ok(()) } @@ -141,8 +142,10 @@ impl Actor { let (actor, inbox) = setup_contract_actor::new( { - let inbox = self.out_msg_maker_inbox.clone(); - move |msg| inbox.send(wire::TakerToMaker::Protocol(msg)).unwrap() + let inbox = self.send_to_maker.clone(); + move |msg| { + tokio::spawn(inbox.do_send_async(wire::TakerToMaker::Protocol(msg))); + } }, setup_contract_actor::OwnParams::Taker(taker_params), sk, diff --git a/daemon/src/wire.rs b/daemon/src/wire.rs index ff8f0f6..71faecb 100644 --- a/daemon/src/wire.rs +++ b/daemon/src/wire.rs @@ -7,6 +7,7 @@ use bdk::bitcoin::{Address, Amount, PublicKey}; use cfd_protocol::secp256k1_zkp::EcdsaAdaptorSignature; use cfd_protocol::{CfdTransactions, PartyParams, PunishParams}; use serde::{Deserialize, Serialize}; +use std::fmt; use std::ops::RangeInclusive; #[derive(Debug, Serialize, Deserialize)] @@ -17,6 +18,15 @@ pub enum TakerToMaker { Protocol(SetupMsg), } +impl fmt::Display for TakerToMaker { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TakerToMaker::TakeOrder { .. } => write!(f, "TakeOrder"), + TakerToMaker::Protocol(_) => write!(f, "Protocol"), + } + } +} + #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", content = "payload")] #[allow(clippy::large_enum_variant)] @@ -28,6 +38,18 @@ pub enum MakerToTaker { Protocol(SetupMsg), } +impl fmt::Display for MakerToTaker { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MakerToTaker::CurrentOrder(_) => write!(f, "CurrentOrder"), + MakerToTaker::ConfirmOrder(_) => write!(f, "ConfirmOrder"), + MakerToTaker::RejectOrder(_) => write!(f, "RejectOrder"), + MakerToTaker::InvalidOrderId(_) => write!(f, "InvalidOrderId"), + MakerToTaker::Protocol(_) => write!(f, "Protocol"), + } + } +} + #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", content = "payload")] pub enum SetupMsg {