1extern crate rand;
34use self::rand::prelude::*;
35use crate::structure::{
36 matrix::Shape::{Col, Row},
37 matrix::{matrix, Matrix, Shape},
38};
39use crate::traits::float::FloatWithPrecision;
40use crate::traits::matrix::MatrixTrait;
41use anyhow::{bail, Result};
42use rand_distr::{Distribution, Uniform};
43
44#[derive(Debug, Copy, Clone)]
45pub enum ConcatenateError {
46 DifferentLength,
47}
48
49impl std::fmt::Display for ConcatenateError {
50 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
51 match *self {
52 ConcatenateError::DifferentLength => write!(
53 f,
54 "To concatenate, vectors or matrices must have the same length"
55 ),
56 }
57 }
58}
59
60pub fn seq<S, T, U>(start: S, end: T, step: U) -> Vec<f64>
76where
77 S: Into<f64> + Copy,
78 T: Into<f64> + Copy,
79 U: Into<f64> + Copy,
80{
81 let s = start.into();
82 let e = end.into();
83 let step = step.into();
84
85 assert!(e >= s);
86
87 let factor: f64 = (e - s) / step;
88 let l: usize = factor.floor() as usize + 1;
89 let mut v: Vec<f64> = vec![0f64; l];
90
91 for (i, v) in v.iter_mut().enumerate() {
92 *v = s + step * (i as f64);
93 }
94 v
95}
96
97pub fn seq_with_precision<S, T, U>(start: S, end: T, step: U, precision: usize) -> Vec<f64>
110where
111 S: Into<f64> + Copy,
112 T: Into<f64> + Copy,
113 U: Into<f64> + Copy,
114{
115 let s = start.into();
116 let e = end.into();
117 let step = step.into();
118
119 assert!(e >= s);
120
121 let factor: f64 = (e - s) / step;
122 let l: usize = factor.floor() as usize + 1;
123 let mut v: Vec<f64> = vec![0f64; l];
124
125 for (i, v) in v.iter_mut().enumerate() {
126 *v = (s + step * (i as f64)).round_with_precision(precision);
127 }
128 v
129}
130
131pub fn cbind(m1: Matrix, m2: Matrix) -> Result<Matrix> {
148 let mut temp = m1;
149 if temp.shape != Col {
150 temp = temp.change_shape();
151 }
152
153 let mut temp2 = m2;
154 if temp2.shape != Col {
155 temp2 = temp2.change_shape();
156 }
157
158 let mut v = temp.data;
159 let mut c = temp.col;
160 let r = temp.row;
161
162 if r != temp2.row {
163 bail!(ConcatenateError::DifferentLength);
164 }
165 v.extend_from_slice(&temp2.data[..]);
166 c += temp2.col;
167 Ok(matrix(v, r, c, Col))
168}
169
170pub fn rbind(m1: Matrix, m2: Matrix) -> Result<Matrix> {
187 let mut temp = m1;
188 if temp.shape != Row {
189 temp = temp.change_shape();
190 }
191
192 let mut temp2 = m2;
193 if temp2.shape != Row {
194 temp2 = temp2.change_shape();
195 }
196
197 let mut v = temp.data;
198 let c = temp.col;
199 let mut r = temp.row;
200
201 if c != temp2.col {
202 bail!(ConcatenateError::DifferentLength);
203 }
204 v.extend_from_slice(&temp2.data[..]);
205 r += temp2.row;
206 Ok(matrix(v, r, c, Row))
207}
208
209pub fn zeros(r: usize, c: usize) -> Matrix {
222 matrix(vec![0f64; r * c], r, c, Row)
223}
224
225pub fn zeros_shape(r: usize, c: usize, shape: Shape) -> Matrix {
227 matrix(vec![0f64; r * c], r, c, shape)
228}
229
230pub fn eye(n: usize) -> Matrix {
240 let mut m = zeros(n, n);
241 for i in 0..n {
242 m[(i, i)] = 1f64;
243 }
244 m
245}
246
247pub fn eye_shape(n: usize, shape: Shape) -> Matrix {
249 let mut m = zeros_shape(n, n, shape);
250 for i in 0..n {
251 m[(i, i)] = 1f64;
252 }
253 m
254}
255
256pub fn linspace<S, T>(start: S, end: T, length: usize) -> Vec<f64>
267where
268 S: Into<f64> + Copy,
269 T: Into<f64> + Copy,
270{
271 let step: f64 = if length > 1 {
272 (end.into() - start.into()) / (length as f64 - 1f64)
273 } else {
274 0f64
275 };
276
277 let mut v = vec![0f64; length];
278 v[0] = start.into();
279 v[length - 1] = end.into();
280
281 for i in 1..length - 1 {
282 v[i] = v[0] + step * (i as f64);
283 }
284 v
285}
286
287pub fn linspace_with_precision<S, T>(start: S, end: T, length: usize, precision: usize) -> Vec<f64>
300where
301 S: Into<f64> + Copy,
302 T: Into<f64> + Copy,
303{
304 let step: f64 = if length > 1 {
305 (end.into() - start.into()) / (length as f64 - 1f64)
306 } else {
307 0f64
308 };
309
310 let mut v = vec![0f64; length];
311 v[0] = start.into().round_with_precision(precision);
312 v[length - 1] = end.into().round_with_precision(precision);
313
314 for i in 1..length - 1 {
315 v[i] = (v[0] + step * (i as f64)).round_with_precision(precision);
316 }
317 v
318}
319
320pub fn rand(r: usize, c: usize) -> Matrix {
326 let mut rng = rand::rng();
327 rand_with_rng(r, c, &mut rng)
328}
329
330pub fn rand_with_rng<R: Rng>(r: usize, c: usize, rng: &mut R) -> Matrix {
336 let uniform = Uniform::new_inclusive(0f64, 1f64).unwrap();
337 rand_with_dist(r, c, rng, uniform)
338}
339
340pub fn rand_with_dist<T: Into<f64>, R: Rng, D: Distribution<T>>(
346 r: usize,
347 c: usize,
348 rng: &mut R,
349 dist: D,
350) -> Matrix {
351 matrix(rng.sample_iter(dist).take(r * c).collect(), r, c, Row)
352}
353
354pub fn logspace<S, T, U>(start: S, end: T, length: usize, base: U) -> Vec<f64>
371where
372 S: Into<f64> + Copy,
373 T: Into<f64> + Copy,
374 U: Into<f64> + Copy,
375{
376 let s: f64 = start.into();
377 let e: f64 = end.into();
378 let b: f64 = base.into();
379
380 assert!(e >= s);
381
382 let step: f64 = if length > 1 {
383 (e - s) / (length as f64 - 1f64)
384 } else {
385 0f64
386 };
387
388 let mut v: Vec<f64> = vec![0f64; length];
389
390 for (i, v) in v.iter_mut().enumerate() {
391 *v = b.powf(s + step * (i as f64));
392 }
393 v
394}
395
396pub fn column_stack(v: &[Vec<f64>]) -> Result<Matrix> {
398 let row = v[0].len();
399 if v.iter().any(|x| x.len() != row) {
400 bail!(ConcatenateError::DifferentLength);
401 }
402 let data = v.iter().flatten().copied().collect();
403 Ok(matrix(data, row, v.len(), Col))
404}
405
406pub fn row_stack(v: &[Vec<f64>]) -> Result<Matrix> {
408 let col = v[0].len();
409 if v.iter().any(|x| x.len() != col) {
410 bail!(ConcatenateError::DifferentLength);
411 }
412 let data = v.iter().flatten().copied().collect();
413 Ok(matrix(data, v.len(), col, Row))
414}
415
416pub fn concat<T: Clone + Copy>(v1: &[T], v2: &[T]) -> Vec<T> {
421 let mut v = v1.to_vec();
422 v.extend_from_slice(v2);
423
424 v
425}
426
427pub fn cat<T: Clone + Copy + Default>(val: T, vec: &[T]) -> Vec<T> {
429 let mut v = vec![val];
430 v.extend_from_slice(vec);
431
432 v
433}