diff --git a/daemon/src/payout_curve.rs b/daemon/src/payout_curve.rs index 83cb5b1..12c3a48 100644 --- a/daemon/src/payout_curve.rs +++ b/daemon/src/payout_curve.rs @@ -152,6 +152,8 @@ pub enum Error { CannotInitCSR, #[error("matrix must be square")] MatrixMustBeSquare, + #[error("cannot invert singular matrix")] + SingularMatrix, #[error("evaluation outside parametric domain")] InvalidDomain, #[error("einsum error--array size mismatch?")] diff --git a/daemon/src/payout_curve/csr_tools.rs b/daemon/src/payout_curve/csr_tools.rs index 89c0270..e4406bc 100644 --- a/daemon/src/payout_curve/csr_tools.rs +++ b/daemon/src/payout_curve/csr_tools.rs @@ -55,7 +55,7 @@ impl CSR { .map(|e| { let b = b_arr.slice(s![.., e]).to_owned(); - let sol = lu_solve(&a_arr, &b).unwrap(); + let sol = lu_solve(&a_arr, &b)?; Ok(sol.to_vec()) }) .collect::, Error>>()?; @@ -89,17 +89,21 @@ impl CSR { } fn lu_solve(a: &Array2, b: &Array1) -> Result, Error> { + if !is_square(a) { + return Result::Err(Error::MatrixMustBeSquare); + } let a = a.to_nalgebra_matrix().lu(); let b = b.to_nalgebra_matrix(); - let x = a - .solve(&b) - .ok_or(Error::MatrixMustBeSquare)? - .to_1d_array()?; + let x = a.solve(&b).ok_or(Error::SingularMatrix)?.to_1d_array()?; Ok(x) } +fn is_square(arr: &Array2) -> bool { + arr.shape()[0] == arr.shape()[1] +} + impl Mul<&Array1> for CSR { type Output = Array1; @@ -121,32 +125,71 @@ mod tests { #[test] fn test_lu_solve() { - let a = Array2::::from(vec![[11., 12., 0.], [0., 22., 23.], [31., 0., 33.]]); - let b = Array1::::from_vec(vec![35., 113., 130.]); - let x_expected = Array1::::from_vec(vec![1., 2., 3.]); + let a = array![[11., 12., 0.], [0., 22., 23.], [31., 0., 33.]]; + let b = array![35., 113., 130.]; + let x_expected = array![1., 2., 3.]; let x = lu_solve(&a, &b).unwrap(); - for (x, expected) in x.into_iter().zip(x_expected) { - assert!( - (x - expected).abs() < f64::EPSILON * 10., - "{} {}", - x, - expected - ) - } + assert!(x.abs_diff_eq(&x_expected, 1e-10)); } #[test] fn negative_csr_test_00() { let a = CSR::new( - Array1::::zeros(0), - Array1::::zeros(0), - Array1::::zeros(11), + Array1::zeros(0), + Array1::zeros(0), + Array1::zeros(11), (1, 3), ) .unwrap_err(); assert!(matches!(a, Error::CannotInitCSR)); } + + // test that all is good + #[test] + fn test_lu_matrix_solve_00() { + let a = CSR::new( + array![11., 12., 22., 23., 31., 33.], + array![0, 1, 1, 2, 0, 2], + array![0, 2, 4, 6], + (3, 3), + ) + .unwrap(); + let b = array![[36., 59., 82.], [204., 249., 294.], [198., 262., 326.],]; + + let x_expected = array![[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]]; + let x = a.matrix_solve(&b).unwrap(); + + assert!(x.abs_diff_eq(&x_expected, 1e-10)); + } + + // test that an indeterminate system borks + #[test] + fn test_lu_matrix_solve_01() { + let a = CSR::new(array![1.], array![2], array![0, 1, 1, 1], (3, 3)).unwrap(); + let b = array![[36., 59., 82.], [204., 249., 294.], [198., 262., 326.],]; + + let e = a.matrix_solve(&b).unwrap_err(); + + assert!(matches!(e, Error::SingularMatrix)); + } + + // test that an incompatible system borks + #[test] + fn test_lu_matrix_solve_02() { + let a = CSR::new( + array![11., 12., 14., 19., 22., 23., 31., 33.], + array![0, 1, 3, 5, 1, 2, 3, 5], + array![0, 4, 8], + (2, 6), + ) + .unwrap(); + let b = array![[447., 503., 559.], [978., 1087., 1196.]]; + + let e = a.matrix_solve(&b).unwrap_err(); + + assert!(matches!(e, Error::MatrixMustBeSquare)); + } }