Browse Source

Added tests for matrix-solver

Added missing tests for `CSR::matrix_solve()` and fixed some
error handling in the process.
feature/actor-custom-derive
DelicioiusHair 3 years ago
parent
commit
8897754138
  1. 2
      daemon/src/payout_curve.rs
  2. 81
      daemon/src/payout_curve/csr_tools.rs

2
daemon/src/payout_curve.rs

@ -152,6 +152,8 @@ pub enum Error {
CannotInitCSR, CannotInitCSR,
#[error("matrix must be square")] #[error("matrix must be square")]
MatrixMustBeSquare, MatrixMustBeSquare,
#[error("cannot invert singular matrix")]
SingularMatrix,
#[error("evaluation outside parametric domain")] #[error("evaluation outside parametric domain")]
InvalidDomain, InvalidDomain,
#[error("einsum error--array size mismatch?")] #[error("einsum error--array size mismatch?")]

81
daemon/src/payout_curve/csr_tools.rs

@ -55,7 +55,7 @@ impl CSR {
.map(|e| { .map(|e| {
let b = b_arr.slice(s![.., e]).to_owned(); let b = b_arr.slice(s![.., e]).to_owned();
let sol = lu_solve(&a_arr, &b).unwrap(); let sol = lu_solve(&a_arr, &b)?;
Ok(sol.to_vec()) Ok(sol.to_vec())
}) })
.collect::<Result<Vec<_>, Error>>()?; .collect::<Result<Vec<_>, Error>>()?;
@ -89,17 +89,21 @@ impl CSR {
} }
fn lu_solve(a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>, Error> { fn lu_solve(a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>, Error> {
if !is_square(a) {
return Result::Err(Error::MatrixMustBeSquare);
}
let a = a.to_nalgebra_matrix().lu(); let a = a.to_nalgebra_matrix().lu();
let b = b.to_nalgebra_matrix(); let b = b.to_nalgebra_matrix();
let x = a let x = a.solve(&b).ok_or(Error::SingularMatrix)?.to_1d_array()?;
.solve(&b)
.ok_or(Error::MatrixMustBeSquare)?
.to_1d_array()?;
Ok(x) Ok(x)
} }
fn is_square<T>(arr: &Array2<T>) -> bool {
arr.shape()[0] == arr.shape()[1]
}
impl Mul<&Array1<f64>> for CSR { impl Mul<&Array1<f64>> for CSR {
type Output = Array1<f64>; type Output = Array1<f64>;
@ -121,32 +125,71 @@ mod tests {
#[test] #[test]
fn test_lu_solve() { fn test_lu_solve() {
let a = Array2::<f64>::from(vec![[11., 12., 0.], [0., 22., 23.], [31., 0., 33.]]); let a = array![[11., 12., 0.], [0., 22., 23.], [31., 0., 33.]];
let b = Array1::<f64>::from_vec(vec![35., 113., 130.]); let b = array![35., 113., 130.];
let x_expected = Array1::<f64>::from_vec(vec![1., 2., 3.]); let x_expected = array![1., 2., 3.];
let x = lu_solve(&a, &b).unwrap(); let x = lu_solve(&a, &b).unwrap();
for (x, expected) in x.into_iter().zip(x_expected) { assert!(x.abs_diff_eq(&x_expected, 1e-10));
assert!(
(x - expected).abs() < f64::EPSILON * 10.,
"{} {}",
x,
expected
)
}
} }
#[test] #[test]
fn negative_csr_test_00() { fn negative_csr_test_00() {
let a = CSR::new( let a = CSR::new(
Array1::<f64>::zeros(0), Array1::zeros(0),
Array1::<usize>::zeros(0), Array1::zeros(0),
Array1::<usize>::zeros(11), Array1::zeros(11),
(1, 3), (1, 3),
) )
.unwrap_err(); .unwrap_err();
assert!(matches!(a, Error::CannotInitCSR)); assert!(matches!(a, Error::CannotInitCSR));
} }
// test that all is good
#[test]
fn test_lu_matrix_solve_00() {
let a = CSR::new(
array![11., 12., 22., 23., 31., 33.],
array![0, 1, 1, 2, 0, 2],
array![0, 2, 4, 6],
(3, 3),
)
.unwrap();
let b = array![[36., 59., 82.], [204., 249., 294.], [198., 262., 326.],];
let x_expected = array![[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]];
let x = a.matrix_solve(&b).unwrap();
assert!(x.abs_diff_eq(&x_expected, 1e-10));
}
// test that an indeterminate system borks
#[test]
fn test_lu_matrix_solve_01() {
let a = CSR::new(array![1.], array![2], array![0, 1, 1, 1], (3, 3)).unwrap();
let b = array![[36., 59., 82.], [204., 249., 294.], [198., 262., 326.],];
let e = a.matrix_solve(&b).unwrap_err();
assert!(matches!(e, Error::SingularMatrix));
}
// test that an incompatible system borks
#[test]
fn test_lu_matrix_solve_02() {
let a = CSR::new(
array![11., 12., 14., 19., 22., 23., 31., 33.],
array![0, 1, 3, 5, 1, 2, 3, 5],
array![0, 4, 8],
(2, 6),
)
.unwrap();
let b = array![[447., 503., 559.], [978., 1087., 1196.]];
let e = a.matrix_solve(&b).unwrap_err();
assert!(matches!(e, Error::MatrixMustBeSquare));
}
} }

Loading…
Cancel
Save