|
@ -55,7 +55,7 @@ impl CSR { |
|
|
.map(|e| { |
|
|
.map(|e| { |
|
|
let b = b_arr.slice(s![.., e]).to_owned(); |
|
|
let b = b_arr.slice(s![.., e]).to_owned(); |
|
|
|
|
|
|
|
|
let sol = lu_solve(&a_arr, &b).unwrap(); |
|
|
let sol = lu_solve(&a_arr, &b)?; |
|
|
Ok(sol.to_vec()) |
|
|
Ok(sol.to_vec()) |
|
|
}) |
|
|
}) |
|
|
.collect::<Result<Vec<_>, Error>>()?; |
|
|
.collect::<Result<Vec<_>, Error>>()?; |
|
@ -89,17 +89,21 @@ impl CSR { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
fn lu_solve(a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>, Error> { |
|
|
fn lu_solve(a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>, Error> { |
|
|
|
|
|
if !is_square(a) { |
|
|
|
|
|
return Result::Err(Error::MatrixMustBeSquare); |
|
|
|
|
|
} |
|
|
let a = a.to_nalgebra_matrix().lu(); |
|
|
let a = a.to_nalgebra_matrix().lu(); |
|
|
let b = b.to_nalgebra_matrix(); |
|
|
let b = b.to_nalgebra_matrix(); |
|
|
|
|
|
|
|
|
let x = a |
|
|
let x = a.solve(&b).ok_or(Error::SingularMatrix)?.to_1d_array()?; |
|
|
.solve(&b) |
|
|
|
|
|
.ok_or(Error::MatrixMustBeSquare)? |
|
|
|
|
|
.to_1d_array()?; |
|
|
|
|
|
|
|
|
|
|
|
Ok(x) |
|
|
Ok(x) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
fn is_square<T>(arr: &Array2<T>) -> bool { |
|
|
|
|
|
arr.shape()[0] == arr.shape()[1] |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
impl Mul<&Array1<f64>> for CSR { |
|
|
impl Mul<&Array1<f64>> for CSR { |
|
|
type Output = Array1<f64>; |
|
|
type Output = Array1<f64>; |
|
|
|
|
|
|
|
@ -121,32 +125,71 @@ mod tests { |
|
|
|
|
|
|
|
|
#[test] |
|
|
#[test] |
|
|
fn test_lu_solve() { |
|
|
fn test_lu_solve() { |
|
|
let a = Array2::<f64>::from(vec![[11., 12., 0.], [0., 22., 23.], [31., 0., 33.]]); |
|
|
let a = array![[11., 12., 0.], [0., 22., 23.], [31., 0., 33.]]; |
|
|
let b = Array1::<f64>::from_vec(vec![35., 113., 130.]); |
|
|
let b = array![35., 113., 130.]; |
|
|
let x_expected = Array1::<f64>::from_vec(vec![1., 2., 3.]); |
|
|
let x_expected = array![1., 2., 3.]; |
|
|
|
|
|
|
|
|
let x = lu_solve(&a, &b).unwrap(); |
|
|
let x = lu_solve(&a, &b).unwrap(); |
|
|
|
|
|
|
|
|
for (x, expected) in x.into_iter().zip(x_expected) { |
|
|
assert!(x.abs_diff_eq(&x_expected, 1e-10)); |
|
|
assert!( |
|
|
|
|
|
(x - expected).abs() < f64::EPSILON * 10., |
|
|
|
|
|
"{} {}", |
|
|
|
|
|
x, |
|
|
|
|
|
expected |
|
|
|
|
|
) |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#[test] |
|
|
#[test] |
|
|
fn negative_csr_test_00() { |
|
|
fn negative_csr_test_00() { |
|
|
let a = CSR::new( |
|
|
let a = CSR::new( |
|
|
Array1::<f64>::zeros(0), |
|
|
Array1::zeros(0), |
|
|
Array1::<usize>::zeros(0), |
|
|
Array1::zeros(0), |
|
|
Array1::<usize>::zeros(11), |
|
|
Array1::zeros(11), |
|
|
(1, 3), |
|
|
(1, 3), |
|
|
) |
|
|
) |
|
|
.unwrap_err(); |
|
|
.unwrap_err(); |
|
|
|
|
|
|
|
|
assert!(matches!(a, Error::CannotInitCSR)); |
|
|
assert!(matches!(a, Error::CannotInitCSR)); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// test that all is good
|
|
|
|
|
|
#[test] |
|
|
|
|
|
fn test_lu_matrix_solve_00() { |
|
|
|
|
|
let a = CSR::new( |
|
|
|
|
|
array![11., 12., 22., 23., 31., 33.], |
|
|
|
|
|
array![0, 1, 1, 2, 0, 2], |
|
|
|
|
|
array![0, 2, 4, 6], |
|
|
|
|
|
(3, 3), |
|
|
|
|
|
) |
|
|
|
|
|
.unwrap(); |
|
|
|
|
|
let b = array![[36., 59., 82.], [204., 249., 294.], [198., 262., 326.],]; |
|
|
|
|
|
|
|
|
|
|
|
let x_expected = array![[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]]; |
|
|
|
|
|
let x = a.matrix_solve(&b).unwrap(); |
|
|
|
|
|
|
|
|
|
|
|
assert!(x.abs_diff_eq(&x_expected, 1e-10)); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// test that an indeterminate system borks
|
|
|
|
|
|
#[test] |
|
|
|
|
|
fn test_lu_matrix_solve_01() { |
|
|
|
|
|
let a = CSR::new(array![1.], array![2], array![0, 1, 1, 1], (3, 3)).unwrap(); |
|
|
|
|
|
let b = array![[36., 59., 82.], [204., 249., 294.], [198., 262., 326.],]; |
|
|
|
|
|
|
|
|
|
|
|
let e = a.matrix_solve(&b).unwrap_err(); |
|
|
|
|
|
|
|
|
|
|
|
assert!(matches!(e, Error::SingularMatrix)); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// test that an incompatible system borks
|
|
|
|
|
|
#[test] |
|
|
|
|
|
fn test_lu_matrix_solve_02() { |
|
|
|
|
|
let a = CSR::new( |
|
|
|
|
|
array![11., 12., 14., 19., 22., 23., 31., 33.], |
|
|
|
|
|
array![0, 1, 3, 5, 1, 2, 3, 5], |
|
|
|
|
|
array![0, 4, 8], |
|
|
|
|
|
(2, 6), |
|
|
|
|
|
) |
|
|
|
|
|
.unwrap(); |
|
|
|
|
|
let b = array![[447., 503., 559.], [978., 1087., 1196.]]; |
|
|
|
|
|
|
|
|
|
|
|
let e = a.matrix_solve(&b).unwrap_err(); |
|
|
|
|
|
|
|
|
|
|
|
assert!(matches!(e, Error::MatrixMustBeSquare)); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|