Browse Source
195: Payout curve r=thomaseizinger a=DeliciousHair Resolves #60. Co-authored-by: DelicioiusHair <mshepit@gmail.com>refactor/no-log-handler
bors[bot]
3 years ago
committed by
GitHub
15 changed files with 2872 additions and 50 deletions
@ -1,30 +1,744 @@ |
|||||
|
use std::fmt; |
||||
|
|
||||
use crate::model::{Leverage, Usd}; |
use crate::model::{Leverage, Usd}; |
||||
use anyhow::Result; |
use crate::payout_curve::curve::Curve; |
||||
|
use anyhow::{Context, Result}; |
||||
use bdk::bitcoin; |
use bdk::bitcoin; |
||||
use cfd_protocol::interval::MAX_PRICE_DEC; |
|
||||
use cfd_protocol::{generate_payouts, Payout}; |
use cfd_protocol::{generate_payouts, Payout}; |
||||
|
use itertools::Itertools; |
||||
|
use ndarray::prelude::*; |
||||
|
use num::{FromPrimitive, ToPrimitive}; |
||||
|
use rust_decimal::Decimal; |
||||
|
|
||||
|
mod basis; |
||||
|
mod basis_eval; |
||||
|
mod compat; |
||||
|
mod csr_tools; |
||||
|
mod curve; |
||||
|
mod curve_factory; |
||||
|
mod splineobject; |
||||
|
mod utils; |
||||
|
|
||||
|
/// function to generate an iterator of values, heuristically viewed as:
|
||||
|
///
|
||||
|
/// `[left_price_boundary, right_price_boundary], maker_payout_value`
|
||||
|
///
|
||||
|
/// with units
|
||||
|
///
|
||||
|
/// `[Usd, Usd], bitcoin::Amount`
|
||||
|
///
|
||||
|
/// A key item to note is that although the POC logic has been to imposed
|
||||
|
/// that maker goes short every time, there is no reason to make the math
|
||||
|
/// have this imposition as well. As such, the `long_position` parameter
|
||||
|
/// is used to indicate which party (Maker or Taker) has the long position,
|
||||
|
/// and everything else is handled internally.
|
||||
|
///
|
||||
|
/// As well, the POC has also demanded that the Maker always has unity
|
||||
|
/// leverage, hence why the ability to to specify this amount has been
|
||||
|
/// omitted from the parameters. Internally, it is hard-coded to unity
|
||||
|
/// in the call to PayoutCurve::new(), so this behaviour can be changed in
|
||||
|
/// the future trivially.
|
||||
|
///
|
||||
|
/// ### Paramters
|
||||
|
///
|
||||
|
/// * price: BTC-USD exchange rate used to create CFD contract
|
||||
|
/// * quantity: Interger number of one-dollar USD contracts contained in the
|
||||
|
/// CFD; expressed as a Usd amount
|
||||
|
/// * leverage: Leveraging used by the taker
|
||||
|
///
|
||||
|
/// ### Returns
|
||||
|
///
|
||||
|
/// The list of [`Payout`]s for the given price, quantity and leverage.
|
||||
|
pub fn calculate(price: Usd, quantity: Usd, leverage: Leverage) -> Result<Vec<Payout>> { |
||||
|
let payouts = calculate_payout_parameters(price, quantity, leverage)? |
||||
|
.into_iter() |
||||
|
.map(PayoutParameter::into_payouts) |
||||
|
.flatten_ok() |
||||
|
.collect::<Result<Vec<_>>>()?; |
||||
|
|
||||
|
Ok(payouts) |
||||
|
} |
||||
|
|
||||
|
const CONTRACT_VALUE: f64 = 1.; |
||||
|
const N_PAYOUTS: usize = 200; |
||||
|
const SHORT_LEVERAGE: usize = 1; |
||||
|
|
||||
pub fn calculate( |
/// Internal calculate function for the payout curve.
|
||||
|
///
|
||||
|
/// To ease testing, we write our tests against this function because it has a more human-friendly
|
||||
|
/// output. The design goal here is that the the above `calculate` function is as thin as possible.
|
||||
|
fn calculate_payout_parameters( |
||||
price: Usd, |
price: Usd, |
||||
_quantity: Usd, |
quantity: Usd, |
||||
maker_payin: bitcoin::Amount, |
long_leverage: Leverage, |
||||
(taker_payin, _leverage): (bitcoin::Amount, Leverage), |
) -> Result<Vec<PayoutParameter>> { |
||||
) -> Result<Vec<Payout>> { |
let initial_rate = price |
||||
let dollars = price.try_into_u64()?; |
.try_into_u64() |
||||
let payouts = vec![ |
.context("Cannot convert price to u64")? as f64; |
||||
generate_payouts( |
let quantity = quantity |
||||
0..=(dollars - 10), |
.try_into_u64() |
||||
maker_payin + taker_payin, |
.context("Cannot convert quantity to u64")? as usize; |
||||
bitcoin::Amount::ZERO, |
|
||||
)?, |
let payout_curve = PayoutCurve::new( |
||||
generate_payouts((dollars - 9)..=(dollars + 10), maker_payin, taker_payin)?, |
initial_rate as f64, |
||||
|
long_leverage.0 as usize, |
||||
|
SHORT_LEVERAGE, |
||||
|
quantity, |
||||
|
CONTRACT_VALUE, |
||||
|
None, |
||||
|
)?; |
||||
|
|
||||
|
let payout_parameters = payout_curve |
||||
|
.generate_payout_scheme(N_PAYOUTS)? |
||||
|
.rows() |
||||
|
.into_iter() |
||||
|
.map(|row| { |
||||
|
let left_bound = row[0] as u64; |
||||
|
let right_bound = row[1] as u64; |
||||
|
let long_amount = row[2]; |
||||
|
|
||||
|
let short_amount = to_sats(payout_curve.total_value - long_amount)?; |
||||
|
let long_amount = to_sats(long_amount)?; |
||||
|
|
||||
|
Ok(PayoutParameter { |
||||
|
left_bound, |
||||
|
right_bound, |
||||
|
long_amount, |
||||
|
short_amount, |
||||
|
}) |
||||
|
}) |
||||
|
.collect::<Result<Vec<_>>>()?; |
||||
|
|
||||
|
Ok(payout_parameters) |
||||
|
} |
||||
|
|
||||
|
#[derive(PartialEq)] |
||||
|
struct PayoutParameter { |
||||
|
left_bound: u64, |
||||
|
right_bound: u64, |
||||
|
long_amount: u64, |
||||
|
short_amount: u64, |
||||
|
} |
||||
|
|
||||
|
impl PayoutParameter { |
||||
|
fn into_payouts(self) -> Result<Vec<Payout>> { |
||||
generate_payouts( |
generate_payouts( |
||||
(dollars + 11)..=MAX_PRICE_DEC, |
self.left_bound..=self.right_bound, |
||||
bitcoin::Amount::ZERO, |
bitcoin::Amount::from_sat(self.short_amount), |
||||
maker_payin + taker_payin, |
bitcoin::Amount::from_sat(self.long_amount), |
||||
)?, |
) |
||||
] |
} |
||||
.concat(); |
} |
||||
|
|
||||
Ok(payouts) |
impl fmt::Debug for PayoutParameter { |
||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
||||
|
write!( |
||||
|
f, |
||||
|
"payout({}..={}, {}, {})", |
||||
|
self.left_bound, self.right_bound, self.short_amount, self.long_amount |
||||
|
) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
/// Converts a float with any precision to a [`bitcoin::Amount`].
|
||||
|
fn to_sats(btc: f64) -> Result<u64> { |
||||
|
let sats_per_btc = Decimal::from(100_000_000); |
||||
|
|
||||
|
let btc = Decimal::from_f64(btc).context("Cannot create decimal from float")?; |
||||
|
let sats = btc * sats_per_btc; |
||||
|
let sats = sats.to_u64().context("Cannot fit sats into u64")?; |
||||
|
|
||||
|
Ok(sats) |
||||
|
} |
||||
|
|
||||
|
#[derive(thiserror::Error, Debug)] |
||||
|
pub enum Error { |
||||
|
#[error("failed to init CSR object--is the specified shape correct?")] |
||||
|
#[allow(clippy::upper_case_acronyms)] |
||||
|
CannotInitCSR, |
||||
|
#[error("matrix must be square")] |
||||
|
MatrixMustBeSquare, |
||||
|
#[error("evaluation outside parametric domain")] |
||||
|
InvalidDomain, |
||||
|
#[error("einsum error--array size mismatch?")] |
||||
|
Einsum, |
||||
|
#[error("no operand string found")] |
||||
|
NoEinsumOperatorString, |
||||
|
#[error("cannot connect periodic curves")] |
||||
|
CannotConnectPeriodicCurves, |
||||
|
#[error("degree must be strictly positive")] |
||||
|
DegreeMustBePositive, |
||||
|
#[error("all parameter arrays must have the same length if not using a tensor grid")] |
||||
|
InvalidDerivative, |
||||
|
#[error("Rational derivative not implemented for order sum(d) > 1")] |
||||
|
DerivativeNotImplemented, |
||||
|
#[error("requested segmentation is too coarse for this curve")] |
||||
|
InvalidSegmentation, |
||||
|
#[error("concatonation error")] |
||||
|
NdArray { |
||||
|
#[from] |
||||
|
source: ndarray::ShapeError, |
||||
|
}, |
||||
|
#[error(transparent)] |
||||
|
NotOneDimensional { |
||||
|
#[from] |
||||
|
source: compat::NotOneDimensional, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
#[derive(Clone, Debug)] |
||||
|
struct PayoutCurve { |
||||
|
curve: Curve, |
||||
|
has_upper_limit: bool, |
||||
|
lower_corner: f64, |
||||
|
upper_corner: f64, |
||||
|
total_value: f64, |
||||
|
} |
||||
|
|
||||
|
impl PayoutCurve { |
||||
|
fn new( |
||||
|
initial_rate: f64, |
||||
|
leverage_long: usize, |
||||
|
leverage_short: usize, |
||||
|
n_contracts: usize, |
||||
|
contract_value: f64, |
||||
|
tolerance: Option<f64>, |
||||
|
) -> Result<Self, Error> { |
||||
|
let tolerance = tolerance.unwrap_or(1e-6); |
||||
|
let bounds = cutoffs(initial_rate, leverage_long, leverage_short); |
||||
|
let total_value = pool_value( |
||||
|
initial_rate, |
||||
|
n_contracts, |
||||
|
contract_value, |
||||
|
leverage_long, |
||||
|
leverage_short, |
||||
|
); |
||||
|
let mut curve = curve_factory::line((0., 0.), (bounds.0, 0.), false)?; |
||||
|
|
||||
|
let payout = |
||||
|
create_long_payout_function(initial_rate, n_contracts, contract_value, leverage_long); |
||||
|
let variable_payout = |
||||
|
curve_factory::fit(payout, bounds.0, bounds.1, Some(tolerance), None)?; |
||||
|
curve.append(variable_payout)?; |
||||
|
|
||||
|
let upper_corner; |
||||
|
if bounds.2 { |
||||
|
let upper_liquidation = curve_factory::line( |
||||
|
(bounds.1, total_value), |
||||
|
(4. * initial_rate, total_value), |
||||
|
false, |
||||
|
)?; |
||||
|
curve.append(upper_liquidation)?; |
||||
|
upper_corner = bounds.1; |
||||
|
} else { |
||||
|
upper_corner = curve.spline.bases[0].end(); |
||||
|
} |
||||
|
|
||||
|
Ok(PayoutCurve { |
||||
|
curve, |
||||
|
has_upper_limit: bounds.2, |
||||
|
lower_corner: bounds.0, |
||||
|
upper_corner, |
||||
|
total_value, |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
pub fn generate_payout_scheme(&self, n_segments: usize) -> Result<Array2<f64>, Error> { |
||||
|
let n_min; |
||||
|
if self.has_upper_limit { |
||||
|
n_min = 3; |
||||
|
} else { |
||||
|
n_min = 2; |
||||
|
} |
||||
|
|
||||
|
if n_segments < n_min { |
||||
|
return Result::Err(Error::InvalidSegmentation); |
||||
|
} |
||||
|
|
||||
|
let t; |
||||
|
if self.has_upper_limit { |
||||
|
t = self.build_sampling_vector_upper_bounded(n_segments); |
||||
|
} else { |
||||
|
t = self.build_sampling_vector_upper_unbounded(n_segments) |
||||
|
} |
||||
|
|
||||
|
let mut z_arr = self.curve.evaluate(&mut &[t][..])?; |
||||
|
if self.has_upper_limit { |
||||
|
self.modify_samples_bounded(&mut z_arr); |
||||
|
} else { |
||||
|
self.modify_samples_unbounded(&mut z_arr); |
||||
|
} |
||||
|
self.generate_segments(&mut z_arr); |
||||
|
|
||||
|
Ok(z_arr) |
||||
|
} |
||||
|
|
||||
|
fn build_sampling_vector_upper_bounded(&self, n_segs: usize) -> Array1<f64> { |
||||
|
let knots = &self.curve.spline.knots(0, None).unwrap()[0]; |
||||
|
let klen = knots.len(); |
||||
|
let n_64 = (n_segs + 1) as f64; |
||||
|
let d = knots[klen - 2] - knots[1]; |
||||
|
let delta_0 = d / (2. * (n_64 - 5.)); |
||||
|
let delta_1 = d * (n_64 - 6.) / ((n_64 - 5.) * (n_64 - 4.)); |
||||
|
|
||||
|
let mut vec = Vec::<f64>::with_capacity(n_segs + 2); |
||||
|
for i in 0..n_segs + 2 { |
||||
|
if i == 0 { |
||||
|
vec.push(self.curve.spline.bases[0].start()); |
||||
|
} else if i == 1 { |
||||
|
vec.push(knots[1]); |
||||
|
} else if i == 2 { |
||||
|
vec.push(knots[1] + delta_0); |
||||
|
} else if i == n_segs - 1 { |
||||
|
vec.push(knots[klen - 2] - delta_0); |
||||
|
} else if i == n_segs { |
||||
|
vec.push(knots[klen - 2]); |
||||
|
} else if i == n_segs + 1 { |
||||
|
vec.push(self.curve.spline.bases[0].end()); |
||||
|
} else { |
||||
|
let c = (i - 2) as f64; |
||||
|
vec.push(knots[1] + delta_0 + c * delta_1); |
||||
|
} |
||||
|
} |
||||
|
Array1::<f64>::from_vec(vec) |
||||
|
} |
||||
|
|
||||
|
fn build_sampling_vector_upper_unbounded(&self, n_segs: usize) -> Array1<f64> { |
||||
|
let knots = &self.curve.spline.knots(0, None).unwrap()[0]; |
||||
|
let klen = knots.len(); |
||||
|
let n_64 = (n_segs + 1) as f64; |
||||
|
let d = knots[klen - 1] - knots[1]; |
||||
|
let delta = d / (n_64 - 1_f64); |
||||
|
let delta_x = d / (2. * (n_64 - 1_f64)); |
||||
|
let delta_y = 3. * d / (2. * (n_64 - 1_f64)); |
||||
|
|
||||
|
let mut vec = Vec::<f64>::with_capacity(n_segs + 2); |
||||
|
for i in 0..n_segs + 2 { |
||||
|
if i == 0 { |
||||
|
vec.push(self.curve.spline.bases[0].start()); |
||||
|
} else if i == 1 { |
||||
|
vec.push(knots[1]); |
||||
|
} else if i == 2 { |
||||
|
vec.push(knots[1] + delta_x); |
||||
|
} else if i == n_segs { |
||||
|
vec.push(knots[klen - 1] - delta_y); |
||||
|
} else if i == n_segs + 1 { |
||||
|
vec.push(knots[klen - 1]); |
||||
|
} else { |
||||
|
let c = (i - 2) as f64; |
||||
|
vec.push(knots[1] + delta_x + c * delta); |
||||
|
} |
||||
|
} |
||||
|
Array1::<f64>::from_vec(vec) |
||||
|
} |
||||
|
|
||||
|
fn modify_samples_bounded(&self, arr: &mut Array2<f64>) { |
||||
|
let n = arr.shape()[0]; |
||||
|
let capacity = 2 * (n - 2); |
||||
|
let mut vec = Vec::<f64>::with_capacity(2 * capacity); |
||||
|
for (i, e) in arr.slice(s![.., 0]).iter().enumerate() { |
||||
|
if i < 2 || i > n - 3 { |
||||
|
vec.push(*e); |
||||
|
} else if i == 2 { |
||||
|
vec.push(arr[[i - 1, 0]]); |
||||
|
vec.push(arr[[i, 1]]); |
||||
|
vec.push((*e + arr[[i + 1, 0]]) / 2.); |
||||
|
} else if i == n - 3 { |
||||
|
vec.push((arr[[i - 1, 0]] + *e) / 2.); |
||||
|
vec.push(arr[[i, 1]]); |
||||
|
vec.push(arr[[i + 1, 0]]); |
||||
|
} else { |
||||
|
vec.push((arr[[i - 1, 0]] + *e) / 2.); |
||||
|
vec.push(arr[[i, 1]]); |
||||
|
vec.push((*e + arr[[i + 1, 0]]) / 2.); |
||||
|
} |
||||
|
vec.push(arr[[i, 1]]); |
||||
|
} |
||||
|
|
||||
|
*arr = Array2::<f64>::from_shape_vec((capacity, 2), vec).unwrap(); |
||||
|
} |
||||
|
|
||||
|
fn modify_samples_unbounded(&self, arr: &mut Array2<f64>) { |
||||
|
let n = arr.shape()[0]; |
||||
|
let capacity = 2 * (n - 1); |
||||
|
let mut vec = Vec::<f64>::with_capacity(2 * capacity); |
||||
|
for (i, e) in arr.slice(s![.., 0]).iter().enumerate() { |
||||
|
if i < 2 { |
||||
|
vec.push(*e); |
||||
|
} else if i == 2 { |
||||
|
vec.push(arr[[i - 1, 0]]); |
||||
|
vec.push(arr[[i, 1]]); |
||||
|
vec.push((*e + arr[[i + 1, 0]]) / 2.); |
||||
|
} else if i == n - 1 { |
||||
|
vec.push((arr[[i - 1, 0]] + *e) / 2.); |
||||
|
vec.push(arr[[i, 1]]); |
||||
|
vec.push(arr[[i, 0]]); |
||||
|
} else { |
||||
|
vec.push((arr[[i - 1, 0]] + *e) / 2.); |
||||
|
vec.push(arr[[i, 1]]); |
||||
|
vec.push((*e + arr[[i + 1, 0]]) / 2.); |
||||
|
} |
||||
|
vec.push(arr[[i, 1]]); |
||||
|
} |
||||
|
|
||||
|
*arr = Array2::<f64>::from_shape_vec((capacity, 2), vec).unwrap(); |
||||
|
} |
||||
|
|
||||
|
/// this should only be used on an array `arr` that has been
|
||||
|
/// processed by self.modify_samples_* first, otherwise the results
|
||||
|
/// will be jibberish.
|
||||
|
fn generate_segments(&self, arr: &mut Array2<f64>) { |
||||
|
let capacity = 3 * arr.shape()[0] / 2; |
||||
|
let mut vec = Vec::<f64>::with_capacity(capacity); |
||||
|
for (i, e) in arr.slice(s![.., 0]).iter().enumerate() { |
||||
|
if i == 0 { |
||||
|
vec.push(e.floor()); |
||||
|
} else if i % 2 == 1 { |
||||
|
vec.push(e.round()); |
||||
|
vec.push(arr[[i, 1]]); |
||||
|
} else { |
||||
|
vec.push(e.round() + 1_f64); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
*arr = Array2::<f64>::from_shape_vec((capacity / 3, 3), vec).unwrap(); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
fn cutoffs(initial_rate: f64, leverage_long: usize, leverage_short: usize) -> (f64, f64, bool) { |
||||
|
let ll_64 = leverage_long as f64; |
||||
|
let ls_64 = leverage_short as f64; |
||||
|
let a = initial_rate * ll_64 / (ll_64 + 1_f64); |
||||
|
if leverage_short == 1 { |
||||
|
let b = 2. * initial_rate; |
||||
|
return (a, b, false); |
||||
|
} |
||||
|
let b = initial_rate * ls_64 / (ls_64 - 1_f64); |
||||
|
|
||||
|
(a, b, true) |
||||
|
} |
||||
|
|
||||
|
fn pool_value( |
||||
|
initial_rate: f64, |
||||
|
n_contracts: usize, |
||||
|
contract_value: f64, |
||||
|
leverage_long: usize, |
||||
|
leverage_short: usize, |
||||
|
) -> f64 { |
||||
|
let ll_64 = leverage_long as f64; |
||||
|
let ls_64 = leverage_short as f64; |
||||
|
let n_64 = n_contracts as f64; |
||||
|
|
||||
|
(n_64 * contract_value / initial_rate) * (1_f64 / ll_64 + 1_f64 / ls_64) |
||||
|
} |
||||
|
|
||||
|
fn create_long_payout_function( |
||||
|
initial_rate: f64, |
||||
|
n_contracts: usize, |
||||
|
contract_value: f64, |
||||
|
leverage_long: usize, |
||||
|
) -> impl Fn(&Array1<f64>) -> Array2<f64> { |
||||
|
let n_64 = n_contracts as f64; |
||||
|
let ll_64 = leverage_long as f64; |
||||
|
|
||||
|
move |t: &Array1<f64>| { |
||||
|
let mut vec = Vec::<f64>::with_capacity(2 * t.len()); |
||||
|
for e in t.iter() { |
||||
|
let eval = (n_64 * contract_value) |
||||
|
* (1_f64 / (initial_rate * ll_64) + (1_f64 / initial_rate - 1_f64 / e)); |
||||
|
vec.push(*e); |
||||
|
vec.push(eval); |
||||
|
} |
||||
|
|
||||
|
Array2::<f64>::from_shape_vec((t.len(), 2), vec).unwrap() |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
#[cfg(test)] |
||||
|
mod tests { |
||||
|
use super::*; |
||||
|
use rust_decimal_macros::dec; |
||||
|
use std::ops::RangeInclusive; |
||||
|
|
||||
|
#[test] |
||||
|
fn test_bounded() { |
||||
|
let initial_rate = 40000.0; |
||||
|
let leverage_long = 5; |
||||
|
let leverage_short = 2; |
||||
|
let n_contracts = 200; |
||||
|
let contract_value = 100.; |
||||
|
|
||||
|
let payout = PayoutCurve::new( |
||||
|
initial_rate, |
||||
|
leverage_long, |
||||
|
leverage_short, |
||||
|
n_contracts, |
||||
|
contract_value, |
||||
|
None, |
||||
|
) |
||||
|
.unwrap(); |
||||
|
|
||||
|
let z = payout.generate_payout_scheme(5000).unwrap(); |
||||
|
|
||||
|
assert!(z.shape()[0] == 5000); |
||||
|
} |
||||
|
|
||||
|
#[test] |
||||
|
fn test_unbounded() { |
||||
|
let initial_rate = 40000.0; |
||||
|
let leverage_long = 5; |
||||
|
let leverage_short = 1; |
||||
|
let n_contracts = 200; |
||||
|
let contract_value = 100.; |
||||
|
|
||||
|
let payout = PayoutCurve::new( |
||||
|
initial_rate, |
||||
|
leverage_long, |
||||
|
leverage_short, |
||||
|
n_contracts, |
||||
|
contract_value, |
||||
|
None, |
||||
|
) |
||||
|
.unwrap(); |
||||
|
|
||||
|
let z = payout.generate_payout_scheme(5000).unwrap(); |
||||
|
|
||||
|
// out-by-one error expected at this point in time
|
||||
|
assert!(z.shape()[0] == 5001); |
||||
|
} |
||||
|
|
||||
|
#[test] |
||||
|
fn calculate_snapshot() { |
||||
|
let actual_payouts = |
||||
|
calculate_payout_parameters(Usd(dec!(54000.00)), Usd(dec!(3500.00)), Leverage(5)) |
||||
|
.unwrap(); |
||||
|
|
||||
|
let expected_payouts = vec![ |
||||
|
payout(0..=45000, 7777777, 0), |
||||
|
payout(45001..=45315, 7750759, 27018), |
||||
|
payout(45316..=45630, 7697244, 80533), |
||||
|
payout(45631..=45945, 7644417, 133359), |
||||
|
payout(45946..=46260, 7592270, 185507), |
||||
|
payout(46261..=46575, 7540793, 236984), |
||||
|
payout(46576..=46890, 7489978, 287799), |
||||
|
payout(46891..=47205, 7439816, 337961), |
||||
|
payout(47206..=47520, 7390298, 387479), |
||||
|
payout(47521..=47835, 7341415, 436362), |
||||
|
payout(47836..=48150, 7293159, 484618), |
||||
|
payout(48151..=48465, 7245520, 532257), |
||||
|
payout(48466..=48780, 7198490, 579287), |
||||
|
payout(48781..=49095, 7152060, 625717), |
||||
|
payout(49096..=49410, 7106222, 671555), |
||||
|
payout(49411..=49725, 7060965, 716812), |
||||
|
payout(49726..=50040, 7016282, 761494), |
||||
|
payout(50041..=50355, 6972164, 805612), |
||||
|
payout(50356..=50670, 6928602, 849174), |
||||
|
payout(50671..=50985, 6885587, 892189), |
||||
|
payout(50986..=51300, 6843111, 934666), |
||||
|
payout(51301..=51615, 6801163, 976613), |
||||
|
payout(51616..=51930, 6759737, 1018040), |
||||
|
payout(51931..=52245, 6718822, 1058955), |
||||
|
payout(52246..=52560, 6678410, 1099367), |
||||
|
payout(52561..=52875, 6638493, 1139284), |
||||
|
payout(52876..=53190, 6599060, 1178716), |
||||
|
payout(53191..=53505, 6560105, 1217672), |
||||
|
payout(53506..=53820, 6521617, 1256160), |
||||
|
payout(53821..=54135, 6483588, 1294189), |
||||
|
payout(54136..=54450, 6446009, 1331768), |
||||
|
payout(54451..=54765, 6408872, 1368905), |
||||
|
payout(54766..=55080, 6372166, 1405610), |
||||
|
payout(55081..=55395, 6335885, 1441892), |
||||
|
payout(55396..=55710, 6300018, 1477758), |
||||
|
payout(55711..=56025, 6264558, 1513219), |
||||
|
payout(56026..=56340, 6229494, 1548282), |
||||
|
payout(56341..=56655, 6194820, 1582957), |
||||
|
payout(56656..=56970, 6160524, 1617253), |
||||
|
payout(56971..=57285, 6126599, 1651177), |
||||
|
payout(57286..=57600, 6093037, 1684740), |
||||
|
payout(57601..=57915, 6059827, 1717949), |
||||
|
payout(57916..=58230, 6026965, 1750812), |
||||
|
payout(58231..=58545, 5994445, 1783332), |
||||
|
payout(58546..=58860, 5962264, 1815512), |
||||
|
payout(58861..=59175, 5930419, 1847358), |
||||
|
payout(59176..=59490, 5898905, 1878872), |
||||
|
payout(59491..=59805, 5867718, 1910059), |
||||
|
payout(59806..=60120, 5836855, 1940922), |
||||
|
payout(60121..=60435, 5806311, 1971465), |
||||
|
payout(60436..=60750, 5776084, 2001693), |
||||
|
payout(60751..=61065, 5746168, 2031608), |
||||
|
payout(61066..=61380, 5716561, 2061216), |
||||
|
payout(61381..=61695, 5687258, 2090519), |
||||
|
payout(61696..=62010, 5658255, 2119522), |
||||
|
payout(62011..=62325, 5629549, 2148228), |
||||
|
payout(62326..=62640, 5601135, 2176642), |
||||
|
payout(62641..=62955, 5573010, 2204767), |
||||
|
payout(62956..=63270, 5545170, 2232607), |
||||
|
payout(63271..=63585, 5517611, 2260165), |
||||
|
payout(63586..=63900, 5490330, 2287447), |
||||
|
payout(63901..=64215, 5463321, 2314455), |
||||
|
payout(64216..=64530, 5436583, 2341194), |
||||
|
payout(64531..=64845, 5410109, 2367667), |
||||
|
payout(64846..=65160, 5383898, 2393879), |
||||
|
payout(65161..=65475, 5357944, 2419833), |
||||
|
payout(65476..=65790, 5332245, 2445532), |
||||
|
payout(65791..=66105, 5306795, 2470982), |
||||
|
payout(66106..=66420, 5281592, 2496185), |
||||
|
payout(66421..=66735, 5256631, 2521146), |
||||
|
payout(66736..=67050, 5231909, 2545868), |
||||
|
payout(67051..=67365, 5207421, 2570356), |
||||
|
payout(67366..=67680, 5183164, 2594612), |
||||
|
payout(67681..=67995, 5159135, 2618642), |
||||
|
payout(67996..=68310, 5135328, 2642449), |
||||
|
payout(68311..=68625, 5111740, 2666037), |
||||
|
payout(68626..=68940, 5088368, 2689409), |
||||
|
payout(68941..=69255, 5065207, 2712569), |
||||
|
payout(69256..=69570, 5042254, 2735523), |
||||
|
payout(69571..=69885, 5019505, 2758272), |
||||
|
payout(69886..=70200, 4996955, 2780821), |
||||
|
payout(70201..=70515, 4974602, 2803175), |
||||
|
payout(70516..=70830, 4952442, 2825335), |
||||
|
payout(70831..=71145, 4930473, 2847304), |
||||
|
payout(71146..=71460, 4908694, 2869083), |
||||
|
payout(71461..=71775, 4887102, 2890675), |
||||
|
payout(71776..=72090, 4865695, 2912081), |
||||
|
payout(72091..=72405, 4844473, 2933304), |
||||
|
payout(72406..=72720, 4823433, 2954344), |
||||
|
payout(72721..=73035, 4802573, 2975204), |
||||
|
payout(73036..=73350, 4781891, 2995886), |
||||
|
payout(73351..=73665, 4761385, 3016391), |
||||
|
payout(73666..=73980, 4741054, 3036722), |
||||
|
payout(73981..=74295, 4720896, 3056881), |
||||
|
payout(74296..=74610, 4700909, 3076868), |
||||
|
payout(74611..=74925, 4681090, 3096686), |
||||
|
payout(74926..=75240, 4661439, 3116338), |
||||
|
payout(75241..=75555, 4641953, 3135824), |
||||
|
payout(75556..=75870, 4622630, 3155146), |
||||
|
payout(75871..=76185, 4603469, 3174307), |
||||
|
payout(76186..=76500, 4584468, 3193309), |
||||
|
payout(76501..=76815, 4565624, 3212153), |
||||
|
payout(76816..=77130, 4546937, 3230840), |
||||
|
payout(77131..=77445, 4528403, 3249374), |
||||
|
payout(77446..=77760, 4510022, 3267755), |
||||
|
payout(77761..=78075, 4491791, 3285986), |
||||
|
payout(78076..=78390, 4473708, 3304068), |
||||
|
payout(78391..=78705, 4455773, 3322004), |
||||
|
payout(78706..=79020, 4437982, 3339795), |
||||
|
payout(79021..=79335, 4420333, 3357443), |
||||
|
payout(79336..=79650, 4402827, 3374950), |
||||
|
payout(79651..=79965, 4385459, 3392318), |
||||
|
payout(79966..=80280, 4368228, 3409548), |
||||
|
payout(80281..=80595, 4351133, 3426643), |
||||
|
payout(80596..=80910, 4334172, 3443605), |
||||
|
payout(80911..=81225, 4317343, 3460434), |
||||
|
payout(81226..=81540, 4300643, 3477134), |
||||
|
payout(81541..=81855, 4284071, 3493705), |
||||
|
payout(81856..=82170, 4267626, 3510151), |
||||
|
payout(82171..=82485, 4251305, 3526472), |
||||
|
payout(82486..=82800, 4235107, 3542670), |
||||
|
payout(82801..=83115, 4219029, 3558748), |
||||
|
payout(83116..=83430, 4203070, 3574707), |
||||
|
payout(83431..=83745, 4187229, 3590547), |
||||
|
payout(83746..=84060, 4171506, 3606271), |
||||
|
payout(84061..=84375, 4155899, 3621878), |
||||
|
payout(84376..=84690, 4140406, 3637371), |
||||
|
payout(84691..=85005, 4125028, 3652749), |
||||
|
payout(85006..=85320, 4109763, 3668014), |
||||
|
payout(85321..=85635, 4094610, 3683167), |
||||
|
payout(85636..=85950, 4079567, 3698209), |
||||
|
payout(85951..=86265, 4064635, 3713142), |
||||
|
payout(86266..=86580, 4049812, 3727965), |
||||
|
payout(86581..=86895, 4035096, 3742680), |
||||
|
payout(86896..=87210, 4020488, 3757289), |
||||
|
payout(87211..=87525, 4005985, 3771792), |
||||
|
payout(87526..=87840, 3991587, 3786189), |
||||
|
payout(87841..=88155, 3977293, 3800484), |
||||
|
payout(88156..=88470, 3963102, 3814675), |
||||
|
payout(88471..=88785, 3949013, 3828764), |
||||
|
payout(88786..=89100, 3935024, 3842753), |
||||
|
payout(89101..=89415, 3921135, 3856642), |
||||
|
payout(89416..=89730, 3907344, 3870432), |
||||
|
payout(89731..=90045, 3893652, 3884125), |
||||
|
payout(90046..=90360, 3880056, 3897721), |
||||
|
payout(90361..=90675, 3866555, 3911221), |
||||
|
payout(90676..=90990, 3853150, 3924627), |
||||
|
payout(90991..=91305, 3839837, 3937940), |
||||
|
payout(91306..=91620, 3826618, 3951159), |
||||
|
payout(91621..=91935, 3813489, 3964287), |
||||
|
payout(91936..=92250, 3800452, 3977325), |
||||
|
payout(92251..=92565, 3787504, 3990273), |
||||
|
payout(92566..=92880, 3774644, 4003133), |
||||
|
payout(92881..=93195, 3761872, 4015905), |
||||
|
payout(93196..=93510, 3749186, 4028591), |
||||
|
payout(93511..=93825, 3736585, 4041192), |
||||
|
payout(93826..=94140, 3724069, 4053708), |
||||
|
payout(94141..=94455, 3711636, 4066141), |
||||
|
payout(94456..=94770, 3699286, 4078491), |
||||
|
payout(94771..=95085, 3687016, 4090760), |
||||
|
payout(95086..=95400, 3674827, 4102949), |
||||
|
payout(95401..=95715, 3662718, 4115059), |
||||
|
payout(95716..=96030, 3650686, 4127091), |
||||
|
payout(96031..=96345, 3638733, 4139044), |
||||
|
payout(96346..=96660, 3626856, 4150920), |
||||
|
payout(96661..=96975, 3615057, 4162720), |
||||
|
payout(96976..=97290, 3603333, 4174444), |
||||
|
payout(97291..=97605, 3591684, 4186093), |
||||
|
payout(97606..=97920, 3580110, 4197666), |
||||
|
payout(97921..=98235, 3568611, 4209166), |
||||
|
payout(98236..=98550, 3557184, 4220593), |
||||
|
payout(98551..=98865, 3545831, 4231946), |
||||
|
payout(98866..=99180, 3534550, 4243227), |
||||
|
payout(99181..=99495, 3523340, 4254437), |
||||
|
payout(99496..=99810, 3512201, 4265576), |
||||
|
payout(99811..=100125, 3501133, 4276644), |
||||
|
payout(100126..=100440, 3490134, 4287643), |
||||
|
payout(100441..=100755, 3479205, 4298572), |
||||
|
payout(100756..=101070, 3468344, 4309433), |
||||
|
payout(101071..=101385, 3457551, 4320226), |
||||
|
payout(101386..=101700, 3446825, 4330952), |
||||
|
payout(101701..=102015, 3436165, 4341611), |
||||
|
payout(102016..=102330, 3425572, 4352205), |
||||
|
payout(102331..=102645, 3415044, 4362732), |
||||
|
payout(102646..=102960, 3404581, 4373196), |
||||
|
payout(102961..=103275, 3394182, 4383594), |
||||
|
payout(103276..=103590, 3383847, 4393930), |
||||
|
payout(103591..=103905, 3373574, 4404202), |
||||
|
payout(103906..=104220, 3363364, 4414412), |
||||
|
payout(104221..=104535, 3353216, 4424561), |
||||
|
payout(104536..=104850, 3343128, 4434648), |
||||
|
payout(104851..=105165, 3333101, 4444675), |
||||
|
payout(105166..=105480, 3323134, 4454643), |
||||
|
payout(105481..=105795, 3313226, 4464551), |
||||
|
payout(105796..=106110, 3303377, 4474400), |
||||
|
payout(106111..=106425, 3293585, 4484192), |
||||
|
payout(106426..=106740, 3283851, 4493926), |
||||
|
payout(106741..=107055, 3274174, 4503603), |
||||
|
payout(107056..=107370, 3264552, 4513225), |
||||
|
payout(107371..=107764, 3254986, 4522790), |
||||
|
payout(107765..=108000, 3240740, 4537037), |
||||
|
]; |
||||
|
|
||||
|
pretty_assertions::assert_eq!(actual_payouts, expected_payouts); |
||||
|
} |
||||
|
|
||||
|
#[test] |
||||
|
fn verfiy_tails() { |
||||
|
let actual_payouts = |
||||
|
calculate_payout_parameters(Usd(dec!(54000.00)), Usd(dec!(3500.00)), Leverage(5)) |
||||
|
.unwrap(); |
||||
|
|
||||
|
let lower_tail = payout(0..=45000, 7777777, 0); |
||||
|
let upper_tail = payout(107765..=108000, 3240740, 4537037); |
||||
|
|
||||
|
pretty_assertions::assert_eq!(actual_payouts.first().unwrap(), &lower_tail); |
||||
|
pretty_assertions::assert_eq!(actual_payouts.last().unwrap(), &upper_tail); |
||||
|
} |
||||
|
|
||||
|
fn payout(range: RangeInclusive<u64>, short: u64, long: u64) -> PayoutParameter { |
||||
|
PayoutParameter { |
||||
|
left_bound: *range.start(), |
||||
|
right_bound: *range.end(), |
||||
|
long_amount: long, |
||||
|
short_amount: short, |
||||
|
} |
||||
|
} |
||||
} |
} |
||||
|
@ -0,0 +1,3 @@ |
|||||
|
## Note |
||||
|
|
||||
|
This codebase is effectively a brute-force copy of the the [Splipy](https://github.com/SINTEF/Splipy) python lib, cherry-picked to do what we need while still remaining consistent with the source material. This is quite a bit of overkill for our purposes, so this will either be refined in the futre **or** a complete translation of the Splipy lib will be created as it's own crate. As such, this code should be considered transitional at best. |
@ -0,0 +1,208 @@ |
|||||
|
use crate::payout_curve::basis_eval::*; |
||||
|
use crate::payout_curve::csr_tools::CSR; |
||||
|
use crate::payout_curve::utils::*; |
||||
|
use crate::payout_curve::Error; |
||||
|
|
||||
|
use core::cmp::max; |
||||
|
use ndarray::prelude::*; |
||||
|
use ndarray::{concatenate, s}; |
||||
|
|
||||
|
#[derive(Clone, Debug)] |
||||
|
pub struct BSplineBasis { |
||||
|
pub knots: Array1<f64>, |
||||
|
pub order: usize, |
||||
|
pub periodic: isize, |
||||
|
pub knot_tol: f64, |
||||
|
} |
||||
|
|
||||
|
impl BSplineBasis { |
||||
|
pub fn new( |
||||
|
order: Option<usize>, |
||||
|
knots: Option<Array1<f64>>, |
||||
|
periodic: Option<isize>, |
||||
|
) -> Result<Self, Error> { |
||||
|
let order = order.unwrap_or(2); |
||||
|
let periodic = periodic.unwrap_or(-1); |
||||
|
let knots = match knots { |
||||
|
Some(knots) => knots, |
||||
|
None => default_knot(order, periodic)?, |
||||
|
}; |
||||
|
let ktol = knot_tolerance(None); |
||||
|
|
||||
|
Ok(BSplineBasis { |
||||
|
order, |
||||
|
knots, |
||||
|
periodic, |
||||
|
knot_tol: ktol, |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
pub fn num_functions(&self) -> usize { |
||||
|
let p = (self.periodic + 1) as usize; |
||||
|
|
||||
|
self.knots.len() - self.order - p |
||||
|
} |
||||
|
|
||||
|
/// Start point of parametric domain. For open knot vectors, this is the
|
||||
|
/// first knot.
|
||||
|
pub fn start(&self) -> f64 { |
||||
|
self.knots[self.order - 1] |
||||
|
} |
||||
|
|
||||
|
/// End point of parametric domain. For open knot vectors, this is the
|
||||
|
/// last knot.
|
||||
|
pub fn end(&self) -> f64 { |
||||
|
self.knots[self.knots.len() - self.order] |
||||
|
} |
||||
|
|
||||
|
/// Fetch greville points, also known as knot averages
|
||||
|
/// over entire knot vector:
|
||||
|
/// .. math:: \\sum_{j=i+1}^{i+p-1} \\frac{t_j}{p-1}
|
||||
|
pub fn greville(&self) -> Array1<f64> { |
||||
|
let n = self.num_functions() as i32; |
||||
|
|
||||
|
(0..n).map(|idx| self.greville_single(idx)).collect() |
||||
|
} |
||||
|
|
||||
|
fn greville_single(&self, index: i32) -> f64 { |
||||
|
let p = self.order as i32; |
||||
|
let den = (self.order - 1) as f64; |
||||
|
|
||||
|
self.knots.slice(s![index + 1..index + p]).sum() / den |
||||
|
} |
||||
|
|
||||
|
/// Evaluate all basis functions in a given set of points.
|
||||
|
/// ## parameters:
|
||||
|
/// * t: The parametric coordinate(s) in which to evaluate
|
||||
|
/// * d: Number of derivatives to compute
|
||||
|
/// * from_right: true if evaluation should be done in the limit from above
|
||||
|
/// ## returns:
|
||||
|
/// * CSR (sparse) matrix N\[i,j\] of all basis functions j evaluated in all points j
|
||||
|
pub fn evaluate(&self, t: &mut Array1<f64>, d: usize, from_right: bool) -> Result<CSR, Error> { |
||||
|
let basis = Basis::new( |
||||
|
self.order, |
||||
|
self.knots.clone().to_owned(), |
||||
|
Some(self.periodic), |
||||
|
Some(self.knot_tol), |
||||
|
); |
||||
|
snap(t, &basis.knots, Some(basis.ktol)); |
||||
|
|
||||
|
if self.order <= d { |
||||
|
let csr = CSR::new( |
||||
|
Array1::<f64>::zeros(0), |
||||
|
Array1::<usize>::zeros(0), |
||||
|
Array1::<usize>::zeros(t.len() + 1), |
||||
|
(t.len(), self.num_functions()), |
||||
|
)?; |
||||
|
|
||||
|
return Ok(csr); |
||||
|
} |
||||
|
|
||||
|
let out = basis.evaluate(t, d, Some(from_right))?; |
||||
|
|
||||
|
Ok(out) |
||||
|
} |
||||
|
|
||||
|
/// Snap evaluation points to knots if they are sufficiently close
|
||||
|
/// as given in by knot_tolerance.
|
||||
|
///
|
||||
|
/// * t: The parametric coordinate(s) in which to evaluate
|
||||
|
pub fn snap(&self, t: &mut Array1<f64>) { |
||||
|
snap(t, &self.knots, Some(self.knot_tol)) |
||||
|
} |
||||
|
|
||||
|
/// Create a knot vector with higher order.
|
||||
|
///
|
||||
|
/// The continuity at the knots are kept unchanged by increasing their
|
||||
|
/// multiplicities.
|
||||
|
///
|
||||
|
/// ### parameters
|
||||
|
/// * amount: relative (polynomial) degree to raise the basis function by
|
||||
|
pub fn raise_order(&mut self, amount: usize) { |
||||
|
if amount > 0 { |
||||
|
let knot_spans_arr = self.knot_spans(true); |
||||
|
let knot_spans = knot_spans_arr.iter().collect::<Vec<_>>(); |
||||
|
let temp = self.knots.clone(); |
||||
|
let mut knots = temp.iter().collect::<Vec<_>>(); |
||||
|
|
||||
|
for _ in 0..amount { |
||||
|
knots.append(&mut knot_spans.clone()); |
||||
|
} |
||||
|
|
||||
|
let mut knots_vec = knots.iter().map(|e| **e).collect::<Vec<_>>(); |
||||
|
knots_vec.sort_by(cmp_f64); |
||||
|
|
||||
|
let knots_arr = Array1::<f64>::from_vec(knots_vec); |
||||
|
|
||||
|
let new_knot; |
||||
|
if self.periodic > -1 { |
||||
|
let n_0 = bisect_left(&knots_arr, &self.start(), knots_arr.len()); |
||||
|
let n_1 = |
||||
|
knot_spans.len() - bisect_left(&knots_arr, &self.end(), knots_arr.len()) - 1; |
||||
|
let mut new_knot_vec = knots[n_0 * amount..n_1 * amount] |
||||
|
.iter() |
||||
|
.map(|e| **e) |
||||
|
.collect::<Vec<_>>(); |
||||
|
new_knot_vec.sort_by(cmp_f64); |
||||
|
new_knot = Array1::<f64>::from_vec(new_knot_vec); |
||||
|
} else { |
||||
|
new_knot = knots_arr; |
||||
|
} |
||||
|
|
||||
|
self.order += amount; |
||||
|
self.knots = new_knot; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
/// Return the set of unique knots in the knot vector.
|
||||
|
///
|
||||
|
/// ### parameters:
|
||||
|
/// * include_ghosts: if knots outside start/end are to be included. These \
|
||||
|
/// knots are used by periodic basis.
|
||||
|
///
|
||||
|
/// ### returns:
|
||||
|
/// * 1-D array of unique knots
|
||||
|
pub fn knot_spans(&self, include_ghosts: bool) -> Array1<f64> { |
||||
|
let p = &self.order; |
||||
|
|
||||
|
// TODO: this is VERY sloppy!
|
||||
|
let mut res: Vec<f64> = vec![]; |
||||
|
if include_ghosts { |
||||
|
res.push(self.knots[0]); |
||||
|
for elem in self.knots.slice(s![1..]).iter() { |
||||
|
if (elem - res[res.len() - 1]).abs() > self.knot_tol { |
||||
|
res.push(*elem); |
||||
|
} |
||||
|
} |
||||
|
} else { |
||||
|
res.push(self.knots[p - 1]); |
||||
|
let klen = self.knots.len(); |
||||
|
for elem in self.knots.slice(s![p - 1..klen - p + 1]).iter() { |
||||
|
if (elem - res[res.len() - 1]).abs() > self.knot_tol { |
||||
|
res.push(*elem); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
Array1::<f64>::from_vec(res) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
fn default_knot(order: usize, periodic: isize) -> Result<Array1<f64>, Error> { |
||||
|
let prd = max(periodic, -1); |
||||
|
let p = (prd + 1) as usize; |
||||
|
let mut knots = concatenate( |
||||
|
Axis(0), |
||||
|
&[ |
||||
|
Array1::<f64>::zeros(order).view(), |
||||
|
Array1::<f64>::ones(order).view(), |
||||
|
], |
||||
|
)?; |
||||
|
|
||||
|
for i in 0..p { |
||||
|
knots[i] = -1.; |
||||
|
knots[2 * order - i - 1] = 2.; |
||||
|
} |
||||
|
|
||||
|
Ok(knots) |
||||
|
} |
@ -0,0 +1,177 @@ |
|||||
|
use crate::payout_curve::csr_tools::CSR; |
||||
|
use crate::payout_curve::utils::*; |
||||
|
use crate::payout_curve::Error; |
||||
|
|
||||
|
use ndarray::prelude::*; |
||||
|
use std::cmp::min; |
||||
|
|
||||
|
#[derive(Clone, Debug)] |
||||
|
pub struct Basis { |
||||
|
pub knots: Array1<f64>, |
||||
|
pub order: usize, |
||||
|
pub periodic: isize, |
||||
|
n_all: usize, |
||||
|
n: usize, |
||||
|
pub start: f64, |
||||
|
pub end: f64, |
||||
|
pub ktol: f64, |
||||
|
} |
||||
|
|
||||
|
impl Basis { |
||||
|
pub fn new( |
||||
|
order: usize, |
||||
|
knots: Array1<f64>, |
||||
|
periodic: Option<isize>, |
||||
|
knot_tol: Option<f64>, |
||||
|
) -> Self { |
||||
|
let p = periodic.unwrap_or(-1); |
||||
|
let n_all = knots.len() - order; |
||||
|
let n = knots.len() - order - ((p + 1) as usize); |
||||
|
let start = knots[order - 1]; |
||||
|
let end = knots[n_all]; |
||||
|
|
||||
|
Basis { |
||||
|
order, |
||||
|
knots, |
||||
|
periodic: p, |
||||
|
n_all, |
||||
|
n, |
||||
|
start, |
||||
|
end, |
||||
|
ktol: knot_tolerance(knot_tol), |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
/// Wrap periodic evaluation into domain
|
||||
|
fn wrap_periodic(&self, t: &mut Array1<f64>, right: &bool) { |
||||
|
for i in 0..t.len() { |
||||
|
if t[i] < self.start || t[i] > self.end { |
||||
|
t[i] = (t[i] - self.start) % (self.end - self.start) + self.start; |
||||
|
} |
||||
|
if (t[i] - self.start).abs() < self.ktol && !right { |
||||
|
t[i] = self.end; |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
pub fn evaluate( |
||||
|
&self, |
||||
|
t: &mut Array1<f64>, |
||||
|
d: usize, |
||||
|
from_right: Option<bool>, |
||||
|
) -> Result<CSR, Error> { |
||||
|
let m = t.len(); |
||||
|
let mut right = from_right.unwrap_or(true); |
||||
|
|
||||
|
if self.periodic >= 0 { |
||||
|
self.wrap_periodic(t, &right); |
||||
|
} |
||||
|
|
||||
|
let mut store = Array1::<f64>::zeros(self.order); |
||||
|
let mut mu; |
||||
|
let mut idx_j; |
||||
|
let mut idx_k; |
||||
|
|
||||
|
let mut data = Array1::<f64>::zeros(m * self.order); |
||||
|
let mut indices = Array1::<usize>::zeros(m * self.order); |
||||
|
let indptr = |
||||
|
Array1::<usize>::from_vec((0..m * self.order + 1).step_by(self.order).collect()); |
||||
|
|
||||
|
for i in 0..m { |
||||
|
right = from_right.unwrap_or(true); |
||||
|
// Special-case the endpoint, so the user doesn't need to
|
||||
|
if (t[i] - self.end).abs() < self.ktol { |
||||
|
right = false; |
||||
|
} |
||||
|
// Skip non-periodic evaluation points outside the domain
|
||||
|
if t[i] < self.start |
||||
|
|| t[i] > self.end |
||||
|
|| ((t[i] - self.start).abs() < self.ktol && !right) |
||||
|
{ |
||||
|
continue; |
||||
|
} |
||||
|
|
||||
|
// mu = index of last non-zero basis function
|
||||
|
if right { |
||||
|
mu = bisect_right(&self.knots, &t[i], self.n_all + self.order); |
||||
|
} else { |
||||
|
mu = bisect_left(&self.knots, &t[i], self.n_all + self.order); |
||||
|
} |
||||
|
mu = min(mu, self.n_all); |
||||
|
|
||||
|
for k in 0..self.order - 1 { |
||||
|
store[k] = 0.; |
||||
|
} |
||||
|
|
||||
|
// the last entry is a dummy-zero which is never used
|
||||
|
store[self.order - 1] = 1.; |
||||
|
|
||||
|
for q in 1..self.order - d { |
||||
|
idx_j = self.order - q - 1; |
||||
|
idx_k = mu - q - 1; |
||||
|
store[idx_j] += store[idx_j + 1] * (self.knots[idx_k + q + 1] - t[i]) |
||||
|
/ (self.knots[idx_k + q + 1] - self.knots[idx_k + 1]); |
||||
|
|
||||
|
for j in self.order - q..self.order - 1 { |
||||
|
// 'i'-index in global knot vector (ref Hughes book pg.21)
|
||||
|
let k = mu - self.order + j; |
||||
|
store[j] = |
||||
|
store[j] * (t[i] - self.knots[k]) / (self.knots[k + q] - self.knots[k]); |
||||
|
store[j] += store[j + 1] * (self.knots[k + q + 1] - t[i]) |
||||
|
/ (self.knots[k + q + 1] - self.knots[k + 1]); |
||||
|
} |
||||
|
idx_j = self.order - 1; |
||||
|
idx_k = mu - 1; |
||||
|
store[idx_j] = store[idx_j] * (t[i] - self.knots[idx_k]) |
||||
|
/ (self.knots[idx_k + q] - self.knots[idx_k]); |
||||
|
} |
||||
|
|
||||
|
for q in self.order - d..self.order { |
||||
|
for j in self.order - q - 1..self.order { |
||||
|
// 'i'-index in global knot vector (ref Hughes book pg.21)
|
||||
|
idx_k = mu - self.order + j; |
||||
|
if j != self.order - q - 1 { |
||||
|
store[j] = |
||||
|
store[j] * (q as f64) / (self.knots[idx_k + q] - self.knots[idx_k]); |
||||
|
} |
||||
|
if j != self.order - 1 { |
||||
|
store[j] -= store[j + 1] * (q as f64) |
||||
|
/ (self.knots[idx_k + q + 1] - self.knots[idx_k + 1]); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
for (j, k) in (i * self.order..(i + 1) * self.order).enumerate() { |
||||
|
data[k] = store[j]; |
||||
|
indices[k] = (mu - self.order + j) % self.n; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
let csr = CSR::new(data, indices, indptr, (m, self.n))?; |
||||
|
|
||||
|
Ok(csr) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
pub fn knot_tolerance(tol: Option<f64>) -> f64 { |
||||
|
tol.unwrap_or(1e-10) |
||||
|
} |
||||
|
|
||||
|
/// Snap evaluation points to knots if they are sufficiently close
|
||||
|
/// as specified by self.ktol
|
||||
|
///
|
||||
|
/// * t: The parametric coordinate(s) in which to evaluate
|
||||
|
/// * knots: knot-vector
|
||||
|
/// * knot_tol: default=1e-10
|
||||
|
pub fn snap(t: &mut Array1<f64>, knots: &Array1<f64>, knot_tol: Option<f64>) { |
||||
|
let ktol = knot_tolerance(knot_tol); |
||||
|
|
||||
|
for j in 0..t.len() { |
||||
|
let i = bisect_left(knots, &t[j], knots.len()); |
||||
|
if i < knots.len() && (knots[i] - t[j]).abs() < ktol { |
||||
|
t[j] = knots[i]; |
||||
|
} else if i > 0 && (knots[i - 1] - t[j]).abs() < ktol { |
||||
|
t[j] = knots[i - 1]; |
||||
|
} |
||||
|
} |
||||
|
} |
@ -0,0 +1,69 @@ |
|||||
|
use nalgebra::{ComplexField, DMatrix, Dynamic, Scalar}; |
||||
|
use ndarray::{Array1, Array2}; |
||||
|
use std::fmt::Debug; |
||||
|
|
||||
|
pub trait ToNAlgebraMatrix<T> { |
||||
|
fn to_nalgebra_matrix(&self) -> DMatrix<T>; |
||||
|
} |
||||
|
|
||||
|
pub trait To1DArray<T> { |
||||
|
fn to_1d_array(&self) -> Result<Array1<T>, NotOneDimensional>; |
||||
|
} |
||||
|
|
||||
|
pub trait To2DArray<T> { |
||||
|
fn to_2d_array(&self) -> Array2<T>; |
||||
|
} |
||||
|
|
||||
|
impl<T: Clone + ComplexField + Scalar + PartialEq + Debug + PartialEq> ToNAlgebraMatrix<T> |
||||
|
for Array1<T> |
||||
|
{ |
||||
|
fn to_nalgebra_matrix(&self) -> DMatrix<T> { |
||||
|
DMatrix::from_row_slice_generic( |
||||
|
Dynamic::new(self.len()), |
||||
|
Dynamic::new(1), |
||||
|
self.to_vec().as_slice(), |
||||
|
) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
impl<T: Clone + ComplexField + Scalar + PartialEq + Debug + PartialEq> ToNAlgebraMatrix<T> |
||||
|
for Array2<T> |
||||
|
{ |
||||
|
fn to_nalgebra_matrix(&self) -> DMatrix<T> { |
||||
|
let flattened = self.rows().into_iter().fold(Vec::new(), |mut acc, next| { |
||||
|
acc.extend(next.into_iter().cloned()); |
||||
|
|
||||
|
acc |
||||
|
}); |
||||
|
|
||||
|
DMatrix::from_row_slice_generic( |
||||
|
Dynamic::new(self.nrows()), |
||||
|
Dynamic::new(self.ncols()), |
||||
|
flattened.as_slice(), |
||||
|
) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
impl<T: Clone + PartialEq + Scalar> To1DArray<T> for DMatrix<T> { |
||||
|
fn to_1d_array(&self) -> Result<Array1<T>, NotOneDimensional> { |
||||
|
if self.ncols() != 1 { |
||||
|
return Err(NotOneDimensional); |
||||
|
} |
||||
|
|
||||
|
Ok(Array1::from_shape_fn(self.nrows(), |index| { |
||||
|
self.get(index).unwrap().clone() |
||||
|
})) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
impl<T: Clone + PartialEq + Scalar> To2DArray<T> for DMatrix<T> { |
||||
|
fn to_2d_array(&self) -> Array2<T> { |
||||
|
Array2::from_shape_fn((self.nrows(), self.ncols()), |indices| { |
||||
|
self.get(indices).unwrap().clone() |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
#[derive(Debug, thiserror::Error)] |
||||
|
#[error("The provided matrix is not one-dimensional and cannot be converted into a 1D array")] |
||||
|
pub struct NotOneDimensional; |
@ -0,0 +1,152 @@ |
|||||
|
use crate::payout_curve::compat::{To1DArray, ToNAlgebraMatrix}; |
||||
|
use crate::payout_curve::Error; |
||||
|
use ndarray::prelude::*; |
||||
|
use std::ops::Mul; |
||||
|
|
||||
|
/// NOTE:
|
||||
|
/// This struct is provided here in this form as nalgebra_sparse
|
||||
|
/// is rather embryonic and incluldes (basically) no solvers
|
||||
|
/// at present. As we only need to be able construct as CSR
|
||||
|
/// matrix and perform multiplication with a (dense) vector, it
|
||||
|
/// seemed to make more sense to define our own rudementary CSR struct
|
||||
|
/// and avoid the bloat of using a crate that will introduce
|
||||
|
/// breaking changes regularly, and we need to write our own
|
||||
|
/// solver regardless.
|
||||
|
#[derive(Clone, Debug, PartialEq)] |
||||
|
#[allow(clippy::upper_case_acronyms)] |
||||
|
pub struct CSR { |
||||
|
pub data: Array1<f64>, |
||||
|
pub indices: Array1<usize>, |
||||
|
pub indptr: Array1<usize>, |
||||
|
pub shape: (usize, usize), |
||||
|
pub nnz: usize, |
||||
|
} |
||||
|
|
||||
|
impl CSR { |
||||
|
pub fn new( |
||||
|
data: Array1<f64>, |
||||
|
indices: Array1<usize>, |
||||
|
indptr: Array1<usize>, |
||||
|
shape: (usize, usize), |
||||
|
) -> Result<Self, Error> { |
||||
|
let major_dim: isize = (indptr.len() as isize) - 1; |
||||
|
let nnz = &data.len(); |
||||
|
|
||||
|
if major_dim > 1 && shape.0 as isize == major_dim { |
||||
|
Result::Ok(CSR { |
||||
|
data, |
||||
|
indices, |
||||
|
indptr, |
||||
|
shape, |
||||
|
nnz: *nnz, |
||||
|
}) |
||||
|
} else { |
||||
|
Result::Err(Error::CannotInitCSR) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// matrix version of `solve()`; useful for solving AX = B. Implementation
|
||||
|
// is horrible.
|
||||
|
pub fn matrix_solve(&self, b_arr: &Array2<f64>) -> Result<Array2<f64>, Error> { |
||||
|
let a_arr = self.todense(); |
||||
|
let ncols = b_arr.shape()[1]; |
||||
|
let mut temp = (0..ncols) |
||||
|
.rev() |
||||
|
.map(|e| { |
||||
|
let b = b_arr.slice(s![.., e]).to_owned(); |
||||
|
|
||||
|
let sol = lu_solve(&a_arr, &b).unwrap(); |
||||
|
Ok(sol.to_vec()) |
||||
|
}) |
||||
|
.collect::<Result<Vec<_>, Error>>()?; |
||||
|
|
||||
|
let nrows = temp[0].len(); |
||||
|
let mut raveled = Vec::with_capacity(nrows * temp.len()); |
||||
|
|
||||
|
for _ in 0..nrows { |
||||
|
for vec in &mut temp { |
||||
|
raveled.push(vec.pop().unwrap()); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
raveled.reverse(); |
||||
|
|
||||
|
let result = Array2::<f64>::from_shape_vec((nrows, ncols), raveled)?.to_owned(); |
||||
|
|
||||
|
Ok(result) |
||||
|
} |
||||
|
|
||||
|
pub fn todense(&self) -> Array2<f64> { |
||||
|
let mut out = Array2::<f64>::zeros(self.shape); |
||||
|
for i in 0..self.shape.0 { |
||||
|
for j in self.indptr[i]..self.indptr[i + 1] { |
||||
|
out[[i, self.indices[j]]] += self.data[j]; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
out |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
fn lu_solve(a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>, Error> { |
||||
|
let a = a.to_nalgebra_matrix().lu(); |
||||
|
let b = b.to_nalgebra_matrix(); |
||||
|
|
||||
|
let x = a |
||||
|
.solve(&b) |
||||
|
.ok_or(Error::MatrixMustBeSquare)? |
||||
|
.to_1d_array()?; |
||||
|
|
||||
|
Ok(x) |
||||
|
} |
||||
|
|
||||
|
impl Mul<&Array1<f64>> for CSR { |
||||
|
type Output = Array1<f64>; |
||||
|
|
||||
|
fn mul(self, rhs: &Array1<f64>) -> Array1<f64> { |
||||
|
let mut out = Array1::<f64>::zeros(self.shape.0); |
||||
|
for i in 0..self.shape.0 { |
||||
|
for j in self.indptr[i]..self.indptr[i + 1] { |
||||
|
out[i] += self.data[j] * rhs[self.indices[j]]; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
out |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
#[cfg(test)] |
||||
|
mod tests { |
||||
|
use super::*; |
||||
|
|
||||
|
#[test] |
||||
|
fn test_lu_solve() { |
||||
|
let a = Array2::<f64>::from(vec![[11., 12., 0.], [0., 22., 23.], [31., 0., 33.]]); |
||||
|
let b = Array1::<f64>::from_vec(vec![35., 113., 130.]); |
||||
|
let x_expected = Array1::<f64>::from_vec(vec![1., 2., 3.]); |
||||
|
|
||||
|
let x = lu_solve(&a, &b).unwrap(); |
||||
|
|
||||
|
for (x, expected) in x.into_iter().zip(x_expected) { |
||||
|
assert!( |
||||
|
(x - expected).abs() < f64::EPSILON * 10., |
||||
|
"{} {}", |
||||
|
x, |
||||
|
expected |
||||
|
) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
#[test] |
||||
|
fn negative_csr_test_00() { |
||||
|
let a = CSR::new( |
||||
|
Array1::<f64>::zeros(0), |
||||
|
Array1::<usize>::zeros(0), |
||||
|
Array1::<usize>::zeros(11), |
||||
|
(1, 3), |
||||
|
) |
||||
|
.unwrap_err(); |
||||
|
|
||||
|
assert!(matches!(a, Error::CannotInitCSR)); |
||||
|
} |
||||
|
} |
@ -0,0 +1,453 @@ |
|||||
|
use crate::payout_curve::basis::BSplineBasis; |
||||
|
use crate::payout_curve::splineobject::SplineObject; |
||||
|
use crate::payout_curve::utils::*; |
||||
|
use crate::payout_curve::Error; |
||||
|
|
||||
|
use ndarray::prelude::*; |
||||
|
use ndarray::s; |
||||
|
use std::cmp::max; |
||||
|
|
||||
|
fn default_basis() -> Result<Vec<BSplineBasis>, Error> { |
||||
|
let out = vec![BSplineBasis::new(None, None, None)?]; |
||||
|
Ok(out) |
||||
|
} |
||||
|
|
||||
|
#[derive(Clone, Debug)] |
||||
|
pub struct Curve { |
||||
|
pub spline: SplineObject, |
||||
|
} |
||||
|
|
||||
|
/// Represents a curve: an object with a one-dimensional parameter space.
|
||||
|
impl Curve { |
||||
|
/// Construct a curve with the given basis and control points. In theory,
|
||||
|
/// the curve could be defined in some Euclidean space of dimension N, but
|
||||
|
/// for the moment only curves in E^2 are supported via hard-coding. That is,
|
||||
|
/// any valid basis set can be passed in when instantiating this object,
|
||||
|
/// but only the first one is every considered in the methods provided.
|
||||
|
///
|
||||
|
/// The default is to create a linear one-element mapping from (0,1) to the
|
||||
|
/// unit interval.
|
||||
|
///
|
||||
|
/// ### parameters
|
||||
|
/// * basis: The underlying B-Spline basis
|
||||
|
/// * controlpoints: An *n* × *d* matrix of control points
|
||||
|
/// * rational: Whether the curve is rational (in which case the
|
||||
|
/// control points are interpreted as pre-multiplied with the weight,
|
||||
|
/// which is the last coordinate)
|
||||
|
pub fn new( |
||||
|
bases: Option<Vec<BSplineBasis>>, |
||||
|
controlpoints: Option<Array2<f64>>, |
||||
|
rational: Option<bool>, |
||||
|
) -> Result<Self, Error> { |
||||
|
let bases = bases.unwrap_or(default_basis()?); |
||||
|
let spline = SplineObject::new(bases, controlpoints, rational)?; |
||||
|
|
||||
|
Ok(Curve { spline }) |
||||
|
} |
||||
|
|
||||
|
/// Evaluate the object at given parametric values.
|
||||
|
///
|
||||
|
/// ### parameters
|
||||
|
/// * t: collection of parametric coordinates in which to evaluate
|
||||
|
/// Realistically, this should actually be an Array1 object, but for
|
||||
|
/// consistency with the underlying SplineObject methods the collection
|
||||
|
/// is used instead.
|
||||
|
///
|
||||
|
/// ### returns
|
||||
|
/// * 2D array
|
||||
|
// pub fn evaluate(&self, t: &mut &[Array1<f64>], tensor: bool) -> Result<ArrayD<f64>, Error> {
|
||||
|
pub fn evaluate(&self, t: &mut &[Array1<f64>]) -> Result<Array2<f64>, Error> { |
||||
|
self.spline.validate_domain(t)?; |
||||
|
|
||||
|
let mut tx = t[0].clone().to_owned(); |
||||
|
let n_csr = self.spline.bases[0].evaluate(&mut tx, 0, true)?; |
||||
|
|
||||
|
// kludge...
|
||||
|
let n0 = self.spline.controlpoints.shape()[0]; |
||||
|
let n1 = self.spline.controlpoints.shape()[1]; |
||||
|
let flat_controlpoints = self |
||||
|
.ravel(&self.spline.controlpoints) |
||||
|
.into_raw_vec() |
||||
|
.to_vec(); |
||||
|
|
||||
|
let arr = Array2::<f64>::from_shape_vec((n0, n1), flat_controlpoints)?; |
||||
|
let mut result = n_csr.todense().dot(&arr); |
||||
|
|
||||
|
// if the spline is rational, we apply the weights and omit the weights column
|
||||
|
if self.spline.rational { |
||||
|
let wpos = result.shape()[1] - 1; |
||||
|
let weights = &&result.slice(s![.., wpos]); |
||||
|
let mut temp = Array2::<f64>::zeros((result.shape()[0], wpos)); |
||||
|
|
||||
|
for i in 0..self.spline.dimension { |
||||
|
{ |
||||
|
let mut slice = temp.slice_mut(s![.., i]); |
||||
|
let update = result.slice(s![.., i]).to_owned() / weights.to_owned(); |
||||
|
slice.assign(&update.view()); |
||||
|
} |
||||
|
} |
||||
|
result = temp; |
||||
|
} |
||||
|
|
||||
|
Ok(result) |
||||
|
} |
||||
|
|
||||
|
/// Extend the curve by merging another curve to the end of it.
|
||||
|
///
|
||||
|
/// The curves are glued together in a C0 fashion with enough repeated
|
||||
|
/// knots. The function assumes that the end of this curve perfectly
|
||||
|
/// matches the start of the input curve.
|
||||
|
///
|
||||
|
/// Obviously, neither curve can be periodic, which enables us
|
||||
|
/// to assume all knot vectors are open.
|
||||
|
///
|
||||
|
/// ### parameters
|
||||
|
/// * curve: Another curve
|
||||
|
///
|
||||
|
/// ### returns
|
||||
|
pub fn append(&mut self, othercurve: Curve) -> Result<(), Error> { |
||||
|
if self.spline.bases[0].periodic > -1 || othercurve.spline.bases[0].periodic > -1 { |
||||
|
return Result::Err(Error::CannotConnectPeriodicCurves); |
||||
|
}; |
||||
|
|
||||
|
let mut extending_curve = othercurve; |
||||
|
|
||||
|
// make sure both are in the same space, and (if needed) have rational weights
|
||||
|
self.spline |
||||
|
.make_splines_compatible(&mut extending_curve.spline)?; |
||||
|
let p1 = self.spline.order(0)?[0]; |
||||
|
let p2 = extending_curve.spline.order(0)?[0]; |
||||
|
|
||||
|
if p1 < p2 { |
||||
|
self.raise_order(p2 - p1)?; |
||||
|
} else { |
||||
|
extending_curve.raise_order(p1 - p2)?; |
||||
|
} |
||||
|
|
||||
|
let p = max(p1, p2); |
||||
|
|
||||
|
let old_knot = self.spline.knots(0, Some(true))?[0].clone(); |
||||
|
let mut add_knot = extending_curve.spline.knots(0, Some(true))?[0].clone(); |
||||
|
add_knot -= add_knot[0]; |
||||
|
add_knot += old_knot[old_knot.len() - 1]; |
||||
|
|
||||
|
let mut new_knot = Array1::<f64>::zeros(add_knot.len() + old_knot.len() - p - 1); |
||||
|
{ |
||||
|
let mut slice = new_knot.slice_mut(s![..old_knot.len() - 1]); |
||||
|
let update = old_knot.slice(s![..old_knot.len() - 1]); |
||||
|
slice.assign(&update.view()); |
||||
|
} |
||||
|
{ |
||||
|
let mut slice = new_knot.slice_mut(s![old_knot.len() - 1..]); |
||||
|
let update = add_knot.slice(s![p..]); |
||||
|
slice.assign(&update.view()); |
||||
|
} |
||||
|
|
||||
|
let rational = self.spline.rational as usize; |
||||
|
let n1 = self.spline.controlpoints.shape()[0]; |
||||
|
let n2 = extending_curve.spline.controlpoints.shape()[0]; |
||||
|
let n3 = self.spline.dimension + rational; |
||||
|
let mut new_controlpoints = Array2::<f64>::zeros((n1 + n2 - 1, n3)); |
||||
|
{ |
||||
|
let mut slice = new_controlpoints.slice_mut(s![..n1, ..]); |
||||
|
let update = self.spline.controlpoints.slice(s![.., ..]); |
||||
|
slice.assign(&update.view()); |
||||
|
} |
||||
|
{ |
||||
|
let mut slice = new_controlpoints.slice_mut(s![n1.., ..]); |
||||
|
let update = extending_curve.spline.controlpoints.slice(s![1.., ..]); |
||||
|
slice.assign(&update.view()); |
||||
|
} |
||||
|
|
||||
|
self.spline.bases = vec![BSplineBasis::new(Some(p), Some(new_knot), None)?]; |
||||
|
self.spline.controlpoints = new_controlpoints.into_dyn().to_owned(); |
||||
|
|
||||
|
Ok(()) |
||||
|
} |
||||
|
|
||||
|
/// Raise the polynomial order of the curve.
|
||||
|
///
|
||||
|
/// ### parameters
|
||||
|
/// * amount: Number of times to raise the order
|
||||
|
pub fn raise_order(&mut self, amount: usize) -> Result<(), Error> { |
||||
|
if amount == 0 { |
||||
|
return Ok(()); |
||||
|
} |
||||
|
|
||||
|
// work outside of self, copy back in at the end
|
||||
|
let mut new_basis = self.spline.bases[0].clone(); |
||||
|
new_basis.raise_order(amount); |
||||
|
|
||||
|
// set up an interpolation problem. This is in projective space,
|
||||
|
// so no problems for rational cases
|
||||
|
let mut interpolation_pts_t = new_basis.greville(); |
||||
|
let n_old = self.spline.bases[0].evaluate(&mut interpolation_pts_t, 0, true)?; |
||||
|
let n_new = new_basis.evaluate(&mut interpolation_pts_t, 0, true)?; |
||||
|
|
||||
|
// Some kludge required to enable .dot(), which doesn't work on dynamic
|
||||
|
// arrays. Surely a better way to do this, but this is quick and dirty
|
||||
|
// and valid since we're in curve land
|
||||
|
let n0 = self.spline.controlpoints.shape()[0]; |
||||
|
let n1 = self.spline.controlpoints.shape()[1]; |
||||
|
let flat_controlpoints = self |
||||
|
.ravel(&self.spline.controlpoints) |
||||
|
.into_raw_vec() |
||||
|
.to_vec(); |
||||
|
|
||||
|
let arr = Array2::<f64>::from_shape_vec((n0, n1), flat_controlpoints)?; |
||||
|
let interpolation_pts_x = n_old.todense().dot(&arr); |
||||
|
let res = n_new.matrix_solve(&interpolation_pts_x)?; |
||||
|
|
||||
|
self.spline.controlpoints = res.into_dyn().to_owned(); |
||||
|
self.spline.bases = vec![new_basis]; |
||||
|
|
||||
|
Ok(()) |
||||
|
} |
||||
|
|
||||
|
/// Computes the euclidian length of the curve in geometric space
|
||||
|
///
|
||||
|
/// .. math:: \\int_{t_0}^{t_1}\\sqrt{x(t)^2 + y(t)^2 + z(t)^2} dt
|
||||
|
///
|
||||
|
/// ### parameters
|
||||
|
/// * t0: lower integration limit
|
||||
|
/// * t1: upper integration limit
|
||||
|
pub fn length(&self, t0: Option<f64>, t1: Option<f64>) -> Result<f64, Error> { |
||||
|
let mut knots = &self.spline.knots(0, Some(false))?[0]; |
||||
|
|
||||
|
// keep only integration boundaries within given start (t0) and stop (t1) interval
|
||||
|
let new_knots_0 = t0 |
||||
|
.map(|t0| { |
||||
|
let i = bisect_left(knots, &t0, knots.len()); |
||||
|
let mut vec = Vec::<f64>::with_capacity(&knots.to_vec()[i..].len() + 1); |
||||
|
vec.push(t0); |
||||
|
for elem in knots.to_vec()[i..].iter() { |
||||
|
vec.push(*elem); |
||||
|
} |
||||
|
Array1::<f64>::from_vec(vec) |
||||
|
}) |
||||
|
.unwrap_or_else(|| knots.to_owned()); |
||||
|
knots = &new_knots_0; |
||||
|
|
||||
|
let new_knots_1 = t1 |
||||
|
.map(|t1| { |
||||
|
let i = bisect_right(knots, &t1, knots.len()); |
||||
|
let mut vec = Vec::<f64>::with_capacity(&knots.to_vec()[..i].len() + 1); |
||||
|
for elem in knots.to_vec()[..i].iter() { |
||||
|
vec.push(*elem); |
||||
|
} |
||||
|
vec.push(t1); |
||||
|
Array1::<f64>::from_vec(vec) |
||||
|
}) |
||||
|
.unwrap_or_else(|| knots.to_owned()); |
||||
|
knots = &new_knots_1; |
||||
|
|
||||
|
let klen = knots.len(); |
||||
|
let gleg = GaussLegendreQuadrature::new(self.spline.order(0)?[0] + 1)?; |
||||
|
|
||||
|
let t = &knots.to_vec()[..klen - 1] |
||||
|
.iter() |
||||
|
.zip(knots.to_vec()[1..].iter()) |
||||
|
.map(|(t0, t1)| (&gleg.sample_points + 1.) / 2. * (t1 - t0) + *t0) |
||||
|
.collect::<Vec<_>>(); |
||||
|
|
||||
|
let w = &knots.to_vec()[..klen - 1] |
||||
|
.iter() |
||||
|
.zip(knots.to_vec()[1..].iter()) |
||||
|
.map(|(t0, t1)| &gleg.weights / 2. * (t1 - t0)) |
||||
|
.collect::<Vec<_>>(); |
||||
|
|
||||
|
let t_flat = self.flattened(&t[..]); |
||||
|
let w_flat = self.flattened(&w[..]); |
||||
|
|
||||
|
let dx = self.derivative(&t_flat, 1, Some(true))?; |
||||
|
let det_j = Array1::<f64>::from_vec( |
||||
|
dx.mapv(|e| e.powi(2)) |
||||
|
.sum_axis(Axis(1)) |
||||
|
.mapv(f64::sqrt) |
||||
|
.iter() |
||||
|
.copied() |
||||
|
.collect::<Vec<_>>(), |
||||
|
); |
||||
|
let out = det_j.dot(&w_flat); |
||||
|
|
||||
|
Ok(out) |
||||
|
} |
||||
|
|
||||
|
fn flattened(&self, vec_arr: &[Array1<f64>]) -> Array1<f64> { |
||||
|
let alloc = vec_arr.iter().fold(0, |sum, e| sum + e.len()); |
||||
|
let mut vec_out = Vec::<f64>::with_capacity(alloc); |
||||
|
for arr in vec_arr.iter() { |
||||
|
for e in arr.to_vec().iter() { |
||||
|
vec_out.push(*e); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
Array1::<f64>::from_vec(vec_out) |
||||
|
} |
||||
|
|
||||
|
/// left here as a private method as it assumes C-contiguous ordering,
|
||||
|
/// which is fine for where we use it here.
|
||||
|
fn ravel(&self, arr: &ArrayD<f64>) -> Array1<f64> { |
||||
|
let alloc = arr.shape().iter().product(); |
||||
|
let mut vec = Vec::<f64>::with_capacity(alloc); |
||||
|
for e in arr.iter() { |
||||
|
vec.push(*e) |
||||
|
} |
||||
|
|
||||
|
Array1::<f64>::from_vec(vec) |
||||
|
} |
||||
|
|
||||
|
/// Evaluate the derivative of the curve at the given parametric values.
|
||||
|
///
|
||||
|
/// This function returns an *n* × *dim* array, where *n* is the number of
|
||||
|
/// evaluation points, and *dim* is the physical dimension of the curve.
|
||||
|
/// **At this point in time, only `dim == 1` works, owing to the provisional
|
||||
|
/// constraints on the struct `Curve` itself.**
|
||||
|
///
|
||||
|
/// ### parameters
|
||||
|
/// * t: Parametric coordinates in which to evaluate
|
||||
|
/// * d: Number of derivatives to compute
|
||||
|
/// * from_right: Evaluation in the limit from right side
|
||||
|
pub fn derivative( |
||||
|
&self, |
||||
|
t: &Array1<f64>, |
||||
|
d: usize, |
||||
|
from_right: Option<bool>, |
||||
|
) -> Result<ArrayD<f64>, Error> { |
||||
|
let from_right = from_right.unwrap_or(true); |
||||
|
|
||||
|
if !self.spline.rational || d < 2 || d > 3 { |
||||
|
let mut tx = &vec![t.clone().to_owned()][..]; |
||||
|
let res = self.spline.derivative(&mut tx, &[d], &[from_right], true)?; |
||||
|
return Ok(res); |
||||
|
} |
||||
|
|
||||
|
// init rusult array
|
||||
|
let mut res = Array2::<f64>::zeros((t.len(), self.spline.dimension)); |
||||
|
|
||||
|
// annoying fix to make the controlpoints not dynamic--implicit
|
||||
|
// assumption of 2D curve only!
|
||||
|
let n0 = self.spline.controlpoints.shape()[0]; |
||||
|
let n1 = self.spline.controlpoints.shape()[1]; |
||||
|
let flat_controlpoints = self |
||||
|
.ravel(&self.spline.controlpoints) |
||||
|
.into_raw_vec() |
||||
|
.to_vec(); |
||||
|
let static_controlpoints = Array2::<f64>::from_shape_vec((n0, n1), flat_controlpoints)?; |
||||
|
let mut t_eval = t.clone(); |
||||
|
let d2 = self.spline.bases[0] |
||||
|
.evaluate(&mut t_eval, 2, from_right)? |
||||
|
.todense() |
||||
|
.dot(&static_controlpoints); |
||||
|
let d1 = self.spline.bases[0] |
||||
|
.evaluate(&mut t_eval, 1, from_right)? |
||||
|
.todense() |
||||
|
.dot(&static_controlpoints); |
||||
|
let d0 = self.spline.bases[0] |
||||
|
.evaluate(&mut t_eval, 0, true)? |
||||
|
.todense() |
||||
|
.dot(&static_controlpoints); |
||||
|
let w0 = &d0.slice(s![.., d0.shape()[1] - 1]).to_owned(); |
||||
|
let w1 = &d1.slice(s![.., d1.shape()[1] - 1]).to_owned(); |
||||
|
let w2 = &d2.slice(s![.., d2.shape()[1] - 1]).to_owned(); |
||||
|
|
||||
|
if d == 2 { |
||||
|
let w0_cube = &w0.mapv(|e| e.powi(3)).to_owned(); |
||||
|
for i in 0..self.spline.dimension { |
||||
|
{ |
||||
|
let update = &((d2.slice(s![.., i]).to_owned() * w0 * w0 |
||||
|
- 2. * w1 |
||||
|
* (d1.slice(s![.., i]).to_owned() * w0 |
||||
|
- d0.slice(s![.., i]).to_owned() * w1) |
||||
|
- d0.slice(s![.., i]).to_owned() * w2 * w1) |
||||
|
/ w0_cube); |
||||
|
|
||||
|
let mut slice = res.slice_mut(s![.., i]); |
||||
|
slice.assign(update); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if d == 3 { |
||||
|
let d3 = self.spline.bases[0] |
||||
|
.evaluate(&mut t_eval, 3, from_right)? |
||||
|
.todense() |
||||
|
.dot(&static_controlpoints); |
||||
|
let w3 = &d3.slice(s![.., d3.shape()[1] - 1]); |
||||
|
let w0_four = w0.mapv(|e| e.powi(6)); |
||||
|
for i in 0..self.spline.dimension { |
||||
|
{ |
||||
|
let h0 = &(d1.slice(s![.., i]).to_owned() * w0 |
||||
|
- d0.slice(s![.., i]).to_owned() * w1); |
||||
|
let h1 = &(d2.slice(s![.., i]).to_owned() * w0 |
||||
|
- d0.slice(s![.., i]).to_owned() * w2); |
||||
|
let h2 = &(d3.slice(s![.., i]).to_owned() * w0 |
||||
|
+ d2.slice(s![.., i]).to_owned() * w1 |
||||
|
- d1.slice(s![.., i]).to_owned() * w2 |
||||
|
- d0.slice(s![.., i]).to_owned() * w3); |
||||
|
let g0 = &(h1 * w0 - 2. * h0 * w1); |
||||
|
let g1 = &(h2 * w0 - 2. * h0 * w2 - h1 * w1); |
||||
|
|
||||
|
let update = (g1 * w0 - 3. * g0 * w1) / &w0_four; |
||||
|
let mut slice = res.slice_mut(s![.., i]); |
||||
|
slice.assign(&update); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
Ok(res.into_dyn().to_owned()) |
||||
|
} |
||||
|
|
||||
|
/// Computes the L2 (squared and per knot span) between this
|
||||
|
/// curve and a target curve as well as the L_infinity error:
|
||||
|
///
|
||||
|
/// .. math:: ||\\boldsymbol{x_h}(t)-\\boldsymbol{x}(t)||_{L^2(t_1,t_2)}^2 = \\int_{t_1}^{t_2}
|
||||
|
/// |\\boldsymbol{x_h}(t)-\\boldsymbol{x}(t)|^2 dt, \\quad \\forall \\;\\text{knots}\\;t_1 <
|
||||
|
/// t_2
|
||||
|
///
|
||||
|
/// .. math:: ||\\boldsymbol{x_h}(t)-\\boldsymbol{x}(t)||_{L^\\infty} = \\max_t
|
||||
|
/// |\\boldsymbol{x_h}(t)-\\boldsymbol{x}(t)|
|
||||
|
///
|
||||
|
/// ### parameters
|
||||
|
/// * function target: callable function which takes as input a vector
|
||||
|
/// of evaluation points t and gives as output a matrix x where
|
||||
|
/// x[i,j] is component j evaluated at point t[i]
|
||||
|
///
|
||||
|
/// ### returns
|
||||
|
/// * L2 error per knot-span
|
||||
|
pub fn error( |
||||
|
&self, |
||||
|
target: impl Fn(&Array1<f64>) -> Array2<f64>, |
||||
|
) -> Result<(Array1<f64>, f64), Error> { |
||||
|
let knots = &self.spline.knots(0, Some(false))?[0]; |
||||
|
let n = self.spline.order(0)?[0]; |
||||
|
let gleg = GaussLegendreQuadrature::new(n + 1)?; |
||||
|
|
||||
|
let mut error_l2 = Vec::with_capacity(knots.len() - 1); |
||||
|
let mut error_linf = Vec::with_capacity(knots.len() - 1); |
||||
|
|
||||
|
for (t0, t1) in knots.to_vec()[..knots.len() - 1] |
||||
|
.iter() |
||||
|
.zip(&mut knots.to_vec()[1..].iter()) |
||||
|
{ |
||||
|
let tg = (&gleg.sample_points + 1.) / 2. * (t1 - t0) + *t0; |
||||
|
let eval = vec![tg.clone()]; |
||||
|
let wg = &gleg.weights / 2. * (t1 - t0); |
||||
|
|
||||
|
let exact = target(&tg); |
||||
|
let error = self.evaluate(&mut &eval[..])? - exact; |
||||
|
let error_2 = &error.mapv(|e| e.powi(2)).sum_axis(Axis(1)); |
||||
|
let error_abs = &error_2.mapv(|e| e.sqrt()); |
||||
|
|
||||
|
let l2_val = error_2.dot(&wg); |
||||
|
let linf_val = error_abs.iter().copied().fold(f64::NEG_INFINITY, f64::max); |
||||
|
|
||||
|
error_l2.push(l2_val); |
||||
|
error_linf.push(linf_val); |
||||
|
} |
||||
|
|
||||
|
let out_inf = error_linf.iter().copied().fold(f64::NEG_INFINITY, f64::max); |
||||
|
|
||||
|
Ok((Array1::<f64>::from_vec(error_l2), out_inf)) |
||||
|
} |
||||
|
} |
@ -0,0 +1,159 @@ |
|||||
|
use crate::payout_curve::basis::BSplineBasis; |
||||
|
use crate::payout_curve::curve::Curve; |
||||
|
use crate::payout_curve::utils::cmp_f64; |
||||
|
use crate::payout_curve::Error; |
||||
|
|
||||
|
use ndarray::prelude::*; |
||||
|
|
||||
|
/// Perform general spline interpolation on a provided basis.
|
||||
|
///
|
||||
|
/// ### parameters
|
||||
|
/// * x: Matrix *X\[i,j\]* of interpolation points *x_i* with components *j*
|
||||
|
/// * basis: Basis on which to interpolate
|
||||
|
/// * t: parametric values at interpolation points; defaults to
|
||||
|
/// Greville points if not provided
|
||||
|
///
|
||||
|
/// ### returns
|
||||
|
/// * Interpolated curve
|
||||
|
pub fn interpolate( |
||||
|
x: &Array2<f64>, |
||||
|
basis: &BSplineBasis, |
||||
|
t: Option<Array1<f64>>, |
||||
|
) -> Result<Curve, Error> { |
||||
|
let mut t = t.unwrap_or_else(|| basis.greville()); |
||||
|
let evals = basis.evaluate(&mut t, 0, true)?; |
||||
|
let controlpoints = evals.matrix_solve(x)?; |
||||
|
let out = Curve::new(Some(vec![basis.clone()]), Some(controlpoints), None)?; |
||||
|
|
||||
|
Ok(out) |
||||
|
} |
||||
|
|
||||
|
/// Computes an interpolation for a parametric curve up to a specified
|
||||
|
/// tolerance. The method will iteratively refine parts where needed
|
||||
|
/// resulting in a non-uniform knot vector with as optimized knot
|
||||
|
/// locations as possible.
|
||||
|
///
|
||||
|
/// ### parameters
|
||||
|
/// * x: callable function `x: t --> (t, x(t))` which takes as input a vector
|
||||
|
/// of evaluation points `t` and gives as output a matrix `x` where `x\[i,j\]`
|
||||
|
/// is component `j` evaluated at point `t\[i\]`
|
||||
|
/// * t0: start of parametric domain
|
||||
|
/// * t1: end of parametric domain
|
||||
|
/// * rtol: relative tolerance for stopping criterium. It is defined to be
|
||||
|
/// `||e||_L2 / D`, where `D` is the length of the curve and `||e||_L2` is
|
||||
|
/// the L2-error (see Curve.error)
|
||||
|
/// * atol: absolute tolerance for stopping criterium. It is defined to be
|
||||
|
/// the maximal distance between the curve approximation and the exact curve
|
||||
|
///
|
||||
|
/// ### returns
|
||||
|
/// Curve (NURBS)
|
||||
|
pub fn fit( |
||||
|
x: impl Fn(&Array1<f64>) -> Array2<f64>, |
||||
|
t0: f64, |
||||
|
t1: f64, |
||||
|
rtol: Option<f64>, |
||||
|
atol: Option<f64>, |
||||
|
) -> Result<Curve, Error> { |
||||
|
let rtol = rtol.unwrap_or(1e-4); |
||||
|
let atol = atol.unwrap_or(0.0); |
||||
|
|
||||
|
let knot_vector = Array1::<f64>::from_vec(vec![t0, t0, t0, t0, t1, t1, t1, t1]); |
||||
|
let b = BSplineBasis::new(Some(4), Some(knot_vector), None)?; |
||||
|
let t = b.greville(); |
||||
|
let exact = &x(&t); |
||||
|
|
||||
|
let mut crv = interpolate(exact, &b, Some(t))?; |
||||
|
let err = crv.error(&x)?; |
||||
|
|
||||
|
// polynomial input (which can be exactly represented) only use one knot span
|
||||
|
if err.1 < 1e-13 { |
||||
|
return Ok(crv); |
||||
|
} |
||||
|
|
||||
|
// for all other curves, start with 4 knot spans
|
||||
|
let mut knot_vec = Vec::<f64>::with_capacity(12); |
||||
|
for _ in 0..4 { |
||||
|
knot_vec.push(t0) |
||||
|
} |
||||
|
for i in 0..4 { |
||||
|
let i_64 = (i + 1) as f64; |
||||
|
let val = i_64 / 5. * (t1 - t0) + t0; |
||||
|
knot_vec.push(val); |
||||
|
} |
||||
|
for _ in 0..4 { |
||||
|
knot_vec.push(t1) |
||||
|
} |
||||
|
let knot_vector = Array1::<f64>::from_vec(knot_vec.clone()); |
||||
|
let b = BSplineBasis::new(Some(4), Some(knot_vector), None)?; |
||||
|
let t = b.greville(); |
||||
|
let exact = &x(&t); |
||||
|
|
||||
|
crv = interpolate(exact, &b, Some(t))?; |
||||
|
let err = crv.error(&x)?; |
||||
|
let mut err_l2 = err.0; |
||||
|
let mut err_max = err.1; |
||||
|
|
||||
|
// this is technically false since we need the length of the target function *x*
|
||||
|
// and not our approximation *crv*, but we don't have the derivative of *x*, so
|
||||
|
// we can't compute it. This seems like a healthy compromise
|
||||
|
let length = crv.length(None, None)?; |
||||
|
let mut target = (err_l2.sum() / length).sqrt(); |
||||
|
|
||||
|
// conv_order = 4
|
||||
|
// square_conv_order = 2 * conv_order
|
||||
|
// scale = square_conv_order + 4
|
||||
|
let scale_64 = 12_f64; |
||||
|
|
||||
|
while target > rtol && err_max > atol { |
||||
|
let knot_span = &crv.spline.knots(0, Some(false))?[0]; |
||||
|
let target_error = (rtol * length).powi(2) / err_l2.len() as f64; |
||||
|
for i in 0..err_l2.len() { |
||||
|
// figure out how many new knots we require in this knot interval:
|
||||
|
// if we converge with *scale* and want an error of *target_error*
|
||||
|
// |e|^2 * (1/n)^scale = target_error^2
|
||||
|
let n = ((err_l2[i].ln() - target_error.ln()) / scale_64) |
||||
|
.exp() |
||||
|
.ceil() as usize; |
||||
|
|
||||
|
// add *n* new interior knots to this knot span
|
||||
|
let new_knots = Array1::<f64>::linspace(knot_span[i], knot_span[i + 1], n + 1); |
||||
|
for e in new_knots.slice(s![1..new_knots.len() - 1]).iter() { |
||||
|
knot_vec.push(*e); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// build new refined knot vector
|
||||
|
knot_vec.sort_by(cmp_f64); |
||||
|
let knot_vector = Array1::<f64>::from_vec(knot_vec.clone()); |
||||
|
let b = BSplineBasis::new(Some(4), Some(knot_vector), None)?; |
||||
|
|
||||
|
// do interpolation and return result
|
||||
|
let t = b.greville(); |
||||
|
let exact = &x(&t); |
||||
|
|
||||
|
crv = interpolate(exact, &b, Some(t))?; |
||||
|
let err = crv.error(&x)?; |
||||
|
err_l2 = err.0; |
||||
|
err_max = err.1; |
||||
|
target = err_l2.sum().sqrt() / length; |
||||
|
} |
||||
|
|
||||
|
Ok(crv) |
||||
|
} |
||||
|
|
||||
|
/// Create a line between two points.
|
||||
|
///
|
||||
|
/// ### parameters
|
||||
|
/// * a, b: start and end points (resp.)
|
||||
|
/// * relative: whether `b` is relative to `a` or absolute
|
||||
|
pub fn line(a: (f64, f64), b: (f64, f64), relative: bool) -> Result<Curve, Error> { |
||||
|
let vec; |
||||
|
if relative { |
||||
|
vec = vec![[a.0, a.1], [a.0 + b.0, a.1 + b.1]]; |
||||
|
} else { |
||||
|
vec = vec![[a.0, a.1], [b.0, b.1]]; |
||||
|
} |
||||
|
let controlpoints = Array2::<f64>::from(vec); |
||||
|
|
||||
|
Curve::new(None, Some(controlpoints), None) |
||||
|
} |
@ -0,0 +1,494 @@ |
|||||
|
use crate::payout_curve::basis::BSplineBasis; |
||||
|
use crate::payout_curve::csr_tools::CSR; |
||||
|
use crate::payout_curve::Error; |
||||
|
use itertools::Itertools; |
||||
|
use ndarray::prelude::*; |
||||
|
use ndarray::{concatenate, Order}; |
||||
|
use ndarray_einsum_beta::{einsum, tensordot}; |
||||
|
use std::collections::HashMap; |
||||
|
|
||||
|
#[derive(Clone, Debug)] |
||||
|
pub struct SplineObject { |
||||
|
pub bases: Vec<BSplineBasis>, |
||||
|
pub controlpoints: ArrayD<f64>, |
||||
|
pub dimension: usize, |
||||
|
pub rational: bool, |
||||
|
pub pardim: usize, |
||||
|
} |
||||
|
|
||||
|
/// Master struct for spline objects with arbitrary dimensions.
|
||||
|
///
|
||||
|
/// This class should be composed instead of used directly.
|
||||
|
impl SplineObject { |
||||
|
pub fn new( |
||||
|
bases: Vec<BSplineBasis>, |
||||
|
controlpoints: Option<Array2<f64>>, |
||||
|
rational: Option<bool>, |
||||
|
) -> Result<Self, Error> { |
||||
|
let mut controlpoints = match controlpoints { |
||||
|
Some(controlpoints) => controlpoints, |
||||
|
None => default_control_points(&bases)?, |
||||
|
}; |
||||
|
let rational = rational.unwrap_or(false); |
||||
|
|
||||
|
if controlpoints.slice(s![0, ..]).shape()[0] == 1 { |
||||
|
controlpoints = concatenate( |
||||
|
Axis(1), |
||||
|
&[ |
||||
|
controlpoints.view(), |
||||
|
Array1::<f64>::zeros(controlpoints.shape()[0]) |
||||
|
.insert_axis(Axis(1)) |
||||
|
.view(), |
||||
|
], |
||||
|
)?; |
||||
|
} |
||||
|
|
||||
|
if rational { |
||||
|
controlpoints = concatenate( |
||||
|
Axis(1), |
||||
|
&[ |
||||
|
controlpoints.view(), |
||||
|
Array1::<f64>::ones(controlpoints.shape()[0]) |
||||
|
.insert_axis(Axis(1)) |
||||
|
.view(), |
||||
|
], |
||||
|
)?; |
||||
|
} |
||||
|
|
||||
|
let dim = controlpoints.shape()[1] - (rational as usize); |
||||
|
let bases_shape = determine_shape(&bases)?; |
||||
|
let ncomps = dim + (rational as usize); |
||||
|
let cpts_shaped = reshaper(controlpoints, bases_shape, ncomps)?; |
||||
|
let pardim = cpts_shaped.shape().len() - 1; |
||||
|
|
||||
|
Ok(SplineObject { |
||||
|
bases, |
||||
|
controlpoints: cpts_shaped, |
||||
|
dimension: dim, |
||||
|
rational, |
||||
|
pardim, |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
/// Check whether the given evaluation parameters are valid
|
||||
|
pub fn validate_domain(&self, t: &[Array1<f64>]) -> Result<(), Error> { |
||||
|
for (basis, params) in self.bases.iter().zip(t.to_owned().iter_mut()) { |
||||
|
if basis.periodic < 0 { |
||||
|
basis.snap(&mut *params); |
||||
|
let p_max = ¶ms.iter().copied().fold(f64::NEG_INFINITY, f64::max); |
||||
|
let p_min = ¶ms.iter().copied().fold(f64::INFINITY, f64::min); |
||||
|
if *p_min < basis.start() || basis.end() < *p_max { |
||||
|
return Result::Err(Error::InvalidDomain); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
Ok(()) |
||||
|
} |
||||
|
|
||||
|
fn tensor_evaluate(&self, eval_bases: &mut [CSR], tensor: bool) -> Result<ArrayD<f64>, Error> { |
||||
|
// KLUDGE!
|
||||
|
// owing to the fact that the conventional ellipsis notation is not yet
|
||||
|
// implemented for einsum, we use this workaround that should cover us.
|
||||
|
// If not, just grow the maps as needed or address the issue:
|
||||
|
// https://github.com/oracleofnj/einsum/issues/6
|
||||
|
let init_map: HashMap<usize, &str> = [ |
||||
|
(2, "ij,jp->ip"), |
||||
|
(3, "ij,jpq->ipq"), |
||||
|
(4, "ij,jpqr->ipqr"), |
||||
|
(5, "ij,jpqrs->ipqrs"), |
||||
|
(6, "ij,jpqrst->ipqrst"), |
||||
|
] |
||||
|
.iter() |
||||
|
.cloned() |
||||
|
.collect(); |
||||
|
|
||||
|
let iter_map: HashMap<usize, &str> = [ |
||||
|
(3, "ij,ijp->ip"), |
||||
|
(4, "ij,ijpq->ipq"), |
||||
|
(5, "ij,ijpqr->ipqr"), |
||||
|
(6, "ij,ijpqrs->ipqrs"), |
||||
|
] |
||||
|
.iter() |
||||
|
.cloned() |
||||
|
.collect(); |
||||
|
|
||||
|
let mut out; |
||||
|
if tensor { |
||||
|
eval_bases.reverse(); |
||||
|
let cpts = self.controlpoints.clone().to_owned(); |
||||
|
let idx = eval_bases.len() - 1; |
||||
|
|
||||
|
out = eval_bases.iter().fold(cpts, |e, tns| { |
||||
|
tensordot(&tns.todense(), &e, &[Axis(1)], &[Axis(idx)]) |
||||
|
}); |
||||
|
} else { |
||||
|
let mut pos = 0; |
||||
|
let mut key = self.bases.len() + 1; |
||||
|
let mut val = match init_map.get(&key) { |
||||
|
Some(val) => Ok(val), |
||||
|
_ => Result::Err(Error::NoEinsumOperatorString), |
||||
|
}?; |
||||
|
|
||||
|
out = einsum(val, &[&eval_bases[pos].todense(), &self.controlpoints]) |
||||
|
.map_err(|_| Error::Einsum)?; |
||||
|
|
||||
|
for _ in eval_bases.iter().skip(1) { |
||||
|
pos += 1; |
||||
|
val = match iter_map.get(&key) { |
||||
|
Some(val) => Ok(val), |
||||
|
_ => Result::Err(Error::NoEinsumOperatorString), |
||||
|
}?; |
||||
|
let temp = out.clone().to_owned(); |
||||
|
|
||||
|
out = |
||||
|
einsum(val, &[&eval_bases[pos].todense(), &temp]).map_err(|_| Error::Einsum)?; |
||||
|
|
||||
|
key -= 1; |
||||
|
} |
||||
|
} |
||||
|
// *** END KLUDGE ****
|
||||
|
|
||||
|
Ok(out) |
||||
|
} |
||||
|
|
||||
|
/// Evaluate the derivative of the object at the given parametric values.
|
||||
|
///
|
||||
|
/// If *tensor* is true, evaluation will take place on a tensor product
|
||||
|
/// grid, i.e. it will return an *n1* × *n2* × ... × *dim* array, where
|
||||
|
/// *ni* is the number of evaluation points in direction *i*, and *dim* is
|
||||
|
/// the physical dimension of the object.
|
||||
|
///
|
||||
|
/// If *tensor* is false, there must be an equal number *n* of evaluation
|
||||
|
/// points in all directions, and the return value will be an *n* × *dim*
|
||||
|
/// array.
|
||||
|
///
|
||||
|
/// ### parameters
|
||||
|
/// * t = [u,v,...]: Parametric coordinates in which to evaluate
|
||||
|
/// * d: Order of derivative to compute, index corresponds to bases index
|
||||
|
/// * from_right: Evaluation in the limit from above; index orresponds to bases index
|
||||
|
/// * tensor: Whether to evaluate on a tensor product grid
|
||||
|
pub fn derivative( |
||||
|
&self, |
||||
|
t: &mut &[Array1<f64>], |
||||
|
d: &[usize], |
||||
|
from_right: &[bool], |
||||
|
tensor: bool, |
||||
|
) -> Result<ArrayD<f64>, Error> { |
||||
|
// check
|
||||
|
let testlen = t.len(); |
||||
|
let ops = [t.len(), d.len(), from_right.len()] |
||||
|
.iter() |
||||
|
.all(|e| *e == testlen); |
||||
|
if !tensor && !ops { |
||||
|
return Result::Err(Error::InvalidDerivative); |
||||
|
} |
||||
|
|
||||
|
self.validate_domain(t)?; |
||||
|
|
||||
|
// Evaluate the derivatives of the corresponding bases at the corresponding points
|
||||
|
// and build the result array
|
||||
|
let mut evals = self |
||||
|
.bases |
||||
|
.iter() |
||||
|
.zip(t.iter().zip(d.iter().zip(from_right.iter()))) |
||||
|
.map(|(b, (t, (d, r)))| { |
||||
|
let mut tx = t.clone(); |
||||
|
let eval = b.evaluate(&mut tx, *d, *r)?; |
||||
|
Ok(eval) |
||||
|
}) |
||||
|
.collect::<Result<Vec<_>, Error>>()?; |
||||
|
let mut result = self.tensor_evaluate(&mut evals, tensor)?; |
||||
|
|
||||
|
// For rational curves, we need to use the quotient rule
|
||||
|
// (n/W)' = (n' W - n W') / W^2 = n'/W - nW'/W^2
|
||||
|
// n'(i) = result[..., i]
|
||||
|
// W'(i) = result[..., -1]
|
||||
|
if self.rational { |
||||
|
if d.iter().sum::<usize>() > 1 { |
||||
|
return Result::Err(Error::DerivativeNotImplemented); |
||||
|
} |
||||
|
|
||||
|
let mut ns = self |
||||
|
.bases |
||||
|
.iter() |
||||
|
.zip(t.iter()) |
||||
|
.map(|(b, t)| { |
||||
|
let mut tx = t.clone(); |
||||
|
let eval = b.evaluate(&mut tx, 0, true)?; |
||||
|
Ok(eval) |
||||
|
}) |
||||
|
.collect::<Result<Vec<_>, Error>>()?; |
||||
|
let non_derivative = self.tensor_evaluate(&mut ns, tensor)?; |
||||
|
|
||||
|
let axis_w = non_derivative.shape().len() - 1; |
||||
|
let idx_w = non_derivative.shape()[axis_w] - 1; |
||||
|
let w = &non_derivative.index_axis(Axis(axis_w), idx_w).to_owned(); |
||||
|
let w_square = &w.mapv(|e| e.powi(2)).to_owned(); |
||||
|
|
||||
|
let axis_r = result.shape().len() - 1; |
||||
|
let idx_r = result.shape()[axis_r] - 1; |
||||
|
let wd = &result.index_axis(Axis(axis_r), idx_r).to_owned(); |
||||
|
|
||||
|
for i in 0..self.dimension { |
||||
|
{ |
||||
|
let update = &(result.index_axis(Axis(axis_r), i).to_owned() / w |
||||
|
- non_derivative.index_axis(Axis(axis_w), i).to_owned() * wd / w_square); |
||||
|
let mut slice = result.index_axis_mut(Axis(axis_r), i); |
||||
|
slice.assign(update); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// delete the last column; some faffing about required to maintain
|
||||
|
// C-contiguous ordering. Probably a much better way to do this...
|
||||
|
let res_shape = &result.shape().iter().copied().collect::<Vec<_>>(); |
||||
|
let mut n_res: usize = res_shape[..axis_r].iter().product(); |
||||
|
n_res *= idx_r; |
||||
|
let idx = (0..axis_r).collect::<Vec<_>>(); |
||||
|
|
||||
|
// this ends up being F-contiguous, every time
|
||||
|
let res_slice = result.select(Axis(axis_r), &idx[..]).to_owned(); |
||||
|
let raveled = res_slice.to_shape(((n_res,), Order::C)).unwrap(); |
||||
|
let fixed = raveled.to_shape((res_slice.shape(), Order::C))?.to_owned(); |
||||
|
|
||||
|
result = fixed; |
||||
|
} |
||||
|
|
||||
|
Ok(result) |
||||
|
} |
||||
|
|
||||
|
/// Return knots vector
|
||||
|
///
|
||||
|
/// If `direction` is given, returns the knots in that direction only.
|
||||
|
/// Otherwise, specifying direction as a negative value returns the
|
||||
|
/// knots of all directions.
|
||||
|
///
|
||||
|
/// ### parameters
|
||||
|
/// * direction: Direction number (axis) in which to get the knots.
|
||||
|
/// * with_multiplicities: If true, return knots with multiplicities \
|
||||
|
/// (i.e. repeated).
|
||||
|
pub fn knots( |
||||
|
&self, |
||||
|
direction: isize, |
||||
|
with_multiplicities: Option<bool>, |
||||
|
) -> Result<Vec<Array1<f64>>, Error> { |
||||
|
let with_multiplicities = with_multiplicities.unwrap_or(false); |
||||
|
let out; |
||||
|
|
||||
|
if direction < 0 { |
||||
|
if with_multiplicities { |
||||
|
out = self |
||||
|
.bases |
||||
|
.iter() |
||||
|
.map(|e| e.knots.clone().to_owned()) |
||||
|
.collect::<Vec<_>>(); |
||||
|
} else { |
||||
|
out = self |
||||
|
.bases |
||||
|
.iter() |
||||
|
.map(|e| e.knot_spans(false).to_owned()) |
||||
|
.collect::<Vec<_>>(); |
||||
|
} |
||||
|
} else { |
||||
|
let p = direction as usize; |
||||
|
if with_multiplicities { |
||||
|
out = (&[self.bases[p].knots.clone().to_owned()]).to_vec(); |
||||
|
} else { |
||||
|
out = (&[self.bases[p].knot_spans(false).to_owned()]).to_vec(); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
Ok(out) |
||||
|
} |
||||
|
|
||||
|
/// This will manipulate one or both to ensure that they are both rational
|
||||
|
/// or nonrational, and that they lie in the same physical space.
|
||||
|
pub fn make_splines_compatible(&mut self, otherspline: &mut SplineObject) -> Result<(), Error> { |
||||
|
if self.rational { |
||||
|
otherspline.force_rational()?; |
||||
|
} else if otherspline.rational { |
||||
|
self.force_rational()?; |
||||
|
} |
||||
|
|
||||
|
if self.dimension > otherspline.dimension { |
||||
|
otherspline.set_dimension(self.dimension)?; |
||||
|
} else { |
||||
|
self.set_dimension(otherspline.dimension)?; |
||||
|
} |
||||
|
|
||||
|
Ok(()) |
||||
|
} |
||||
|
|
||||
|
/// Force a rational representation of the object.
|
||||
|
pub fn force_rational(&mut self) -> Result<(), Error> { |
||||
|
if !self.rational { |
||||
|
self.controlpoints = self.insert_phys(&self.controlpoints, 1f64)?; |
||||
|
self.rational = true; |
||||
|
} |
||||
|
|
||||
|
Ok(()) |
||||
|
} |
||||
|
|
||||
|
/// Sets the physical dimension of the object. If increased, the new
|
||||
|
/// components are set to zero.
|
||||
|
///
|
||||
|
/// ### parameters
|
||||
|
/// * new_dim: New dimension
|
||||
|
pub fn set_dimension(&mut self, new_dim: usize) -> Result<(), Error> { |
||||
|
let mut dim = self.dimension; |
||||
|
|
||||
|
while new_dim > dim { |
||||
|
self.controlpoints = self.insert_phys(&self.controlpoints, 0f64)?; |
||||
|
dim += 1; |
||||
|
} |
||||
|
|
||||
|
while new_dim < dim { |
||||
|
let axis = if self.rational { -2 } else { -1 }; |
||||
|
self.controlpoints = self.delete_phys(&self.controlpoints, axis)?; |
||||
|
dim -= 1; |
||||
|
} |
||||
|
|
||||
|
self.dimension = new_dim; |
||||
|
|
||||
|
Ok(()) |
||||
|
} |
||||
|
|
||||
|
fn insert_phys(&self, arr: &ArrayD<f64>, insert_value: f64) -> Result<ArrayD<f64>, Error> { |
||||
|
let mut arr_shape = arr.shape().to_vec(); |
||||
|
let n = arr_shape[arr_shape.len() - 1]; |
||||
|
let arr_prod = arr_shape.iter().product(); |
||||
|
|
||||
|
let raveled = arr.to_shape(((arr_prod,), Order::C))?; |
||||
|
let mut new_arr = Array1::<f64>::zeros(0); |
||||
|
|
||||
|
for i in (0..raveled.len()).step_by(n) { |
||||
|
let new_row = concatenate( |
||||
|
Axis(0), |
||||
|
&[ |
||||
|
raveled.slice(s![i..i + n]).view(), |
||||
|
(insert_value * Array1::<f64>::ones(1)).view(), |
||||
|
], |
||||
|
)?; |
||||
|
new_arr = concatenate(Axis(0), &[new_arr.view(), new_row.view()])?; |
||||
|
} |
||||
|
|
||||
|
arr_shape[n] += 1; |
||||
|
let out = new_arr.to_shape((&arr_shape[..], Order::C))?.to_owned(); |
||||
|
|
||||
|
Ok(out) |
||||
|
} |
||||
|
|
||||
|
fn delete_phys(&self, arr: &ArrayD<f64>, axis: isize) -> Result<ArrayD<f64>, Error> { |
||||
|
let mut arr_shape = arr.shape().to_vec(); |
||||
|
let n = arr_shape[arr_shape.len() - 1]; |
||||
|
let step = (n as isize + axis) as usize; |
||||
|
let arr_prod = arr_shape.iter().product(); |
||||
|
|
||||
|
let raveled = arr.to_shape(((arr_prod,), Order::C))?; |
||||
|
let mut new_arr = Array1::<f64>::zeros(0); |
||||
|
|
||||
|
for i in (0..raveled.len()).step_by(n) { |
||||
|
let new_row; |
||||
|
if axis < -1 { |
||||
|
let front = raveled.slice(s![i..i + step]).clone().to_owned(); |
||||
|
let tail = raveled.slice(s![i + step + 1..i + n]).clone().to_owned(); |
||||
|
new_row = concatenate(Axis(0), &[front.view(), tail.view()])?; |
||||
|
} else { |
||||
|
new_row = raveled.slice(s![i..i + step]).clone().to_owned(); |
||||
|
} |
||||
|
new_arr = concatenate(Axis(0), &[new_arr.view(), new_row.view()])?; |
||||
|
} |
||||
|
|
||||
|
arr_shape[n - 1] -= 1; |
||||
|
let out = new_arr.to_shape((&arr_shape[..], Order::C))?.to_owned(); |
||||
|
|
||||
|
Ok(out) |
||||
|
} |
||||
|
|
||||
|
/// Return polynomial order (degree + 1).
|
||||
|
///
|
||||
|
/// If `direction` is given, returns the order in that direction only.
|
||||
|
/// Otherwise, specifying direction as a negative value returns the
|
||||
|
/// order of all directions.
|
||||
|
///
|
||||
|
/// ### parameters
|
||||
|
/// * direction: Direction in which to get the order.
|
||||
|
pub fn order(&self, direction: isize) -> Result<Vec<usize>, Error> { |
||||
|
let out; |
||||
|
if direction < 0 { |
||||
|
out = self.bases.iter().map(|e| e.order).collect::<Vec<_>>(); |
||||
|
} else { |
||||
|
let p = direction as usize; |
||||
|
out = (&[self.bases[p].order]).to_vec(); |
||||
|
} |
||||
|
|
||||
|
Ok(out) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
fn default_control_points(bases: &[BSplineBasis]) -> Result<Array2<f64>, Error> { |
||||
|
let mut temp = bases |
||||
|
.iter() |
||||
|
.rev() |
||||
|
.map(|b| { |
||||
|
let mut v = b.greville().into_raw_vec(); |
||||
|
v.reverse(); |
||||
|
v |
||||
|
}) |
||||
|
.multi_cartesian_product() |
||||
|
.collect::<Vec<_>>(); |
||||
|
temp.reverse(); |
||||
|
|
||||
|
// because the above is just a little bit incorrect...
|
||||
|
for elem in temp.iter_mut() { |
||||
|
elem.reverse(); |
||||
|
} |
||||
|
|
||||
|
let mut data = Vec::new(); |
||||
|
let ncols = temp.first().map_or(0, |row| row.len()); |
||||
|
let mut nrows = 0; |
||||
|
|
||||
|
for elem in temp.iter() { |
||||
|
data.extend_from_slice(elem); |
||||
|
nrows += 1; |
||||
|
} |
||||
|
|
||||
|
let out = Array2::from_shape_vec((nrows, ncols), data)?; |
||||
|
|
||||
|
Ok(out) |
||||
|
} |
||||
|
|
||||
|
/// Custom reshaping function to preserve control points of several
|
||||
|
/// dimensions that are stored contiguously.
|
||||
|
///
|
||||
|
/// The return value has shape (*newshape, ncomps), where ncomps is
|
||||
|
/// the number of components per control point, as inferred by the
|
||||
|
/// size of `arr` and the desired shape.
|
||||
|
fn reshaper( |
||||
|
arr: Array2<f64>, |
||||
|
mut newshape: Vec<usize>, |
||||
|
ncomps: usize, |
||||
|
) -> Result<ArrayD<f64>, Error> { |
||||
|
newshape.reverse(); |
||||
|
newshape.push(ncomps); |
||||
|
|
||||
|
let mut spec: Vec<usize> = (0..newshape.len() - 1).collect(); |
||||
|
spec.reverse(); |
||||
|
spec.push(newshape.len() - 1); |
||||
|
|
||||
|
let tmp = arr.to_shape((&newshape[..], Order::C))?; |
||||
|
let tmp = tmp.to_owned().into_dyn(); |
||||
|
|
||||
|
let out = tmp.view().permuted_axes(&spec[..]).to_owned(); |
||||
|
|
||||
|
Ok(out) |
||||
|
} |
||||
|
|
||||
|
fn determine_shape(bases: &[BSplineBasis]) -> Result<Vec<usize>, Error> { |
||||
|
let out = bases |
||||
|
.iter() |
||||
|
.map(|e| e.num_functions()) |
||||
|
.collect::<Vec<usize>>(); |
||||
|
|
||||
|
Ok(out) |
||||
|
} |
@ -0,0 +1,168 @@ |
|||||
|
use crate::payout_curve::Error; |
||||
|
use ndarray::prelude::*; |
||||
|
use ndarray::s; |
||||
|
use std::cmp::Ordering; |
||||
|
use std::f64::consts::PI; |
||||
|
|
||||
|
pub fn bisect_left(arr: &Array1<f64>, val: &f64, mut hi: usize) -> usize { |
||||
|
let mut lo: usize = 0; |
||||
|
while lo < hi { |
||||
|
let mid = (lo + hi) / 2; |
||||
|
if arr[mid] < *val { |
||||
|
lo = mid + 1; |
||||
|
} else { |
||||
|
hi = mid; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
lo |
||||
|
} |
||||
|
|
||||
|
pub fn bisect_right(arr: &Array1<f64>, val: &f64, mut hi: usize) -> usize { |
||||
|
let mut lo: usize = 0; |
||||
|
while lo < hi { |
||||
|
let mid = (lo + hi) / 2; |
||||
|
if *val < arr[mid] { |
||||
|
hi = mid; |
||||
|
} else { |
||||
|
lo = mid + 1; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
lo |
||||
|
} |
||||
|
|
||||
|
pub fn cmp_f64(a: &f64, b: &f64) -> Ordering { |
||||
|
if a < b { |
||||
|
return Ordering::Less; |
||||
|
} else if a > b { |
||||
|
return Ordering::Greater; |
||||
|
} |
||||
|
Ordering::Equal |
||||
|
} |
||||
|
|
||||
|
/// Gauss-Legendre_Quadrature
|
||||
|
///
|
||||
|
/// Could not find a rust implementation of this, so have created one from
|
||||
|
/// a C implementation found
|
||||
|
/// [here](https://rosettacode.org/wiki/Numerical_integration/Gauss-Legendre_Quadrature#C).
|
||||
|
///
|
||||
|
/// The code is well short of optimal, but it gets things moving. Better
|
||||
|
/// versions are provided by, for example,
|
||||
|
/// [numpy.polynomial.legendre.leggauss](https://github.com/numpy/numpy/blob/v1.21.0/numpy/polynomial/legendre.py#L1519-L1584)
|
||||
|
/// but the implementaitons are more involved so we have opted for quick and
|
||||
|
/// dirty for the time being.
|
||||
|
#[derive(Debug, Clone)] |
||||
|
pub struct GaussLegendreQuadrature { |
||||
|
pub sample_points: Array1<f64>, |
||||
|
pub weights: Array1<f64>, |
||||
|
} |
||||
|
|
||||
|
impl GaussLegendreQuadrature { |
||||
|
pub fn new(order: usize) -> Result<Self, Error> { |
||||
|
if order < 1 { |
||||
|
return Result::Err(Error::DegreeMustBePositive); |
||||
|
} |
||||
|
|
||||
|
let data = legendre_wrapper(&order); |
||||
|
|
||||
|
Ok(GaussLegendreQuadrature { |
||||
|
sample_points: data.0, |
||||
|
weights: data.1, |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
fn legendre_wrapper(order: &usize) -> (Array1<f64>, Array1<f64>) { |
||||
|
let arr = legendre_coefficients(order); |
||||
|
legendre_roots(&arr) |
||||
|
} |
||||
|
|
||||
|
fn legendre_coefficients(order: &usize) -> Array2<f64> { |
||||
|
let mut lcoef_arr = Array2::<f64>::zeros((*order + 1, *order + 1)); |
||||
|
lcoef_arr[[0, 0]] = 1.; |
||||
|
lcoef_arr[[1, 1]] = 1.; |
||||
|
|
||||
|
for n in 2..*order + 1 { |
||||
|
let n_64 = n as f64; |
||||
|
lcoef_arr[[n, 0]] = -(n_64 - 1.) * lcoef_arr[[n - 2, 0]] / n_64; |
||||
|
|
||||
|
for i in 1..n + 1 { |
||||
|
lcoef_arr[[n, i]] = ((2. * n_64 - 1.) * lcoef_arr[[n - 1, i - 1]] |
||||
|
- (n_64 - 1.) * lcoef_arr[[n - 2, i]]) |
||||
|
/ n_64; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
lcoef_arr |
||||
|
} |
||||
|
|
||||
|
fn legendre_eval(coeff_arr: &Array2<f64>, n: &usize, x: &f64) -> f64 { |
||||
|
let mut s = coeff_arr[[*n, *n]]; |
||||
|
for i in (1..*n + 1).rev() { |
||||
|
s = s * (*x) + coeff_arr[[*n, i - 1]]; |
||||
|
} |
||||
|
|
||||
|
s |
||||
|
} |
||||
|
|
||||
|
fn legendre_diff(coeff_arr: &Array2<f64>, n: &usize, x: &f64) -> f64 { |
||||
|
let n_64 = *n as f64; |
||||
|
n_64 * (x * legendre_eval(coeff_arr, n, x) - legendre_eval(coeff_arr, &(n - 1), x)) |
||||
|
/ (x * x - 1.) |
||||
|
} |
||||
|
|
||||
|
fn legendre_roots(coeff_arr: &Array2<f64>) -> (Array1<f64>, Array1<f64>) { |
||||
|
let n = coeff_arr.shape()[0] - 1; |
||||
|
let n_64 = n as f64; |
||||
|
|
||||
|
let mut sample_points_arr = Array1::<f64>::zeros(n + 1); |
||||
|
let mut weights_arr = Array1::<f64>::zeros(n + 1); |
||||
|
|
||||
|
for i in 1..n + 1 { |
||||
|
let i_64 = i as f64; |
||||
|
let mut x = (PI * (i_64 - 0.25) / (n_64 + 0.5)).cos(); |
||||
|
let mut x1 = x; |
||||
|
x -= legendre_eval(coeff_arr, &n, &x) / legendre_diff(coeff_arr, &n, &x); |
||||
|
|
||||
|
while fdim(&x, &x1) > 2e-16 { |
||||
|
x1 = x; |
||||
|
x -= legendre_eval(coeff_arr, &n, &x) / legendre_diff(coeff_arr, &n, &x); |
||||
|
} |
||||
|
|
||||
|
sample_points_arr[i - 1] = x; |
||||
|
x1 = legendre_diff(coeff_arr, &n, &x); |
||||
|
weights_arr[i - 1] = 2. / ((1. - x * x) * x1 * x1); |
||||
|
} |
||||
|
|
||||
|
// truncate the dummy value off the end + reverse sample points +
|
||||
|
// use symmetry to stable things up a bit.
|
||||
|
let mut samples = sample_points_arr.slice(s![..n; -1]).to_owned(); |
||||
|
samples = symmetric_samples(&samples); |
||||
|
|
||||
|
let mut weights = weights_arr.slice(s![..n]).to_owned(); |
||||
|
weights = symmetric_weights(&weights); |
||||
|
|
||||
|
(samples, weights) |
||||
|
} |
||||
|
|
||||
|
fn symmetric_samples(arr: &Array1<f64>) -> Array1<f64> { |
||||
|
let arr_rev = arr.slice(s![..; -1]).to_owned(); |
||||
|
(arr - &arr_rev) / 2. |
||||
|
} |
||||
|
|
||||
|
fn symmetric_weights(arr: &Array1<f64>) -> Array1<f64> { |
||||
|
let s = &arr.sum_axis(Axis(0)); |
||||
|
let arr_rev = arr.slice(s![..; -1]).to_owned(); |
||||
|
(arr + &arr_rev) / s |
||||
|
} |
||||
|
|
||||
|
fn fdim(a: &f64, b: &f64) -> f64 { |
||||
|
let res; |
||||
|
if a - b > 0f64 { |
||||
|
res = a - b; |
||||
|
} else { |
||||
|
res = 0f64; |
||||
|
} |
||||
|
res |
||||
|
} |
Loading…
Reference in new issue