Browse Source

Expand Tasks API for ergonomics

The fact that we're storing a Vec<RemoteHandle<()>) internally is an
implementation detail.
Calling `spawn_with_handle` is done internally now, guarded by a `debug_assert!`
macro discouraging from doing it twice.
rollover-test-2
Mariusz Klochowicz 3 years ago
parent
commit
855181f06c
No known key found for this signature in database GPG Key ID: 470C865699C8D4D
  1. 19
      daemon/src/connection.rs
  2. 104
      daemon/src/lib.rs
  3. 12
      daemon/src/maker.rs
  4. 36
      daemon/src/maker_inc_connections.rs
  5. 12
      daemon/src/taker.rs
  6. 44
      daemon/src/tokio_ext.rs
  7. 24
      daemon/tests/harness/mod.rs

19
daemon/src/connection.rs

@ -1,7 +1,5 @@
use crate::tokio_ext::FutureExt;
use crate::{log_error, noise, send_to_socket, wire};
use crate::{log_error, noise, send_to_socket, wire, Tasks};
use anyhow::Result;
use futures::future::RemoteHandle;
use futures::StreamExt;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
@ -15,7 +13,7 @@ use xtra_productivity::xtra_productivity;
struct ConnectedState {
last_heartbeat: SystemTime,
_tasks: Vec<RemoteHandle<()>>,
_tasks: Tasks,
}
pub struct Actor {
@ -97,21 +95,18 @@ impl Actor {
let send_to_socket = send_to_socket::Actor::new(write, noise.clone());
let mut tasks = vec![self
.send_to_maker_ctx
.attach(send_to_socket)
.spawn_with_handle()];
let mut tasks = Tasks::default();
tasks.add(self.send_to_maker_ctx.attach(send_to_socket));
let read = FramedRead::new(read, wire::EncryptedJsonCodec::new(noise))
.map(move |item| MakerStreamMessage { item });
let this = ctx.address().expect("self to be alive");
tasks.push(this.attach_stream(read).spawn_with_handle());
tasks.add(this.attach_stream(read));
tasks.push(
tasks.add(
ctx.notify_interval(self.timeout, || MeasurePulse)
.expect("we just started")
.spawn_with_handle(),
.expect("we just started"),
);
self.connected_state = Some(ConnectedState {

104
daemon/src/lib.rs

@ -57,6 +57,22 @@ pub const N_PAYOUTS: usize = 200;
/// If it gets dropped, all tasks are cancelled.
pub struct Tasks(Vec<RemoteHandle<()>>);
impl Tasks {
/// Spawn the task on the runtime and remembers the handle
/// NOTE: Do *not* call spawn_with_handle() before calling `add`,
/// such calls will trigger panic in debug mode.
pub fn add(&mut self, f: impl Future<Output = ()> + Send + 'static) {
let handle = f.spawn_with_handle();
self.0.push(handle);
}
}
impl Default for Tasks {
fn default() -> Self {
Tasks(vec![])
}
}
pub struct MakerActorSystem<O, M, T, W> {
pub cfd_actor_addr: Address<maker_cfd::Actor<O, M, T, W>>,
pub cfd_feed_receiver: watch::Receiver<Vec<Cfd>>,
@ -112,7 +128,7 @@ where
let (oracle_addr, mut oracle_ctx) = xtra::Context::new(None);
let (inc_conn_addr, inc_conn_ctx) = xtra::Context::new(None);
let mut tasks = vec![];
let mut tasks = Tasks::default();
let (cfd_actor_addr, cfd_actor_fut) = maker_cfd::Actor::new(
db,
@ -130,46 +146,35 @@ where
.create(None)
.run();
tasks.push(cfd_actor_fut.spawn_with_handle());
tasks.add(cfd_actor_fut);
tasks.push(
inc_conn_ctx
.run(inc_conn_constructor(
Box::new(cfd_actor_addr.clone()),
Box::new(cfd_actor_addr.clone()),
))
.spawn_with_handle(),
);
tasks.add(inc_conn_ctx.run(inc_conn_constructor(
Box::new(cfd_actor_addr.clone()),
Box::new(cfd_actor_addr.clone()),
)));
tasks.push(
tasks.add(
monitor_ctx
.notify_interval(Duration::from_secs(20), || monitor::Sync)
.map_err(|e| anyhow::anyhow!(e))?
.spawn_with_handle(),
.map_err(|e| anyhow::anyhow!(e))?,
);
tasks.push(
tasks.add(
monitor_ctx
.run(monitor_constructor(Box::new(cfd_actor_addr.clone()), cfds.clone()).await?)
.spawn_with_handle(),
.run(monitor_constructor(Box::new(cfd_actor_addr.clone()), cfds.clone()).await?),
);
tasks.push(
tasks.add(
oracle_ctx
.notify_interval(Duration::from_secs(5), || oracle::Sync)
.map_err(|e| anyhow::anyhow!(e))?
.spawn_with_handle(),
.map_err(|e| anyhow::anyhow!(e))?,
);
let (fan_out_actor, fan_out_actor_fut) =
fan_out::Actor::new(&[&cfd_actor_addr, &monitor_addr])
.create(None)
.run();
tasks.push(fan_out_actor_fut.spawn_with_handle());
tasks.add(fan_out_actor_fut);
tasks.push(
oracle_ctx
.run(oracle_constructor(cfds, Box::new(fan_out_actor)))
.spawn_with_handle(),
);
tasks.add(oracle_ctx.run(oracle_constructor(cfds, Box::new(fan_out_actor))));
oracle_addr.send(oracle::Sync).await?;
@ -181,7 +186,7 @@ where
order_feed_receiver,
update_cfd_feed_receiver,
inc_conn_addr,
tasks: Tasks(tasks),
tasks,
})
}
}
@ -239,7 +244,7 @@ where
let (monitor_addr, mut monitor_ctx) = xtra::Context::new(None);
let (oracle_addr, mut oracle_ctx) = xtra::Context::new(None);
let mut tasks = vec![];
let mut tasks = Tasks::default();
let (connection_actor_addr, connection_actor_ctx) = xtra::Context::new(None);
let (cfd_actor_addr, cfd_actor_fut) = taker_cfd::Actor::new(
@ -257,36 +262,29 @@ where
.create(None)
.run();
tasks.push(cfd_actor_fut.spawn_with_handle());
tasks.push(
connection_actor_ctx
.run(connection::Actor::new(
maker_online_status_feed_sender,
Box::new(cfd_actor_addr.clone()),
identity_sk,
maker_heartbeat_interval,
))
.spawn_with_handle(),
);
tasks.add(cfd_actor_fut);
tasks.push(
tasks.add(connection_actor_ctx.run(connection::Actor::new(
maker_online_status_feed_sender,
Box::new(cfd_actor_addr.clone()),
identity_sk,
maker_heartbeat_interval,
)));
tasks.add(
monitor_ctx
.notify_interval(Duration::from_secs(20), || monitor::Sync)
.map_err(|e| anyhow::anyhow!(e))?
.spawn_with_handle(),
.map_err(|e| anyhow::anyhow!(e))?,
);
tasks.push(
tasks.add(
monitor_ctx
.run(monitor_constructor(Box::new(cfd_actor_addr.clone()), cfds.clone()).await?)
.spawn_with_handle(),
.run(monitor_constructor(Box::new(cfd_actor_addr.clone()), cfds.clone()).await?),
);
tasks.push(
tasks.add(
oracle_ctx
.notify_interval(Duration::from_secs(5), || oracle::Sync)
.map_err(|e| anyhow::anyhow!(e))?
.spawn_with_handle(),
.map_err(|e| anyhow::anyhow!(e))?,
);
let (fan_out_actor, fan_out_actor_fut) =
@ -294,13 +292,9 @@ where
.create(None)
.run();
tasks.push(fan_out_actor_fut.spawn_with_handle());
tasks.add(fan_out_actor_fut);
tasks.push(
oracle_ctx
.run(oracle_constructor(cfds, Box::new(fan_out_actor)))
.spawn_with_handle(),
);
tasks.add(oracle_ctx.run(oracle_constructor(cfds, Box::new(fan_out_actor))));
tracing::debug!("Taker actor system ready");
@ -311,7 +305,7 @@ where
order_feed_receiver,
update_cfd_feed_receiver,
maker_online_status_feed_receiver,
tasks: Tasks(tasks),
tasks,
})
}
}

12
daemon/src/maker.rs

@ -9,7 +9,7 @@ use daemon::seed::Seed;
use daemon::tokio_ext::FutureExt;
use daemon::{
bitmex_price_feed, db, housekeeping, logger, maker_cfd, maker_inc_connections, monitor, oracle,
wallet, wallet_sync, MakerActorSystem, HEARTBEAT_INTERVAL, N_PAYOUTS,
wallet, wallet_sync, MakerActorSystem, Tasks, HEARTBEAT_INTERVAL, N_PAYOUTS,
};
use sqlx::sqlite::SqliteConnectOptions;
use sqlx::SqlitePool;
@ -223,8 +223,10 @@ async fn main() -> Result<()> {
tracing::info!("Listening on {}", local_addr);
let mut tasks = Tasks::default();
let (task, quote_updates) = bitmex_price_feed::new().await?;
let _task = task.spawn_with_handle();
tasks.add(task);
let db = SqlitePool::connect_with(
SqliteConnectOptions::new()
@ -282,11 +284,9 @@ async fn main() -> Result<()> {
Poll::Ready(Some(message))
});
let _listener_task = incoming_connection_addr
.attach_stream(listener_stream)
.spawn_with_handle();
tasks.add(incoming_connection_addr.attach_stream(listener_stream));
let _wallet_sync_task = wallet_sync::new(wallet, wallet_feed_sender).spawn_with_handle();
tasks.add(wallet_sync::new(wallet, wallet_feed_sender));
let cfd_action_channel = MessageChannel::<maker_cfd::CfdAction>::clone_channel(&cfd_actor_addr);
let new_order_channel = MessageChannel::<maker_cfd::NewOrder>::clone_channel(&cfd_actor_addr);

36
daemon/src/maker_inc_connections.rs

@ -3,9 +3,8 @@ use crate::model::cfd::{Order, OrderId};
use crate::model::{BitMexPriceEventId, TakerId};
use crate::noise::TransportStateExt;
use crate::tokio_ext::FutureExt;
use crate::{forward_only_ok, maker_cfd, noise, send_to_socket, wire};
use crate::{forward_only_ok, maker_cfd, noise, send_to_socket, wire, Tasks};
use anyhow::Result;
use futures::future::RemoteHandle;
use futures::{StreamExt, TryStreamExt};
use std::collections::HashMap;
use std::io;
@ -72,7 +71,7 @@ pub struct Actor {
taker_msg_channel: Box<dyn MessageChannel<FromTaker>>,
noise_priv_key: x25519_dalek::StaticSecret,
heartbeat_interval: Duration,
tasks: Vec<RemoteHandle<()>>,
tasks: Tasks,
}
impl Actor {
@ -88,7 +87,7 @@ impl Actor {
taker_msg_channel: taker_msg_channel.clone_channel(),
noise_priv_key,
heartbeat_interval,
tasks: Vec::new(),
tasks: Tasks::default(),
}
}
@ -136,29 +135,26 @@ impl Actor {
forward_only_ok::Actor::new(self.taker_msg_channel.clone_channel())
.create(None)
.run();
self.tasks.push(forward_to_cfd_fut.spawn_with_handle());
self.tasks.add(forward_to_cfd_fut);
// only allow outgoing messages while we are successfully reading incoming ones
let heartbeat_interval = self.heartbeat_interval;
self.tasks.push(
async move {
let mut actor = send_to_socket::Actor::new(write, transport_state.clone());
self.tasks.add(async move {
let mut actor = send_to_socket::Actor::new(write, transport_state.clone());
let _heartbeat_handle = out_msg_actor_context
.notify_interval(heartbeat_interval, || wire::MakerToTaker::Heartbeat)
.expect("actor not to shutdown")
.spawn_with_handle();
let _heartbeat_handle = out_msg_actor_context
.notify_interval(heartbeat_interval, || wire::MakerToTaker::Heartbeat)
.expect("actor not to shutdown")
.spawn_with_handle();
out_msg_actor_context
.handle_while(&mut actor, forward_to_cfd.attach_stream(read))
.await;
out_msg_actor_context
.handle_while(&mut actor, forward_to_cfd.attach_stream(read))
.await;
tracing::error!("Closing connection to taker {}", taker_id);
tracing::error!("Closing connection to taker {}", taker_id);
actor.shutdown().await;
}
.spawn_with_handle(),
);
actor.shutdown().await;
});
self.write_connections
.insert(taker_id, out_msg_actor_address);

12
daemon/src/taker.rs

@ -9,7 +9,7 @@ use daemon::seed::Seed;
use daemon::tokio_ext::FutureExt;
use daemon::{
bitmex_price_feed, connection, db, housekeeping, logger, monitor, oracle, taker_cfd, wallet,
wallet_sync, TakerActorSystem, HEARTBEAT_INTERVAL, N_PAYOUTS,
wallet_sync, TakerActorSystem, Tasks, HEARTBEAT_INTERVAL, N_PAYOUTS,
};
use sqlx::sqlite::SqliteConnectOptions;
use sqlx::SqlitePool;
@ -207,8 +207,10 @@ async fn main() -> Result<()> {
let (wallet_feed_sender, wallet_feed_receiver) = watch::channel::<WalletInfo>(wallet_info);
let mut tasks = Tasks::default();
let (task, quote_updates) = bitmex_price_feed::new().await?;
let _task = task.spawn_with_handle();
tasks.add(task);
let figment = rocket::Config::figment()
.merge(("address", opts.http_address.ip()))
@ -258,7 +260,7 @@ async fn main() -> Result<()> {
connect(connection_actor_addr, opts.maker_id, opts.maker).await?;
let _wallet_sync_task = wallet_sync::new(wallet, wallet_feed_sender).spawn_with_handle();
tasks.add(wallet_sync::new(wallet, wallet_feed_sender));
let take_offer_channel = MessageChannel::<taker_cfd::TakeOffer>::clone_channel(&cfd_actor_addr);
let cfd_action_channel = MessageChannel::<taker_cfd::CfdAction>::clone_channel(&cfd_actor_addr);
@ -290,7 +292,7 @@ async fn main() -> Result<()> {
let shutdown_handle = rocket.shutdown();
// shutdown the rocket server maker if goes offline
let _rocket_shutdown_task = (async move {
tasks.add(async move {
loop {
maker_online_status_feed_receiver.changed().await.unwrap();
if maker_online_status_feed_receiver.borrow().clone() == ConnectionStatus::Offline {
@ -299,7 +301,7 @@ async fn main() -> Result<()> {
return;
}
}
}).spawn_with_handle();
});
rocket.launch().await?;
db.close().await;

44
daemon/src/tokio_ext.rs

@ -1,5 +1,6 @@
use futures::future::RemoteHandle;
use futures::FutureExt as _;
use std::any::{Any, TypeId};
use std::fmt;
use std::future::Future;
use std::time::Duration;
@ -26,7 +27,7 @@ pub trait FutureExt: Future + Sized {
/// The task will be stopped when the handle gets dropped.
fn spawn_with_handle(self) -> RemoteHandle<Self::Output>
where
Self: Future<Output = ()> + Send + 'static;
Self: Future<Output = ()> + Send + Any + 'static;
}
impl<F> FutureExt for F
@ -39,8 +40,13 @@ where
fn spawn_with_handle(self) -> RemoteHandle<()>
where
Self: Future<Output = ()> + Send + 'static,
Self: Future<Output = ()> + Send + Any + 'static,
{
debug_assert!(
TypeId::of::<RemoteHandle<()>>() != self.type_id(),
"RemoteHandle<()> is a handle to already spawned task",
);
let (future, handle) = self.remote_handle();
// we want to disallow calls to tokio::spawn outside FutureExt
#[allow(clippy::disallowed_method)]
@ -48,3 +54,37 @@ where
handle
}
}
#[cfg(test)]
mod tests {
use std::panic;
use tokio::time::sleep;
use super::*;
#[tokio::test]
async fn spawning_a_regular_future_does_not_panic() {
let result = panic::catch_unwind(|| {
let _handle = sleep(Duration::from_secs(2)).spawn_with_handle();
});
assert!(result.is_ok());
}
#[tokio::test]
async fn panics_when_called_spawn_with_handle_on_remote_handle() {
let result = panic::catch_unwind(|| {
let handle = sleep(Duration::from_secs(2)).spawn_with_handle();
let _handle_to_a_handle = handle.spawn_with_handle();
});
if cfg!(debug_assertions) {
assert!(
result.is_err(),
"Spawning a remote handle into a separate task should panic_in_debug_mode"
);
} else {
assert!(result.is_ok(), "Do not panic in release mode");
}
}
}

24
daemon/tests/harness/mod.rs

@ -7,9 +7,7 @@ use daemon::maker_cfd::CfdAction;
use daemon::model::cfd::{Cfd, Order, Origin};
use daemon::model::{Price, Usd};
use daemon::seed::Seed;
use daemon::tokio_ext::FutureExt;
use daemon::{db, maker_cfd, maker_inc_connections, taker_cfd, MakerActorSystem};
use futures::future::RemoteHandle;
use daemon::{db, maker_cfd, maker_inc_connections, taker_cfd, MakerActorSystem, Tasks};
use rust_decimal_macros::dec;
use sqlx::SqlitePool;
use std::net::SocketAddr;
@ -49,7 +47,7 @@ pub struct Maker {
pub mocks: mocks::Mocks,
pub listen_addr: SocketAddr,
pub identity_pk: x25519_dalek::PublicKey,
_tasks: Vec<RemoteHandle<()>>,
_tasks: Tasks,
}
impl Maker {
@ -67,10 +65,10 @@ impl Maker {
let mut mocks = mocks::Mocks::default();
let (oracle, monitor, wallet) = mocks::create_actors(&mocks);
let mut tasks = vec![];
let mut tasks = Tasks::default();
let (wallet_addr, wallet_fut) = wallet.create(None).run();
tasks.push(wallet_fut.spawn_with_handle());
tasks.add(wallet_fut);
let settlement_time_interval_hours = time::Duration::hours(24);
@ -114,13 +112,7 @@ impl Maker {
Poll::Ready(Some(message))
});
tasks.push(
maker
.inc_conn_addr
.clone()
.attach_stream(listener_stream)
.spawn_with_handle(),
);
tasks.add(maker.inc_conn_addr.clone().attach_stream(listener_stream));
Self {
system: maker,
@ -165,7 +157,7 @@ impl Maker {
pub struct Taker {
pub system: daemon::TakerActorSystem<OracleActor, MonitorActor, WalletActor>,
pub mocks: mocks::Mocks,
_tasks: Vec<RemoteHandle<()>>,
_tasks: Tasks,
}
impl Taker {
@ -195,10 +187,10 @@ impl Taker {
let mut mocks = mocks::Mocks::default();
let (oracle, monitor, wallet) = mocks::create_actors(&mocks);
let mut tasks = vec![];
let mut tasks = Tasks::default();
let (wallet_addr, wallet_fut) = wallet.create(None).run();
tasks.push(wallet_fut.spawn_with_handle());
tasks.add(wallet_fut);
// system startup sends sync messages, mock them
mocks.mock_sync_handlers().await;

Loading…
Cancel
Save