Browse Source

Convert `send_wire_message_actor` to xtra::Actor

This will give us a generalized `Sink` interface for writing to the
connection.
fix-bad-api-calls
Thomas Eizinger 3 years ago
parent
commit
b6d0fc6c2f
No known key found for this signature in database GPG Key ID: 651AC83A6C6C8B96
  1. 15
      daemon/src/maker.rs
  2. 42
      daemon/src/maker_inc_connections.rs
  3. 35
      daemon/src/send_to_socket.rs
  4. 31
      daemon/src/send_wire_message_actor.rs
  5. 15
      daemon/src/taker.rs
  6. 19
      daemon/src/taker_cfd.rs
  7. 22
      daemon/src/wire.rs

15
daemon/src/maker.rs

@ -29,7 +29,7 @@ mod model;
mod routes; mod routes;
mod routes_maker; mod routes_maker;
mod seed; mod seed;
mod send_wire_message_actor; mod send_to_socket;
mod setup_contract_actor; mod setup_contract_actor;
mod to_sse_event; mod to_sse_event;
mod wallet; mod wallet;
@ -194,16 +194,17 @@ async fn main() -> Result<()> {
cfd_maker_actor_inbox.clone(), cfd_maker_actor_inbox.clone(),
taker_id, taker_id,
); );
let (out_msg_actor, out_msg_actor_inbox) =
send_wire_message_actor::new::<wire::MakerToTaker>(write); let out_msg_actor = send_to_socket::Actor::new(write)
.create(None)
.spawn_global();
tokio::spawn(in_taker_actor); tokio::spawn(in_taker_actor);
tokio::spawn(out_msg_actor);
maker_inc_connections_address maker_inc_connections_address
.do_send_async(maker_inc_connections::NewTakerOnline { .do_send_async(maker_inc_connections::NewTakerOnline {
taker_id, taker_id,
out_msg_actor_inbox, out_msg_actor,
}) })
.await .await
.unwrap(); .unwrap();
@ -238,3 +239,7 @@ async fn main() -> Result<()> {
Ok(()) Ok(())
} }
impl xtra::Message for wire::MakerToTaker {
type Result = ();
}

42
daemon/src/maker_inc_connections.rs

@ -2,18 +2,15 @@ use crate::actors::log_error;
use crate::model::cfd::{Order, OrderId}; use crate::model::cfd::{Order, OrderId};
use crate::model::TakerId; use crate::model::TakerId;
use crate::wire::SetupMsg; use crate::wire::SetupMsg;
use crate::{maker_cfd, wire}; use crate::{maker_cfd, send_to_socket, wire};
use anyhow::{Context as AnyhowContext, Result}; use anyhow::{Context as AnyhowContext, Result};
use async_trait::async_trait; use async_trait::async_trait;
use futures::{Future, StreamExt}; use futures::{Future, StreamExt};
use std::collections::HashMap; use std::collections::HashMap;
use tokio::net::tcp::OwnedReadHalf; use tokio::net::tcp::OwnedReadHalf;
use tokio::sync::mpsc;
use tokio_util::codec::{FramedRead, LengthDelimitedCodec}; use tokio_util::codec::{FramedRead, LengthDelimitedCodec};
use xtra::prelude::*; use xtra::prelude::*;
type MakerToTakerSender = mpsc::UnboundedSender<wire::MakerToTaker>;
pub struct BroadcastOrder(pub Option<Order>); pub struct BroadcastOrder(pub Option<Order>);
impl Message for BroadcastOrder { impl Message for BroadcastOrder {
@ -40,7 +37,7 @@ impl Message for TakerMessage {
pub struct NewTakerOnline { pub struct NewTakerOnline {
pub taker_id: TakerId, pub taker_id: TakerId,
pub out_msg_actor_inbox: MakerToTakerSender, pub out_msg_actor: Address<send_to_socket::Actor>,
} }
impl Message for NewTakerOnline { impl Message for NewTakerOnline {
@ -48,7 +45,7 @@ impl Message for NewTakerOnline {
} }
pub struct Actor { pub struct Actor {
write_connections: HashMap<TakerId, MakerToTakerSender>, write_connections: HashMap<TakerId, Address<send_to_socket::Actor>>,
cfd_actor: Address<maker_cfd::Actor>, cfd_actor: Address<maker_cfd::Actor>,
} }
@ -57,44 +54,53 @@ impl xtra::Actor for Actor {}
impl Actor { impl Actor {
pub fn new(cfd_actor: Address<maker_cfd::Actor>) -> Self { pub fn new(cfd_actor: Address<maker_cfd::Actor>) -> Self {
Self { Self {
write_connections: HashMap::<TakerId, MakerToTakerSender>::new(), write_connections: HashMap::new(),
cfd_actor, cfd_actor,
} }
} }
fn send_to_taker(&self, taker_id: TakerId, msg: wire::MakerToTaker) -> Result<()> { async fn send_to_taker(&self, taker_id: TakerId, msg: wire::MakerToTaker) -> Result<()> {
let conn = self let conn = self
.write_connections .write_connections
.get(&taker_id) .get(&taker_id)
.context("no connection to taker_id")?; .context("no connection to taker_id")?;
conn.send(msg)?; conn.do_send_async(msg).await?;
Ok(()) Ok(())
} }
async fn handle_broadcast_order(&mut self, msg: BroadcastOrder) -> Result<()> { async fn handle_broadcast_order(&mut self, msg: BroadcastOrder) -> Result<()> {
let order = msg.0; let order = msg.0;
self.write_connections
.values() for conn in self.write_connections.values() {
.try_for_each(|conn| conn.send(wire::MakerToTaker::CurrentOrder(order.clone())))?; conn.do_send_async(wire::MakerToTaker::CurrentOrder(order.clone()))
.await?;
}
Ok(()) Ok(())
} }
async fn handle_taker_message(&mut self, msg: TakerMessage) -> Result<()> { async fn handle_taker_message(&mut self, msg: TakerMessage) -> Result<()> {
match msg.command { match msg.command {
TakerCommand::SendOrder { order } => { TakerCommand::SendOrder { order } => {
self.send_to_taker(msg.taker_id, wire::MakerToTaker::CurrentOrder(order))?; self.send_to_taker(msg.taker_id, wire::MakerToTaker::CurrentOrder(order))
.await?;
} }
TakerCommand::NotifyInvalidOrderId { id } => { TakerCommand::NotifyInvalidOrderId { id } => {
self.send_to_taker(msg.taker_id, wire::MakerToTaker::InvalidOrderId(id))?; self.send_to_taker(msg.taker_id, wire::MakerToTaker::InvalidOrderId(id))
.await?;
} }
TakerCommand::NotifyOrderAccepted { id } => { TakerCommand::NotifyOrderAccepted { id } => {
self.send_to_taker(msg.taker_id, wire::MakerToTaker::ConfirmOrder(id))?; self.send_to_taker(msg.taker_id, wire::MakerToTaker::ConfirmOrder(id))
.await?;
} }
TakerCommand::NotifyOrderRejected { id } => { TakerCommand::NotifyOrderRejected { id } => {
self.send_to_taker(msg.taker_id, wire::MakerToTaker::RejectOrder(id))?; self.send_to_taker(msg.taker_id, wire::MakerToTaker::RejectOrder(id))
.await?;
} }
TakerCommand::OutProtocolMsg { setup_msg } => { TakerCommand::OutProtocolMsg { setup_msg } => {
self.send_to_taker(msg.taker_id, wire::MakerToTaker::Protocol(setup_msg))?; self.send_to_taker(msg.taker_id, wire::MakerToTaker::Protocol(setup_msg))
.await?;
} }
} }
Ok(()) Ok(())
@ -106,7 +112,7 @@ impl Actor {
.await?; .await?;
self.write_connections self.write_connections
.insert(msg.taker_id, msg.out_msg_actor_inbox); .insert(msg.taker_id, msg.out_msg_actor);
Ok(()) Ok(())
} }
} }

35
daemon/src/send_to_socket.rs

@ -0,0 +1,35 @@
use futures::SinkExt;
use serde::Serialize;
use std::fmt;
use tokio::net::tcp::OwnedWriteHalf;
use tokio_util::codec::{FramedWrite, LengthDelimitedCodec};
use xtra::{Handler, Message};
pub struct Actor {
write: FramedWrite<OwnedWriteHalf, LengthDelimitedCodec>,
}
impl Actor {
pub fn new(write: OwnedWriteHalf) -> Self {
Self {
write: FramedWrite::new(write, LengthDelimitedCodec::new()),
}
}
}
#[async_trait::async_trait]
impl<T> Handler<T> for Actor
where
T: Message<Result = ()> + Serialize + fmt::Display,
{
async fn handle(&mut self, message: T, ctx: &mut xtra::Context<Self>) {
let bytes = serde_json::to_vec(&message).expect("serialization should never fail");
if let Err(e) = self.write.send(bytes.into()).await {
tracing::error!("Failed to write message {} to socket: {}", message, e);
ctx.stop();
}
}
}
impl xtra::Actor for Actor {}

31
daemon/src/send_wire_message_actor.rs

@ -1,31 +0,0 @@
use futures::{Future, SinkExt};
use serde::Serialize;
use tokio::net::tcp::OwnedWriteHalf;
use tokio::sync::mpsc;
use tokio_util::codec::{FramedWrite, LengthDelimitedCodec};
pub fn new<T>(write: OwnedWriteHalf) -> (impl Future<Output = ()>, mpsc::UnboundedSender<T>)
where
T: Serialize,
{
let (sender, mut receiver) = mpsc::unbounded_channel::<T>();
let actor = async move {
let mut framed_write = FramedWrite::new(write, LengthDelimitedCodec::new());
while let Some(message) = receiver.recv().await {
match framed_write
.send(serde_json::to_vec(&message).unwrap().into())
.await
{
Ok(_) => {}
Err(_) => {
tracing::error!("TCP connection error");
break;
}
}
}
};
(actor, sender)
}

15
daemon/src/taker.rs

@ -14,6 +14,7 @@ use std::thread::sleep;
use std::time::Duration; use std::time::Duration;
use tokio::sync::watch; use tokio::sync::watch;
use tracing_subscriber::filter::LevelFilter; use tracing_subscriber::filter::LevelFilter;
use wire::TakerToMaker;
use xtra::spawn::TokioGlobalSpawnExt; use xtra::spawn::TokioGlobalSpawnExt;
use xtra::Actor; use xtra::Actor;
@ -26,7 +27,7 @@ mod model;
mod routes; mod routes;
mod routes_taker; mod routes_taker;
mod seed; mod seed;
mod send_wire_message_actor; mod send_to_socket;
mod setup_contract_actor; mod setup_contract_actor;
mod taker_cfd; mod taker_cfd;
mod taker_inc_message_actor; mod taker_inc_message_actor;
@ -157,8 +158,9 @@ async fn main() -> Result<()> {
None => return Err(rocket), None => return Err(rocket),
}; };
let (out_maker_messages_actor, out_maker_actor_inbox) = let send_to_maker = send_to_socket::Actor::new(write)
send_wire_message_actor::new(write); .create(None)
.spawn_global();
let cfd_actor_inbox = taker_cfd::Actor::new( let cfd_actor_inbox = taker_cfd::Actor::new(
db, db,
@ -166,7 +168,7 @@ async fn main() -> Result<()> {
schnorrsig::PublicKey::from_keypair(SECP256K1, &oracle), schnorrsig::PublicKey::from_keypair(SECP256K1, &oracle),
cfd_feed_sender, cfd_feed_sender,
order_feed_sender, order_feed_sender,
out_maker_actor_inbox, send_to_maker,
) )
.await .await
.unwrap() .unwrap()
@ -178,7 +180,6 @@ async fn main() -> Result<()> {
tokio::spawn(wallet_sync::new(wallet, wallet_feed_sender)); tokio::spawn(wallet_sync::new(wallet, wallet_feed_sender));
tokio::spawn(inc_maker_messages_actor); tokio::spawn(inc_maker_messages_actor);
tokio::spawn(out_maker_messages_actor);
Ok(rocket.manage(cfd_actor_inbox)) Ok(rocket.manage(cfd_actor_inbox))
}, },
@ -201,3 +202,7 @@ async fn main() -> Result<()> {
Ok(()) Ok(())
} }
impl xtra::Message for TakerToMaker {
type Result = ();
}

19
daemon/src/taker_cfd.rs

@ -8,7 +8,7 @@ use crate::model::cfd::{Cfd, CfdState, CfdStateCommon, Dlc, Order, OrderId};
use crate::model::Usd; use crate::model::Usd;
use crate::wallet::Wallet; use crate::wallet::Wallet;
use crate::wire::SetupMsg; use crate::wire::SetupMsg;
use crate::{setup_contract_actor, wire}; use crate::{send_to_socket, setup_contract_actor, wire};
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use bdk::bitcoin::secp256k1::schnorrsig; use bdk::bitcoin::secp256k1::schnorrsig;
@ -38,7 +38,7 @@ pub struct Actor {
oracle_pk: schnorrsig::PublicKey, oracle_pk: schnorrsig::PublicKey,
cfd_feed_actor_inbox: watch::Sender<Vec<Cfd>>, cfd_feed_actor_inbox: watch::Sender<Vec<Cfd>>,
order_feed_actor_inbox: watch::Sender<Option<Order>>, order_feed_actor_inbox: watch::Sender<Option<Order>>,
out_msg_maker_inbox: mpsc::UnboundedSender<wire::TakerToMaker>, send_to_maker: Address<send_to_socket::Actor>,
current_contract_setup: Option<mpsc::UnboundedSender<SetupMsg>>, current_contract_setup: Option<mpsc::UnboundedSender<SetupMsg>>,
// TODO: Move the contract setup into a dedicated actor and send messages to that actor that // TODO: Move the contract setup into a dedicated actor and send messages to that actor that
// manages the state instead of this ugly buffer // manages the state instead of this ugly buffer
@ -52,7 +52,7 @@ impl Actor {
oracle_pk: schnorrsig::PublicKey, oracle_pk: schnorrsig::PublicKey,
cfd_feed_actor_inbox: watch::Sender<Vec<Cfd>>, cfd_feed_actor_inbox: watch::Sender<Vec<Cfd>>,
order_feed_actor_inbox: watch::Sender<Option<Order>>, order_feed_actor_inbox: watch::Sender<Option<Order>>,
out_msg_maker_inbox: mpsc::UnboundedSender<wire::TakerToMaker>, send_to_maker: Address<send_to_socket::Actor>,
) -> Result<Self> { ) -> Result<Self> {
let mut conn = db.acquire().await?; let mut conn = db.acquire().await?;
cfd_feed_actor_inbox.send(load_all_cfds(&mut conn).await?)?; cfd_feed_actor_inbox.send(load_all_cfds(&mut conn).await?)?;
@ -63,7 +63,7 @@ impl Actor {
oracle_pk, oracle_pk,
cfd_feed_actor_inbox, cfd_feed_actor_inbox,
order_feed_actor_inbox, order_feed_actor_inbox,
out_msg_maker_inbox, send_to_maker,
current_contract_setup: None, current_contract_setup: None,
contract_setup_message_buffer: vec![], contract_setup_message_buffer: vec![],
}) })
@ -90,8 +90,9 @@ impl Actor {
self.cfd_feed_actor_inbox self.cfd_feed_actor_inbox
.send(load_all_cfds(&mut conn).await?)?; .send(load_all_cfds(&mut conn).await?)?;
self.out_msg_maker_inbox self.send_to_maker
.send(wire::TakerToMaker::TakeOrder { order_id, quantity })?; .do_send_async(wire::TakerToMaker::TakeOrder { order_id, quantity })
.await?;
Ok(()) Ok(())
} }
@ -141,8 +142,10 @@ impl Actor {
let (actor, inbox) = setup_contract_actor::new( let (actor, inbox) = setup_contract_actor::new(
{ {
let inbox = self.out_msg_maker_inbox.clone(); let inbox = self.send_to_maker.clone();
move |msg| inbox.send(wire::TakerToMaker::Protocol(msg)).unwrap() move |msg| {
tokio::spawn(inbox.do_send_async(wire::TakerToMaker::Protocol(msg)));
}
}, },
setup_contract_actor::OwnParams::Taker(taker_params), setup_contract_actor::OwnParams::Taker(taker_params),
sk, sk,

22
daemon/src/wire.rs

@ -7,6 +7,7 @@ use bdk::bitcoin::{Address, Amount, PublicKey};
use cfd_protocol::secp256k1_zkp::EcdsaAdaptorSignature; use cfd_protocol::secp256k1_zkp::EcdsaAdaptorSignature;
use cfd_protocol::{CfdTransactions, PartyParams, PunishParams}; use cfd_protocol::{CfdTransactions, PartyParams, PunishParams};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt;
use std::ops::RangeInclusive; use std::ops::RangeInclusive;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@ -17,6 +18,15 @@ pub enum TakerToMaker {
Protocol(SetupMsg), Protocol(SetupMsg),
} }
impl fmt::Display for TakerToMaker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TakerToMaker::TakeOrder { .. } => write!(f, "TakeOrder"),
TakerToMaker::Protocol(_) => write!(f, "Protocol"),
}
}
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", content = "payload")] #[serde(tag = "type", content = "payload")]
#[allow(clippy::large_enum_variant)] #[allow(clippy::large_enum_variant)]
@ -28,6 +38,18 @@ pub enum MakerToTaker {
Protocol(SetupMsg), Protocol(SetupMsg),
} }
impl fmt::Display for MakerToTaker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MakerToTaker::CurrentOrder(_) => write!(f, "CurrentOrder"),
MakerToTaker::ConfirmOrder(_) => write!(f, "ConfirmOrder"),
MakerToTaker::RejectOrder(_) => write!(f, "RejectOrder"),
MakerToTaker::InvalidOrderId(_) => write!(f, "InvalidOrderId"),
MakerToTaker::Protocol(_) => write!(f, "Protocol"),
}
}
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", content = "payload")] #[serde(tag = "type", content = "payload")]
pub enum SetupMsg { pub enum SetupMsg {

Loading…
Cancel
Save