From bea968d27b81d2b2e87451c8d024aa7f6f6f97ac Mon Sep 17 00:00:00 2001 From: Mariusz Klochowicz Date: Wed, 17 Nov 2021 14:32:22 +1030 Subject: [PATCH 1/4] Tie lifetimes of spawned tasks to actors --- daemon/src/connection.rs | 20 ++++++++++++-------- daemon/src/maker.rs | 8 +++++--- daemon/src/maker_cfd.rs | 30 +++++++++++++++++++----------- daemon/src/taker.rs | 8 ++++---- daemon/src/taker_cfd.rs | 30 +++++++++++++++++++++--------- daemon/tests/harness/mod.rs | 22 +++++++++++++++++++--- 6 files changed, 80 insertions(+), 38 deletions(-) diff --git a/daemon/src/connection.rs b/daemon/src/connection.rs index 40f0f95..b16f045 100644 --- a/daemon/src/connection.rs +++ b/daemon/src/connection.rs @@ -15,7 +15,7 @@ use xtra_productivity::xtra_productivity; struct ConnectedState { last_heartbeat: SystemTime, - _pulse_handle: RemoteHandle<()>, + _tasks: Vec>, } pub struct Actor { @@ -97,22 +97,26 @@ impl Actor { let send_to_socket = send_to_socket::Actor::new(write, noise.clone()); - tokio::spawn(self.send_to_maker_ctx.attach(send_to_socket)); + let mut tasks = vec![self + .send_to_maker_ctx + .attach(send_to_socket) + .spawn_with_handle()]; let read = FramedRead::new(read, wire::EncryptedJsonCodec::new(noise)) .map(move |item| MakerStreamMessage { item }); let this = ctx.address().expect("self to be alive"); - tokio::spawn(this.attach_stream(read)); + tasks.push(this.attach_stream(read).spawn_with_handle()); - let pulse_remote_handle = ctx - .notify_interval(self.timeout, || MeasurePulse) - .expect("we just started") - .spawn_with_handle(); + tasks.push( + ctx.notify_interval(self.timeout, || MeasurePulse) + .expect("we just started") + .spawn_with_handle(), + ); self.connected_state = Some(ConnectedState { last_heartbeat: SystemTime::now(), - _pulse_handle: pulse_remote_handle, + _tasks: tasks, }); self.status_sender .send(ConnectionStatus::Online) diff --git a/daemon/src/maker.rs b/daemon/src/maker.rs index a8c7e0d..930f64e 100644 --- a/daemon/src/maker.rs +++ b/daemon/src/maker.rs @@ -224,7 +224,7 @@ async fn main() -> Result<()> { tracing::info!("Listening on {}", local_addr); let (task, quote_updates) = bitmex_price_feed::new().await?; - tokio::spawn(task); + let _task = task.spawn_with_handle(); let db = SqlitePool::connect_with( SqliteConnectOptions::new() @@ -282,9 +282,11 @@ async fn main() -> Result<()> { Poll::Ready(Some(message)) }); - tokio::spawn(incoming_connection_addr.attach_stream(listener_stream)); + let _listener_task = incoming_connection_addr + .attach_stream(listener_stream) + .spawn_with_handle(); - tokio::spawn(wallet_sync::new(wallet, wallet_feed_sender)); + let _wallet_sync_task = wallet_sync::new(wallet, wallet_feed_sender).spawn_with_handle(); let cfd_action_channel = MessageChannel::::clone_channel(&cfd_actor_addr); let new_order_channel = MessageChannel::::clone_channel(&cfd_actor_addr); diff --git a/daemon/src/maker_cfd.rs b/daemon/src/maker_cfd.rs index 4f3afd2..fae4699 100644 --- a/daemon/src/maker_cfd.rs +++ b/daemon/src/maker_cfd.rs @@ -8,11 +8,13 @@ use crate::model::cfd::{ }; use crate::model::{Price, TakerId, Timestamp, Usd}; use crate::monitor::MonitorParams; +use crate::tokio_ext::FutureExt; use crate::{log_error, maker_inc_connections, monitor, oracle, setup_contract, wallet, wire}; use anyhow::{Context as _, Result}; use async_trait::async_trait; use bdk::bitcoin::secp256k1::schnorrsig; use futures::channel::mpsc; +use futures::future::RemoteHandle; use futures::{future, SinkExt}; use maia::secp256k1_zkp::Signature; use sqlx::pool::PoolConnection; @@ -81,6 +83,7 @@ enum SetupState { Active { taker: TakerId, sender: mpsc::UnboundedSender, + _task: RemoteHandle<()>, }, None, } @@ -89,6 +92,7 @@ enum RollOverState { Active { taker: TakerId, sender: mpsc::UnboundedSender, + _task: RemoteHandle<()>, }, None, } @@ -198,7 +202,7 @@ impl Actor { msg: wire::SetupMsg, ) -> Result<()> { match &mut self.setup_state { - SetupState::Active { taker, sender } if taker_id == *taker => { + SetupState::Active { taker, sender, .. } if taker_id == *taker => { sender.send(msg).await?; } SetupState::Active { taker, .. } => { @@ -218,7 +222,7 @@ impl Actor { msg: wire::RollOverMsg, ) -> Result<()> { match &mut self.roll_over_state { - RollOverState::Active { taker, sender } if taker_id == *taker => { + RollOverState::Active { taker, sender, .. } if taker_id == *taker => { sender.send(msg).await?; } RollOverState::Active { taker, .. } => { @@ -622,18 +626,20 @@ where .address() .expect("actor to be able to give address to itself"); - tokio::spawn(async move { + let task = async move { let dlc = contract_future.await; this.send(CfdSetupCompleted { order_id, dlc }) .await .expect("always connected to ourselves"); - }); + } + .spawn_with_handle(); // 6. Record that we are in an active contract setup self.setup_state = SetupState::Active { sender, taker: taker_id, + _task: task, }; Ok(()) @@ -785,18 +791,20 @@ where .address() .expect("actor to be able to give address to itself"); - self.roll_over_state = RollOverState::Active { - sender, - taker: taker_id, - }; - - tokio::spawn(async move { + let task = async move { let dlc = contract_future.await; this.send(CfdRollOverCompleted { order_id, dlc }) .await .expect("always connected to ourselves") - }); + } + .spawn_with_handle(); + + self.roll_over_state = RollOverState::Active { + sender, + taker: taker_id, + _task: task, + }; self.remove_pending_proposal(&order_id) .context("accepted roll_over")?; diff --git a/daemon/src/taker.rs b/daemon/src/taker.rs index 6cc1c95..c2aa29f 100644 --- a/daemon/src/taker.rs +++ b/daemon/src/taker.rs @@ -208,7 +208,7 @@ async fn main() -> Result<()> { let (wallet_feed_sender, wallet_feed_receiver) = watch::channel::(wallet_info); let (task, quote_updates) = bitmex_price_feed::new().await?; - tokio::spawn(task); + let _task = task.spawn_with_handle(); let figment = rocket::Config::figment() .merge(("address", opts.http_address.ip())) @@ -258,7 +258,7 @@ async fn main() -> Result<()> { connect(connection_actor_addr, opts.maker_id, opts.maker).await?; - tokio::spawn(wallet_sync::new(wallet, wallet_feed_sender)); + let _wallet_sync_task = wallet_sync::new(wallet, wallet_feed_sender).spawn_with_handle(); let take_offer_channel = MessageChannel::::clone_channel(&cfd_actor_addr); let cfd_action_channel = MessageChannel::::clone_channel(&cfd_actor_addr); @@ -290,7 +290,7 @@ async fn main() -> Result<()> { let shutdown_handle = rocket.shutdown(); // shutdown the rocket server maker if goes offline - tokio::spawn(async move { + let _rocket_shutdown_task = (async move { loop { maker_online_status_feed_receiver.changed().await.unwrap(); if maker_online_status_feed_receiver.borrow().clone() == ConnectionStatus::Offline { @@ -299,7 +299,7 @@ async fn main() -> Result<()> { return; } } - }); + }).spawn_with_handle(); rocket.launch().await?; db.close().await; diff --git a/daemon/src/taker_cfd.rs b/daemon/src/taker_cfd.rs index 3ac2ce1..2887816 100644 --- a/daemon/src/taker_cfd.rs +++ b/daemon/src/taker_cfd.rs @@ -7,12 +7,14 @@ use crate::model::cfd::{ }; use crate::model::{BitMexPriceEventId, Price, Timestamp, Usd}; use crate::monitor::{self, MonitorParams}; +use crate::tokio_ext::FutureExt; use crate::wire::{MakerToTaker, RollOverMsg, SetupMsg}; use crate::{log_error, oracle, setup_contract, wallet, wire}; use anyhow::{bail, Context as _, Result}; use async_trait::async_trait; use bdk::bitcoin::secp256k1::schnorrsig; use futures::channel::mpsc; +use futures::future::RemoteHandle; use futures::{future, SinkExt}; use std::collections::HashMap; use tokio::sync::watch; @@ -49,6 +51,7 @@ pub struct CfdRollOverCompleted { enum SetupState { Active { sender: mpsc::UnboundedSender, + _task: RemoteHandle<()>, }, None, } @@ -56,6 +59,7 @@ enum SetupState { enum RollOverState { Active { sender: mpsc::UnboundedSender, + _task: RemoteHandle<()>, }, None, } @@ -252,7 +256,7 @@ where async fn handle_inc_protocol_msg(&mut self, msg: SetupMsg) -> Result<()> { match &mut self.setup_state { - SetupState::Active { sender } => { + SetupState::Active { sender, .. } => { sender.send(msg).await?; } SetupState::None => { @@ -265,7 +269,7 @@ where async fn handle_inc_roll_over_msg(&mut self, msg: RollOverMsg) -> Result<()> { match &mut self.roll_over_state { - RollOverState::Active { sender } => { + RollOverState::Active { sender, .. } => { sender.send(msg).await?; } RollOverState::None => { @@ -515,15 +519,19 @@ where .address() .expect("actor to be able to give address to itself"); - tokio::spawn(async move { + let task = async move { let dlc = contract_future.await; this.send(CfdSetupCompleted { order_id, dlc }) .await .expect("always connected to ourselves") - }); + } + .spawn_with_handle(); - self.setup_state = SetupState::Active { sender }; + self.setup_state = SetupState::Active { + sender, + _task: task, + }; Ok(()) } @@ -578,15 +586,19 @@ where .address() .expect("actor to be able to give address to itself"); - self.roll_over_state = RollOverState::Active { sender }; - - tokio::spawn(async move { + let task = async move { let dlc = contract_future.await; this.send(CfdRollOverCompleted { order_id, dlc }) .await .expect("always connected to ourselves") - }); + } + .spawn_with_handle(); + + self.roll_over_state = RollOverState::Active { + sender, + _task: task, + }; self.remove_pending_proposal(&order_id) .context("Could not remove accepted roll over")?; diff --git a/daemon/tests/harness/mod.rs b/daemon/tests/harness/mod.rs index a564326..e7dc818 100644 --- a/daemon/tests/harness/mod.rs +++ b/daemon/tests/harness/mod.rs @@ -7,7 +7,9 @@ use daemon::maker_cfd::CfdAction; use daemon::model::cfd::{Cfd, Order, Origin}; use daemon::model::{Price, Usd}; use daemon::seed::Seed; +use daemon::tokio_ext::FutureExt; use daemon::{db, maker_cfd, maker_inc_connections, taker_cfd, MakerActorSystem}; +use futures::future::RemoteHandle; use rust_decimal_macros::dec; use sqlx::SqlitePool; use std::net::SocketAddr; @@ -47,6 +49,7 @@ pub struct Maker { pub mocks: mocks::Mocks, pub listen_addr: SocketAddr, pub identity_pk: x25519_dalek::PublicKey, + _tasks: Vec>, } impl Maker { @@ -64,8 +67,10 @@ impl Maker { let mut mocks = mocks::Mocks::default(); let (oracle, monitor, wallet) = mocks::create_actors(&mocks); + let mut tasks = vec![]; + let (wallet_addr, wallet_fut) = wallet.create(None).run(); - tokio::spawn(wallet_fut); + tasks.push(wallet_fut.spawn_with_handle()); let settlement_time_interval_hours = time::Duration::hours(24); @@ -109,13 +114,20 @@ impl Maker { Poll::Ready(Some(message)) }); - tokio::spawn(maker.inc_conn_addr.clone().attach_stream(listener_stream)); + tasks.push( + maker + .inc_conn_addr + .clone() + .attach_stream(listener_stream) + .spawn_with_handle(), + ); Self { system: maker, identity_pk, listen_addr: address, mocks, + _tasks: tasks, } } @@ -153,6 +165,7 @@ impl Maker { pub struct Taker { pub system: daemon::TakerActorSystem, pub mocks: mocks::Mocks, + _tasks: Vec>, } impl Taker { @@ -182,8 +195,10 @@ impl Taker { let mut mocks = mocks::Mocks::default(); let (oracle, monitor, wallet) = mocks::create_actors(&mocks); + let mut tasks = vec![]; + let (wallet_addr, wallet_fut) = wallet.create(None).run(); - tokio::spawn(wallet_fut); + tasks.push(wallet_fut.spawn_with_handle()); // system startup sends sync messages, mock them mocks.mock_sync_handlers().await; @@ -213,6 +228,7 @@ impl Taker { Self { system: taker, mocks, + _tasks: tasks, } } From 6afe3f8e6c1c1cbfe875749a975af58a6e0597b8 Mon Sep 17 00:00:00 2001 From: Mariusz Klochowicz Date: Wed, 17 Nov 2021 15:27:26 +1030 Subject: [PATCH 2/4] Disallow calls to tokio::spawn outside dedicated extension trait Prevent accidentally spawning a task that could outlive the actor system and become a zombie. Any long-lived task should be tied to its owner by keeping RemoteHandle around. --- daemon/clippy.toml | 3 +++ daemon/src/lib.rs | 1 - daemon/src/tokio_ext.rs | 4 ++++ 3 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 daemon/clippy.toml diff --git a/daemon/clippy.toml b/daemon/clippy.toml new file mode 100644 index 0000000..1e152c1 --- /dev/null +++ b/daemon/clippy.toml @@ -0,0 +1,3 @@ +disallowed-methods = [ + "tokio::spawn", # tasks can outlive the actor system, prefer spawn_with_handle() +] diff --git a/daemon/src/lib.rs b/daemon/src/lib.rs index 7530aec..3f31aa0 100644 --- a/daemon/src/lib.rs +++ b/daemon/src/lib.rs @@ -1,6 +1,5 @@ #![cfg_attr(not(test), warn(clippy::unwrap_used))] #![warn(clippy::disallowed_method)] - use crate::db::load_all_cfds; use crate::maker_cfd::{FromTaker, NewTakerOnline}; use crate::model::cfd::{Cfd, Order, UpdateCfdProposals}; diff --git a/daemon/src/tokio_ext.rs b/daemon/src/tokio_ext.rs index f5783b2..c004af6 100644 --- a/daemon/src/tokio_ext.rs +++ b/daemon/src/tokio_ext.rs @@ -10,6 +10,8 @@ where F: Future> + Send + 'static, E: fmt::Display, { + // we want to disallow calls to tokio::spawn outside FutureExt + #[allow(clippy::disallowed_method)] tokio::spawn(async move { if let Err(e) = future.await { tracing::warn!("Task failed: {:#}", e); @@ -40,6 +42,8 @@ where Self: Future + Send + 'static, { let (future, handle) = self.remote_handle(); + // we want to disallow calls to tokio::spawn outside FutureExt + #[allow(clippy::disallowed_method)] tokio::spawn(future); handle } From 855181f06cd445c2c6495be4d780c0d28b9a7fcc Mon Sep 17 00:00:00 2001 From: Mariusz Klochowicz Date: Wed, 17 Nov 2021 16:15:02 +1030 Subject: [PATCH 3/4] Expand Tasks API for ergonomics The fact that we're storing a Vec) internally is an implementation detail. Calling `spawn_with_handle` is done internally now, guarded by a `debug_assert!` macro discouraging from doing it twice. --- daemon/src/connection.rs | 19 ++--- daemon/src/lib.rs | 104 +++++++++++++--------------- daemon/src/maker.rs | 12 ++-- daemon/src/maker_inc_connections.rs | 36 +++++----- daemon/src/taker.rs | 12 ++-- daemon/src/tokio_ext.rs | 44 +++++++++++- daemon/tests/harness/mod.rs | 24 +++---- 7 files changed, 135 insertions(+), 116 deletions(-) diff --git a/daemon/src/connection.rs b/daemon/src/connection.rs index b16f045..7d0ac5b 100644 --- a/daemon/src/connection.rs +++ b/daemon/src/connection.rs @@ -1,7 +1,5 @@ -use crate::tokio_ext::FutureExt; -use crate::{log_error, noise, send_to_socket, wire}; +use crate::{log_error, noise, send_to_socket, wire, Tasks}; use anyhow::Result; -use futures::future::RemoteHandle; use futures::StreamExt; use std::net::SocketAddr; use std::sync::{Arc, Mutex}; @@ -15,7 +13,7 @@ use xtra_productivity::xtra_productivity; struct ConnectedState { last_heartbeat: SystemTime, - _tasks: Vec>, + _tasks: Tasks, } pub struct Actor { @@ -97,21 +95,18 @@ impl Actor { let send_to_socket = send_to_socket::Actor::new(write, noise.clone()); - let mut tasks = vec![self - .send_to_maker_ctx - .attach(send_to_socket) - .spawn_with_handle()]; + let mut tasks = Tasks::default(); + tasks.add(self.send_to_maker_ctx.attach(send_to_socket)); let read = FramedRead::new(read, wire::EncryptedJsonCodec::new(noise)) .map(move |item| MakerStreamMessage { item }); let this = ctx.address().expect("self to be alive"); - tasks.push(this.attach_stream(read).spawn_with_handle()); + tasks.add(this.attach_stream(read)); - tasks.push( + tasks.add( ctx.notify_interval(self.timeout, || MeasurePulse) - .expect("we just started") - .spawn_with_handle(), + .expect("we just started"), ); self.connected_state = Some(ConnectedState { diff --git a/daemon/src/lib.rs b/daemon/src/lib.rs index 3f31aa0..164251d 100644 --- a/daemon/src/lib.rs +++ b/daemon/src/lib.rs @@ -57,6 +57,22 @@ pub const N_PAYOUTS: usize = 200; /// If it gets dropped, all tasks are cancelled. pub struct Tasks(Vec>); +impl Tasks { + /// Spawn the task on the runtime and remembers the handle + /// NOTE: Do *not* call spawn_with_handle() before calling `add`, + /// such calls will trigger panic in debug mode. + pub fn add(&mut self, f: impl Future + Send + 'static) { + let handle = f.spawn_with_handle(); + self.0.push(handle); + } +} + +impl Default for Tasks { + fn default() -> Self { + Tasks(vec![]) + } +} + pub struct MakerActorSystem { pub cfd_actor_addr: Address>, pub cfd_feed_receiver: watch::Receiver>, @@ -112,7 +128,7 @@ where let (oracle_addr, mut oracle_ctx) = xtra::Context::new(None); let (inc_conn_addr, inc_conn_ctx) = xtra::Context::new(None); - let mut tasks = vec![]; + let mut tasks = Tasks::default(); let (cfd_actor_addr, cfd_actor_fut) = maker_cfd::Actor::new( db, @@ -130,46 +146,35 @@ where .create(None) .run(); - tasks.push(cfd_actor_fut.spawn_with_handle()); + tasks.add(cfd_actor_fut); - tasks.push( - inc_conn_ctx - .run(inc_conn_constructor( - Box::new(cfd_actor_addr.clone()), - Box::new(cfd_actor_addr.clone()), - )) - .spawn_with_handle(), - ); + tasks.add(inc_conn_ctx.run(inc_conn_constructor( + Box::new(cfd_actor_addr.clone()), + Box::new(cfd_actor_addr.clone()), + ))); - tasks.push( + tasks.add( monitor_ctx .notify_interval(Duration::from_secs(20), || monitor::Sync) - .map_err(|e| anyhow::anyhow!(e))? - .spawn_with_handle(), + .map_err(|e| anyhow::anyhow!(e))?, ); - tasks.push( + tasks.add( monitor_ctx - .run(monitor_constructor(Box::new(cfd_actor_addr.clone()), cfds.clone()).await?) - .spawn_with_handle(), + .run(monitor_constructor(Box::new(cfd_actor_addr.clone()), cfds.clone()).await?), ); - tasks.push( + tasks.add( oracle_ctx .notify_interval(Duration::from_secs(5), || oracle::Sync) - .map_err(|e| anyhow::anyhow!(e))? - .spawn_with_handle(), + .map_err(|e| anyhow::anyhow!(e))?, ); let (fan_out_actor, fan_out_actor_fut) = fan_out::Actor::new(&[&cfd_actor_addr, &monitor_addr]) .create(None) .run(); - tasks.push(fan_out_actor_fut.spawn_with_handle()); + tasks.add(fan_out_actor_fut); - tasks.push( - oracle_ctx - .run(oracle_constructor(cfds, Box::new(fan_out_actor))) - .spawn_with_handle(), - ); + tasks.add(oracle_ctx.run(oracle_constructor(cfds, Box::new(fan_out_actor)))); oracle_addr.send(oracle::Sync).await?; @@ -181,7 +186,7 @@ where order_feed_receiver, update_cfd_feed_receiver, inc_conn_addr, - tasks: Tasks(tasks), + tasks, }) } } @@ -239,7 +244,7 @@ where let (monitor_addr, mut monitor_ctx) = xtra::Context::new(None); let (oracle_addr, mut oracle_ctx) = xtra::Context::new(None); - let mut tasks = vec![]; + let mut tasks = Tasks::default(); let (connection_actor_addr, connection_actor_ctx) = xtra::Context::new(None); let (cfd_actor_addr, cfd_actor_fut) = taker_cfd::Actor::new( @@ -257,36 +262,29 @@ where .create(None) .run(); - tasks.push(cfd_actor_fut.spawn_with_handle()); - - tasks.push( - connection_actor_ctx - .run(connection::Actor::new( - maker_online_status_feed_sender, - Box::new(cfd_actor_addr.clone()), - identity_sk, - maker_heartbeat_interval, - )) - .spawn_with_handle(), - ); + tasks.add(cfd_actor_fut); - tasks.push( + tasks.add(connection_actor_ctx.run(connection::Actor::new( + maker_online_status_feed_sender, + Box::new(cfd_actor_addr.clone()), + identity_sk, + maker_heartbeat_interval, + ))); + + tasks.add( monitor_ctx .notify_interval(Duration::from_secs(20), || monitor::Sync) - .map_err(|e| anyhow::anyhow!(e))? - .spawn_with_handle(), + .map_err(|e| anyhow::anyhow!(e))?, ); - tasks.push( + tasks.add( monitor_ctx - .run(monitor_constructor(Box::new(cfd_actor_addr.clone()), cfds.clone()).await?) - .spawn_with_handle(), + .run(monitor_constructor(Box::new(cfd_actor_addr.clone()), cfds.clone()).await?), ); - tasks.push( + tasks.add( oracle_ctx .notify_interval(Duration::from_secs(5), || oracle::Sync) - .map_err(|e| anyhow::anyhow!(e))? - .spawn_with_handle(), + .map_err(|e| anyhow::anyhow!(e))?, ); let (fan_out_actor, fan_out_actor_fut) = @@ -294,13 +292,9 @@ where .create(None) .run(); - tasks.push(fan_out_actor_fut.spawn_with_handle()); + tasks.add(fan_out_actor_fut); - tasks.push( - oracle_ctx - .run(oracle_constructor(cfds, Box::new(fan_out_actor))) - .spawn_with_handle(), - ); + tasks.add(oracle_ctx.run(oracle_constructor(cfds, Box::new(fan_out_actor)))); tracing::debug!("Taker actor system ready"); @@ -311,7 +305,7 @@ where order_feed_receiver, update_cfd_feed_receiver, maker_online_status_feed_receiver, - tasks: Tasks(tasks), + tasks, }) } } diff --git a/daemon/src/maker.rs b/daemon/src/maker.rs index 930f64e..6a32231 100644 --- a/daemon/src/maker.rs +++ b/daemon/src/maker.rs @@ -9,7 +9,7 @@ use daemon::seed::Seed; use daemon::tokio_ext::FutureExt; use daemon::{ bitmex_price_feed, db, housekeeping, logger, maker_cfd, maker_inc_connections, monitor, oracle, - wallet, wallet_sync, MakerActorSystem, HEARTBEAT_INTERVAL, N_PAYOUTS, + wallet, wallet_sync, MakerActorSystem, Tasks, HEARTBEAT_INTERVAL, N_PAYOUTS, }; use sqlx::sqlite::SqliteConnectOptions; use sqlx::SqlitePool; @@ -223,8 +223,10 @@ async fn main() -> Result<()> { tracing::info!("Listening on {}", local_addr); + let mut tasks = Tasks::default(); + let (task, quote_updates) = bitmex_price_feed::new().await?; - let _task = task.spawn_with_handle(); + tasks.add(task); let db = SqlitePool::connect_with( SqliteConnectOptions::new() @@ -282,11 +284,9 @@ async fn main() -> Result<()> { Poll::Ready(Some(message)) }); - let _listener_task = incoming_connection_addr - .attach_stream(listener_stream) - .spawn_with_handle(); + tasks.add(incoming_connection_addr.attach_stream(listener_stream)); - let _wallet_sync_task = wallet_sync::new(wallet, wallet_feed_sender).spawn_with_handle(); + tasks.add(wallet_sync::new(wallet, wallet_feed_sender)); let cfd_action_channel = MessageChannel::::clone_channel(&cfd_actor_addr); let new_order_channel = MessageChannel::::clone_channel(&cfd_actor_addr); diff --git a/daemon/src/maker_inc_connections.rs b/daemon/src/maker_inc_connections.rs index dec196f..9b69fe9 100644 --- a/daemon/src/maker_inc_connections.rs +++ b/daemon/src/maker_inc_connections.rs @@ -3,9 +3,8 @@ use crate::model::cfd::{Order, OrderId}; use crate::model::{BitMexPriceEventId, TakerId}; use crate::noise::TransportStateExt; use crate::tokio_ext::FutureExt; -use crate::{forward_only_ok, maker_cfd, noise, send_to_socket, wire}; +use crate::{forward_only_ok, maker_cfd, noise, send_to_socket, wire, Tasks}; use anyhow::Result; -use futures::future::RemoteHandle; use futures::{StreamExt, TryStreamExt}; use std::collections::HashMap; use std::io; @@ -72,7 +71,7 @@ pub struct Actor { taker_msg_channel: Box>, noise_priv_key: x25519_dalek::StaticSecret, heartbeat_interval: Duration, - tasks: Vec>, + tasks: Tasks, } impl Actor { @@ -88,7 +87,7 @@ impl Actor { taker_msg_channel: taker_msg_channel.clone_channel(), noise_priv_key, heartbeat_interval, - tasks: Vec::new(), + tasks: Tasks::default(), } } @@ -136,29 +135,26 @@ impl Actor { forward_only_ok::Actor::new(self.taker_msg_channel.clone_channel()) .create(None) .run(); - self.tasks.push(forward_to_cfd_fut.spawn_with_handle()); + self.tasks.add(forward_to_cfd_fut); // only allow outgoing messages while we are successfully reading incoming ones let heartbeat_interval = self.heartbeat_interval; - self.tasks.push( - async move { - let mut actor = send_to_socket::Actor::new(write, transport_state.clone()); + self.tasks.add(async move { + let mut actor = send_to_socket::Actor::new(write, transport_state.clone()); - let _heartbeat_handle = out_msg_actor_context - .notify_interval(heartbeat_interval, || wire::MakerToTaker::Heartbeat) - .expect("actor not to shutdown") - .spawn_with_handle(); + let _heartbeat_handle = out_msg_actor_context + .notify_interval(heartbeat_interval, || wire::MakerToTaker::Heartbeat) + .expect("actor not to shutdown") + .spawn_with_handle(); - out_msg_actor_context - .handle_while(&mut actor, forward_to_cfd.attach_stream(read)) - .await; + out_msg_actor_context + .handle_while(&mut actor, forward_to_cfd.attach_stream(read)) + .await; - tracing::error!("Closing connection to taker {}", taker_id); + tracing::error!("Closing connection to taker {}", taker_id); - actor.shutdown().await; - } - .spawn_with_handle(), - ); + actor.shutdown().await; + }); self.write_connections .insert(taker_id, out_msg_actor_address); diff --git a/daemon/src/taker.rs b/daemon/src/taker.rs index c2aa29f..dabfd9a 100644 --- a/daemon/src/taker.rs +++ b/daemon/src/taker.rs @@ -9,7 +9,7 @@ use daemon::seed::Seed; use daemon::tokio_ext::FutureExt; use daemon::{ bitmex_price_feed, connection, db, housekeeping, logger, monitor, oracle, taker_cfd, wallet, - wallet_sync, TakerActorSystem, HEARTBEAT_INTERVAL, N_PAYOUTS, + wallet_sync, TakerActorSystem, Tasks, HEARTBEAT_INTERVAL, N_PAYOUTS, }; use sqlx::sqlite::SqliteConnectOptions; use sqlx::SqlitePool; @@ -207,8 +207,10 @@ async fn main() -> Result<()> { let (wallet_feed_sender, wallet_feed_receiver) = watch::channel::(wallet_info); + let mut tasks = Tasks::default(); + let (task, quote_updates) = bitmex_price_feed::new().await?; - let _task = task.spawn_with_handle(); + tasks.add(task); let figment = rocket::Config::figment() .merge(("address", opts.http_address.ip())) @@ -258,7 +260,7 @@ async fn main() -> Result<()> { connect(connection_actor_addr, opts.maker_id, opts.maker).await?; - let _wallet_sync_task = wallet_sync::new(wallet, wallet_feed_sender).spawn_with_handle(); + tasks.add(wallet_sync::new(wallet, wallet_feed_sender)); let take_offer_channel = MessageChannel::::clone_channel(&cfd_actor_addr); let cfd_action_channel = MessageChannel::::clone_channel(&cfd_actor_addr); @@ -290,7 +292,7 @@ async fn main() -> Result<()> { let shutdown_handle = rocket.shutdown(); // shutdown the rocket server maker if goes offline - let _rocket_shutdown_task = (async move { + tasks.add(async move { loop { maker_online_status_feed_receiver.changed().await.unwrap(); if maker_online_status_feed_receiver.borrow().clone() == ConnectionStatus::Offline { @@ -299,7 +301,7 @@ async fn main() -> Result<()> { return; } } - }).spawn_with_handle(); + }); rocket.launch().await?; db.close().await; diff --git a/daemon/src/tokio_ext.rs b/daemon/src/tokio_ext.rs index c004af6..7357dc3 100644 --- a/daemon/src/tokio_ext.rs +++ b/daemon/src/tokio_ext.rs @@ -1,5 +1,6 @@ use futures::future::RemoteHandle; use futures::FutureExt as _; +use std::any::{Any, TypeId}; use std::fmt; use std::future::Future; use std::time::Duration; @@ -26,7 +27,7 @@ pub trait FutureExt: Future + Sized { /// The task will be stopped when the handle gets dropped. fn spawn_with_handle(self) -> RemoteHandle where - Self: Future + Send + 'static; + Self: Future + Send + Any + 'static; } impl FutureExt for F @@ -39,8 +40,13 @@ where fn spawn_with_handle(self) -> RemoteHandle<()> where - Self: Future + Send + 'static, + Self: Future + Send + Any + 'static, { + debug_assert!( + TypeId::of::>() != self.type_id(), + "RemoteHandle<()> is a handle to already spawned task", + ); + let (future, handle) = self.remote_handle(); // we want to disallow calls to tokio::spawn outside FutureExt #[allow(clippy::disallowed_method)] @@ -48,3 +54,37 @@ where handle } } + +#[cfg(test)] +mod tests { + use std::panic; + + use tokio::time::sleep; + + use super::*; + + #[tokio::test] + async fn spawning_a_regular_future_does_not_panic() { + let result = panic::catch_unwind(|| { + let _handle = sleep(Duration::from_secs(2)).spawn_with_handle(); + }); + assert!(result.is_ok()); + } + + #[tokio::test] + async fn panics_when_called_spawn_with_handle_on_remote_handle() { + let result = panic::catch_unwind(|| { + let handle = sleep(Duration::from_secs(2)).spawn_with_handle(); + let _handle_to_a_handle = handle.spawn_with_handle(); + }); + + if cfg!(debug_assertions) { + assert!( + result.is_err(), + "Spawning a remote handle into a separate task should panic_in_debug_mode" + ); + } else { + assert!(result.is_ok(), "Do not panic in release mode"); + } + } +} diff --git a/daemon/tests/harness/mod.rs b/daemon/tests/harness/mod.rs index e7dc818..42fe437 100644 --- a/daemon/tests/harness/mod.rs +++ b/daemon/tests/harness/mod.rs @@ -7,9 +7,7 @@ use daemon::maker_cfd::CfdAction; use daemon::model::cfd::{Cfd, Order, Origin}; use daemon::model::{Price, Usd}; use daemon::seed::Seed; -use daemon::tokio_ext::FutureExt; -use daemon::{db, maker_cfd, maker_inc_connections, taker_cfd, MakerActorSystem}; -use futures::future::RemoteHandle; +use daemon::{db, maker_cfd, maker_inc_connections, taker_cfd, MakerActorSystem, Tasks}; use rust_decimal_macros::dec; use sqlx::SqlitePool; use std::net::SocketAddr; @@ -49,7 +47,7 @@ pub struct Maker { pub mocks: mocks::Mocks, pub listen_addr: SocketAddr, pub identity_pk: x25519_dalek::PublicKey, - _tasks: Vec>, + _tasks: Tasks, } impl Maker { @@ -67,10 +65,10 @@ impl Maker { let mut mocks = mocks::Mocks::default(); let (oracle, monitor, wallet) = mocks::create_actors(&mocks); - let mut tasks = vec![]; + let mut tasks = Tasks::default(); let (wallet_addr, wallet_fut) = wallet.create(None).run(); - tasks.push(wallet_fut.spawn_with_handle()); + tasks.add(wallet_fut); let settlement_time_interval_hours = time::Duration::hours(24); @@ -114,13 +112,7 @@ impl Maker { Poll::Ready(Some(message)) }); - tasks.push( - maker - .inc_conn_addr - .clone() - .attach_stream(listener_stream) - .spawn_with_handle(), - ); + tasks.add(maker.inc_conn_addr.clone().attach_stream(listener_stream)); Self { system: maker, @@ -165,7 +157,7 @@ impl Maker { pub struct Taker { pub system: daemon::TakerActorSystem, pub mocks: mocks::Mocks, - _tasks: Vec>, + _tasks: Tasks, } impl Taker { @@ -195,10 +187,10 @@ impl Taker { let mut mocks = mocks::Mocks::default(); let (oracle, monitor, wallet) = mocks::create_actors(&mocks); - let mut tasks = vec![]; + let mut tasks = Tasks::default(); let (wallet_addr, wallet_fut) = wallet.create(None).run(); - tasks.push(wallet_fut.spawn_with_handle()); + tasks.add(wallet_fut); // system startup sends sync messages, mock them mocks.mock_sync_handlers().await; From beb10e84fdb50f919848816da42a54855dce836c Mon Sep 17 00:00:00 2001 From: Mariusz Klochowicz Date: Wed, 17 Nov 2021 14:06:19 +1030 Subject: [PATCH 4/4] Don't use tokio feature from xtra Prefer using our FutureExt to enforce keeping RemoteHandles to prevent zombie tasks. --- Cargo.lock | 1 - daemon/Cargo.toml | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d353f09..114b3cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3826,7 +3826,6 @@ dependencies = [ "futures-timer", "futures-util", "pollster", - "tokio", ] [[package]] diff --git a/daemon/Cargo.toml b/daemon/Cargo.toml index 1cb748d..f3c21c8 100644 --- a/daemon/Cargo.toml +++ b/daemon/Cargo.toml @@ -45,7 +45,7 @@ tracing = { version = "0.1" } tracing-subscriber = { version = "0.2", default-features = false, features = ["fmt", "ansi", "env-filter", "chrono", "tracing-log", "json"] } uuid = { version = "0.8", features = ["serde", "v4"] } x25519-dalek = { version = "1.1" } -xtra = { version = "0.6", features = ["with-tokio-1"] } +xtra = { version = "0.6" } xtra_productivity = { version = "0.1.0" } [[bin]]