Browse Source

Refactor maker_inc_connections::Actor

To allow us to control where to forward each message from the taker.
chore/leaner-release-process
Lucas Soriano del Pino 3 years ago
parent
commit
4245c22b44
No known key found for this signature in database GPG Key ID: EE611E973A1530E7
  1. 58
      daemon/src/forward_only_ok.rs
  2. 1
      daemon/src/lib.rs
  3. 8
      daemon/src/maker_cfd.rs
  4. 108
      daemon/src/maker_inc_connections.rs

58
daemon/src/forward_only_ok.rs

@ -1,58 +0,0 @@
use std::fmt;
use xtra::prelude::MessageChannel;
use xtra::{Handler, KeepRunning};
/// A forwarding actor that only forwards [`Result::Ok`] values and shuts itself down upon the first
/// error.
pub struct Actor<M> {
forward: Box<dyn MessageChannel<M>>,
}
impl<M> Actor<M> {
pub fn new(forward: Box<dyn MessageChannel<M>>) -> Self {
Self { forward }
}
}
pub struct Message<TOk, TErr>(pub Result<TOk, TErr>);
impl<TOk, TErr> xtra::Message for Message<TOk, TErr>
where
TOk: Send + 'static,
TErr: Send + 'static,
{
type Result = KeepRunning;
}
#[async_trait::async_trait]
impl<TOk, TErr> Handler<Message<TOk, TErr>> for Actor<TOk>
where
TOk: xtra::Message<Result = ()> + Send + 'static,
TErr: fmt::Display + Send + 'static,
{
async fn handle(
&mut self,
Message(result): Message<TOk, TErr>,
_: &mut xtra::Context<Self>,
) -> KeepRunning {
let ok = match result {
Ok(ok) => ok,
Err(e) => {
tracing::error!("Stopping forwarding due to error: {}", e);
return KeepRunning::StopSelf;
}
};
if let Err(xtra::Disconnected) = self.forward.send(ok).await {
tracing::info!("Target actor disappeared, stopping");
return KeepRunning::StopSelf;
}
KeepRunning::Yes
}
}
impl<T: 'static + Send> xtra::Actor for Actor<T> {}

1
daemon/src/lib.rs

@ -26,7 +26,6 @@ pub mod cfd_actors;
pub mod connection; pub mod connection;
pub mod db; pub mod db;
pub mod fan_out; pub mod fan_out;
pub mod forward_only_ok;
pub mod housekeeping; pub mod housekeeping;
pub mod keypair; pub mod keypair;
pub mod logger; pub mod logger;

8
daemon/src/maker_cfd.rs

@ -439,9 +439,15 @@ where
// state change. Once we know that we go for either an accept/reject scenario we // state change. Once we know that we go for either an accept/reject scenario we
// have to remove the current order. // have to remove the current order.
self.current_order_id = None; self.current_order_id = None;
// Need to use `do_send_async` here because invoking the
// corresponding handler can result in a deadlock with another
// invocation in `maker_inc_connections.rs`
#[allow(clippy::disallowed_method)]
self.takers self.takers
.send(maker_inc_connections::BroadcastOrder(None)) .do_send_async(maker_inc_connections::BroadcastOrder(None))
.await?; .await?;
self.projection_actor.send(projection::Update(None)).await?; self.projection_actor.send(projection::Update(None)).await?;
// 3. Insert CFD in DB // 3. Insert CFD in DB

108
daemon/src/maker_inc_connections.rs

@ -2,10 +2,9 @@ use crate::maker_cfd::{FromTaker, TakerConnected, TakerDisconnected};
use crate::model::cfd::Order; use crate::model::cfd::Order;
use crate::model::Identity; use crate::model::Identity;
use crate::noise::TransportStateExt; use crate::noise::TransportStateExt;
use crate::tokio_ext::FutureExt; use crate::{maker_cfd, noise, send_to_socket, wire, Tasks};
use crate::{forward_only_ok, maker_cfd, noise, send_to_socket, wire, Tasks};
use anyhow::Result; use anyhow::Result;
use futures::{StreamExt, TryStreamExt}; use futures::TryStreamExt;
use std::collections::HashMap; use std::collections::HashMap;
use std::io; use std::io;
use std::net::SocketAddr; use std::net::SocketAddr;
@ -14,11 +13,12 @@ use std::time::Duration;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_util::codec::FramedRead; use tokio_util::codec::FramedRead;
use xtra::prelude::*; use xtra::prelude::*;
use xtra::{Actor as _, KeepRunning}; use xtra::KeepRunning;
use xtra_productivity::xtra_productivity; use xtra_productivity::xtra_productivity;
pub struct BroadcastOrder(pub Option<Order>); pub struct BroadcastOrder(pub Option<Order>);
#[derive(Debug)]
pub struct TakerMessage { pub struct TakerMessage {
pub taker_id: Identity, pub taker_id: Identity,
pub msg: wire::MakerToTaker, pub msg: wire::MakerToTaker,
@ -41,7 +41,7 @@ pub struct Actor {
taker_msg_channel: Box<dyn MessageChannel<FromTaker>>, taker_msg_channel: Box<dyn MessageChannel<FromTaker>>,
noise_priv_key: x25519_dalek::StaticSecret, noise_priv_key: x25519_dalek::StaticSecret,
heartbeat_interval: Duration, heartbeat_interval: Duration,
tasks: Tasks, connection_tasks: HashMap<Identity, Tasks>,
} }
impl Actor { impl Actor {
@ -59,7 +59,18 @@ impl Actor {
taker_msg_channel: taker_msg_channel.clone_channel(), taker_msg_channel: taker_msg_channel.clone_channel(),
noise_priv_key, noise_priv_key,
heartbeat_interval, heartbeat_interval,
tasks: Tasks::default(), connection_tasks: HashMap::new(),
}
}
async fn drop_taker_connection(&mut self, taker_id: &Identity) {
if self.write_connections.remove(taker_id).is_some() {
tracing::info!(%taker_id, "Dropping connection");
let _ = self
.taker_disconnected_channel
.send(maker_cfd::TakerDisconnected { id: *taker_id })
.await;
let _ = self.connection_tasks.remove(taker_id);
} }
} }
@ -76,13 +87,9 @@ impl Actor {
let msg_str = msg.to_string(); let msg_str = msg.to_string();
if conn.send(msg).await.is_err() { if conn.send(msg).await.is_err() {
tracing::info!(%taker_id, "Failed to send {} to taker, removing connection", msg_str); tracing::error!(%taker_id, "Failed to send message to taker: {}", msg_str);
if self.write_connections.remove(taker_id).is_some() { self.drop_taker_connection(taker_id).await;
let _ = self return Err(NoConnection(*taker_id));
.taker_disconnected_channel
.send(maker_cfd::TakerDisconnected { id: *taker_id })
.await;
}
} }
Ok(()) Ok(())
@ -92,7 +99,7 @@ impl Actor {
&mut self, &mut self,
mut stream: TcpStream, mut stream: TcpStream,
taker_address: SocketAddr, taker_address: SocketAddr,
_: &mut Context<Self>, ctx: &mut Context<Self>,
) -> Result<()> { ) -> Result<()> {
let transport_state = noise::responder_handshake(&mut stream, &self.noise_priv_key).await?; let transport_state = noise::responder_handshake(&mut stream, &self.noise_priv_key).await?;
let taker_id = Identity::new(transport_state.get_remote_public_key()?); let taker_id = Identity::new(transport_state.get_remote_public_key()?);
@ -102,43 +109,39 @@ impl Actor {
let transport_state = Arc::new(Mutex::new(transport_state)); let transport_state = Arc::new(Mutex::new(transport_state));
let (read, write) = stream.into_split(); let (read, write) = stream.into_split();
let read = FramedRead::new(read, wire::EncryptedJsonCodec::new(transport_state.clone())) let mut read =
.map_ok(move |msg| FromTaker { taker_id, msg }) FramedRead::new(read, wire::EncryptedJsonCodec::new(transport_state.clone()));
.map(forward_only_ok::Message);
let (out_msg_actor_address, mut out_msg_actor_context) = xtra::Context::new(None);
let (forward_to_cfd, forward_to_cfd_fut) =
forward_only_ok::Actor::new(self.taker_msg_channel.clone_channel())
.create(None)
.run();
self.tasks.add(forward_to_cfd_fut);
// only allow outgoing messages while we are successfully reading incoming ones
let heartbeat_interval = self.heartbeat_interval;
let taker_disconnected_channel = self.taker_disconnected_channel.clone_channel();
self.tasks.add(async move {
let mut actor = send_to_socket::Actor::new(write, transport_state.clone());
let _heartbeat_handle = out_msg_actor_context
.notify_interval(heartbeat_interval, || wire::MakerToTaker::Heartbeat)
.expect("actor not to shutdown")
.spawn_with_handle();
out_msg_actor_context
.handle_while(&mut actor, forward_to_cfd.attach_stream(read))
.await;
tracing::error!("Closing connection to taker {}", taker_id); let this = ctx.address().expect("self to be alive");
let _ = taker_disconnected_channel let taker_msg_channel = self.taker_msg_channel.clone_channel();
.send(maker_cfd::TakerDisconnected { id: taker_id }) let read_fut = async move {
.await; while let Ok(Some(msg)) = read.try_next().await {
let res = taker_msg_channel.send(FromTaker { taker_id, msg }).await;
if res.is_err() {
break;
}
}
let _ = this.send(ReadFail(taker_id)).await;
};
let (out_msg, mut out_msg_actor_context) = xtra::Context::new(None);
let send_to_socket_actor = send_to_socket::Actor::new(write, transport_state.clone());
actor.shutdown().await; let heartbeat_fut = out_msg_actor_context
}); .notify_interval(self.heartbeat_interval, || wire::MakerToTaker::Heartbeat)
.expect("actor not to shutdown");
self.write_connections let write_fut = out_msg_actor_context.run(send_to_socket_actor);
.insert(taker_id, out_msg_actor_address);
self.write_connections.insert(taker_id, out_msg);
let mut tasks = Tasks::default();
tasks.add(read_fut);
tasks.add(heartbeat_fut);
tasks.add(write_fut);
self.connection_tasks.insert(taker_id, tasks);
let _ = self let _ = self
.taker_connected_channel .taker_connected_channel
@ -186,6 +189,15 @@ impl Actor {
} }
} }
} }
async fn handle_read_fail(&mut self, msg: ReadFail) {
let taker_id = msg.0;
tracing::error!(%taker_id, "Failed to read incoming messages from taker");
self.drop_taker_connection(&taker_id).await;
}
} }
struct ReadFail(Identity);
impl xtra::Actor for Actor {} impl xtra::Actor for Actor {}

Loading…
Cancel
Save