From 855181f06cd445c2c6495be4d780c0d28b9a7fcc Mon Sep 17 00:00:00 2001 From: Mariusz Klochowicz Date: Wed, 17 Nov 2021 16:15:02 +1030 Subject: [PATCH] Expand Tasks API for ergonomics The fact that we're storing a Vec) internally is an implementation detail. Calling `spawn_with_handle` is done internally now, guarded by a `debug_assert!` macro discouraging from doing it twice. --- daemon/src/connection.rs | 19 ++--- daemon/src/lib.rs | 104 +++++++++++++--------------- daemon/src/maker.rs | 12 ++-- daemon/src/maker_inc_connections.rs | 36 +++++----- daemon/src/taker.rs | 12 ++-- daemon/src/tokio_ext.rs | 44 +++++++++++- daemon/tests/harness/mod.rs | 24 +++---- 7 files changed, 135 insertions(+), 116 deletions(-) diff --git a/daemon/src/connection.rs b/daemon/src/connection.rs index b16f045..7d0ac5b 100644 --- a/daemon/src/connection.rs +++ b/daemon/src/connection.rs @@ -1,7 +1,5 @@ -use crate::tokio_ext::FutureExt; -use crate::{log_error, noise, send_to_socket, wire}; +use crate::{log_error, noise, send_to_socket, wire, Tasks}; use anyhow::Result; -use futures::future::RemoteHandle; use futures::StreamExt; use std::net::SocketAddr; use std::sync::{Arc, Mutex}; @@ -15,7 +13,7 @@ use xtra_productivity::xtra_productivity; struct ConnectedState { last_heartbeat: SystemTime, - _tasks: Vec>, + _tasks: Tasks, } pub struct Actor { @@ -97,21 +95,18 @@ impl Actor { let send_to_socket = send_to_socket::Actor::new(write, noise.clone()); - let mut tasks = vec![self - .send_to_maker_ctx - .attach(send_to_socket) - .spawn_with_handle()]; + let mut tasks = Tasks::default(); + tasks.add(self.send_to_maker_ctx.attach(send_to_socket)); let read = FramedRead::new(read, wire::EncryptedJsonCodec::new(noise)) .map(move |item| MakerStreamMessage { item }); let this = ctx.address().expect("self to be alive"); - tasks.push(this.attach_stream(read).spawn_with_handle()); + tasks.add(this.attach_stream(read)); - tasks.push( + tasks.add( ctx.notify_interval(self.timeout, || MeasurePulse) - .expect("we just started") - .spawn_with_handle(), + .expect("we just started"), ); self.connected_state = Some(ConnectedState { diff --git a/daemon/src/lib.rs b/daemon/src/lib.rs index 3f31aa0..164251d 100644 --- a/daemon/src/lib.rs +++ b/daemon/src/lib.rs @@ -57,6 +57,22 @@ pub const N_PAYOUTS: usize = 200; /// If it gets dropped, all tasks are cancelled. pub struct Tasks(Vec>); +impl Tasks { + /// Spawn the task on the runtime and remembers the handle + /// NOTE: Do *not* call spawn_with_handle() before calling `add`, + /// such calls will trigger panic in debug mode. + pub fn add(&mut self, f: impl Future + Send + 'static) { + let handle = f.spawn_with_handle(); + self.0.push(handle); + } +} + +impl Default for Tasks { + fn default() -> Self { + Tasks(vec![]) + } +} + pub struct MakerActorSystem { pub cfd_actor_addr: Address>, pub cfd_feed_receiver: watch::Receiver>, @@ -112,7 +128,7 @@ where let (oracle_addr, mut oracle_ctx) = xtra::Context::new(None); let (inc_conn_addr, inc_conn_ctx) = xtra::Context::new(None); - let mut tasks = vec![]; + let mut tasks = Tasks::default(); let (cfd_actor_addr, cfd_actor_fut) = maker_cfd::Actor::new( db, @@ -130,46 +146,35 @@ where .create(None) .run(); - tasks.push(cfd_actor_fut.spawn_with_handle()); + tasks.add(cfd_actor_fut); - tasks.push( - inc_conn_ctx - .run(inc_conn_constructor( - Box::new(cfd_actor_addr.clone()), - Box::new(cfd_actor_addr.clone()), - )) - .spawn_with_handle(), - ); + tasks.add(inc_conn_ctx.run(inc_conn_constructor( + Box::new(cfd_actor_addr.clone()), + Box::new(cfd_actor_addr.clone()), + ))); - tasks.push( + tasks.add( monitor_ctx .notify_interval(Duration::from_secs(20), || monitor::Sync) - .map_err(|e| anyhow::anyhow!(e))? - .spawn_with_handle(), + .map_err(|e| anyhow::anyhow!(e))?, ); - tasks.push( + tasks.add( monitor_ctx - .run(monitor_constructor(Box::new(cfd_actor_addr.clone()), cfds.clone()).await?) - .spawn_with_handle(), + .run(monitor_constructor(Box::new(cfd_actor_addr.clone()), cfds.clone()).await?), ); - tasks.push( + tasks.add( oracle_ctx .notify_interval(Duration::from_secs(5), || oracle::Sync) - .map_err(|e| anyhow::anyhow!(e))? - .spawn_with_handle(), + .map_err(|e| anyhow::anyhow!(e))?, ); let (fan_out_actor, fan_out_actor_fut) = fan_out::Actor::new(&[&cfd_actor_addr, &monitor_addr]) .create(None) .run(); - tasks.push(fan_out_actor_fut.spawn_with_handle()); + tasks.add(fan_out_actor_fut); - tasks.push( - oracle_ctx - .run(oracle_constructor(cfds, Box::new(fan_out_actor))) - .spawn_with_handle(), - ); + tasks.add(oracle_ctx.run(oracle_constructor(cfds, Box::new(fan_out_actor)))); oracle_addr.send(oracle::Sync).await?; @@ -181,7 +186,7 @@ where order_feed_receiver, update_cfd_feed_receiver, inc_conn_addr, - tasks: Tasks(tasks), + tasks, }) } } @@ -239,7 +244,7 @@ where let (monitor_addr, mut monitor_ctx) = xtra::Context::new(None); let (oracle_addr, mut oracle_ctx) = xtra::Context::new(None); - let mut tasks = vec![]; + let mut tasks = Tasks::default(); let (connection_actor_addr, connection_actor_ctx) = xtra::Context::new(None); let (cfd_actor_addr, cfd_actor_fut) = taker_cfd::Actor::new( @@ -257,36 +262,29 @@ where .create(None) .run(); - tasks.push(cfd_actor_fut.spawn_with_handle()); - - tasks.push( - connection_actor_ctx - .run(connection::Actor::new( - maker_online_status_feed_sender, - Box::new(cfd_actor_addr.clone()), - identity_sk, - maker_heartbeat_interval, - )) - .spawn_with_handle(), - ); + tasks.add(cfd_actor_fut); - tasks.push( + tasks.add(connection_actor_ctx.run(connection::Actor::new( + maker_online_status_feed_sender, + Box::new(cfd_actor_addr.clone()), + identity_sk, + maker_heartbeat_interval, + ))); + + tasks.add( monitor_ctx .notify_interval(Duration::from_secs(20), || monitor::Sync) - .map_err(|e| anyhow::anyhow!(e))? - .spawn_with_handle(), + .map_err(|e| anyhow::anyhow!(e))?, ); - tasks.push( + tasks.add( monitor_ctx - .run(monitor_constructor(Box::new(cfd_actor_addr.clone()), cfds.clone()).await?) - .spawn_with_handle(), + .run(monitor_constructor(Box::new(cfd_actor_addr.clone()), cfds.clone()).await?), ); - tasks.push( + tasks.add( oracle_ctx .notify_interval(Duration::from_secs(5), || oracle::Sync) - .map_err(|e| anyhow::anyhow!(e))? - .spawn_with_handle(), + .map_err(|e| anyhow::anyhow!(e))?, ); let (fan_out_actor, fan_out_actor_fut) = @@ -294,13 +292,9 @@ where .create(None) .run(); - tasks.push(fan_out_actor_fut.spawn_with_handle()); + tasks.add(fan_out_actor_fut); - tasks.push( - oracle_ctx - .run(oracle_constructor(cfds, Box::new(fan_out_actor))) - .spawn_with_handle(), - ); + tasks.add(oracle_ctx.run(oracle_constructor(cfds, Box::new(fan_out_actor)))); tracing::debug!("Taker actor system ready"); @@ -311,7 +305,7 @@ where order_feed_receiver, update_cfd_feed_receiver, maker_online_status_feed_receiver, - tasks: Tasks(tasks), + tasks, }) } } diff --git a/daemon/src/maker.rs b/daemon/src/maker.rs index 930f64e..6a32231 100644 --- a/daemon/src/maker.rs +++ b/daemon/src/maker.rs @@ -9,7 +9,7 @@ use daemon::seed::Seed; use daemon::tokio_ext::FutureExt; use daemon::{ bitmex_price_feed, db, housekeeping, logger, maker_cfd, maker_inc_connections, monitor, oracle, - wallet, wallet_sync, MakerActorSystem, HEARTBEAT_INTERVAL, N_PAYOUTS, + wallet, wallet_sync, MakerActorSystem, Tasks, HEARTBEAT_INTERVAL, N_PAYOUTS, }; use sqlx::sqlite::SqliteConnectOptions; use sqlx::SqlitePool; @@ -223,8 +223,10 @@ async fn main() -> Result<()> { tracing::info!("Listening on {}", local_addr); + let mut tasks = Tasks::default(); + let (task, quote_updates) = bitmex_price_feed::new().await?; - let _task = task.spawn_with_handle(); + tasks.add(task); let db = SqlitePool::connect_with( SqliteConnectOptions::new() @@ -282,11 +284,9 @@ async fn main() -> Result<()> { Poll::Ready(Some(message)) }); - let _listener_task = incoming_connection_addr - .attach_stream(listener_stream) - .spawn_with_handle(); + tasks.add(incoming_connection_addr.attach_stream(listener_stream)); - let _wallet_sync_task = wallet_sync::new(wallet, wallet_feed_sender).spawn_with_handle(); + tasks.add(wallet_sync::new(wallet, wallet_feed_sender)); let cfd_action_channel = MessageChannel::::clone_channel(&cfd_actor_addr); let new_order_channel = MessageChannel::::clone_channel(&cfd_actor_addr); diff --git a/daemon/src/maker_inc_connections.rs b/daemon/src/maker_inc_connections.rs index dec196f..9b69fe9 100644 --- a/daemon/src/maker_inc_connections.rs +++ b/daemon/src/maker_inc_connections.rs @@ -3,9 +3,8 @@ use crate::model::cfd::{Order, OrderId}; use crate::model::{BitMexPriceEventId, TakerId}; use crate::noise::TransportStateExt; use crate::tokio_ext::FutureExt; -use crate::{forward_only_ok, maker_cfd, noise, send_to_socket, wire}; +use crate::{forward_only_ok, maker_cfd, noise, send_to_socket, wire, Tasks}; use anyhow::Result; -use futures::future::RemoteHandle; use futures::{StreamExt, TryStreamExt}; use std::collections::HashMap; use std::io; @@ -72,7 +71,7 @@ pub struct Actor { taker_msg_channel: Box>, noise_priv_key: x25519_dalek::StaticSecret, heartbeat_interval: Duration, - tasks: Vec>, + tasks: Tasks, } impl Actor { @@ -88,7 +87,7 @@ impl Actor { taker_msg_channel: taker_msg_channel.clone_channel(), noise_priv_key, heartbeat_interval, - tasks: Vec::new(), + tasks: Tasks::default(), } } @@ -136,29 +135,26 @@ impl Actor { forward_only_ok::Actor::new(self.taker_msg_channel.clone_channel()) .create(None) .run(); - self.tasks.push(forward_to_cfd_fut.spawn_with_handle()); + self.tasks.add(forward_to_cfd_fut); // only allow outgoing messages while we are successfully reading incoming ones let heartbeat_interval = self.heartbeat_interval; - self.tasks.push( - async move { - let mut actor = send_to_socket::Actor::new(write, transport_state.clone()); + self.tasks.add(async move { + let mut actor = send_to_socket::Actor::new(write, transport_state.clone()); - let _heartbeat_handle = out_msg_actor_context - .notify_interval(heartbeat_interval, || wire::MakerToTaker::Heartbeat) - .expect("actor not to shutdown") - .spawn_with_handle(); + let _heartbeat_handle = out_msg_actor_context + .notify_interval(heartbeat_interval, || wire::MakerToTaker::Heartbeat) + .expect("actor not to shutdown") + .spawn_with_handle(); - out_msg_actor_context - .handle_while(&mut actor, forward_to_cfd.attach_stream(read)) - .await; + out_msg_actor_context + .handle_while(&mut actor, forward_to_cfd.attach_stream(read)) + .await; - tracing::error!("Closing connection to taker {}", taker_id); + tracing::error!("Closing connection to taker {}", taker_id); - actor.shutdown().await; - } - .spawn_with_handle(), - ); + actor.shutdown().await; + }); self.write_connections .insert(taker_id, out_msg_actor_address); diff --git a/daemon/src/taker.rs b/daemon/src/taker.rs index c2aa29f..dabfd9a 100644 --- a/daemon/src/taker.rs +++ b/daemon/src/taker.rs @@ -9,7 +9,7 @@ use daemon::seed::Seed; use daemon::tokio_ext::FutureExt; use daemon::{ bitmex_price_feed, connection, db, housekeeping, logger, monitor, oracle, taker_cfd, wallet, - wallet_sync, TakerActorSystem, HEARTBEAT_INTERVAL, N_PAYOUTS, + wallet_sync, TakerActorSystem, Tasks, HEARTBEAT_INTERVAL, N_PAYOUTS, }; use sqlx::sqlite::SqliteConnectOptions; use sqlx::SqlitePool; @@ -207,8 +207,10 @@ async fn main() -> Result<()> { let (wallet_feed_sender, wallet_feed_receiver) = watch::channel::(wallet_info); + let mut tasks = Tasks::default(); + let (task, quote_updates) = bitmex_price_feed::new().await?; - let _task = task.spawn_with_handle(); + tasks.add(task); let figment = rocket::Config::figment() .merge(("address", opts.http_address.ip())) @@ -258,7 +260,7 @@ async fn main() -> Result<()> { connect(connection_actor_addr, opts.maker_id, opts.maker).await?; - let _wallet_sync_task = wallet_sync::new(wallet, wallet_feed_sender).spawn_with_handle(); + tasks.add(wallet_sync::new(wallet, wallet_feed_sender)); let take_offer_channel = MessageChannel::::clone_channel(&cfd_actor_addr); let cfd_action_channel = MessageChannel::::clone_channel(&cfd_actor_addr); @@ -290,7 +292,7 @@ async fn main() -> Result<()> { let shutdown_handle = rocket.shutdown(); // shutdown the rocket server maker if goes offline - let _rocket_shutdown_task = (async move { + tasks.add(async move { loop { maker_online_status_feed_receiver.changed().await.unwrap(); if maker_online_status_feed_receiver.borrow().clone() == ConnectionStatus::Offline { @@ -299,7 +301,7 @@ async fn main() -> Result<()> { return; } } - }).spawn_with_handle(); + }); rocket.launch().await?; db.close().await; diff --git a/daemon/src/tokio_ext.rs b/daemon/src/tokio_ext.rs index c004af6..7357dc3 100644 --- a/daemon/src/tokio_ext.rs +++ b/daemon/src/tokio_ext.rs @@ -1,5 +1,6 @@ use futures::future::RemoteHandle; use futures::FutureExt as _; +use std::any::{Any, TypeId}; use std::fmt; use std::future::Future; use std::time::Duration; @@ -26,7 +27,7 @@ pub trait FutureExt: Future + Sized { /// The task will be stopped when the handle gets dropped. fn spawn_with_handle(self) -> RemoteHandle where - Self: Future + Send + 'static; + Self: Future + Send + Any + 'static; } impl FutureExt for F @@ -39,8 +40,13 @@ where fn spawn_with_handle(self) -> RemoteHandle<()> where - Self: Future + Send + 'static, + Self: Future + Send + Any + 'static, { + debug_assert!( + TypeId::of::>() != self.type_id(), + "RemoteHandle<()> is a handle to already spawned task", + ); + let (future, handle) = self.remote_handle(); // we want to disallow calls to tokio::spawn outside FutureExt #[allow(clippy::disallowed_method)] @@ -48,3 +54,37 @@ where handle } } + +#[cfg(test)] +mod tests { + use std::panic; + + use tokio::time::sleep; + + use super::*; + + #[tokio::test] + async fn spawning_a_regular_future_does_not_panic() { + let result = panic::catch_unwind(|| { + let _handle = sleep(Duration::from_secs(2)).spawn_with_handle(); + }); + assert!(result.is_ok()); + } + + #[tokio::test] + async fn panics_when_called_spawn_with_handle_on_remote_handle() { + let result = panic::catch_unwind(|| { + let handle = sleep(Duration::from_secs(2)).spawn_with_handle(); + let _handle_to_a_handle = handle.spawn_with_handle(); + }); + + if cfg!(debug_assertions) { + assert!( + result.is_err(), + "Spawning a remote handle into a separate task should panic_in_debug_mode" + ); + } else { + assert!(result.is_ok(), "Do not panic in release mode"); + } + } +} diff --git a/daemon/tests/harness/mod.rs b/daemon/tests/harness/mod.rs index e7dc818..42fe437 100644 --- a/daemon/tests/harness/mod.rs +++ b/daemon/tests/harness/mod.rs @@ -7,9 +7,7 @@ use daemon::maker_cfd::CfdAction; use daemon::model::cfd::{Cfd, Order, Origin}; use daemon::model::{Price, Usd}; use daemon::seed::Seed; -use daemon::tokio_ext::FutureExt; -use daemon::{db, maker_cfd, maker_inc_connections, taker_cfd, MakerActorSystem}; -use futures::future::RemoteHandle; +use daemon::{db, maker_cfd, maker_inc_connections, taker_cfd, MakerActorSystem, Tasks}; use rust_decimal_macros::dec; use sqlx::SqlitePool; use std::net::SocketAddr; @@ -49,7 +47,7 @@ pub struct Maker { pub mocks: mocks::Mocks, pub listen_addr: SocketAddr, pub identity_pk: x25519_dalek::PublicKey, - _tasks: Vec>, + _tasks: Tasks, } impl Maker { @@ -67,10 +65,10 @@ impl Maker { let mut mocks = mocks::Mocks::default(); let (oracle, monitor, wallet) = mocks::create_actors(&mocks); - let mut tasks = vec![]; + let mut tasks = Tasks::default(); let (wallet_addr, wallet_fut) = wallet.create(None).run(); - tasks.push(wallet_fut.spawn_with_handle()); + tasks.add(wallet_fut); let settlement_time_interval_hours = time::Duration::hours(24); @@ -114,13 +112,7 @@ impl Maker { Poll::Ready(Some(message)) }); - tasks.push( - maker - .inc_conn_addr - .clone() - .attach_stream(listener_stream) - .spawn_with_handle(), - ); + tasks.add(maker.inc_conn_addr.clone().attach_stream(listener_stream)); Self { system: maker, @@ -165,7 +157,7 @@ impl Maker { pub struct Taker { pub system: daemon::TakerActorSystem, pub mocks: mocks::Mocks, - _tasks: Vec>, + _tasks: Tasks, } impl Taker { @@ -195,10 +187,10 @@ impl Taker { let mut mocks = mocks::Mocks::default(); let (oracle, monitor, wallet) = mocks::create_actors(&mocks); - let mut tasks = vec![]; + let mut tasks = Tasks::default(); let (wallet_addr, wallet_fut) = wallet.create(None).run(); - tasks.push(wallet_fut.spawn_with_handle()); + tasks.add(wallet_fut); // system startup sends sync messages, mock them mocks.mock_sync_handlers().await;