From 143cc09e65258114586e5ac0fc35177c6ee2b909 Mon Sep 17 00:00:00 2001 From: Mariusz Klochowicz Date: Tue, 16 Nov 2021 17:32:46 +1030 Subject: [PATCH 1/3] Extend our futures trait for better ergonomics with remote handles Simplify dealing with spawning a task on a runtime that returns a remote a handle that can be used to stop it. --- daemon/src/connection.rs | 8 ++++---- daemon/src/tokio_ext.rs | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/daemon/src/connection.rs b/daemon/src/connection.rs index 5ce4448..40f0f95 100644 --- a/daemon/src/connection.rs +++ b/daemon/src/connection.rs @@ -1,7 +1,8 @@ +use crate::tokio_ext::FutureExt; use crate::{log_error, noise, send_to_socket, wire}; use anyhow::Result; use futures::future::RemoteHandle; -use futures::{FutureExt, StreamExt}; +use futures::StreamExt; use std::net::SocketAddr; use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime}; @@ -104,11 +105,10 @@ impl Actor { let this = ctx.address().expect("self to be alive"); tokio::spawn(this.attach_stream(read)); - let (pulse_future, pulse_remote_handle) = ctx + let pulse_remote_handle = ctx .notify_interval(self.timeout, || MeasurePulse) .expect("we just started") - .remote_handle(); - tokio::spawn(pulse_future); + .spawn_with_handle(); self.connected_state = Some(ConnectedState { last_heartbeat: SystemTime::now(), diff --git a/daemon/src/tokio_ext.rs b/daemon/src/tokio_ext.rs index 4a7ff52..f5783b2 100644 --- a/daemon/src/tokio_ext.rs +++ b/daemon/src/tokio_ext.rs @@ -1,3 +1,5 @@ +use futures::future::RemoteHandle; +use futures::FutureExt as _; use std::fmt; use std::future::Future; use std::time::Duration; @@ -17,6 +19,12 @@ where pub trait FutureExt: Future + Sized { fn timeout(self, duration: Duration) -> Timeout; + + /// Spawn the future on a task in the runtime and return a RemoteHandle to it. + /// The task will be stopped when the handle gets dropped. + fn spawn_with_handle(self) -> RemoteHandle + where + Self: Future + Send + 'static; } impl FutureExt for F @@ -26,4 +34,13 @@ where fn timeout(self, duration: Duration) -> Timeout { timeout(duration, self) } + + fn spawn_with_handle(self) -> RemoteHandle<()> + where + Self: Future + Send + 'static, + { + let (future, handle) = self.remote_handle(); + tokio::spawn(future); + handle + } } From 9c4e89325f43b64ce5573cbc4bd0d8d92c13b32d Mon Sep 17 00:00:00 2001 From: Mariusz Klochowicz Date: Tue, 16 Nov 2021 17:35:10 +1030 Subject: [PATCH 2/3] Tie tasks spawned inside actor systems to actor system lifetime When the actor system shuts down, async tasks spawned by it on the tokio runtime should be stopped too. --- daemon/src/lib.rs | 115 +++++++++++++++++++--------- daemon/src/maker.rs | 8 +- daemon/src/maker_inc_connections.rs | 39 ++++++---- daemon/src/taker.rs | 8 +- daemon/tests/harness/mod.rs | 13 ++-- 5 files changed, 120 insertions(+), 63 deletions(-) diff --git a/daemon/src/lib.rs b/daemon/src/lib.rs index c8f7014..c3bdb8c 100644 --- a/daemon/src/lib.rs +++ b/daemon/src/lib.rs @@ -3,8 +3,10 @@ use crate::db::load_all_cfds; use crate::maker_cfd::{FromTaker, NewTakerOnline}; use crate::model::cfd::{Cfd, Order, UpdateCfdProposals}; use crate::oracle::Attestation; +use crate::tokio_ext::FutureExt; use anyhow::Result; use connection::ConnectionStatus; +use futures::future::RemoteHandle; use maia::secp256k1_zkp::schnorrsig; use sqlx::SqlitePool; use std::collections::HashMap; @@ -12,7 +14,6 @@ use std::future::Future; use std::time::Duration; use tokio::sync::watch; use xtra::message_channel::{MessageChannel, StrongMessageChannel}; -use xtra::spawn::TokioGlobalSpawnExt; use xtra::{Actor, Address}; pub mod actors; @@ -50,12 +51,18 @@ const HEARTBEAT_INTERVAL: std::time::Duration = Duration::from_secs(5); pub const N_PAYOUTS: usize = 200; +/// Struct controlling the lifetime of the async tasks, +/// such as running actors and periodic notifications. +/// If it gets dropped, all tasks are cancelled. +pub struct Tasks(Vec>); + pub struct MakerActorSystem { pub cfd_actor_addr: Address>, pub cfd_feed_receiver: watch::Receiver>, pub order_feed_receiver: watch::Receiver>, pub update_cfd_feed_receiver: watch::Receiver, pub inc_conn_addr: Address, + pub tasks: Tasks, } impl MakerActorSystem @@ -104,7 +111,9 @@ where let (oracle_addr, mut oracle_ctx) = xtra::Context::new(None); let (inc_conn_addr, inc_conn_ctx) = xtra::Context::new(None); - let cfd_actor_addr = maker_cfd::Actor::new( + let mut tasks = vec![]; + + let (cfd_actor_addr, cfd_actor_fut) = maker_cfd::Actor::new( db, wallet_addr, settlement_time_interval_hours, @@ -118,33 +127,48 @@ where n_payouts, ) .create(None) - .spawn_global(); + .run(); + + tasks.push(cfd_actor_fut.spawn_with_handle()); - tokio::spawn(inc_conn_ctx.run(inc_conn_constructor( - Box::new(cfd_actor_addr.clone()), - Box::new(cfd_actor_addr.clone()), - ))); + tasks.push( + inc_conn_ctx + .run(inc_conn_constructor( + Box::new(cfd_actor_addr.clone()), + Box::new(cfd_actor_addr.clone()), + )) + .spawn_with_handle(), + ); - tokio::spawn( + tasks.push( monitor_ctx .notify_interval(Duration::from_secs(20), || monitor::Sync) - .map_err(|e| anyhow::anyhow!(e))?, + .map_err(|e| anyhow::anyhow!(e))? + .spawn_with_handle(), ); - tokio::spawn( + tasks.push( monitor_ctx - .run(monitor_constructor(Box::new(cfd_actor_addr.clone()), cfds.clone()).await?), + .run(monitor_constructor(Box::new(cfd_actor_addr.clone()), cfds.clone()).await?) + .spawn_with_handle(), ); - tokio::spawn( + tasks.push( oracle_ctx .notify_interval(Duration::from_secs(5), || oracle::Sync) - .map_err(|e| anyhow::anyhow!(e))?, + .map_err(|e| anyhow::anyhow!(e))? + .spawn_with_handle(), ); - let fan_out_actor = fan_out::Actor::new(&[&cfd_actor_addr, &monitor_addr]) - .create(None) - .spawn_global(); + let (fan_out_actor, fan_out_actor_fut) = + fan_out::Actor::new(&[&cfd_actor_addr, &monitor_addr]) + .create(None) + .run(); + tasks.push(fan_out_actor_fut.spawn_with_handle()); - tokio::spawn(oracle_ctx.run(oracle_constructor(cfds, Box::new(fan_out_actor)))); + tasks.push( + oracle_ctx + .run(oracle_constructor(cfds, Box::new(fan_out_actor))) + .spawn_with_handle(), + ); oracle_addr.do_send_async(oracle::Sync).await?; @@ -156,6 +180,7 @@ where order_feed_receiver, update_cfd_feed_receiver, inc_conn_addr, + tasks: Tasks(tasks), }) } } @@ -167,6 +192,7 @@ pub struct TakerActorSystem { pub order_feed_receiver: watch::Receiver>, pub update_cfd_feed_receiver: watch::Receiver, pub maker_online_status_feed_receiver: watch::Receiver, + pub tasks: Tasks, } impl TakerActorSystem @@ -211,8 +237,10 @@ where let (monitor_addr, mut monitor_ctx) = xtra::Context::new(None); let (oracle_addr, mut oracle_ctx) = xtra::Context::new(None); + let mut tasks = vec![]; + let (connection_actor_addr, connection_actor_ctx) = xtra::Context::new(None); - let cfd_actor_addr = taker_cfd::Actor::new( + let (cfd_actor_addr, cfd_actor_fut) = taker_cfd::Actor::new( db, wallet_addr, oracle_pk, @@ -225,36 +253,52 @@ where n_payouts, ) .create(None) - .spawn_global(); + .run(); + + tasks.push(cfd_actor_fut.spawn_with_handle()); - tokio::spawn(connection_actor_ctx.run(connection::Actor::new( - maker_online_status_feed_sender, - Box::new(cfd_actor_addr.clone()), - identity_sk, - HEARTBEAT_INTERVAL * 2, - ))); + tasks.push( + connection_actor_ctx + .run(connection::Actor::new( + maker_online_status_feed_sender, + Box::new(cfd_actor_addr.clone()), + identity_sk, + HEARTBEAT_INTERVAL * 2, + )) + .spawn_with_handle(), + ); - tokio::spawn( + tasks.push( monitor_ctx .notify_interval(Duration::from_secs(20), || monitor::Sync) - .map_err(|e| anyhow::anyhow!(e))?, + .map_err(|e| anyhow::anyhow!(e))? + .spawn_with_handle(), ); - tokio::spawn( + tasks.push( monitor_ctx - .run(monitor_constructor(Box::new(cfd_actor_addr.clone()), cfds.clone()).await?), + .run(monitor_constructor(Box::new(cfd_actor_addr.clone()), cfds.clone()).await?) + .spawn_with_handle(), ); - tokio::spawn( + tasks.push( oracle_ctx .notify_interval(Duration::from_secs(5), || oracle::Sync) - .map_err(|e| anyhow::anyhow!(e))?, + .map_err(|e| anyhow::anyhow!(e))? + .spawn_with_handle(), ); - let fan_out_actor = fan_out::Actor::new(&[&cfd_actor_addr, &monitor_addr]) - .create(None) - .spawn_global(); + let (fan_out_actor, fan_out_actor_fut) = + fan_out::Actor::new(&[&cfd_actor_addr, &monitor_addr]) + .create(None) + .run(); - tokio::spawn(oracle_ctx.run(oracle_constructor(cfds, Box::new(fan_out_actor)))); + tasks.push(fan_out_actor_fut.spawn_with_handle()); + + tasks.push( + oracle_ctx + .run(oracle_constructor(cfds, Box::new(fan_out_actor))) + .spawn_with_handle(), + ); tracing::debug!("Taker actor system ready"); @@ -265,6 +309,7 @@ where order_feed_receiver, update_cfd_feed_receiver, maker_online_status_feed_receiver, + tasks: Tasks(tasks), }) } } diff --git a/daemon/src/maker.rs b/daemon/src/maker.rs index c1185f7..f1fc87c 100644 --- a/daemon/src/maker.rs +++ b/daemon/src/maker.rs @@ -6,6 +6,7 @@ use clap::{Parser, Subcommand}; use daemon::auth::{self, MAKER_USERNAME}; use daemon::model::WalletInfo; use daemon::seed::Seed; +use daemon::tokio_ext::FutureExt; use daemon::{ bitmex_price_feed, db, housekeeping, logger, maker_cfd, maker_inc_connections, monitor, oracle, wallet, wallet_sync, MakerActorSystem, N_PAYOUTS, @@ -19,7 +20,6 @@ use std::task::Poll; use tokio::sync::watch; use tracing_subscriber::filter::LevelFilter; use xtra::prelude::*; -use xtra::spawn::TokioGlobalSpawnExt; use xtra::Actor; mod routes_maker; @@ -159,14 +159,15 @@ async fn main() -> Result<()> { let bitcoin_network = opts.network.bitcoin_network(); let ext_priv_key = seed.derive_extended_priv_key(bitcoin_network)?; - let wallet = wallet::Actor::new( + let (wallet, wallet_fut) = wallet::Actor::new( opts.network.electrum(), &data_dir.join("maker_wallet.sqlite"), ext_priv_key, ) .await? .create(None) - .spawn_global(); + .run(); + let _wallet_handle = wallet_fut.spawn_with_handle(); // do this before withdraw to ensure the wallet is synced let wallet_info = wallet.send(wallet::Sync).await??; @@ -250,6 +251,7 @@ async fn main() -> Result<()> { order_feed_receiver, update_cfd_feed_receiver, inc_conn_addr: incoming_connection_addr, + tasks: _tasks, } = MakerActorSystem::new( db.clone(), wallet.clone(), diff --git a/daemon/src/maker_inc_connections.rs b/daemon/src/maker_inc_connections.rs index 7cd5e0a..daafdbf 100644 --- a/daemon/src/maker_inc_connections.rs +++ b/daemon/src/maker_inc_connections.rs @@ -1,8 +1,10 @@ use crate::maker_cfd::{FromTaker, NewTakerOnline}; use crate::model::cfd::{Order, OrderId}; use crate::model::{BitMexPriceEventId, TakerId}; +use crate::tokio_ext::FutureExt; use crate::{forward_only_ok, maker_cfd, noise, send_to_socket, wire, HEARTBEAT_INTERVAL}; use anyhow::Result; +use futures::future::RemoteHandle; use futures::{StreamExt, TryStreamExt}; use std::collections::HashMap; use std::io; @@ -11,7 +13,6 @@ use std::sync::{Arc, Mutex}; use tokio::net::TcpStream; use tokio_util::codec::FramedRead; use xtra::prelude::*; -use xtra::spawn::TokioGlobalSpawnExt; use xtra::{Actor as _, KeepRunning}; use xtra_productivity::xtra_productivity; @@ -68,6 +69,7 @@ pub struct Actor { new_taker_channel: Box>, taker_msg_channel: Box>, noise_priv_key: x25519_dalek::StaticSecret, + tasks: Vec>, } impl Actor { @@ -81,6 +83,7 @@ impl Actor { new_taker_channel: new_taker_channel.clone_channel(), taker_msg_channel: taker_msg_channel.clone_channel(), noise_priv_key, + tasks: Vec::new(), } } @@ -125,28 +128,32 @@ impl Actor { let (out_msg_actor_address, mut out_msg_actor_context) = xtra::Context::new(None); - let forward_to_cfd = forward_only_ok::Actor::new(self.taker_msg_channel.clone_channel()) - .create(None) - .spawn_global(); + let (forward_to_cfd, forward_to_cfd_fut) = + forward_only_ok::Actor::new(self.taker_msg_channel.clone_channel()) + .create(None) + .run(); + self.tasks.push(forward_to_cfd_fut.spawn_with_handle()); // only allow outgoing messages while we are successfully reading incoming ones - tokio::spawn(async move { - let mut actor = send_to_socket::Actor::new(write, noise.clone()); + self.tasks.push( + async move { + let mut actor = send_to_socket::Actor::new(write, noise.clone()); - tokio::spawn( - out_msg_actor_context + let _heartbeat_handle = out_msg_actor_context .notify_interval(HEARTBEAT_INTERVAL, || wire::MakerToTaker::Heartbeat) - .expect("actor not to shutdown"), - ); + .expect("actor not to shutdown") + .spawn_with_handle(); - out_msg_actor_context - .handle_while(&mut actor, forward_to_cfd.attach_stream(read)) - .await; + out_msg_actor_context + .handle_while(&mut actor, forward_to_cfd.attach_stream(read)) + .await; - tracing::error!("Closing connection to taker {}", taker_id); + tracing::error!("Closing connection to taker {}", taker_id); - actor.shutdown().await; - }); + actor.shutdown().await; + } + .spawn_with_handle(), + ); self.write_connections .insert(taker_id, out_msg_actor_address); diff --git a/daemon/src/taker.rs b/daemon/src/taker.rs index 34aee79..e9b7ec4 100644 --- a/daemon/src/taker.rs +++ b/daemon/src/taker.rs @@ -6,6 +6,7 @@ use clap::{Parser, Subcommand}; use daemon::connection::ConnectionStatus; use daemon::model::WalletInfo; use daemon::seed::Seed; +use daemon::tokio_ext::FutureExt; use daemon::{ bitmex_price_feed, connection, db, housekeeping, logger, monitor, oracle, taker_cfd, wallet, wallet_sync, TakerActorSystem, N_PAYOUTS, @@ -20,7 +21,6 @@ use tokio::sync::watch; use tokio::time::sleep; use tracing_subscriber::filter::LevelFilter; use xtra::prelude::MessageChannel; -use xtra::spawn::TokioGlobalSpawnExt; use xtra::Actor; mod routes_taker; @@ -168,14 +168,15 @@ async fn main() -> Result<()> { let ext_priv_key = seed.derive_extended_priv_key(bitcoin_network)?; let (_, identity_sk) = seed.derive_identity(); - let wallet = wallet::Actor::new( + let (wallet, wallet_fut) = wallet::Actor::new( opts.network.electrum(), &data_dir.join("taker_wallet.sqlite"), ext_priv_key, ) .await? .create(None) - .spawn_global(); + .run(); + let _wallet_handle = wallet_fut.spawn_with_handle(); // do this before withdraw to ensure the wallet is synced let wallet_info = wallet.send(wallet::Sync).await??; @@ -237,6 +238,7 @@ async fn main() -> Result<()> { order_feed_receiver, update_cfd_feed_receiver, mut maker_online_status_feed_receiver, + tasks: _tasks, } = TakerActorSystem::new( db.clone(), wallet.clone(), diff --git a/daemon/tests/harness/mod.rs b/daemon/tests/harness/mod.rs index 7c180d2..4e3bb29 100644 --- a/daemon/tests/harness/mod.rs +++ b/daemon/tests/harness/mod.rs @@ -2,7 +2,7 @@ use crate::harness::mocks::monitor::MonitorActor; use crate::harness::mocks::oracle::OracleActor; use crate::harness::mocks::wallet::WalletActor; use crate::schnorrsig; -use daemon::connection::Connect; +use daemon::connection::{Connect, ConnectionStatus}; use daemon::maker_cfd::CfdAction; use daemon::model::cfd::{Cfd, Order, Origin}; use daemon::model::{Price, Usd}; @@ -18,7 +18,6 @@ use tracing::subscriber::DefaultGuard; use tracing_subscriber::filter::LevelFilter; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::EnvFilter; -use xtra::spawn::TokioGlobalSpawnExt; use xtra::Actor; pub mod bdk; @@ -26,6 +25,8 @@ pub mod flow; pub mod maia; pub mod mocks; +const N_PAYOUTS_FOR_TEST: usize = 5; + pub async fn start_both() -> (Maker, Taker) { let oracle_pk: schnorrsig::PublicKey = schnorrsig::PublicKey::from_str( "ddd4636845a90185991826be5a494cde9f4a6947b1727217afedc6292fa4caf7", @@ -37,8 +38,6 @@ pub async fn start_both() -> (Maker, Taker) { (maker, taker) } -const N_PAYOUTS_FOR_TEST: usize = 5; - /// Maker Test Setup pub struct Maker { pub system: @@ -64,7 +63,8 @@ impl Maker { let (oracle, monitor, wallet) = mocks::create_actors(&mocks); mocks.mock_common_empty_handlers().await; - let wallet_addr = wallet.create(None).spawn_global(); + let (wallet_addr, wallet_fut) = wallet.create(None).run(); + tokio::spawn(wallet_fut); let settlement_time_interval_hours = time::Duration::hours(24); @@ -167,7 +167,8 @@ impl Taker { let (oracle, monitor, wallet) = mocks::create_actors(&mocks); mocks.mock_common_empty_handlers().await; - let wallet_addr = wallet.create(None).spawn_global(); + let (wallet_addr, wallet_fut) = wallet.create(None).run(); + tokio::spawn(wallet_fut); let taker = daemon::TakerActorSystem::new( db, From e8d2c8618299038fe32ae1d2d48e2b0a37f3ebb9 Mon Sep 17 00:00:00 2001 From: Mariusz Klochowicz Date: Tue, 16 Nov 2021 17:40:44 +1030 Subject: [PATCH 3/3] Add a test for triggering and noticing a maker shutdown --- daemon/tests/happy_path.rs | 26 +++++++++++++++++++++++++- daemon/tests/harness/mod.rs | 4 ++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/daemon/tests/happy_path.rs b/daemon/tests/happy_path.rs index d7a70b4..1cb148c 100644 --- a/daemon/tests/happy_path.rs +++ b/daemon/tests/happy_path.rs @@ -1,9 +1,12 @@ -use crate::harness::flow::{is_next_none, next_cfd, next_order, next_some}; +use crate::harness::flow::{is_next_none, next, next_cfd, next_order, next_some}; use crate::harness::{assert_is_same_order, dummy_new_order, init_tracing, start_both}; +use daemon::connection::ConnectionStatus; use daemon::model::cfd::CfdState; use daemon::model::Usd; use maia::secp256k1_zkp::schnorrsig; use rust_decimal_macros::dec; +use std::time::Duration; +use tokio::time::sleep; mod harness; #[tokio::test] @@ -100,3 +103,24 @@ async fn taker_takes_order_and_maker_accepts_and_contract_setup() { assert!(matches!(taker_cfd.state, CfdState::PendingOpen { .. })); assert!(matches!(maker_cfd.state, CfdState::PendingOpen { .. })); } + +#[tokio::test] +async fn taker_notices_lack_of_maker() { + let _guard = init_tracing(); + + let (maker, mut taker) = start_both().await; + assert_eq!( + ConnectionStatus::Online, + next(taker.maker_status_feed()).await.unwrap() + ); + + std::mem::drop(maker); + + // TODO: shorten this sleep by specifying different heartbeat interval for tests + sleep(Duration::from_secs(12)).await; + + assert_eq!( + ConnectionStatus::Offline, + next(taker.maker_status_feed()).await.unwrap(), + ); +} diff --git a/daemon/tests/harness/mod.rs b/daemon/tests/harness/mod.rs index 4e3bb29..687c36f 100644 --- a/daemon/tests/harness/mod.rs +++ b/daemon/tests/harness/mod.rs @@ -152,6 +152,10 @@ impl Taker { &mut self.system.order_feed_receiver } + pub fn maker_status_feed(&mut self) -> &mut watch::Receiver { + &mut self.system.maker_online_status_feed_receiver + } + pub async fn start( oracle_pk: schnorrsig::PublicKey, maker_address: SocketAddr,