Browse Source

Merge #927

927: Remove `send_to_socket::Actor` r=thomaseizinger a=thomaseizinger

This removes a level of indirection for actually sending messages which
simplifies the overall handling of the connection and unifies how
heartbeats are sent to the taker. This is important because we don't
want heartbeats to be handled in a special way as we use them to detect
issues with the connection.

May help with https://github.com/itchysats/itchysats/issues/759.

Co-authored-by: Thomas Eizinger <thomas@eizinger.io>
update-blockstream-electrum-server-url
bors[bot] 3 years ago
committed by GitHub
parent
commit
84c70cac65
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 122
      daemon/src/connection.rs
  2. 1
      daemon/src/lib.rs
  3. 99
      daemon/src/maker_inc_connections.rs
  4. 52
      daemon/src/send_to_socket.rs
  5. 7
      daemon/src/wire.rs

122
daemon/src/connection.rs

@ -8,7 +8,6 @@ use crate::model::Timestamp;
use crate::model::Usd;
use crate::noise;
use crate::rollover_taker;
use crate::send_to_socket;
use crate::setup_taker;
use crate::taker_cfd::CurrentOrder;
use crate::tokio_ext::FutureExt;
@ -38,21 +37,75 @@ use xtra_productivity::xtra_productivity;
/// Time between reconnection attempts
const CONNECT_TO_MAKER_INTERVAL: Duration = Duration::from_secs(5);
struct ConnectedState {
last_heartbeat: SystemTime,
_tasks: Tasks,
/// The "Connected" state of our connection with the maker.
#[allow(clippy::large_enum_variant)]
enum State {
Connected {
last_heartbeat: SystemTime,
write: wire::Write<wire::MakerToTaker, wire::TakerToMaker>,
_tasks: Tasks,
},
Disconnected,
}
impl State {
async fn send(&mut self, msg: wire::TakerToMaker) -> Result<()> {
let msg_str = msg.to_string();
let write = match self {
State::Connected { write, .. } => write,
State::Disconnected => {
bail!("Cannot send {}, not connected to maker", msg_str);
}
};
tracing::trace!(target = "wire", "Sending {}", msg_str);
write
.send(msg)
.await
.with_context(|| format!("Failed to send message {} to maker", msg_str))?;
Ok(())
}
fn handle_incoming_heartbeat(&mut self) {
match self {
State::Connected { last_heartbeat, .. } => {
*last_heartbeat = SystemTime::now();
}
State::Disconnected => {
debug_assert!(false, "Received heartbeat in disconnected state")
}
}
}
fn disconnect_if_last_heartbeat_older_than(&mut self, timeout: Duration) -> bool {
let duration_since_last_heartbeat = match self {
State::Connected { last_heartbeat, .. } => SystemTime::now()
.duration_since(*last_heartbeat)
.expect("clock is monotonic"),
State::Disconnected => return false,
};
if duration_since_last_heartbeat < timeout {
return false;
}
*self = State::Disconnected;
true
}
}
pub struct Actor {
status_sender: watch::Sender<ConnectionStatus>,
send_to_maker: Box<dyn MessageChannel<wire::TakerToMaker>>,
send_to_maker_ctx: xtra::Context<send_to_socket::Actor<wire::MakerToTaker, wire::TakerToMaker>>,
identity_sk: x25519_dalek::StaticSecret,
current_order: Box<dyn MessageChannel<CurrentOrder>>,
/// Max duration since the last heartbeat until we die.
heartbeat_timeout: Duration,
connect_timeout: Duration,
connected_state: Option<ConnectedState>,
state: State,
setup_actors: AddressMap<OrderId, setup_taker::Actor>,
collab_settlement_actors: AddressMap<OrderId, collab_settlement_taker::Actor>,
rollover_actors: AddressMap<OrderId, rollover_taker::Actor>,
@ -122,16 +175,12 @@ impl Actor {
hearthbeat_timeout: Duration,
connect_timeout: Duration,
) -> Self {
let (send_to_maker_addr, send_to_maker_ctx) = xtra::Context::new(None);
Self {
status_sender,
send_to_maker: Box::new(send_to_maker_addr),
send_to_maker_ctx,
identity_sk,
current_order: current_order.clone_channel(),
heartbeat_timeout: hearthbeat_timeout,
connected_state: None,
state: State::Disconnected,
setup_actors: AddressMap::default(),
connect_timeout,
collab_settlement_actors: AddressMap::default(),
@ -142,13 +191,6 @@ impl Actor {
#[xtra_productivity(message_impl = false)]
impl Actor {
async fn handle_taker_to_maker(&mut self, message: wire::TakerToMaker) {
let msg_str = message.to_string();
if self.send_to_maker.send(message).await.is_err() {
tracing::warn!("Failed to send wire message {} to maker", msg_str);
}
}
async fn handle_collab_settlement_actor_stopping(
&mut self,
message: Stopping<collab_settlement_taker::Actor>,
@ -163,8 +205,14 @@ impl Actor {
#[xtra_productivity]
impl Actor {
async fn handle_taker_to_maker(&mut self, message: wire::TakerToMaker) {
if let Err(e) = self.state.send(message).await {
tracing::warn!("{:#}", e);
}
}
async fn handle_take_order(&mut self, msg: TakeOrder) -> Result<()> {
self.send_to_maker
self.state
.send(wire::TakerToMaker::TakeOrder {
order_id: msg.order_id,
quantity: msg.quantity,
@ -186,7 +234,7 @@ impl Actor {
address,
} = msg;
self.send_to_maker
self.state
.send(wire::TakerToMaker::Settlement {
order_id,
msg: wire::taker_to_maker::Settlement::Propose {
@ -210,7 +258,7 @@ impl Actor {
address,
} = msg;
self.send_to_maker
self.state
.send(wire::TakerToMaker::ProposeRollOver {
order_id,
timestamp,
@ -301,20 +349,18 @@ impl Actor {
let this = ctx.address().expect("self to be alive");
let send_to_socket = send_to_socket::Actor::new(write);
let mut tasks = Tasks::default();
tasks.add(self.send_to_maker_ctx.attach(send_to_socket));
tasks.add(this.attach_stream(read.map(move |item| MakerStreamMessage { item })));
tasks.add(
ctx.notify_interval(self.heartbeat_timeout, || MeasurePulse)
.expect("we just started"),
);
self.connected_state = Some(ConnectedState {
self.state = State::Connected {
last_heartbeat: SystemTime::now(),
write,
_tasks: tasks,
});
};
self.status_sender
.send(ConnectionStatus::Online)
.expect("receiver to outlive the actor");
@ -335,14 +381,11 @@ impl Actor {
}
};
tracing::trace!("Received '{}'", msg);
tracing::trace!(target = "wire", "Received {}", msg);
match msg {
wire::MakerToTaker::Heartbeat => {
self.connected_state
.as_mut()
.expect("wire messages only to arrive in connected state")
.last_heartbeat = SystemTime::now();
self.state.handle_incoming_heartbeat();
}
wire::MakerToTaker::ConfirmOrder(order_id) => {
if self
@ -440,20 +483,13 @@ impl Actor {
}
fn handle_measure_pulse(&mut self, _: MeasurePulse) {
let time_since_last_heartbeat = SystemTime::now()
.duration_since(
self.connected_state
.as_ref()
.expect("only run pulse measurements if connected")
.last_heartbeat,
)
.expect("now is always later than heartbeat");
if time_since_last_heartbeat > self.heartbeat_timeout {
if self
.state
.disconnect_if_last_heartbeat_older_than(self.heartbeat_timeout)
{
self.status_sender
.send(ConnectionStatus::Offline { reason: None })
.expect("watch receiver to outlive the actor");
self.connected_state = None;
}
}
}

1
daemon/src/lib.rs

@ -62,7 +62,6 @@ pub mod rollover_taker;
pub mod routes;
pub mod seed;
pub mod send_async_safe;
pub mod send_to_socket;
pub mod setup_contract;
pub mod setup_maker;
pub mod setup_taker;

99
daemon/src/maker_inc_connections.rs

@ -11,7 +11,6 @@ use crate::model::Identity;
use crate::noise;
use crate::noise::TransportStateExt;
use crate::rollover_maker;
use crate::send_to_socket;
use crate::setup_maker;
use crate::tokio_ext::FutureExt;
use crate::wire;
@ -97,8 +96,7 @@ pub enum ListenerMessage {
}
pub struct Actor {
write_connections:
HashMap<Identity, Address<send_to_socket::Actor<wire::TakerToMaker, wire::MakerToTaker>>>,
connections: HashMap<Identity, Connection>,
taker_connected_channel: Box<dyn MessageChannel<TakerConnected>>,
taker_disconnected_channel: Box<dyn MessageChannel<TakerDisconnected>>,
taker_msg_channel: Box<dyn MessageChannel<FromTaker>>,
@ -107,7 +105,28 @@ pub struct Actor {
setup_actors: AddressMap<OrderId, setup_maker::Actor>,
settlement_actors: AddressMap<OrderId, collab_settlement_maker::Actor>,
rollover_actors: AddressMap<OrderId, rollover_maker::Actor>,
connection_tasks: HashMap<Identity, Tasks>,
}
/// A connection to a taker.
struct Connection {
taker: Identity,
write: wire::Write<wire::TakerToMaker, wire::MakerToTaker>,
_tasks: Tasks,
}
impl Connection {
async fn send(&mut self, msg: wire::MakerToTaker) -> Result<()> {
let msg_str = msg.to_string();
tracing::trace!(target = "wire", taker_id = %self.taker, "Sending {}", msg_str);
self.write
.send(msg)
.await
.with_context(|| format!("Failed to send msg {} to taker {}", msg_str, self.taker))?;
Ok(())
}
}
impl Actor {
@ -119,7 +138,7 @@ impl Actor {
heartbeat_interval: Duration,
) -> Self {
Self {
write_connections: HashMap::new(),
connections: HashMap::new(),
taker_connected_channel: taker_connected_channel.clone_channel(),
taker_disconnected_channel: taker_disconnected_channel.clone_channel(),
taker_msg_channel: taker_msg_channel.clone_channel(),
@ -128,19 +147,17 @@ impl Actor {
setup_actors: AddressMap::default(),
settlement_actors: AddressMap::default(),
rollover_actors: AddressMap::default(),
connection_tasks: HashMap::new(),
}
}
async fn drop_taker_connection(&mut self, taker_id: &Identity) {
if self.write_connections.remove(taker_id).is_some() {
if self.connections.remove(taker_id).is_some() {
tracing::info!(%taker_id, "Dropping connection");
let _ = self
.taker_disconnected_channel
.send(maker_cfd::TakerDisconnected { id: *taker_id })
.log_failure("Failed to inform about taker disconnect")
.await;
let _ = self.connection_tasks.remove(taker_id);
}
}
@ -150,14 +167,11 @@ impl Actor {
msg: wire::MakerToTaker,
) -> Result<(), NoConnection> {
let conn = self
.write_connections
.get(taker_id)
.connections
.get_mut(taker_id)
.ok_or_else(|| NoConnection(*taker_id))?;
let msg_str = msg.to_string();
if conn.send(msg).await.is_err() {
tracing::error!(%taker_id, "Failed to send message to taker: {}", msg_str);
self.drop_taker_connection(taker_id).await;
return Err(NoConnection(*taker_id));
}
@ -226,22 +240,22 @@ impl Actor {
let _ = this.send(ReadFail(taker_id)).await;
};
let (out_msg, mut out_msg_actor_context) = xtra::Context::new(None);
let send_to_socket_actor = send_to_socket::Actor::new(write);
let heartbeat_fut = out_msg_actor_context
.notify_interval(self.heartbeat_interval, || wire::MakerToTaker::Heartbeat)
let heartbeat_fut = ctx
.notify_interval(self.heartbeat_interval, move || SendHeartbeat(taker_id))
.expect("actor not to shutdown");
let write_fut = out_msg_actor_context.run(send_to_socket_actor);
self.write_connections.insert(taker_id, out_msg);
let mut tasks = Tasks::default();
tasks.add(read_fut);
tasks.add(heartbeat_fut);
tasks.add(write_fut);
self.connection_tasks.insert(taker_id, tasks);
self.connections.insert(
taker_id,
Connection {
_tasks: tasks,
taker: taker_id,
write,
},
);
let _ = self
.taker_connected_channel
@ -253,6 +267,8 @@ impl Actor {
}
}
pub struct SendHeartbeat(Identity);
#[derive(Debug, thiserror::Error)]
#[error("No connection to taker {0}")]
pub struct NoConnection(Identity);
@ -261,10 +277,35 @@ pub struct NoConnection(Identity);
impl Actor {
async fn handle_broadcast_order(&mut self, msg: BroadcastOrder) {
let order = msg.0;
for taker_id in self.write_connections.clone().keys() {
self.send_to_taker(taker_id, wire::MakerToTaker::CurrentOrder(order.clone())).await.expect("send_to_taker only fails on missing hashmap entry and we are iterating over those entries");
tracing::trace!(%taker_id, "sent new order: {:?}", order.as_ref().map(|o| o.id));
let mut broken_connections = Vec::with_capacity(self.connections.len());
for (id, conn) in &mut self.connections {
if let Err(e) = conn
.send(wire::MakerToTaker::CurrentOrder(order.clone()))
.await
{
tracing::warn!("{:#}", e);
broken_connections.push(*id);
continue;
}
tracing::trace!(taker_id = %id, "Sent new order: {:?}", order.as_ref().map(|o| o.id));
}
for id in broken_connections {
self.drop_taker_connection(&id).await;
}
}
async fn handle_send_heartbeat(&mut self, msg: SendHeartbeat) {
let result = self
.send_to_taker(&msg.0, wire::MakerToTaker::Heartbeat)
.await;
// use explicit match on `Err` to catch fn signature changes
debug_assert!(!matches!(result, Err(NoConnection(_))), "`send_to_taker` only fails if we don't have a HashMap entry. We clean those up together with the heartbeat task. How did we get called without a connection?");
}
async fn handle_confirm_order(&mut self, msg: ConfirmOrder) -> Result<()> {
@ -341,6 +382,10 @@ impl Actor {
#[xtra_productivity(message_impl = false)]
impl Actor {
async fn handle_msg_from_taker(&mut self, msg: FromTaker) -> Result<()> {
let msg_str = msg.msg.to_string();
tracing::trace!(target = "wire", taker_id = %msg.taker_id, "Received {}", msg_str);
use wire::TakerToMaker::*;
match msg.msg {
Protocol { order_id, msg } => match self.setup_actors.get_connected(&order_id) {

52
daemon/src/send_to_socket.rs

@ -1,52 +0,0 @@
use crate::wire;
use futures::stream::SplitSink;
use futures::SinkExt;
use serde::Serialize;
use std::fmt;
use tokio::net::TcpStream;
use tokio_util::codec::Framed;
use xtra::Handler;
use xtra::Message;
pub struct Actor<D, E> {
write: SplitSink<Framed<TcpStream, wire::EncryptedJsonCodec<D, E>>, E>,
}
impl<D, E> Actor<D, E> {
pub fn new(write: SplitSink<Framed<TcpStream, wire::EncryptedJsonCodec<D, E>>, E>) -> Self {
Self { write }
}
}
#[async_trait::async_trait]
impl<D, E> Handler<E> for Actor<D, E>
where
D: Send + Sync + 'static,
E: Message<Result = ()> + Serialize + fmt::Display + Sync,
{
async fn handle(&mut self, message: E, ctx: &mut xtra::Context<Self>) {
let message_name = message.to_string(); // send consumes the message, avoid a clone just in case it errors by getting the name here
tracing::trace!("Sending '{}'", message_name);
if let Err(e) = self.write.send(message).await {
tracing::error!("Failed to write message {} to socket: {}", message_name, e);
ctx.stop();
}
}
}
impl<D, E> xtra::Actor for Actor<D, E>
where
D: 'static + Send,
E: 'static + Send,
{
}
impl xtra::Message for wire::MakerToTaker {
type Result = ();
}
impl xtra::Message for wire::TakerToMaker {
type Result = ();
}

7
daemon/src/wire.rs

@ -14,6 +14,8 @@ use bdk::bitcoin::Address;
use bdk::bitcoin::Amount;
use bdk::bitcoin::PublicKey;
use bytes::BytesMut;
use futures::stream::SplitSink;
use futures::stream::SplitStream;
use maia::secp256k1_zkp::EcdsaAdaptorSignature;
use maia::secp256k1_zkp::SecretKey;
use maia::CfdTransactions;
@ -27,10 +29,15 @@ use std::collections::HashMap;
use std::fmt;
use std::marker::PhantomData;
use std::ops::RangeInclusive;
use tokio::net::TcpStream;
use tokio_util::codec::Decoder;
use tokio_util::codec::Encoder;
use tokio_util::codec::Framed;
use tokio_util::codec::LengthDelimitedCodec;
pub type Read<D, E> = SplitStream<Framed<TcpStream, EncryptedJsonCodec<D, E>>>;
pub type Write<D, E> = SplitSink<Framed<TcpStream, EncryptedJsonCodec<D, E>>, E>;
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, PartialOrd)]
pub struct Version(semver::Version);

Loading…
Cancel
Save