diff --git a/Cargo.lock b/Cargo.lock index 4673bcd..ddb769a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -531,7 +531,6 @@ dependencies = [ "reqwest", "rocket", "rocket-basicauth", - "rocket_db_pools", "rust-embed", "rust_decimal", "rust_decimal_macros", @@ -2074,26 +2073,6 @@ dependencies = [ "unicode-xid", ] -[[package]] -name = "rocket_db_pools" -version = "0.1.0-rc" -source = "git+https://github.com/SergioBenitez/Rocket#8cae077ba1d54b92cdef3e171a730b819d5eeb8e" -dependencies = [ - "rocket", - "rocket_db_pools_codegen", - "sqlx", - "version_check", -] - -[[package]] -name = "rocket_db_pools_codegen" -version = "0.1.0-rc" -source = "git+https://github.com/SergioBenitez/Rocket#8cae077ba1d54b92cdef3e171a730b819d5eeb8e" -dependencies = [ - "devise", - "quote", -] - [[package]] name = "rocket_http" version = "0.5.0-rc.1" diff --git a/daemon/Cargo.toml b/daemon/Cargo.toml index fbd611b..2b854f8 100644 --- a/daemon/Cargo.toml +++ b/daemon/Cargo.toml @@ -23,7 +23,6 @@ rand = "0.6" reqwest = { version = "0.11", default-features = false, features = ["json", "rustls-tls-webpki-roots"] } rocket = { version = "0.5.0-rc.1", features = ["json"] } rocket-basicauth = { version = "2", default-features = false } -rocket_db_pools = { git = "https://github.com/SergioBenitez/Rocket", features = ["sqlx_sqlite"] } rust-embed = "6.2" rust_decimal = { version = "1.16", features = ["serde-float", "serde-arbitrary-precision"] } rust_decimal_macros = "1.16" @@ -32,7 +31,7 @@ serde_json = "1" serde_plain = "1" serde_with = { version = "1", features = ["macros"] } sha2 = "0.9" -sqlx = { version = "0.5", features = ["offline", "sqlite", "uuid"] } +sqlx = { version = "0.5", features = ["offline", "sqlite", "uuid", "runtime-tokio-rustls"] } thiserror = "1" time = { version = "0.3", features = ["serde"] } tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "net"] } diff --git a/daemon/src/db.rs b/daemon/src/db.rs index 27fad5a..1575081 100644 --- a/daemon/src/db.rs +++ b/daemon/src/db.rs @@ -1,7 +1,6 @@ use crate::model::cfd::{Cfd, CfdState, Order, OrderId}; use crate::model::{BitMexPriceEventId, Usd}; use anyhow::{Context, Result}; -use rocket_db_pools::sqlx; use rust_decimal::Decimal; use sqlx::pool::PoolConnection; use sqlx::{Sqlite, SqlitePool}; diff --git a/daemon/src/maker.rs b/daemon/src/maker.rs index 72d25e5..eb56ed6 100644 --- a/daemon/src/maker.rs +++ b/daemon/src/maker.rs @@ -16,7 +16,7 @@ use daemon::{ }; use futures::Future; use rocket::fairing::AdHoc; -use rocket_db_pools::Database; +use sqlx::sqlite::SqliteConnectOptions; use sqlx::SqlitePool; use std::collections::HashMap; use std::net::SocketAddr; @@ -32,10 +32,6 @@ use xtra::spawn::TokioGlobalSpawnExt; mod routes_maker; -#[derive(Database)] -#[database("maker")] -pub struct Db(sqlx::SqlitePool); - #[derive(Clap)] struct Opts { /// The port to listen on for p2p connections. @@ -156,7 +152,6 @@ async fn main() -> Result<()> { let (wallet_feed_sender, wallet_feed_receiver) = watch::channel::(wallet_info); let figment = rocket::Config::figment() - .merge(("databases.maker.url", data_dir.join("maker.sqlite"))) .merge(("address", opts.http_address.ip())) .merge(("port", opts.http_address.port())); @@ -173,31 +168,32 @@ async fn main() -> Result<()> { let (task, quote_updates) = bitmex_price_feed::new().await?; tokio::spawn(task); + let db = SqlitePool::connect_with( + SqliteConnectOptions::new() + .create_if_missing(true) + .filename(data_dir.join("maker.sqlite")), + ) + .await?; + rocket::custom(figment) .manage(wallet_feed_receiver) .manage(auth_password) .manage(quote_updates) .manage(bitcoin_network) - .attach(Db::init()) - .attach(AdHoc::try_on_ignite( - "SQL migrations", - |rocket| async move { - match Db::fetch(&rocket) { - Some(db) => match db::run_migrations(&**db).await { - Ok(_) => Ok(rocket), - Err(_) => Err(rocket), - }, - None => Err(rocket), + .attach(AdHoc::try_on_ignite("SQL migrations", { + let db = db.clone(); + + move |rocket| async move { + match db::run_migrations(&db).await { + Ok(_) => Ok(rocket), + Err(_) => Err(rocket), } - }, - )) + } + })) .attach(AdHoc::try_on_ignite("Create actors", { - move |rocket| async move { - let db = match Db::fetch(&rocket) { - Some(db) => (**db).clone(), - None => return Err(rocket), - }; + let db = db.clone(); + move |rocket| async move { let mut conn = db.acquire().await.unwrap(); housekeeping::transition_non_continue_cfds_to_setup_failed(&mut conn) @@ -208,27 +204,30 @@ async fn main() -> Result<()> { .unwrap(); let ActorSystem { - cfd_actor_addr, - cfd_feed_receiver, - order_feed_receiver, - update_cfd_feed_receiver, - } = ActorSystem::new( - db, - wallet.clone(), - oracle, - |cfds, channel| oracle::Actor::new(cfds, channel), - { - |channel, cfds| { - let electrum = opts.network.electrum().to_string(); - async move { - monitor::Actor::new(electrum, channel, cfds.clone()).await - } - } - }, - |channel0, channel1| maker_inc_connections::Actor::new(channel0, channel1), - listener, - ) - .await; + cfd_actor_addr, + cfd_feed_receiver, + order_feed_receiver, + update_cfd_feed_receiver, + } = + ActorSystem::new( + db, + wallet.clone(), + oracle, + |cfds, channel| oracle::Actor::new(cfds, channel), + { + |channel, cfds| { + let electrum = opts.network.electrum().to_string(); + async move { + monitor::Actor::new(electrum, channel, cfds.clone()).await + } + } + }, + |channel0, channel1| { + maker_inc_connections::Actor::new(channel0, channel1) + }, + listener, + ) + .await; tokio::spawn(wallet_sync::new(wallet, wallet_feed_sender)); @@ -263,6 +262,8 @@ async fn main() -> Result<()> { .launch() .await?; + db.close().await; + Ok(()) } diff --git a/daemon/src/taker.rs b/daemon/src/taker.rs index ab646d2..0db2523 100644 --- a/daemon/src/taker.rs +++ b/daemon/src/taker.rs @@ -13,7 +13,8 @@ use daemon::{ }; use futures::StreamExt; use rocket::fairing::AdHoc; -use rocket_db_pools::Database; +use sqlx::sqlite::SqliteConnectOptions; +use sqlx::SqlitePool; use std::collections::HashMap; use std::net::SocketAddr; use std::path::PathBuf; @@ -31,10 +32,6 @@ mod routes_taker; const CONNECTION_RETRY_INTERVAL: Duration = Duration::from_secs(5); -#[derive(Database)] -#[database("taker")] -pub struct Db(sqlx::SqlitePool); - #[derive(Clap)] struct Opts { /// The IP address of the other party (i.e. the maker). @@ -164,36 +161,36 @@ async fn main() -> Result<()> { tokio::spawn(task); let figment = rocket::Config::figment() - .merge(("databases.taker.url", data_dir.join("taker.sqlite"))) .merge(("address", opts.http_address.ip())) .merge(("port", opts.http_address.port())); + let db = SqlitePool::connect_with( + SqliteConnectOptions::new() + .create_if_missing(true) + .filename(data_dir.join("taker.sqlite")), + ) + .await?; + rocket::custom(figment) .manage(order_feed_receiver) .manage(wallet_feed_receiver) .manage(update_feed_receiver) .manage(quote_updates) .manage(bitcoin_network) - .attach(Db::init()) - .attach(AdHoc::try_on_ignite( - "SQL migrations", - |rocket| async move { - match Db::fetch(&rocket) { - Some(db) => match db::run_migrations(&**db).await { - Ok(_) => Ok(rocket), - Err(_) => Err(rocket), - }, - None => Err(rocket), + .attach(AdHoc::try_on_ignite("SQL migrations", { + let db = db.clone(); + + move |rocket| async move { + match db::run_migrations(&db).await { + Ok(_) => Ok(rocket), + Err(_) => Err(rocket), } - }, - )) - .attach(AdHoc::try_on_ignite( - "Create actors", + } + })) + .attach(AdHoc::try_on_ignite("Create actors", { + let db = db.clone(); + move |rocket| async move { - let db = match Db::fetch(&rocket) { - Some(db) => (**db).clone(), - None => return Err(rocket), - }; let mut conn = db.acquire().await.unwrap(); housekeeping::transition_non_continue_cfds_to_setup_failed(&mut conn) @@ -212,8 +209,6 @@ async fn main() -> Result<()> { let (monitor_actor_address, mut monitor_actor_context) = xtra::Context::new(None); let (oracle_actor_address, mut oracle_actor_context) = xtra::Context::new(None); - let mut conn = db.acquire().await.unwrap(); - let cfds = load_all_cfds(&mut conn).await.unwrap(); let cfd_actor_inbox = taker_cfd::Actor::new( db.clone(), wallet.clone(), @@ -273,8 +268,8 @@ async fn main() -> Result<()> { .manage(take_offer_channel) .manage(cfd_action_channel) .manage(cfd_feed_receiver)) - }, - )) + } + })) .mount( "/api", rocket::routes![ @@ -292,5 +287,7 @@ async fn main() -> Result<()> { .launch() .await?; + db.close().await; + Ok(()) }