Browse Source

Added tests for matrix-solver

Added missing tests for `CSR::matrix_solve()` and fixed some
error handling in the process.
feature/actor-custom-derive
DelicioiusHair 3 years ago
parent
commit
8897754138
  1. 2
      daemon/src/payout_curve.rs
  2. 81
      daemon/src/payout_curve/csr_tools.rs

2
daemon/src/payout_curve.rs

@ -152,6 +152,8 @@ pub enum Error {
CannotInitCSR,
#[error("matrix must be square")]
MatrixMustBeSquare,
#[error("cannot invert singular matrix")]
SingularMatrix,
#[error("evaluation outside parametric domain")]
InvalidDomain,
#[error("einsum error--array size mismatch?")]

81
daemon/src/payout_curve/csr_tools.rs

@ -55,7 +55,7 @@ impl CSR {
.map(|e| {
let b = b_arr.slice(s![.., e]).to_owned();
let sol = lu_solve(&a_arr, &b).unwrap();
let sol = lu_solve(&a_arr, &b)?;
Ok(sol.to_vec())
})
.collect::<Result<Vec<_>, Error>>()?;
@ -89,17 +89,21 @@ impl CSR {
}
fn lu_solve(a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>, Error> {
if !is_square(a) {
return Result::Err(Error::MatrixMustBeSquare);
}
let a = a.to_nalgebra_matrix().lu();
let b = b.to_nalgebra_matrix();
let x = a
.solve(&b)
.ok_or(Error::MatrixMustBeSquare)?
.to_1d_array()?;
let x = a.solve(&b).ok_or(Error::SingularMatrix)?.to_1d_array()?;
Ok(x)
}
fn is_square<T>(arr: &Array2<T>) -> bool {
arr.shape()[0] == arr.shape()[1]
}
impl Mul<&Array1<f64>> for CSR {
type Output = Array1<f64>;
@ -121,32 +125,71 @@ mod tests {
#[test]
fn test_lu_solve() {
let a = Array2::<f64>::from(vec![[11., 12., 0.], [0., 22., 23.], [31., 0., 33.]]);
let b = Array1::<f64>::from_vec(vec![35., 113., 130.]);
let x_expected = Array1::<f64>::from_vec(vec![1., 2., 3.]);
let a = array![[11., 12., 0.], [0., 22., 23.], [31., 0., 33.]];
let b = array![35., 113., 130.];
let x_expected = array![1., 2., 3.];
let x = lu_solve(&a, &b).unwrap();
for (x, expected) in x.into_iter().zip(x_expected) {
assert!(
(x - expected).abs() < f64::EPSILON * 10.,
"{} {}",
x,
expected
)
}
assert!(x.abs_diff_eq(&x_expected, 1e-10));
}
#[test]
fn negative_csr_test_00() {
let a = CSR::new(
Array1::<f64>::zeros(0),
Array1::<usize>::zeros(0),
Array1::<usize>::zeros(11),
Array1::zeros(0),
Array1::zeros(0),
Array1::zeros(11),
(1, 3),
)
.unwrap_err();
assert!(matches!(a, Error::CannotInitCSR));
}
// test that all is good
#[test]
fn test_lu_matrix_solve_00() {
let a = CSR::new(
array![11., 12., 22., 23., 31., 33.],
array![0, 1, 1, 2, 0, 2],
array![0, 2, 4, 6],
(3, 3),
)
.unwrap();
let b = array![[36., 59., 82.], [204., 249., 294.], [198., 262., 326.],];
let x_expected = array![[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]];
let x = a.matrix_solve(&b).unwrap();
assert!(x.abs_diff_eq(&x_expected, 1e-10));
}
// test that an indeterminate system borks
#[test]
fn test_lu_matrix_solve_01() {
let a = CSR::new(array![1.], array![2], array![0, 1, 1, 1], (3, 3)).unwrap();
let b = array![[36., 59., 82.], [204., 249., 294.], [198., 262., 326.],];
let e = a.matrix_solve(&b).unwrap_err();
assert!(matches!(e, Error::SingularMatrix));
}
// test that an incompatible system borks
#[test]
fn test_lu_matrix_solve_02() {
let a = CSR::new(
array![11., 12., 14., 19., 22., 23., 31., 33.],
array![0, 1, 3, 5, 1, 2, 3, 5],
array![0, 4, 8],
(2, 6),
)
.unwrap();
let b = array![[447., 503., 559.], [978., 1087., 1196.]];
let e = a.matrix_solve(&b).unwrap_err();
assert!(matches!(e, Error::MatrixMustBeSquare));
}
}

Loading…
Cancel
Save