From eda1b88b0a52d8ccfb9066591fe8100f806af6f0 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 27 Sep 2021 17:31:44 +1000 Subject: [PATCH] Introduce a dedicated codec --- Cargo.lock | 1 + daemon/Cargo.toml | 1 + daemon/src/maker_inc_connections.rs | 11 +++--- daemon/src/send_to_socket.rs | 23 ++++++------ daemon/src/taker_cfd.rs | 4 +-- daemon/src/taker_inc_message_actor.rs | 10 ++---- daemon/src/wire.rs | 52 +++++++++++++++++++++++++++ 7 files changed, 75 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 697552e..de84b6d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -492,6 +492,7 @@ dependencies = [ "async-trait", "atty", "bdk", + "bytes", "cfd_protocol", "clap", "futures", diff --git a/daemon/Cargo.toml b/daemon/Cargo.toml index bee7d6a..eedbb2e 100644 --- a/daemon/Cargo.toml +++ b/daemon/Cargo.toml @@ -8,6 +8,7 @@ anyhow = "1" async-trait = "0.1.51" atty = "0.2" bdk = { git = "https://github.com/bitcoindevkit/bdk/" } +bytes = "1" cfd_protocol = { path = "../cfd_protocol" } clap = "3.0.0-beta.4" futures = { version = "0.3", default-features = false } diff --git a/daemon/src/maker_inc_connections.rs b/daemon/src/maker_inc_connections.rs index 2c08182..18f32ee 100644 --- a/daemon/src/maker_inc_connections.rs +++ b/daemon/src/maker_inc_connections.rs @@ -8,7 +8,7 @@ use async_trait::async_trait; use futures::{Future, StreamExt}; use std::collections::HashMap; use tokio::net::tcp::OwnedReadHalf; -use tokio_util::codec::{FramedRead, LengthDelimitedCodec}; +use tokio_util::codec::FramedRead; use xtra::prelude::*; pub struct BroadcastOrder(pub Option); @@ -37,7 +37,7 @@ impl Message for TakerMessage { pub struct NewTakerOnline { pub taker_id: TakerId, - pub out_msg_actor: Address, + pub out_msg_actor: Address>, } impl Message for NewTakerOnline { @@ -45,7 +45,7 @@ impl Message for NewTakerOnline { } pub struct Actor { - write_connections: HashMap>, + write_connections: HashMap>>, cfd_actor: Address, } @@ -156,10 +156,7 @@ pub fn in_taker_messages( cfd_actor_inbox: Address, taker_id: TakerId, ) -> impl Future { - let mut messages = FramedRead::new(read, LengthDelimitedCodec::new()).map(|result| { - let message = serde_json::from_slice::(&result?)?; - anyhow::Result::<_>::Ok(message) - }); + let mut messages = FramedRead::new(read, wire::JsonCodec::new()); async move { while let Some(message) = messages.next().await { diff --git a/daemon/src/send_to_socket.rs b/daemon/src/send_to_socket.rs index 1ea6146..ac82e32 100644 --- a/daemon/src/send_to_socket.rs +++ b/daemon/src/send_to_socket.rs @@ -1,35 +1,36 @@ +use crate::wire::JsonCodec; use futures::SinkExt; use serde::Serialize; use std::fmt; use tokio::net::tcp::OwnedWriteHalf; -use tokio_util::codec::{FramedWrite, LengthDelimitedCodec}; +use tokio_util::codec::FramedWrite; use xtra::{Handler, Message}; -pub struct Actor { - write: FramedWrite, +pub struct Actor { + write: FramedWrite>, } -impl Actor { +impl Actor { pub fn new(write: OwnedWriteHalf) -> Self { Self { - write: FramedWrite::new(write, LengthDelimitedCodec::new()), + write: FramedWrite::new(write, JsonCodec::new()), } } } #[async_trait::async_trait] -impl Handler for Actor +impl Handler for Actor where - T: Message + Serialize + fmt::Display, + T: Message + Serialize + fmt::Display + Sync, { async fn handle(&mut self, message: T, ctx: &mut xtra::Context) { - let bytes = serde_json::to_vec(&message).expect("serialization should never fail"); + let message_name = message.to_string(); // send consumes the message, avoid a clone just in case it errors by getting the name here - if let Err(e) = self.write.send(bytes.into()).await { - tracing::error!("Failed to write message {} to socket: {}", message, e); + if let Err(e) = self.write.send(message).await { + tracing::error!("Failed to write message {} to socket: {}", message_name, e); ctx.stop(); } } } -impl xtra::Actor for Actor {} +impl xtra::Actor for Actor {} diff --git a/daemon/src/taker_cfd.rs b/daemon/src/taker_cfd.rs index a619023..2a7339f 100644 --- a/daemon/src/taker_cfd.rs +++ b/daemon/src/taker_cfd.rs @@ -39,7 +39,7 @@ pub struct Actor { oracle_pk: schnorrsig::PublicKey, cfd_feed_actor_inbox: watch::Sender>, order_feed_actor_inbox: watch::Sender>, - send_to_maker: Address, + send_to_maker: Address>, current_contract_setup: Option>, // TODO: Move the contract setup into a dedicated actor and send messages to that actor that // manages the state instead of this ugly buffer @@ -54,7 +54,7 @@ impl Actor { oracle_pk: schnorrsig::PublicKey, cfd_feed_actor_inbox: watch::Sender>, order_feed_actor_inbox: watch::Sender>, - send_to_maker: Address, + send_to_maker: Address>, monitor_actor: Address>, ) -> Result { let mut conn = db.acquire().await?; diff --git a/daemon/src/taker_inc_message_actor.rs b/daemon/src/taker_inc_message_actor.rs index 75e0a3e..c82ef70 100644 --- a/daemon/src/taker_inc_message_actor.rs +++ b/daemon/src/taker_inc_message_actor.rs @@ -1,17 +1,13 @@ use crate::model::cfd::Origin; +use crate::wire::JsonCodec; use crate::{taker_cfd, wire}; use futures::{Future, StreamExt}; use tokio::net::tcp::OwnedReadHalf; -use tokio_util::codec::{FramedRead, LengthDelimitedCodec}; +use tokio_util::codec::FramedRead; use xtra::prelude::*; pub fn new(read: OwnedReadHalf, cfd_actor: Address) -> impl Future { - let frame_read = FramedRead::new(read, LengthDelimitedCodec::new()); - - let mut messages = frame_read.map(|result| { - let message = serde_json::from_slice::(&result?)?; - anyhow::Result::<_>::Ok(message) - }); + let mut messages = FramedRead::new(read, JsonCodec::new()); async move { while let Some(message) = messages.next().await { diff --git a/daemon/src/wire.rs b/daemon/src/wire.rs index 71faecb..0edcc22 100644 --- a/daemon/src/wire.rs +++ b/daemon/src/wire.rs @@ -4,11 +4,15 @@ use crate::Order; use bdk::bitcoin::secp256k1::Signature; use bdk::bitcoin::util::psbt::PartiallySignedTransaction; use bdk::bitcoin::{Address, Amount, PublicKey}; +use bytes::BytesMut; use cfd_protocol::secp256k1_zkp::EcdsaAdaptorSignature; use cfd_protocol::{CfdTransactions, PartyParams, PunishParams}; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::fmt; +use std::marker::PhantomData; use std::ops::RangeInclusive; +use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec}; #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", content = "payload")] @@ -50,6 +54,54 @@ impl fmt::Display for MakerToTaker { } } +pub struct JsonCodec { + _type: PhantomData, + inner: LengthDelimitedCodec, +} + +impl JsonCodec { + pub fn new() -> Self { + Self { + _type: PhantomData, + inner: LengthDelimitedCodec::new(), + } + } +} + +impl Decoder for JsonCodec +where + T: DeserializeOwned, +{ + type Item = T; + type Error = anyhow::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + let bytes = match self.inner.decode(src)? { + None => return Ok(None), + Some(bytes) => bytes, + }; + + let item = serde_json::from_slice(&bytes)?; + + Ok(Some(item)) + } +} + +impl Encoder for JsonCodec +where + T: Serialize, +{ + type Error = anyhow::Error; + + fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> { + let bytes = serde_json::to_vec(&item)?; + + self.inner.encode(bytes.into(), dst)?; + + Ok(()) + } +} + #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", content = "payload")] pub enum SetupMsg {