You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
203 lines
6.8 KiB
203 lines
6.8 KiB
use crate::maker_cfd::{FromTaker, TakerConnected, TakerDisconnected};
|
|
use crate::model::cfd::Order;
|
|
use crate::model::Identity;
|
|
use crate::noise::TransportStateExt;
|
|
use crate::{maker_cfd, noise, send_to_socket, wire, Tasks};
|
|
use anyhow::Result;
|
|
use futures::TryStreamExt;
|
|
use std::collections::HashMap;
|
|
use std::io;
|
|
use std::net::SocketAddr;
|
|
use std::sync::{Arc, Mutex};
|
|
use std::time::Duration;
|
|
use tokio::net::TcpStream;
|
|
use tokio_util::codec::FramedRead;
|
|
use xtra::prelude::*;
|
|
use xtra::KeepRunning;
|
|
use xtra_productivity::xtra_productivity;
|
|
|
|
pub struct BroadcastOrder(pub Option<Order>);
|
|
|
|
#[derive(Debug)]
|
|
pub struct TakerMessage {
|
|
pub taker_id: Identity,
|
|
pub msg: wire::MakerToTaker,
|
|
}
|
|
|
|
pub enum ListenerMessage {
|
|
NewConnection {
|
|
stream: TcpStream,
|
|
address: SocketAddr,
|
|
},
|
|
Error {
|
|
source: io::Error,
|
|
},
|
|
}
|
|
|
|
pub struct Actor {
|
|
write_connections: HashMap<Identity, Address<send_to_socket::Actor<wire::MakerToTaker>>>,
|
|
taker_connected_channel: Box<dyn MessageChannel<TakerConnected>>,
|
|
taker_disconnected_channel: Box<dyn MessageChannel<TakerDisconnected>>,
|
|
taker_msg_channel: Box<dyn MessageChannel<FromTaker>>,
|
|
noise_priv_key: x25519_dalek::StaticSecret,
|
|
heartbeat_interval: Duration,
|
|
connection_tasks: HashMap<Identity, Tasks>,
|
|
}
|
|
|
|
impl Actor {
|
|
pub fn new(
|
|
taker_connected_channel: Box<dyn MessageChannel<TakerConnected>>,
|
|
taker_disconnected_channel: Box<dyn MessageChannel<TakerDisconnected>>,
|
|
taker_msg_channel: Box<dyn MessageChannel<FromTaker>>,
|
|
noise_priv_key: x25519_dalek::StaticSecret,
|
|
heartbeat_interval: Duration,
|
|
) -> Self {
|
|
Self {
|
|
write_connections: HashMap::new(),
|
|
taker_connected_channel: taker_connected_channel.clone_channel(),
|
|
taker_disconnected_channel: taker_disconnected_channel.clone_channel(),
|
|
taker_msg_channel: taker_msg_channel.clone_channel(),
|
|
noise_priv_key,
|
|
heartbeat_interval,
|
|
connection_tasks: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
async fn drop_taker_connection(&mut self, taker_id: &Identity) {
|
|
if self.write_connections.remove(taker_id).is_some() {
|
|
tracing::info!(%taker_id, "Dropping connection");
|
|
let _ = self
|
|
.taker_disconnected_channel
|
|
.send(maker_cfd::TakerDisconnected { id: *taker_id })
|
|
.await;
|
|
let _ = self.connection_tasks.remove(taker_id);
|
|
}
|
|
}
|
|
|
|
async fn send_to_taker(
|
|
&mut self,
|
|
taker_id: &Identity,
|
|
msg: wire::MakerToTaker,
|
|
) -> Result<(), NoConnection> {
|
|
let conn = self
|
|
.write_connections
|
|
.get(taker_id)
|
|
.ok_or_else(|| NoConnection(*taker_id))?;
|
|
|
|
let msg_str = msg.to_string();
|
|
|
|
if conn.send(msg).await.is_err() {
|
|
tracing::error!(%taker_id, "Failed to send message to taker: {}", msg_str);
|
|
self.drop_taker_connection(taker_id).await;
|
|
return Err(NoConnection(*taker_id));
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_new_connection_impl(
|
|
&mut self,
|
|
mut stream: TcpStream,
|
|
taker_address: SocketAddr,
|
|
ctx: &mut Context<Self>,
|
|
) -> Result<()> {
|
|
let transport_state = noise::responder_handshake(&mut stream, &self.noise_priv_key).await?;
|
|
let taker_id = Identity::new(transport_state.get_remote_public_key()?);
|
|
|
|
tracing::info!(%taker_id, address = %taker_address, "New taker connected");
|
|
|
|
let transport_state = Arc::new(Mutex::new(transport_state));
|
|
|
|
let (read, write) = stream.into_split();
|
|
let mut read =
|
|
FramedRead::new(read, wire::EncryptedJsonCodec::new(transport_state.clone()));
|
|
|
|
let this = ctx.address().expect("self to be alive");
|
|
let taker_msg_channel = self.taker_msg_channel.clone_channel();
|
|
let read_fut = async move {
|
|
while let Ok(Some(msg)) = read.try_next().await {
|
|
let res = taker_msg_channel.send(FromTaker { taker_id, msg }).await;
|
|
|
|
if res.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
|
|
let _ = this.send(ReadFail(taker_id)).await;
|
|
};
|
|
|
|
let (out_msg, mut out_msg_actor_context) = xtra::Context::new(None);
|
|
let send_to_socket_actor = send_to_socket::Actor::new(write, transport_state.clone());
|
|
|
|
let heartbeat_fut = out_msg_actor_context
|
|
.notify_interval(self.heartbeat_interval, || wire::MakerToTaker::Heartbeat)
|
|
.expect("actor not to shutdown");
|
|
|
|
let write_fut = out_msg_actor_context.run(send_to_socket_actor);
|
|
|
|
self.write_connections.insert(taker_id, out_msg);
|
|
|
|
let mut tasks = Tasks::default();
|
|
tasks.add(read_fut);
|
|
tasks.add(heartbeat_fut);
|
|
tasks.add(write_fut);
|
|
self.connection_tasks.insert(taker_id, tasks);
|
|
|
|
let _ = self
|
|
.taker_connected_channel
|
|
.send(maker_cfd::TakerConnected { id: taker_id })
|
|
.await;
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
#[error("No connection to taker {0}")]
|
|
pub struct NoConnection(Identity);
|
|
|
|
#[xtra_productivity]
|
|
impl Actor {
|
|
async fn handle_broadcast_order(&mut self, msg: BroadcastOrder) {
|
|
let order = msg.0;
|
|
for taker_id in self.write_connections.clone().keys() {
|
|
self.send_to_taker(taker_id, wire::MakerToTaker::CurrentOrder(order.clone())).await.expect("send_to_taker only fails on missing hashmap entry and we are iterating over those entries");
|
|
tracing::trace!(%taker_id, "sent new order: {:?}", order.as_ref().map(|o| o.id));
|
|
}
|
|
}
|
|
|
|
async fn handle_taker_message(&mut self, msg: TakerMessage) -> Result<(), NoConnection> {
|
|
self.send_to_taker(&msg.taker_id, msg.msg).await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle(&mut self, msg: ListenerMessage, ctx: &mut Context<Self>) -> KeepRunning {
|
|
match msg {
|
|
ListenerMessage::NewConnection { stream, address } => {
|
|
if let Err(err) = self.handle_new_connection_impl(stream, address, ctx).await {
|
|
tracing::warn!("Maker was unable to negotiate a new connection: {}", err);
|
|
}
|
|
KeepRunning::Yes
|
|
}
|
|
ListenerMessage::Error { source } => {
|
|
tracing::warn!("TCP listener produced an error: {}", source);
|
|
|
|
// Maybe we should move the actual listening on the socket into here and restart the
|
|
// actor upon an error?
|
|
KeepRunning::Yes
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn handle_read_fail(&mut self, msg: ReadFail) {
|
|
let taker_id = msg.0;
|
|
tracing::error!(%taker_id, "Failed to read incoming messages from taker");
|
|
|
|
self.drop_taker_connection(&taker_id).await;
|
|
}
|
|
}
|
|
|
|
struct ReadFail(Identity);
|
|
|
|
impl xtra::Actor for Actor {}
|
|
|