Browse Source

Rewrite matrix_solve to not contain `unwrap`

contact-taker-before-changing-cfd-state
Thomas Eizinger 3 years ago
parent
commit
5930059806
No known key found for this signature in database GPG Key ID: 651AC83A6C6C8B96
  1. 37
      daemon/src/payout_curve/csr_tools.rs

37
daemon/src/payout_curve/csr_tools.rs

@ -1,5 +1,6 @@
use crate::payout_curve::compat::{To1DArray, ToNAlgebraMatrix};
use crate::payout_curve::Error;
use itertools::Itertools;
use ndarray::prelude::*;
use std::ops::Mul;
@ -49,29 +50,19 @@ impl CSR {
// is horrible.
pub fn matrix_solve(&self, b_arr: &Array2<f64>) -> Result<Array2<f64>, Error> {
let a_arr = self.todense();
let ncols = b_arr.shape()[1];
let mut temp = (0..ncols)
.rev()
.map(|e| {
let b = b_arr.slice(s![.., e]).to_owned();
let sol = lu_solve(&a_arr, &b)?;
Ok(sol.to_vec())
})
.collect::<Result<Vec<_>, Error>>()?;
let nrows = temp[0].len();
let mut raveled = Vec::with_capacity(nrows * temp.len());
for _ in 0..nrows {
for vec in &mut temp {
raveled.push(vec.pop().unwrap());
}
}
raveled.reverse();
let result = Array2::<f64>::from_shape_vec((nrows, ncols), raveled)?.to_owned();
let zeros = Array2::zeros((b_arr.nrows(), 0));
let result = b_arr
.columns()
.into_iter()
.map(|b| lu_solve(&a_arr, &b.to_owned()))
.fold_ok(zeros, |mut result, column| {
result
.push_column(column.view())
.expect("shape was initialized correctly");
result
})?;
Ok(result)
}

Loading…
Cancel
Save