From b01e3b1dc659181d7ca0a2b873e5558a412ac7d2 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 15 Nov 2021 15:55:55 +1100 Subject: [PATCH] Move handling of bad connection into `send_to_taker` function This makes sure we notice bad connections upon all messages. --- daemon/src/maker_inc_connections.rs | 47 ++++++++++++++--------------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/daemon/src/maker_inc_connections.rs b/daemon/src/maker_inc_connections.rs index 63a1172..1b299f3 100644 --- a/daemon/src/maker_inc_connections.rs +++ b/daemon/src/maker_inc_connections.rs @@ -84,14 +84,18 @@ impl Actor { } } - async fn send_to_taker(&self, taker_id: TakerId, msg: wire::MakerToTaker) -> Result<()> { + async fn send_to_taker(&mut self, taker_id: &TakerId, msg: wire::MakerToTaker) -> Result<()> { let conn = self .write_connections - .get(&taker_id) + .get(taker_id) .context("no connection to taker_id")?; - // use `.send` here to ensure we only continue once the message has been sent - conn.send(msg).await?; + let msg_str = msg.to_string(); + + if conn.send(msg).await.is_err() { + tracing::info!(%taker_id, "Failed to send {} to taker, removing connection", msg_str); + self.write_connections.remove(taker_id); + } Ok(()) } @@ -156,49 +160,42 @@ impl Actor { impl Actor { async fn handle_broadcast_order(&mut self, msg: BroadcastOrder) -> Result<()> { let order = msg.0; - for (taker_id, conn) in self.write_connections.clone() { - if conn - .do_send_async(wire::MakerToTaker::CurrentOrder(order.clone())) - .await - .is_err() - { - tracing::trace!(%taker_id, "removing outdated connection to taker because unable to send order: {:?}", order); - self.write_connections.remove(&taker_id); - } else { - tracing::trace!(%taker_id, "sent new order: {:?}", order.as_ref().map(|o| o.id)); - } + for taker_id in self.write_connections.clone().keys() { + self.send_to_taker(taker_id, wire::MakerToTaker::CurrentOrder(order.clone())).await.expect("send_to_taker only fails on missing hashmap entry and we are iterating over those entries"); + tracing::trace!(%taker_id, "sent new order: {:?}", order.as_ref().map(|o| o.id)); } + Ok(()) } async fn handle_taker_message(&mut self, msg: TakerMessage) -> Result<()> { match msg.command { TakerCommand::SendOrder { order } => { - self.send_to_taker(msg.taker_id, wire::MakerToTaker::CurrentOrder(order)) + self.send_to_taker(&msg.taker_id, wire::MakerToTaker::CurrentOrder(order)) .await?; } TakerCommand::NotifyInvalidOrderId { id } => { - self.send_to_taker(msg.taker_id, wire::MakerToTaker::InvalidOrderId(id)) + self.send_to_taker(&msg.taker_id, wire::MakerToTaker::InvalidOrderId(id)) .await?; } TakerCommand::NotifyOrderAccepted { id } => { - self.send_to_taker(msg.taker_id, wire::MakerToTaker::ConfirmOrder(id)) + self.send_to_taker(&msg.taker_id, wire::MakerToTaker::ConfirmOrder(id)) .await?; } TakerCommand::NotifyOrderRejected { id } => { - self.send_to_taker(msg.taker_id, wire::MakerToTaker::RejectOrder(id)) + self.send_to_taker(&msg.taker_id, wire::MakerToTaker::RejectOrder(id)) .await?; } TakerCommand::NotifySettlementAccepted { id } => { - self.send_to_taker(msg.taker_id, wire::MakerToTaker::ConfirmSettlement(id)) + self.send_to_taker(&msg.taker_id, wire::MakerToTaker::ConfirmSettlement(id)) .await?; } TakerCommand::NotifySettlementRejected { id } => { - self.send_to_taker(msg.taker_id, wire::MakerToTaker::RejectSettlement(id)) + self.send_to_taker(&msg.taker_id, wire::MakerToTaker::RejectSettlement(id)) .await?; } TakerCommand::Protocol(setup_msg) => { - self.send_to_taker(msg.taker_id, wire::MakerToTaker::Protocol(setup_msg)) + self.send_to_taker(&msg.taker_id, wire::MakerToTaker::Protocol(setup_msg)) .await?; } TakerCommand::NotifyRollOverAccepted { @@ -206,7 +203,7 @@ impl Actor { oracle_event_id, } => { self.send_to_taker( - msg.taker_id, + &msg.taker_id, wire::MakerToTaker::ConfirmRollOver { order_id: id, oracle_event_id, @@ -215,12 +212,12 @@ impl Actor { .await?; } TakerCommand::NotifyRollOverRejected { id } => { - self.send_to_taker(msg.taker_id, wire::MakerToTaker::RejectRollOver(id)) + self.send_to_taker(&msg.taker_id, wire::MakerToTaker::RejectRollOver(id)) .await?; } TakerCommand::RollOverProtocol(roll_over_msg) => { self.send_to_taker( - msg.taker_id, + &msg.taker_id, wire::MakerToTaker::RollOverProtocol(roll_over_msg), ) .await?;