Browse Source

Convert `send_wire_message_actor` to xtra::Actor

This will give us a generalized `Sink` interface for writing to the
connection.
fix-bad-api-calls
Thomas Eizinger 3 years ago
parent
commit
b6d0fc6c2f
No known key found for this signature in database GPG Key ID: 651AC83A6C6C8B96
  1. 15
      daemon/src/maker.rs
  2. 42
      daemon/src/maker_inc_connections.rs
  3. 35
      daemon/src/send_to_socket.rs
  4. 31
      daemon/src/send_wire_message_actor.rs
  5. 15
      daemon/src/taker.rs
  6. 19
      daemon/src/taker_cfd.rs
  7. 22
      daemon/src/wire.rs

15
daemon/src/maker.rs

@ -29,7 +29,7 @@ mod model;
mod routes;
mod routes_maker;
mod seed;
mod send_wire_message_actor;
mod send_to_socket;
mod setup_contract_actor;
mod to_sse_event;
mod wallet;
@ -194,16 +194,17 @@ async fn main() -> Result<()> {
cfd_maker_actor_inbox.clone(),
taker_id,
);
let (out_msg_actor, out_msg_actor_inbox) =
send_wire_message_actor::new::<wire::MakerToTaker>(write);
let out_msg_actor = send_to_socket::Actor::new(write)
.create(None)
.spawn_global();
tokio::spawn(in_taker_actor);
tokio::spawn(out_msg_actor);
maker_inc_connections_address
.do_send_async(maker_inc_connections::NewTakerOnline {
taker_id,
out_msg_actor_inbox,
out_msg_actor,
})
.await
.unwrap();
@ -238,3 +239,7 @@ async fn main() -> Result<()> {
Ok(())
}
impl xtra::Message for wire::MakerToTaker {
type Result = ();
}

42
daemon/src/maker_inc_connections.rs

@ -2,18 +2,15 @@ use crate::actors::log_error;
use crate::model::cfd::{Order, OrderId};
use crate::model::TakerId;
use crate::wire::SetupMsg;
use crate::{maker_cfd, wire};
use crate::{maker_cfd, send_to_socket, wire};
use anyhow::{Context as AnyhowContext, Result};
use async_trait::async_trait;
use futures::{Future, StreamExt};
use std::collections::HashMap;
use tokio::net::tcp::OwnedReadHalf;
use tokio::sync::mpsc;
use tokio_util::codec::{FramedRead, LengthDelimitedCodec};
use xtra::prelude::*;
type MakerToTakerSender = mpsc::UnboundedSender<wire::MakerToTaker>;
pub struct BroadcastOrder(pub Option<Order>);
impl Message for BroadcastOrder {
@ -40,7 +37,7 @@ impl Message for TakerMessage {
pub struct NewTakerOnline {
pub taker_id: TakerId,
pub out_msg_actor_inbox: MakerToTakerSender,
pub out_msg_actor: Address<send_to_socket::Actor>,
}
impl Message for NewTakerOnline {
@ -48,7 +45,7 @@ impl Message for NewTakerOnline {
}
pub struct Actor {
write_connections: HashMap<TakerId, MakerToTakerSender>,
write_connections: HashMap<TakerId, Address<send_to_socket::Actor>>,
cfd_actor: Address<maker_cfd::Actor>,
}
@ -57,44 +54,53 @@ impl xtra::Actor for Actor {}
impl Actor {
pub fn new(cfd_actor: Address<maker_cfd::Actor>) -> Self {
Self {
write_connections: HashMap::<TakerId, MakerToTakerSender>::new(),
write_connections: HashMap::new(),
cfd_actor,
}
}
fn send_to_taker(&self, taker_id: TakerId, msg: wire::MakerToTaker) -> Result<()> {
async fn send_to_taker(&self, taker_id: TakerId, msg: wire::MakerToTaker) -> Result<()> {
let conn = self
.write_connections
.get(&taker_id)
.context("no connection to taker_id")?;
conn.send(msg)?;
conn.do_send_async(msg).await?;
Ok(())
}
async fn handle_broadcast_order(&mut self, msg: BroadcastOrder) -> Result<()> {
let order = msg.0;
self.write_connections
.values()
.try_for_each(|conn| conn.send(wire::MakerToTaker::CurrentOrder(order.clone())))?;
for conn in self.write_connections.values() {
conn.do_send_async(wire::MakerToTaker::CurrentOrder(order.clone()))
.await?;
}
Ok(())
}
async fn handle_taker_message(&mut self, msg: TakerMessage) -> Result<()> {
match msg.command {
TakerCommand::SendOrder { order } => {
self.send_to_taker(msg.taker_id, wire::MakerToTaker::CurrentOrder(order))?;
self.send_to_taker(msg.taker_id, wire::MakerToTaker::CurrentOrder(order))
.await?;
}
TakerCommand::NotifyInvalidOrderId { id } => {
self.send_to_taker(msg.taker_id, wire::MakerToTaker::InvalidOrderId(id))?;
self.send_to_taker(msg.taker_id, wire::MakerToTaker::InvalidOrderId(id))
.await?;
}
TakerCommand::NotifyOrderAccepted { id } => {
self.send_to_taker(msg.taker_id, wire::MakerToTaker::ConfirmOrder(id))?;
self.send_to_taker(msg.taker_id, wire::MakerToTaker::ConfirmOrder(id))
.await?;
}
TakerCommand::NotifyOrderRejected { id } => {
self.send_to_taker(msg.taker_id, wire::MakerToTaker::RejectOrder(id))?;
self.send_to_taker(msg.taker_id, wire::MakerToTaker::RejectOrder(id))
.await?;
}
TakerCommand::OutProtocolMsg { setup_msg } => {
self.send_to_taker(msg.taker_id, wire::MakerToTaker::Protocol(setup_msg))?;
self.send_to_taker(msg.taker_id, wire::MakerToTaker::Protocol(setup_msg))
.await?;
}
}
Ok(())
@ -106,7 +112,7 @@ impl Actor {
.await?;
self.write_connections
.insert(msg.taker_id, msg.out_msg_actor_inbox);
.insert(msg.taker_id, msg.out_msg_actor);
Ok(())
}
}

35
daemon/src/send_to_socket.rs

@ -0,0 +1,35 @@
use futures::SinkExt;
use serde::Serialize;
use std::fmt;
use tokio::net::tcp::OwnedWriteHalf;
use tokio_util::codec::{FramedWrite, LengthDelimitedCodec};
use xtra::{Handler, Message};
pub struct Actor {
write: FramedWrite<OwnedWriteHalf, LengthDelimitedCodec>,
}
impl Actor {
pub fn new(write: OwnedWriteHalf) -> Self {
Self {
write: FramedWrite::new(write, LengthDelimitedCodec::new()),
}
}
}
#[async_trait::async_trait]
impl<T> Handler<T> for Actor
where
T: Message<Result = ()> + Serialize + fmt::Display,
{
async fn handle(&mut self, message: T, ctx: &mut xtra::Context<Self>) {
let bytes = serde_json::to_vec(&message).expect("serialization should never fail");
if let Err(e) = self.write.send(bytes.into()).await {
tracing::error!("Failed to write message {} to socket: {}", message, e);
ctx.stop();
}
}
}
impl xtra::Actor for Actor {}

31
daemon/src/send_wire_message_actor.rs

@ -1,31 +0,0 @@
use futures::{Future, SinkExt};
use serde::Serialize;
use tokio::net::tcp::OwnedWriteHalf;
use tokio::sync::mpsc;
use tokio_util::codec::{FramedWrite, LengthDelimitedCodec};
pub fn new<T>(write: OwnedWriteHalf) -> (impl Future<Output = ()>, mpsc::UnboundedSender<T>)
where
T: Serialize,
{
let (sender, mut receiver) = mpsc::unbounded_channel::<T>();
let actor = async move {
let mut framed_write = FramedWrite::new(write, LengthDelimitedCodec::new());
while let Some(message) = receiver.recv().await {
match framed_write
.send(serde_json::to_vec(&message).unwrap().into())
.await
{
Ok(_) => {}
Err(_) => {
tracing::error!("TCP connection error");
break;
}
}
}
};
(actor, sender)
}

15
daemon/src/taker.rs

@ -14,6 +14,7 @@ use std::thread::sleep;
use std::time::Duration;
use tokio::sync::watch;
use tracing_subscriber::filter::LevelFilter;
use wire::TakerToMaker;
use xtra::spawn::TokioGlobalSpawnExt;
use xtra::Actor;
@ -26,7 +27,7 @@ mod model;
mod routes;
mod routes_taker;
mod seed;
mod send_wire_message_actor;
mod send_to_socket;
mod setup_contract_actor;
mod taker_cfd;
mod taker_inc_message_actor;
@ -157,8 +158,9 @@ async fn main() -> Result<()> {
None => return Err(rocket),
};
let (out_maker_messages_actor, out_maker_actor_inbox) =
send_wire_message_actor::new(write);
let send_to_maker = send_to_socket::Actor::new(write)
.create(None)
.spawn_global();
let cfd_actor_inbox = taker_cfd::Actor::new(
db,
@ -166,7 +168,7 @@ async fn main() -> Result<()> {
schnorrsig::PublicKey::from_keypair(SECP256K1, &oracle),
cfd_feed_sender,
order_feed_sender,
out_maker_actor_inbox,
send_to_maker,
)
.await
.unwrap()
@ -178,7 +180,6 @@ async fn main() -> Result<()> {
tokio::spawn(wallet_sync::new(wallet, wallet_feed_sender));
tokio::spawn(inc_maker_messages_actor);
tokio::spawn(out_maker_messages_actor);
Ok(rocket.manage(cfd_actor_inbox))
},
@ -201,3 +202,7 @@ async fn main() -> Result<()> {
Ok(())
}
impl xtra::Message for TakerToMaker {
type Result = ();
}

19
daemon/src/taker_cfd.rs

@ -8,7 +8,7 @@ use crate::model::cfd::{Cfd, CfdState, CfdStateCommon, Dlc, Order, OrderId};
use crate::model::Usd;
use crate::wallet::Wallet;
use crate::wire::SetupMsg;
use crate::{setup_contract_actor, wire};
use crate::{send_to_socket, setup_contract_actor, wire};
use anyhow::Result;
use async_trait::async_trait;
use bdk::bitcoin::secp256k1::schnorrsig;
@ -38,7 +38,7 @@ pub struct Actor {
oracle_pk: schnorrsig::PublicKey,
cfd_feed_actor_inbox: watch::Sender<Vec<Cfd>>,
order_feed_actor_inbox: watch::Sender<Option<Order>>,
out_msg_maker_inbox: mpsc::UnboundedSender<wire::TakerToMaker>,
send_to_maker: Address<send_to_socket::Actor>,
current_contract_setup: Option<mpsc::UnboundedSender<SetupMsg>>,
// TODO: Move the contract setup into a dedicated actor and send messages to that actor that
// manages the state instead of this ugly buffer
@ -52,7 +52,7 @@ impl Actor {
oracle_pk: schnorrsig::PublicKey,
cfd_feed_actor_inbox: watch::Sender<Vec<Cfd>>,
order_feed_actor_inbox: watch::Sender<Option<Order>>,
out_msg_maker_inbox: mpsc::UnboundedSender<wire::TakerToMaker>,
send_to_maker: Address<send_to_socket::Actor>,
) -> Result<Self> {
let mut conn = db.acquire().await?;
cfd_feed_actor_inbox.send(load_all_cfds(&mut conn).await?)?;
@ -63,7 +63,7 @@ impl Actor {
oracle_pk,
cfd_feed_actor_inbox,
order_feed_actor_inbox,
out_msg_maker_inbox,
send_to_maker,
current_contract_setup: None,
contract_setup_message_buffer: vec![],
})
@ -90,8 +90,9 @@ impl Actor {
self.cfd_feed_actor_inbox
.send(load_all_cfds(&mut conn).await?)?;
self.out_msg_maker_inbox
.send(wire::TakerToMaker::TakeOrder { order_id, quantity })?;
self.send_to_maker
.do_send_async(wire::TakerToMaker::TakeOrder { order_id, quantity })
.await?;
Ok(())
}
@ -141,8 +142,10 @@ impl Actor {
let (actor, inbox) = setup_contract_actor::new(
{
let inbox = self.out_msg_maker_inbox.clone();
move |msg| inbox.send(wire::TakerToMaker::Protocol(msg)).unwrap()
let inbox = self.send_to_maker.clone();
move |msg| {
tokio::spawn(inbox.do_send_async(wire::TakerToMaker::Protocol(msg)));
}
},
setup_contract_actor::OwnParams::Taker(taker_params),
sk,

22
daemon/src/wire.rs

@ -7,6 +7,7 @@ use bdk::bitcoin::{Address, Amount, PublicKey};
use cfd_protocol::secp256k1_zkp::EcdsaAdaptorSignature;
use cfd_protocol::{CfdTransactions, PartyParams, PunishParams};
use serde::{Deserialize, Serialize};
use std::fmt;
use std::ops::RangeInclusive;
#[derive(Debug, Serialize, Deserialize)]
@ -17,6 +18,15 @@ pub enum TakerToMaker {
Protocol(SetupMsg),
}
impl fmt::Display for TakerToMaker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TakerToMaker::TakeOrder { .. } => write!(f, "TakeOrder"),
TakerToMaker::Protocol(_) => write!(f, "Protocol"),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", content = "payload")]
#[allow(clippy::large_enum_variant)]
@ -28,6 +38,18 @@ pub enum MakerToTaker {
Protocol(SetupMsg),
}
impl fmt::Display for MakerToTaker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MakerToTaker::CurrentOrder(_) => write!(f, "CurrentOrder"),
MakerToTaker::ConfirmOrder(_) => write!(f, "ConfirmOrder"),
MakerToTaker::RejectOrder(_) => write!(f, "RejectOrder"),
MakerToTaker::InvalidOrderId(_) => write!(f, "InvalidOrderId"),
MakerToTaker::Protocol(_) => write!(f, "Protocol"),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", content = "payload")]
pub enum SetupMsg {

Loading…
Cancel
Save