From 9f2209c99c3a9fceb8ed1936c943d7db9e7b7714 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Sun, 19 Dec 2021 16:47:03 +1100 Subject: [PATCH] Don't manually lock EncryptedJsonCodec There is a utility in `futures` for splitting anything that implements `AsyncRead` + `AsyncWrite` into two components. It also uses a lock interally but at least we don't have to implement it ourselves. In this case, it also fixes using the wrong mutex (std::sync). A side-effect of this change is that our codec now needs to have two type parameters because we only construct it once and we have two different messages for sending and receiving. The implications of this are minimal though. Might help with #759. --- daemon/src/connection.rs | 15 +++-------- daemon/src/maker_inc_connections.rs | 17 +++++-------- daemon/src/send_to_socket.rs | 33 ++++++++++++------------ daemon/src/wire.rs | 39 +++++++++++------------------ 4 files changed, 41 insertions(+), 63 deletions(-) diff --git a/daemon/src/connection.rs b/daemon/src/connection.rs index 855db38..4f1b1fd 100644 --- a/daemon/src/connection.rs +++ b/daemon/src/connection.rs @@ -26,14 +26,11 @@ use futures::SinkExt; use futures::StreamExt; use futures::TryStreamExt; use std::net::SocketAddr; -use std::sync::Arc; -use std::sync::Mutex; use std::time::Duration; use std::time::SystemTime; use tokio::net::TcpStream; use tokio::sync::watch; -use tokio_util::codec::FramedRead; -use tokio_util::codec::FramedWrite; +use tokio_util::codec::Framed; use xtra::prelude::MessageChannel; use xtra::KeepRunning; use xtra_productivity::xtra_productivity; @@ -49,7 +46,7 @@ struct ConnectedState { pub struct Actor { status_sender: watch::Sender, send_to_maker: Box>, - send_to_maker_ctx: xtra::Context>, + send_to_maker_ctx: xtra::Context>, identity_sk: x25519_dalek::StaticSecret, current_order: Box>, /// Max duration since the last heartbeat until we die. @@ -238,7 +235,7 @@ impl Actor { ) -> Result<()> { tracing::debug!(address = %maker_addr, "Connecting to maker"); - let (read, write, noise) = { + let (mut write, mut read) = { let mut connection = TcpStream::connect(&maker_addr) .timeout(self.connect_timeout) .await @@ -257,13 +254,9 @@ impl Actor { ) .await?; - let (read, write) = connection.into_split(); - (read, write, Arc::new(Mutex::new(noise))) + Framed::new(connection, EncryptedJsonCodec::new(noise)).split() }; - let mut read = FramedRead::new(read, wire::EncryptedJsonCodec::new(noise.clone())); - let mut write = FramedWrite::new(write, EncryptedJsonCodec::new(noise)); - let our_version = Version::current(); write.send(TakerToMaker::Hello(our_version.clone())).await?; diff --git a/daemon/src/maker_inc_connections.rs b/daemon/src/maker_inc_connections.rs index c7b84ee..d2d9f70 100644 --- a/daemon/src/maker_inc_connections.rs +++ b/daemon/src/maker_inc_connections.rs @@ -26,16 +26,14 @@ use anyhow::bail; use anyhow::Context; use anyhow::Result; use futures::SinkExt; +use futures::StreamExt; use futures::TryStreamExt; use std::collections::HashMap; use std::io; use std::net::SocketAddr; -use std::sync::Arc; -use std::sync::Mutex; use std::time::Duration; use tokio::net::TcpStream; -use tokio_util::codec::FramedRead; -use tokio_util::codec::FramedWrite; +use tokio_util::codec::Framed; use xtra::prelude::*; use xtra::KeepRunning; use xtra_productivity::xtra_productivity; @@ -99,7 +97,8 @@ pub enum ListenerMessage { } pub struct Actor { - write_connections: HashMap>>, + write_connections: + HashMap>>, taker_connected_channel: Box>, taker_disconnected_channel: Box>, taker_msg_channel: Box>, @@ -175,12 +174,8 @@ impl Actor { let transport_state = noise::responder_handshake(&mut stream, &self.noise_priv_key).await?; let taker_id = Identity::new(transport_state.get_remote_public_key()?); - let transport_state = Arc::new(Mutex::new(transport_state)); - - let (read, write) = stream.into_split(); - let mut read = - FramedRead::new(read, wire::EncryptedJsonCodec::new(transport_state.clone())); - let mut write = FramedWrite::new(write, EncryptedJsonCodec::new(transport_state)); + let (mut write, mut read) = + Framed::new(stream, EncryptedJsonCodec::new(transport_state)).split(); match read .try_next() diff --git a/daemon/src/send_to_socket.rs b/daemon/src/send_to_socket.rs index 3ba2ba9..8c5170c 100644 --- a/daemon/src/send_to_socket.rs +++ b/daemon/src/send_to_socket.rs @@ -1,34 +1,30 @@ use crate::wire; -use crate::wire::EncryptedJsonCodec; +use futures::stream::SplitSink; use futures::SinkExt; use serde::Serialize; use std::fmt; -use tokio::io::AsyncWriteExt; -use tokio::net::tcp::OwnedWriteHalf; -use tokio_util::codec::FramedWrite; +use tokio::net::TcpStream; +use tokio_util::codec::Framed; use xtra::Handler; use xtra::Message; -pub struct Actor { - write: FramedWrite>, +pub struct Actor { + write: SplitSink>, E>, } -impl Actor { - pub fn new(write: FramedWrite>) -> Self { +impl Actor { + pub fn new(write: SplitSink>, E>) -> Self { Self { write } } - - pub async fn shutdown(self) { - let _ = self.write.into_inner().shutdown().await; - } } #[async_trait::async_trait] -impl Handler for Actor +impl Handler for Actor where - T: Message + Serialize + fmt::Display + Sync, + D: Send + Sync + 'static, + E: Message + Serialize + fmt::Display + Sync, { - async fn handle(&mut self, message: T, ctx: &mut xtra::Context) { + async fn handle(&mut self, message: E, ctx: &mut xtra::Context) { let message_name = message.to_string(); // send consumes the message, avoid a clone just in case it errors by getting the name here tracing::trace!("Sending '{}'", message_name); @@ -40,7 +36,12 @@ where } } -impl xtra::Actor for Actor {} +impl xtra::Actor for Actor +where + D: 'static + Send, + E: 'static + Send, +{ +} impl xtra::Message for wire::MakerToTaker { type Result = (); diff --git a/daemon/src/wire.rs b/daemon/src/wire.rs index 24b0f30..2828e51 100644 --- a/daemon/src/wire.rs +++ b/daemon/src/wire.rs @@ -27,8 +27,6 @@ use std::collections::HashMap; use std::fmt; use std::marker::PhantomData; use std::ops::RangeInclusive; -use std::sync::Arc; -use std::sync::Mutex; use tokio_util::codec::Decoder; use tokio_util::codec::Encoder; use tokio_util::codec::LengthDelimitedCodec; @@ -168,14 +166,15 @@ impl fmt::Display for MakerToTaker { } } -pub struct EncryptedJsonCodec { - _type: PhantomData, +/// A codec that can decode encrypted JSON into the type `D` and encode `E` to encrypted JSON. +pub struct EncryptedJsonCodec { + _type: PhantomData<(D, E)>, inner: LengthDelimitedCodec, - transport_state: Arc>, + transport_state: TransportState, } -impl EncryptedJsonCodec { - pub fn new(transport_state: Arc>) -> Self { +impl EncryptedJsonCodec { + pub fn new(transport_state: TransportState) -> Self { Self { _type: PhantomData, inner: LengthDelimitedCodec::new(), @@ -184,11 +183,11 @@ impl EncryptedJsonCodec { } } -impl Decoder for EncryptedJsonCodec +impl Decoder for EncryptedJsonCodec where - T: DeserializeOwned, + D: DeserializeOwned, { - type Item = T; + type Item = D; type Error = anyhow::Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { @@ -197,16 +196,11 @@ where Some(bytes) => bytes, }; - let mut transport = self - .transport_state - .lock() - .expect("acquired mutex lock on Noise object to encrypt message"); - let decrypted = bytes .chunks(NOISE_MAX_MSG_LEN as usize) .map(|chunk| { let mut buf = vec![0u8; chunk.len() - NOISE_TAG_LEN as usize]; - transport.read_message(chunk, &mut *buf)?; + self.transport_state.read_message(chunk, &mut *buf)?; Ok(buf) }) .collect::>>>()? @@ -220,25 +214,20 @@ where } } -impl Encoder for EncryptedJsonCodec +impl Encoder for EncryptedJsonCodec where - T: Serialize, + E: Serialize, { type Error = anyhow::Error; - fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> { + fn encode(&mut self, item: E, dst: &mut BytesMut) -> Result<(), Self::Error> { let bytes = serde_json::to_vec(&item)?; - let mut transport = self - .transport_state - .lock() - .expect("acquired mutex lock on Noise object to encrypt message"); - let encrypted = bytes .chunks((NOISE_MAX_MSG_LEN - NOISE_TAG_LEN) as usize) .map(|chunk| { let mut buf = vec![0u8; chunk.len() + NOISE_TAG_LEN as usize]; - transport.write_message(chunk, &mut *buf)?; + self.transport_state.write_message(chunk, &mut *buf)?; Ok(buf) }) .collect::>>>()?