You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

702 lines
22 KiB

use crate::model::cfd::{Cfd, CfdState, Order, OrderId, Origin};
use crate::model::{Leverage, OracleEventId, Position};
use anyhow::{Context, Result};
use rocket_db_pools::sqlx;
use sqlx::pool::PoolConnection;
use sqlx::{Acquire, Sqlite, SqlitePool};
use std::convert::TryInto;
use std::mem;
pub async fn run_migrations(pool: &SqlitePool) -> anyhow::Result<()> {
sqlx::migrate!("./migrations").run(pool).await?;
Ok(())
}
pub async fn insert_order(order: &Order, conn: &mut PoolConnection<Sqlite>) -> anyhow::Result<()> {
let uuid = serde_json::to_string(&order.id).unwrap();
let trading_pair = serde_json::to_string(&order.trading_pair).unwrap();
let position = serde_json::to_string(&order.position).unwrap();
let initial_price = serde_json::to_string(&order.price).unwrap();
let min_quantity = serde_json::to_string(&order.min_quantity).unwrap();
let max_quantity = serde_json::to_string(&order.max_quantity).unwrap();
let leverage = order.leverage.0;
let liquidation_price = serde_json::to_string(&order.liquidation_price).unwrap();
let creation_timestamp = serde_json::to_string(&order.creation_timestamp).unwrap();
let term = serde_json::to_string(&order.term).unwrap();
let origin = serde_json::to_string(&order.origin).unwrap();
let oracle_event_id = order.oracle_event_id.0.clone();
sqlx::query!(
r#"
insert into orders (
uuid,
trading_pair,
position,
initial_price,
min_quantity,
max_quantity,
leverage,
liquidation_price,
creation_timestamp,
term,
origin,
oracle_event_id
) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
"#,
uuid,
trading_pair,
position,
initial_price,
min_quantity,
max_quantity,
leverage,
liquidation_price,
creation_timestamp,
term,
origin,
oracle_event_id
)
.execute(conn)
.await?;
Ok(())
}
pub async fn load_order_by_id(
id: OrderId,
conn: &mut PoolConnection<Sqlite>,
) -> anyhow::Result<Order> {
let uuid = serde_json::to_string(&id).unwrap();
let row = sqlx::query!(
r#"
select * from orders where uuid = ?;
"#,
uuid
)
.fetch_one(conn)
.await?;
let uuid = serde_json::from_str(row.uuid.as_str()).unwrap();
let trading_pair = serde_json::from_str(row.trading_pair.as_str()).unwrap();
let position = serde_json::from_str(row.position.as_str()).unwrap();
let initial_price = serde_json::from_str(row.initial_price.as_str()).unwrap();
let min_quantity = serde_json::from_str(row.min_quantity.as_str()).unwrap();
let max_quantity = serde_json::from_str(row.max_quantity.as_str()).unwrap();
let leverage = Leverage(row.leverage.try_into().unwrap());
let liquidation_price = serde_json::from_str(row.liquidation_price.as_str()).unwrap();
let creation_timestamp = serde_json::from_str(row.creation_timestamp.as_str()).unwrap();
let term = serde_json::from_str(row.term.as_str()).unwrap();
let origin = serde_json::from_str(row.origin.as_str()).unwrap();
let oracle_event_id = OracleEventId(row.oracle_event_id);
Ok(Order {
id: uuid,
trading_pair,
position,
price: initial_price,
min_quantity,
max_quantity,
leverage,
liquidation_price,
creation_timestamp,
term,
origin,
oracle_event_id,
})
}
pub async fn insert_cfd(cfd: Cfd, conn: &mut PoolConnection<Sqlite>) -> anyhow::Result<()> {
let mut tx = conn.begin().await?;
let order_uuid = serde_json::to_string(&cfd.order.id)?;
let order_row = sqlx::query!(
r#"
select * from orders where uuid = ?;
"#,
order_uuid
)
.fetch_one(&mut tx)
.await?;
let order_id = order_row.id;
let quantity_usd = serde_json::to_string(&cfd.quantity_usd)?;
let cfd_state = serde_json::to_string(&cfd.state)?;
// save cfd + state in a transaction to make sure the state is only inserted if the cfd was
// inserted
let cfd_id = sqlx::query!(
r#"
insert into cfds (
order_id,
order_uuid,
quantity_usd
) values (?, ?, ?);
"#,
order_id,
order_uuid,
quantity_usd,
)
.execute(&mut tx)
.await?
.last_insert_rowid();
sqlx::query!(
r#"
insert into cfd_states (
cfd_id,
state
) values (?, ?);
"#,
cfd_id,
cfd_state,
)
.execute(&mut tx)
.await?;
tx.commit().await?;
Ok(())
}
#[allow(dead_code)]
pub async fn insert_new_cfd_state_by_order_id(
order_id: OrderId,
new_state: CfdState,
conn: &mut PoolConnection<Sqlite>,
) -> anyhow::Result<()> {
let cfd_id = load_cfd_id_by_order_uuid(order_id, conn).await?;
let latest_cfd_state_in_db = load_latest_cfd_state(cfd_id, conn)
.await
.context("loading latest state failed")?;
// make sure that the new state is different than the current one to avoid that we save the same
// state twice
if mem::discriminant(&latest_cfd_state_in_db) == mem::discriminant(&new_state) {
tracing::warn!(
"Same state transition for cfd with order_id {}: {}",
order_id,
latest_cfd_state_in_db
);
}
let cfd_state = serde_json::to_string(&new_state)?;
sqlx::query!(
r#"
insert into cfd_states (
cfd_id,
state
) values (?, ?);
"#,
cfd_id,
cfd_state,
)
.execute(conn)
.await?;
Ok(())
}
#[allow(dead_code)]
async fn load_cfd_id_by_order_uuid(
order_uuid: OrderId,
conn: &mut PoolConnection<Sqlite>,
) -> anyhow::Result<i64> {
let order_uuid = serde_json::to_string(&order_uuid)?;
let cfd_id = sqlx::query!(
r#"
select
id
from cfds
where order_uuid = ?;
"#,
order_uuid
)
.fetch_one(conn)
.await?;
let cfd_id = cfd_id.id.context("No cfd found")?;
Ok(cfd_id)
}
#[allow(dead_code)]
async fn load_latest_cfd_state(
cfd_id: i64,
conn: &mut PoolConnection<Sqlite>,
) -> anyhow::Result<CfdState> {
let latest_cfd_state = sqlx::query!(
r#"
select
state
from cfd_states
where cfd_id = ?
order by id desc
limit 1;
"#,
cfd_id
)
.fetch_one(conn)
.await?;
let latest_cfd_state_in_db: CfdState = serde_json::from_str(latest_cfd_state.state.as_str())?;
Ok(latest_cfd_state_in_db)
}
pub async fn load_cfd_by_order_id(
order_id: OrderId,
conn: &mut PoolConnection<Sqlite>,
) -> Result<Cfd> {
let order_uuid = serde_json::to_string(&order_id)?;
let row = sqlx::query!(
r#"
select
orders.uuid as order_id,
orders.initial_price as price,
orders.min_quantity as min_quantity,
orders.max_quantity as max_quantity,
orders.leverage as leverage,
orders.trading_pair as trading_pair,
orders.position as position,
orders.origin as origin,
orders.liquidation_price as liquidation_price,
orders.creation_timestamp as creation_timestamp,
orders.term as term,
orders.oracle_event_id,
cfds.quantity_usd as quantity_usd,
cfd_states.state as state
from cfds as cfds
inner join orders as orders on cfds.order_id = orders.id
inner join cfd_states as cfd_states on cfd_states.cfd_id = cfds.id
where cfd_states.state in (
select
state
from cfd_states
where cfd_id = cfds.id
order by id desc
limit 1
)
and orders.uuid = ?
"#,
order_uuid
)
.fetch_one(conn)
.await?;
let order_id = serde_json::from_str(row.order_id.as_str()).unwrap();
let trading_pair = serde_json::from_str(row.trading_pair.as_str()).unwrap();
let position: Position = serde_json::from_str(row.position.as_str()).unwrap();
let price = serde_json::from_str(row.price.as_str()).unwrap();
let min_quantity = serde_json::from_str(row.min_quantity.as_str()).unwrap();
let max_quantity = serde_json::from_str(row.max_quantity.as_str()).unwrap();
let leverage = Leverage(row.leverage.try_into().unwrap());
let liquidation_price = serde_json::from_str(row.liquidation_price.as_str()).unwrap();
let creation_timestamp = serde_json::from_str(row.creation_timestamp.as_str()).unwrap();
let term = serde_json::from_str(row.term.as_str()).unwrap();
let origin: Origin = serde_json::from_str(row.origin.as_str()).unwrap();
let oracle_event_id = OracleEventId(row.oracle_event_id.clone());
let quantity = serde_json::from_str(row.quantity_usd.as_str()).unwrap();
let latest_state = serde_json::from_str(row.state.as_str()).unwrap();
let order = Order {
id: order_id,
trading_pair,
position,
price,
min_quantity,
max_quantity,
leverage,
liquidation_price,
creation_timestamp,
term,
origin,
oracle_event_id,
};
Ok(Cfd {
order,
quantity_usd: quantity,
state: latest_state,
})
}
/// Loads all CFDs with the latest state as the CFD state
pub async fn load_all_cfds(conn: &mut PoolConnection<Sqlite>) -> anyhow::Result<Vec<Cfd>> {
// TODO: Could be optimized with something like but not sure it's worth the complexity:
let rows = sqlx::query!(
r#"
select
orders.uuid as order_id,
orders.initial_price as price,
orders.min_quantity as min_quantity,
orders.max_quantity as max_quantity,
orders.leverage as leverage,
orders.trading_pair as trading_pair,
orders.position as position,
orders.origin as origin,
orders.liquidation_price as liquidation_price,
orders.creation_timestamp as creation_timestamp,
orders.term as term,
orders.oracle_event_id,
cfds.quantity_usd as quantity_usd,
cfd_states.state as state
from cfds as cfds
inner join orders as orders on cfds.order_id = orders.id
inner join cfd_states as cfd_states on cfd_states.cfd_id = cfds.id
where cfd_states.state in (
select
state
from cfd_states
where cfd_id = cfds.id
order by id desc
limit 1
)
"#
)
.fetch_all(conn)
.await?;
let cfds = rows
.iter()
.map(|row| {
let order_id = serde_json::from_str(row.order_id.as_str()).unwrap();
let trading_pair = serde_json::from_str(row.trading_pair.as_str()).unwrap();
let position: Position = serde_json::from_str(row.position.as_str()).unwrap();
let price = serde_json::from_str(row.price.as_str()).unwrap();
let min_quantity = serde_json::from_str(row.min_quantity.as_str()).unwrap();
let max_quantity = serde_json::from_str(row.max_quantity.as_str()).unwrap();
let leverage = Leverage(row.leverage.try_into().unwrap());
let liquidation_price = serde_json::from_str(row.liquidation_price.as_str()).unwrap();
let creation_timestamp = serde_json::from_str(row.creation_timestamp.as_str()).unwrap();
let term = serde_json::from_str(row.term.as_str()).unwrap();
let origin: Origin = serde_json::from_str(row.origin.as_str()).unwrap();
let oracle_event_id = OracleEventId(row.oracle_event_id.clone());
let quantity = serde_json::from_str(row.quantity_usd.as_str()).unwrap();
let latest_state = serde_json::from_str(row.state.as_str()).unwrap();
let order = Order {
id: order_id,
trading_pair,
position,
price,
min_quantity,
max_quantity,
leverage,
liquidation_price,
creation_timestamp,
term,
origin,
oracle_event_id,
};
Cfd {
order,
quantity_usd: quantity,
state: latest_state,
}
})
.collect();
Ok(cfds)
}
/// Loads all CFDs with the latest state as the CFD state
pub async fn load_cfds_by_oracle_event_id(
oracle_event_id: OracleEventId,
conn: &mut PoolConnection<Sqlite>,
) -> anyhow::Result<Vec<Cfd>> {
let rows = sqlx::query!(
r#"
select
orders.uuid as order_id,
orders.initial_price as price,
orders.min_quantity as min_quantity,
orders.max_quantity as max_quantity,
orders.leverage as leverage,
orders.trading_pair as trading_pair,
orders.position as position,
orders.origin as origin,
orders.liquidation_price as liquidation_price,
orders.creation_timestamp as creation_timestamp,
orders.term as term,
orders.oracle_event_id,
cfds.quantity_usd as quantity_usd,
cfd_states.state as state
from cfds as cfds
inner join orders as orders on cfds.order_id = orders.id
inner join cfd_states as cfd_states on cfd_states.cfd_id = cfds.id
where cfd_states.state in (
select
state
from cfd_states
where cfd_id = cfds.id
order by id desc
limit 1
)
and orders.oracle_event_id = ?
"#,
oracle_event_id.0
)
.fetch_all(conn)
.await?;
let cfds = rows
.iter()
.map(|row| {
let order_id = serde_json::from_str(row.order_id.as_str()).unwrap();
let trading_pair = serde_json::from_str(row.trading_pair.as_str()).unwrap();
let position: Position = serde_json::from_str(row.position.as_str()).unwrap();
let price = serde_json::from_str(row.price.as_str()).unwrap();
let min_quantity = serde_json::from_str(row.min_quantity.as_str()).unwrap();
let max_quantity = serde_json::from_str(row.max_quantity.as_str()).unwrap();
let leverage = Leverage(row.leverage.try_into().unwrap());
let liquidation_price = serde_json::from_str(row.liquidation_price.as_str()).unwrap();
let creation_timestamp = serde_json::from_str(row.creation_timestamp.as_str()).unwrap();
let term = serde_json::from_str(row.term.as_str()).unwrap();
let origin: Origin = serde_json::from_str(row.origin.as_str()).unwrap();
let oracle_event_id = OracleEventId(row.oracle_event_id.clone());
let quantity = serde_json::from_str(row.quantity_usd.as_str()).unwrap();
let latest_state = serde_json::from_str(row.state.as_str()).unwrap();
let order = Order {
id: order_id,
trading_pair,
position,
price,
min_quantity,
max_quantity,
leverage,
liquidation_price,
creation_timestamp,
term,
origin,
oracle_event_id,
};
Cfd {
order,
quantity_usd: quantity,
state: latest_state,
}
})
.collect();
Ok(cfds)
}
#[cfg(test)]
mod tests {
use std::fs::File;
use std::time::SystemTime;
use rust_decimal_macros::dec;
use sqlx::SqlitePool;
use tempfile::tempdir;
use crate::db::insert_order;
use crate::model::cfd::{Cfd, CfdState, CfdStateCommon, Order};
use crate::model::Usd;
use super::*;
#[tokio::test]
async fn test_insert_and_load_order() {
let pool = setup_test_db().await;
let mut conn = pool.acquire().await.unwrap();
let order = Order::default();
insert_order(&order, &mut conn).await.unwrap();
let order_loaded = load_order_by_id(order.id, &mut conn).await.unwrap();
assert_eq!(order, order_loaded);
}
#[tokio::test]
async fn test_insert_and_load_cfd() {
let pool = setup_test_db().await;
let mut conn = pool.acquire().await.unwrap();
let cfd = Cfd::default();
insert_order(&cfd.order, &mut conn).await.unwrap();
insert_cfd(cfd.clone(), &mut conn).await.unwrap();
let cfds_from_db = load_all_cfds(&mut conn).await.unwrap();
let cfd_from_db = cfds_from_db.first().unwrap().clone();
assert_eq!(cfd, cfd_from_db)
}
#[tokio::test]
async fn test_insert_and_load_cfd_by_order_id() {
let pool = setup_test_db().await;
let mut conn = pool.acquire().await.unwrap();
let cfd = Cfd::default();
let order_id = cfd.order.id;
insert_order(&cfd.order, &mut conn).await.unwrap();
insert_cfd(cfd.clone(), &mut conn).await.unwrap();
let cfd_from_db = load_cfd_by_order_id(order_id, &mut conn).await.unwrap();
assert_eq!(cfd, cfd_from_db)
}
#[tokio::test]
async fn test_insert_and_load_cfd_by_order_id_multiple() {
let pool = setup_test_db().await;
let mut conn = pool.acquire().await.unwrap();
let cfd = Cfd::default();
let order_id = cfd.order.id;
insert_order(&cfd.order, &mut conn).await.unwrap();
insert_cfd(cfd.clone(), &mut conn).await.unwrap();
let cfd_from_db = load_cfd_by_order_id(order_id, &mut conn).await.unwrap();
assert_eq!(cfd, cfd_from_db);
let cfd = Cfd::default();
let order_id = cfd.order.id;
insert_order(&cfd.order, &mut conn).await.unwrap();
insert_cfd(cfd.clone(), &mut conn).await.unwrap();
let cfd_from_db = load_cfd_by_order_id(order_id, &mut conn).await.unwrap();
assert_eq!(cfd, cfd_from_db);
}
#[tokio::test]
async fn test_insert_and_load_cfd_by_oracle_event_id() {
let pool = setup_test_db().await;
let mut conn = pool.acquire().await.unwrap();
let oracle_event_id_1 = OracleEventId("dummy_1".to_string());
let oracle_event_id_2 = OracleEventId("dummy_2".to_string());
let cfd_1 = Cfd::default()
.with_order(Order::default().with_oracle_event_id(oracle_event_id_1.clone()));
insert_order(&cfd_1.order, &mut conn).await.unwrap();
insert_cfd(cfd_1.clone(), &mut conn).await.unwrap();
let cfd_from_db = load_cfds_by_oracle_event_id(oracle_event_id_1.clone(), &mut conn)
.await
.unwrap();
assert_eq!(vec![cfd_1.clone()], cfd_from_db);
let cfd_2 = Cfd::default()
.with_order(Order::default().with_oracle_event_id(oracle_event_id_1.clone()));
insert_order(&cfd_2.order, &mut conn).await.unwrap();
insert_cfd(cfd_2.clone(), &mut conn).await.unwrap();
let cfd_from_db = load_cfds_by_oracle_event_id(oracle_event_id_1, &mut conn)
.await
.unwrap();
assert_eq!(vec![cfd_1, cfd_2], cfd_from_db);
let cfd_3 = Cfd::default()
.with_order(Order::default().with_oracle_event_id(oracle_event_id_2.clone()));
insert_order(&cfd_3.order, &mut conn).await.unwrap();
insert_cfd(cfd_3.clone(), &mut conn).await.unwrap();
let cfd_from_db = load_cfds_by_oracle_event_id(oracle_event_id_2, &mut conn)
.await
.unwrap();
assert_eq!(vec![cfd_3], cfd_from_db);
}
#[tokio::test]
async fn test_insert_new_cfd_state() {
let pool = setup_test_db().await;
let mut conn = pool.acquire().await.unwrap();
let mut cfd = Cfd::default();
insert_order(&cfd.order, &mut conn).await.unwrap();
insert_cfd(cfd.clone(), &mut conn).await.unwrap();
cfd.state = CfdState::Accepted {
common: CfdStateCommon {
transition_timestamp: SystemTime::now(),
},
};
insert_new_cfd_state_by_order_id(cfd.order.id, cfd.state.clone(), &mut conn)
.await
.unwrap();
let cfds_from_db = load_all_cfds(&mut conn).await.unwrap();
let cfd_from_db = cfds_from_db.first().unwrap().clone();
assert_eq!(cfd, cfd_from_db)
}
async fn setup_test_db() -> SqlitePool {
let temp_db = tempdir().unwrap().into_path().join("tempdb");
// file has to exist in order to connect with sqlite
File::create(temp_db.clone()).unwrap();
dbg!(&temp_db);
let pool = SqlitePool::connect(format!("sqlite:{}", temp_db.display()).as_str())
.await
.unwrap();
run_migrations(&pool).await.unwrap();
pool
}
impl Default for Cfd {
fn default() -> Self {
Cfd::new(
Order::default(),
Usd(dec!(1000)),
CfdState::OutgoingOrderRequest {
common: CfdStateCommon {
transition_timestamp: SystemTime::now(),
},
},
)
}
}
impl Cfd {
pub fn with_order(mut self, order: Order) -> Self {
self.order = order;
self
}
}
impl Default for Order {
fn default() -> Self {
Order::new(
Usd(dec!(1000)),
Usd(dec!(100)),
Usd(dec!(1000)),
Origin::Theirs,
OracleEventId("Dummy".to_string()),
)
.unwrap()
}
}
impl Order {
pub fn with_oracle_event_id(mut self, oracle_event_id: OracleEventId) -> Self {
self.oracle_event_id = oracle_event_id;
self
}
}
}