Browse Source

Don't manually lock EncryptedJsonCodec

There is a utility in `futures` for splitting anything that implements
`AsyncRead` + `AsyncWrite` into two components. It also uses a lock
interally but at least we don't have to implement it ourselves. In this
case, it also fixes using the wrong mutex (std::sync). A side-effect of
this change is that our codec now needs to have two type parameters
because we only construct it once and we have two different messages for
sending and receiving. The implications of this are minimal though.

Might help with #759.
update-blockstream-electrum-server-url
Thomas Eizinger 3 years ago
parent
commit
9f2209c99c
No known key found for this signature in database GPG Key ID: 651AC83A6C6C8B96
  1. 15
      daemon/src/connection.rs
  2. 17
      daemon/src/maker_inc_connections.rs
  3. 33
      daemon/src/send_to_socket.rs
  4. 39
      daemon/src/wire.rs

15
daemon/src/connection.rs

@ -26,14 +26,11 @@ use futures::SinkExt;
use futures::StreamExt;
use futures::TryStreamExt;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use std::time::SystemTime;
use tokio::net::TcpStream;
use tokio::sync::watch;
use tokio_util::codec::FramedRead;
use tokio_util::codec::FramedWrite;
use tokio_util::codec::Framed;
use xtra::prelude::MessageChannel;
use xtra::KeepRunning;
use xtra_productivity::xtra_productivity;
@ -49,7 +46,7 @@ struct ConnectedState {
pub struct Actor {
status_sender: watch::Sender<ConnectionStatus>,
send_to_maker: Box<dyn MessageChannel<wire::TakerToMaker>>,
send_to_maker_ctx: xtra::Context<send_to_socket::Actor<wire::TakerToMaker>>,
send_to_maker_ctx: xtra::Context<send_to_socket::Actor<wire::MakerToTaker, wire::TakerToMaker>>,
identity_sk: x25519_dalek::StaticSecret,
current_order: Box<dyn MessageChannel<CurrentOrder>>,
/// Max duration since the last heartbeat until we die.
@ -238,7 +235,7 @@ impl Actor {
) -> Result<()> {
tracing::debug!(address = %maker_addr, "Connecting to maker");
let (read, write, noise) = {
let (mut write, mut read) = {
let mut connection = TcpStream::connect(&maker_addr)
.timeout(self.connect_timeout)
.await
@ -257,13 +254,9 @@ impl Actor {
)
.await?;
let (read, write) = connection.into_split();
(read, write, Arc::new(Mutex::new(noise)))
Framed::new(connection, EncryptedJsonCodec::new(noise)).split()
};
let mut read = FramedRead::new(read, wire::EncryptedJsonCodec::new(noise.clone()));
let mut write = FramedWrite::new(write, EncryptedJsonCodec::new(noise));
let our_version = Version::current();
write.send(TakerToMaker::Hello(our_version.clone())).await?;

17
daemon/src/maker_inc_connections.rs

@ -26,16 +26,14 @@ use anyhow::bail;
use anyhow::Context;
use anyhow::Result;
use futures::SinkExt;
use futures::StreamExt;
use futures::TryStreamExt;
use std::collections::HashMap;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio_util::codec::FramedRead;
use tokio_util::codec::FramedWrite;
use tokio_util::codec::Framed;
use xtra::prelude::*;
use xtra::KeepRunning;
use xtra_productivity::xtra_productivity;
@ -99,7 +97,8 @@ pub enum ListenerMessage {
}
pub struct Actor {
write_connections: HashMap<Identity, Address<send_to_socket::Actor<wire::MakerToTaker>>>,
write_connections:
HashMap<Identity, Address<send_to_socket::Actor<wire::TakerToMaker, wire::MakerToTaker>>>,
taker_connected_channel: Box<dyn MessageChannel<TakerConnected>>,
taker_disconnected_channel: Box<dyn MessageChannel<TakerDisconnected>>,
taker_msg_channel: Box<dyn MessageChannel<FromTaker>>,
@ -175,12 +174,8 @@ impl Actor {
let transport_state = noise::responder_handshake(&mut stream, &self.noise_priv_key).await?;
let taker_id = Identity::new(transport_state.get_remote_public_key()?);
let transport_state = Arc::new(Mutex::new(transport_state));
let (read, write) = stream.into_split();
let mut read =
FramedRead::new(read, wire::EncryptedJsonCodec::new(transport_state.clone()));
let mut write = FramedWrite::new(write, EncryptedJsonCodec::new(transport_state));
let (mut write, mut read) =
Framed::new(stream, EncryptedJsonCodec::new(transport_state)).split();
match read
.try_next()

33
daemon/src/send_to_socket.rs

@ -1,34 +1,30 @@
use crate::wire;
use crate::wire::EncryptedJsonCodec;
use futures::stream::SplitSink;
use futures::SinkExt;
use serde::Serialize;
use std::fmt;
use tokio::io::AsyncWriteExt;
use tokio::net::tcp::OwnedWriteHalf;
use tokio_util::codec::FramedWrite;
use tokio::net::TcpStream;
use tokio_util::codec::Framed;
use xtra::Handler;
use xtra::Message;
pub struct Actor<T> {
write: FramedWrite<OwnedWriteHalf, EncryptedJsonCodec<T>>,
pub struct Actor<D, E> {
write: SplitSink<Framed<TcpStream, wire::EncryptedJsonCodec<D, E>>, E>,
}
impl<T> Actor<T> {
pub fn new(write: FramedWrite<OwnedWriteHalf, EncryptedJsonCodec<T>>) -> Self {
impl<D, E> Actor<D, E> {
pub fn new(write: SplitSink<Framed<TcpStream, wire::EncryptedJsonCodec<D, E>>, E>) -> Self {
Self { write }
}
pub async fn shutdown(self) {
let _ = self.write.into_inner().shutdown().await;
}
}
#[async_trait::async_trait]
impl<T> Handler<T> for Actor<T>
impl<D, E> Handler<E> for Actor<D, E>
where
T: Message<Result = ()> + Serialize + fmt::Display + Sync,
D: Send + Sync + 'static,
E: Message<Result = ()> + Serialize + fmt::Display + Sync,
{
async fn handle(&mut self, message: T, ctx: &mut xtra::Context<Self>) {
async fn handle(&mut self, message: E, ctx: &mut xtra::Context<Self>) {
let message_name = message.to_string(); // send consumes the message, avoid a clone just in case it errors by getting the name here
tracing::trace!("Sending '{}'", message_name);
@ -40,7 +36,12 @@ where
}
}
impl<T: 'static + Send> xtra::Actor for Actor<T> {}
impl<D, E> xtra::Actor for Actor<D, E>
where
D: 'static + Send,
E: 'static + Send,
{
}
impl xtra::Message for wire::MakerToTaker {
type Result = ();

39
daemon/src/wire.rs

@ -27,8 +27,6 @@ use std::collections::HashMap;
use std::fmt;
use std::marker::PhantomData;
use std::ops::RangeInclusive;
use std::sync::Arc;
use std::sync::Mutex;
use tokio_util::codec::Decoder;
use tokio_util::codec::Encoder;
use tokio_util::codec::LengthDelimitedCodec;
@ -168,14 +166,15 @@ impl fmt::Display for MakerToTaker {
}
}
pub struct EncryptedJsonCodec<T> {
_type: PhantomData<T>,
/// A codec that can decode encrypted JSON into the type `D` and encode `E` to encrypted JSON.
pub struct EncryptedJsonCodec<D, E> {
_type: PhantomData<(D, E)>,
inner: LengthDelimitedCodec,
transport_state: Arc<Mutex<TransportState>>,
transport_state: TransportState,
}
impl<T> EncryptedJsonCodec<T> {
pub fn new(transport_state: Arc<Mutex<TransportState>>) -> Self {
impl<D, E> EncryptedJsonCodec<D, E> {
pub fn new(transport_state: TransportState) -> Self {
Self {
_type: PhantomData,
inner: LengthDelimitedCodec::new(),
@ -184,11 +183,11 @@ impl<T> EncryptedJsonCodec<T> {
}
}
impl<T> Decoder for EncryptedJsonCodec<T>
impl<D, E> Decoder for EncryptedJsonCodec<D, E>
where
T: DeserializeOwned,
D: DeserializeOwned,
{
type Item = T;
type Item = D;
type Error = anyhow::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
@ -197,16 +196,11 @@ where
Some(bytes) => bytes,
};
let mut transport = self
.transport_state
.lock()
.expect("acquired mutex lock on Noise object to encrypt message");
let decrypted = bytes
.chunks(NOISE_MAX_MSG_LEN as usize)
.map(|chunk| {
let mut buf = vec![0u8; chunk.len() - NOISE_TAG_LEN as usize];
transport.read_message(chunk, &mut *buf)?;
self.transport_state.read_message(chunk, &mut *buf)?;
Ok(buf)
})
.collect::<Result<Vec<Vec<u8>>>>()?
@ -220,25 +214,20 @@ where
}
}
impl<T> Encoder<T> for EncryptedJsonCodec<T>
impl<D, E> Encoder<E> for EncryptedJsonCodec<D, E>
where
T: Serialize,
E: Serialize,
{
type Error = anyhow::Error;
fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
fn encode(&mut self, item: E, dst: &mut BytesMut) -> Result<(), Self::Error> {
let bytes = serde_json::to_vec(&item)?;
let mut transport = self
.transport_state
.lock()
.expect("acquired mutex lock on Noise object to encrypt message");
let encrypted = bytes
.chunks((NOISE_MAX_MSG_LEN - NOISE_TAG_LEN) as usize)
.map(|chunk| {
let mut buf = vec![0u8; chunk.len() + NOISE_TAG_LEN as usize];
transport.write_message(chunk, &mut *buf)?;
self.transport_state.write_message(chunk, &mut *buf)?;
Ok(buf)
})
.collect::<Result<Vec<Vec<u8>>>>()?

Loading…
Cancel
Save