You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
208 lines
6.3 KiB
208 lines
6.3 KiB
use crate::payout_curve::basis_eval::*;
|
|
use crate::payout_curve::csr_tools::CSR;
|
|
use crate::payout_curve::utils::*;
|
|
use crate::payout_curve::Error;
|
|
|
|
use core::cmp::max;
|
|
use ndarray::prelude::*;
|
|
use ndarray::{concatenate, s};
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct BSplineBasis {
|
|
pub knots: Array1<f64>,
|
|
pub order: usize,
|
|
pub periodic: isize,
|
|
pub knot_tol: f64,
|
|
}
|
|
|
|
impl BSplineBasis {
|
|
pub fn new(
|
|
order: Option<usize>,
|
|
knots: Option<Array1<f64>>,
|
|
periodic: Option<isize>,
|
|
) -> Result<Self, Error> {
|
|
let order = order.unwrap_or(2);
|
|
let periodic = periodic.unwrap_or(-1);
|
|
let knots = match knots {
|
|
Some(knots) => knots,
|
|
None => default_knot(order, periodic)?,
|
|
};
|
|
let ktol = knot_tolerance(None);
|
|
|
|
Ok(BSplineBasis {
|
|
order,
|
|
knots,
|
|
periodic,
|
|
knot_tol: ktol,
|
|
})
|
|
}
|
|
|
|
pub fn num_functions(&self) -> usize {
|
|
let p = (self.periodic + 1) as usize;
|
|
|
|
self.knots.len() - self.order - p
|
|
}
|
|
|
|
/// Start point of parametric domain. For open knot vectors, this is the
|
|
/// first knot.
|
|
pub fn start(&self) -> f64 {
|
|
self.knots[self.order - 1]
|
|
}
|
|
|
|
/// End point of parametric domain. For open knot vectors, this is the
|
|
/// last knot.
|
|
pub fn end(&self) -> f64 {
|
|
self.knots[self.knots.len() - self.order]
|
|
}
|
|
|
|
/// Fetch greville points, also known as knot averages
|
|
/// over entire knot vector:
|
|
/// .. math:: \\sum_{j=i+1}^{i+p-1} \\frac{t_j}{p-1}
|
|
pub fn greville(&self) -> Array1<f64> {
|
|
let n = self.num_functions() as i32;
|
|
|
|
(0..n).map(|idx| self.greville_single(idx)).collect()
|
|
}
|
|
|
|
fn greville_single(&self, index: i32) -> f64 {
|
|
let p = self.order as i32;
|
|
let den = (self.order - 1) as f64;
|
|
|
|
self.knots.slice(s![index + 1..index + p]).sum() / den
|
|
}
|
|
|
|
/// Evaluate all basis functions in a given set of points.
|
|
/// ## parameters:
|
|
/// * t: The parametric coordinate(s) in which to evaluate
|
|
/// * d: Number of derivatives to compute
|
|
/// * from_right: true if evaluation should be done in the limit from above
|
|
/// ## returns:
|
|
/// * CSR (sparse) matrix N\[i,j\] of all basis functions j evaluated in all points j
|
|
pub fn evaluate(&self, t: &mut Array1<f64>, d: usize, from_right: bool) -> Result<CSR, Error> {
|
|
let basis = Basis::new(
|
|
self.order,
|
|
self.knots.clone().to_owned(),
|
|
Some(self.periodic),
|
|
Some(self.knot_tol),
|
|
);
|
|
snap(t, &basis.knots, Some(basis.ktol));
|
|
|
|
if self.order <= d {
|
|
let csr = CSR::new(
|
|
Array1::<f64>::zeros(0),
|
|
Array1::<usize>::zeros(0),
|
|
Array1::<usize>::zeros(t.len() + 1),
|
|
(t.len(), self.num_functions()),
|
|
)?;
|
|
|
|
return Ok(csr);
|
|
}
|
|
|
|
let out = basis.evaluate(t, d, Some(from_right))?;
|
|
|
|
Ok(out)
|
|
}
|
|
|
|
/// Snap evaluation points to knots if they are sufficiently close
|
|
/// as given in by knot_tolerance.
|
|
///
|
|
/// * t: The parametric coordinate(s) in which to evaluate
|
|
pub fn snap(&self, t: &mut Array1<f64>) {
|
|
snap(t, &self.knots, Some(self.knot_tol))
|
|
}
|
|
|
|
/// Create a knot vector with higher order.
|
|
///
|
|
/// The continuity at the knots are kept unchanged by increasing their
|
|
/// multiplicities.
|
|
///
|
|
/// ### parameters
|
|
/// * amount: relative (polynomial) degree to raise the basis function by
|
|
pub fn raise_order(&mut self, amount: usize) {
|
|
if amount > 0 {
|
|
let knot_spans_arr = self.knot_spans(true);
|
|
let knot_spans = knot_spans_arr.iter().collect::<Vec<_>>();
|
|
let temp = self.knots.clone();
|
|
let mut knots = temp.iter().collect::<Vec<_>>();
|
|
|
|
for _ in 0..amount {
|
|
knots.append(&mut knot_spans.clone());
|
|
}
|
|
|
|
let mut knots_vec = knots.iter().map(|e| **e).collect::<Vec<_>>();
|
|
knots_vec.sort_by(cmp_f64);
|
|
|
|
let knots_arr = Array1::<f64>::from_vec(knots_vec);
|
|
|
|
let new_knot;
|
|
if self.periodic > -1 {
|
|
let n_0 = bisect_left(&knots_arr, &self.start(), knots_arr.len());
|
|
let n_1 =
|
|
knot_spans.len() - bisect_left(&knots_arr, &self.end(), knots_arr.len()) - 1;
|
|
let mut new_knot_vec = knots[n_0 * amount..n_1 * amount]
|
|
.iter()
|
|
.map(|e| **e)
|
|
.collect::<Vec<_>>();
|
|
new_knot_vec.sort_by(cmp_f64);
|
|
new_knot = Array1::<f64>::from_vec(new_knot_vec);
|
|
} else {
|
|
new_knot = knots_arr;
|
|
}
|
|
|
|
self.order += amount;
|
|
self.knots = new_knot;
|
|
}
|
|
}
|
|
|
|
/// Return the set of unique knots in the knot vector.
|
|
///
|
|
/// ### parameters:
|
|
/// * include_ghosts: if knots outside start/end are to be included. These \
|
|
/// knots are used by periodic basis.
|
|
///
|
|
/// ### returns:
|
|
/// * 1-D array of unique knots
|
|
pub fn knot_spans(&self, include_ghosts: bool) -> Array1<f64> {
|
|
let p = &self.order;
|
|
|
|
// TODO: this is VERY sloppy!
|
|
let mut res: Vec<f64> = vec![];
|
|
if include_ghosts {
|
|
res.push(self.knots[0]);
|
|
for elem in self.knots.slice(s![1..]).iter() {
|
|
if (elem - res[res.len() - 1]).abs() > self.knot_tol {
|
|
res.push(*elem);
|
|
}
|
|
}
|
|
} else {
|
|
res.push(self.knots[p - 1]);
|
|
let klen = self.knots.len();
|
|
for elem in self.knots.slice(s![p - 1..klen - p + 1]).iter() {
|
|
if (elem - res[res.len() - 1]).abs() > self.knot_tol {
|
|
res.push(*elem);
|
|
}
|
|
}
|
|
}
|
|
|
|
Array1::<f64>::from_vec(res)
|
|
}
|
|
}
|
|
|
|
fn default_knot(order: usize, periodic: isize) -> Result<Array1<f64>, Error> {
|
|
let prd = max(periodic, -1);
|
|
let p = (prd + 1) as usize;
|
|
let mut knots = concatenate(
|
|
Axis(0),
|
|
&[
|
|
Array1::<f64>::zeros(order).view(),
|
|
Array1::<f64>::ones(order).view(),
|
|
],
|
|
)?;
|
|
|
|
for i in 0..p {
|
|
knots[i] = -1.;
|
|
knots[2 * order - i - 1] = 2.;
|
|
}
|
|
|
|
Ok(knots)
|
|
}
|
|
|