peroxide/complex/
matrix.rs

1use std::{
2    cmp::{max, min},
3    fmt,
4    ops::{Add, Div, Index, IndexMut, Mul, Neg, Sub},
5};
6
7use anyhow::{bail, Result};
8use matrixmultiply::CGemmOption;
9use num_complex::Complex;
10use peroxide_num::{ExpLogOps, PowOps, TrigOps};
11use rand_distr::num_traits::{One, Zero};
12
13use crate::{
14    complex::C64,
15    structure::matrix::Shape,
16    traits::fp::{FPMatrix, FPVector},
17    traits::general::Algorithm,
18    traits::math::{InnerProduct, LinearOp, MatrixProduct, Norm, Normed, Vector},
19    traits::matrix::{Form, LinearAlgebra, MatrixTrait, SolveKind, PQLU, QR, SVD, UPLO, WAZD},
20    traits::mutable::MutMatrix,
21    util::low_level::{copy_vec_ptr, swap_vec_ptr},
22    util::non_macro::ConcatenateError,
23    util::useful::{nearly_eq, tab},
24};
25
26/// R-like complex matrix structure
27///
28/// # Examples
29///
30/// ```rust
31/// use peroxide::fuga::*;
32/// use peroxide::complex::matrix::ComplexMatrix;
33///
34/// let v1 = ComplexMatrix {
35/// data: vec![
36///     C64::new(1f64, 1f64),
37///     C64::new(2f64, 2f64),
38///     C64::new(3f64, 3f64),
39///     C64::new(4f64, 4f64),
40/// ],
41/// row: 2,
42/// col: 2,
43/// shape: Row,
44/// }; // [[1+1i,2+2i],[3+3i,4+4i]]
45/// ```
46#[derive(Debug, Clone, Default)]
47pub struct ComplexMatrix {
48    pub data: Vec<C64>,
49    pub row: usize,
50    pub col: usize,
51    pub shape: Shape,
52}
53
54// =============================================================================
55// Various complex matrix constructor
56// =============================================================================
57
58/// R-like complex matrix constructor
59///
60/// # Examples
61/// ```rust
62/// #[macro_use]
63/// extern crate peroxide;
64/// use peroxide::fuga::*;
65/// use peroxide::complex::matrix::cmatrix;
66///
67/// fn main() {
68///     let a = cmatrix(vec![C64::new(1f64, 1f64),
69///                       C64::new(2f64, 2f64),
70///                       C64::new(3f64, 3f64),
71///                       C64::new(4f64, 4f64)],
72///                    2, 2, Row
73///     );
74///     a.col.print(); // Print matrix column
75/// }
76/// ```
77pub fn cmatrix<T>(v: Vec<T>, r: usize, c: usize, shape: Shape) -> ComplexMatrix
78where
79    T: Into<C64>,
80{
81    ComplexMatrix {
82        data: v.into_iter().map(|t| t.into()).collect::<Vec<C64>>(),
83        row: r,
84        col: c,
85        shape,
86    }
87}
88
89/// R-like complex matrix constructor (Explicit ver.)
90pub fn r_cmatrix<T>(v: Vec<T>, r: usize, c: usize, shape: Shape) -> ComplexMatrix
91where
92    T: Into<C64>,
93{
94    cmatrix(v, r, c, shape)
95}
96
97/// Python-like complex matrix constructor
98///
99/// # Examples
100/// ```rust
101/// #[macro_use]
102/// extern crate peroxide;
103/// use peroxide::fuga::*;
104/// use peroxide::complex::matrix::*;
105///
106/// fn main() {
107///     let a = py_cmatrix(vec![vec![C64::new(1f64, 1f64),
108///                                         C64::new(2f64, 2f64)],
109///                                    vec![C64::new(3f64, 3f64),
110///                                         C64::new(4f64, 4f64)]
111///     ]);
112///     let b = cmatrix(vec![C64::new(1f64, 1f64),
113///                                 C64::new(2f64, 2f64),
114///                                 C64::new(3f64, 3f64),
115///                                 C64::new(4f64, 4f64)],
116///                             2, 2, Row
117///     );
118///     assert_eq!(a, b);
119/// }
120/// ```
121pub fn py_cmatrix<T>(v: Vec<Vec<T>>) -> ComplexMatrix
122where
123    T: Into<C64> + Copy,
124{
125    let r = v.len();
126    let c = v[0].len();
127    let data: Vec<T> = v.into_iter().flatten().collect();
128    cmatrix(data, r, c, Shape::Row)
129}
130
131/// Matlab-like matrix constructor
132///
133/// Note that the entries to the `ml_cmatrix`
134/// needs to be in the `a+bi` format
135/// without any spaces between the real and imaginary
136/// parts of the Complex number.
137///
138/// # Examples
139/// ```rust
140/// #[macro_use]
141/// extern crate peroxide;
142/// use peroxide::fuga::*;
143/// use peroxide::complex::matrix::*;
144///
145/// fn main() {
146///     let a = ml_cmatrix("1.0+1.0i 2.0+2.0i;
147///                                3.0+3.0i 4.0+4.0i");
148///     let b = cmatrix(vec![C64::new(1f64, 1f64),
149///                                 C64::new(2f64, 2f64),
150///                                 C64::new(3f64, 3f64),
151///                                 C64::new(4f64, 4f64)],
152///                             2, 2, Row
153///     );
154///     assert_eq!(a, b);
155/// }
156/// ```
157pub fn ml_cmatrix(s: &str) -> ComplexMatrix {
158    let str_row = s.split(";").collect::<Vec<&str>>();
159    let r = str_row.len();
160    let str_data = str_row
161        .iter()
162        .map(|x| x.trim().split(" ").collect::<Vec<&str>>())
163        .collect::<Vec<Vec<&str>>>();
164    let c = str_data[0].len();
165    let data = str_data
166        .iter()
167        .flat_map(|t| {
168            t.iter()
169                .map(|x| x.parse::<C64>().unwrap())
170                .collect::<Vec<C64>>()
171        })
172        .collect::<Vec<C64>>();
173
174    cmatrix(data, r, c, Shape::Row)
175}
176
177///  Pretty Print
178impl fmt::Display for ComplexMatrix {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
180        write!(f, "{}", self.spread())
181    }
182}
183
184/// PartialEq implements
185impl PartialEq for ComplexMatrix {
186    fn eq(&self, other: &ComplexMatrix) -> bool {
187        if self.shape == other.shape {
188            self.data
189                .clone()
190                .into_iter()
191                .zip(other.data.clone())
192                .all(|(x, y)| nearly_eq(x.re, y.re) && nearly_eq(x.im, y.im))
193                && self.row == other.row
194        } else {
195            self.eq(&other.change_shape())
196        }
197    }
198}
199
200impl MatrixTrait for ComplexMatrix {
201    type Scalar = C64;
202
203    /// Raw pointer for `self.data`
204    fn ptr(&self) -> *const C64 {
205        &self.data[0] as *const C64
206    }
207
208    /// Raw mutable pointer for `self.data`
209    fn mut_ptr(&mut self) -> *mut C64 {
210        &mut self.data[0] as *mut C64
211    }
212
213    /// Slice of `self.data`
214    ///
215    /// # Examples
216    /// ```rust
217    /// use peroxide::fuga::*;
218    /// use peroxide::complex::matrix::*;
219    ///
220    /// let a = cmatrix(vec![C64::new(1f64, 1f64),
221    ///                             C64::new(2f64, 2f64),
222    ///                             C64::new(3f64, 3f64),
223    ///                             C64::new(4f64, 4f64)],
224    ///                             2, 2, Row
225    ///     );
226    /// let b = a.as_slice();
227    /// assert_eq!(b, &[C64::new(1f64, 1f64),
228    ///                 C64::new(2f64, 2f64),
229    ///                 C64::new(3f64, 3f64),
230    ///                 C64::new(4f64, 4f64)]);
231    /// ```
232    fn as_slice(&self) -> &[C64] {
233        &self.data[..]
234    }
235
236    /// Mutable slice of `self.data`
237    ///
238    /// # Examples
239    /// ```rust
240    /// use peroxide::fuga::*;
241    /// use peroxide::complex::matrix::*;
242    ///
243    /// let mut a = cmatrix(vec![C64::new(1f64, 1f64),
244    ///                                 C64::new(2f64, 2f64),
245    ///                                 C64::new(3f64, 3f64),
246    ///                                 C64::new(4f64, 4f64)],
247    ///                             2, 2, Row
248    ///     );
249    /// let mut b = a.as_mut_slice();
250    /// b[1] = C64::new(5f64, 5f64);
251    /// assert_eq!(b, &[C64::new(1f64, 1f64),
252    ///                 C64::new(5f64, 5f64),
253    ///                 C64::new(3f64, 3f64),
254    ///                 C64::new(4f64, 4f64)]);
255    /// assert_eq!(a, cmatrix(vec![C64::new(1f64, 1f64),
256    ///                                   C64::new(5f64, 5f64),
257    ///                                   C64::new(3f64, 3f64),
258    ///                                   C64::new(4f64, 4f64)],
259    ///                               2, 2, Row));
260    /// ```
261    fn as_mut_slice(&mut self) -> &mut [C64] {
262        &mut self.data[..]
263    }
264
265    /// Change Bindings
266    ///
267    /// `Row` -> `Col` or `Col` -> `Row`
268    ///
269    /// # Examples
270    /// ```rust
271    /// use peroxide::fuga::*;
272    /// use peroxide::complex::matrix::*;
273    ///
274    /// let mut a = cmatrix(vec![C64::new(1f64, 1f64),
275    ///                                 C64::new(2f64, 2f64),
276    ///                                 C64::new(3f64, 3f64),
277    ///                                 C64::new(4f64, 4f64)],
278    ///                             2, 2, Row
279    ///     );
280    /// assert_eq!(a.shape, Row);
281    /// let b = a.change_shape();
282    /// assert_eq!(b.shape, Col);
283    /// ```
284    fn change_shape(&self) -> Self {
285        let r = self.row;
286        let c = self.col;
287        assert_eq!(r * c, self.data.len());
288        let l = r * c - 1;
289        let mut data: Vec<C64> = self.data.clone();
290        let ref_data = &self.data;
291
292        match self.shape {
293            Shape::Row => {
294                for i in 0..l {
295                    let s = (i * c) % l;
296                    data[i] = ref_data[s];
297                }
298                data[l] = ref_data[l];
299                cmatrix(data, r, c, Shape::Col)
300            }
301            Shape::Col => {
302                for i in 0..l {
303                    let s = (i * r) % l;
304                    data[i] = ref_data[s];
305                }
306                data[l] = ref_data[l];
307                cmatrix(data, r, c, Shape::Row)
308            }
309        }
310    }
311
312    /// Change Bindings Mutably
313    ///
314    /// `Row` -> `Col` or `Col` -> `Row`
315    ///
316    /// # Examples
317    /// ```rust
318    /// use peroxide::fuga::*;
319    /// use peroxide::complex::matrix::*;
320    ///
321    /// let mut a = cmatrix(vec![
322    ///         C64::new(1f64, 1f64),
323    ///         C64::new(2f64, 2f64),
324    ///         C64::new(3f64, 3f64),
325    ///         C64::new(4f64, 4f64)
326    ///     ],
327    ///     2, 2, Row
328    /// );
329    /// assert_eq!(a.shape, Row);
330    /// a.change_shape_mut();
331    /// assert_eq!(a.shape, Col);
332    /// ```
333    fn change_shape_mut(&mut self) {
334        let r = self.row;
335        let c = self.col;
336        assert_eq!(r * c, self.data.len());
337        let l = r * c - 1;
338        let ref_data = self.data.clone();
339
340        match self.shape {
341            Shape::Row => {
342                for i in 0..l {
343                    let s = (i * c) % l;
344                    self.data[i] = ref_data[s];
345                }
346                self.data[l] = ref_data[l];
347                self.shape = Shape::Col;
348            }
349            Shape::Col => {
350                for i in 0..l {
351                    let s = (i * r) % l;
352                    self.data[i] = ref_data[s];
353                }
354                self.data[l] = ref_data[l];
355                self.shape = Shape::Row;
356            }
357        }
358    }
359
360    /// Spread data(1D vector) to 2D formatted String
361    ///
362    /// # Examples
363    /// ```rust
364    /// use peroxide::fuga::*;
365    /// use peroxide::complex::matrix::*;
366    ///
367    /// let a = cmatrix(vec![C64::new(1f64, 1f64),
368    ///                                 C64::new(2f64, 2f64),
369    ///                                 C64::new(3f64, 3f64),
370    ///                                 C64::new(4f64, 4f64)],
371    ///                             2, 2, Row
372    ///     );
373    /// println!("{}", a.spread()); // same as println!("{}", a);
374    /// // Result:
375    /// //       c[0]    c[1]
376    /// // r[0]  1+1i    3+3i
377    /// // r[1]  2+2i    4+4i
378    /// ```
379    fn spread(&self) -> String {
380        assert_eq!(self.row * self.col, self.data.len());
381        let r = self.row;
382        let c = self.col;
383        let mut key_row = 20usize;
384        let mut key_col = 20usize;
385
386        if r > 100 || c > 100 || (r > 20 && c > 20) {
387            let part = if r <= 10 {
388                key_row = r;
389                key_col = 100;
390                self.take_col(100)
391            } else if c <= 10 {
392                key_row = 100;
393                key_col = c;
394                self.take_row(100)
395            } else {
396                self.take_row(20).take_col(20)
397            };
398            return format!(
399                "Result is too large to print - {}x{}\n only print {}x{} parts:\n{}",
400                self.row.to_string(),
401                self.col.to_string(),
402                key_row.to_string(),
403                key_col.to_string(),
404                part.spread()
405            );
406        }
407
408        // Find maximum length of data
409        let sample = self.data.clone();
410        let mut space: usize = sample
411            .into_iter()
412            .map(
413                |x| min(format!("{:.4}", x).len(), x.to_string().len()), // Choose minimum of approx vs normal
414            )
415            .fold(0, |x, y| max(x, y))
416            + 1;
417
418        if space < 5 {
419            space = 5;
420        }
421
422        let mut result = String::new();
423
424        result.push_str(&tab("", 5));
425        for i in 0..c {
426            result.push_str(&tab(&format!("c[{}]", i), space)); // Header
427        }
428        result.push('\n');
429
430        for i in 0..r {
431            result.push_str(&tab(&format!("r[{}]", i), 5));
432            for j in 0..c {
433                let st1 = format!("{:.4}", self[(i, j)]); // Round at fourth position
434                let st2 = self[(i, j)].to_string(); // Normal string
435                let mut st = st2.clone();
436
437                // Select more small thing
438                if st1.len() < st2.len() {
439                    st = st1;
440                }
441
442                result.push_str(&tab(&st, space));
443            }
444            if i == (r - 1) {
445                break;
446            }
447            result.push('\n');
448        }
449
450        return result;
451    }
452
453    /// Extract Column
454    ///
455    /// # Examples
456    /// ```rust
457    /// #[macro_use]
458    /// extern crate peroxide;
459    /// use peroxide::fuga::*;
460    /// use peroxide::complex::matrix::*;
461    ///
462    /// fn main() {
463    ///     let a = cmatrix(vec![C64::new(1f64, 1f64),
464    ///                             C64::new(2f64, 2f64),
465    ///                             C64::new(3f64, 3f64),
466    ///                             C64::new(4f64, 4f64)],
467    ///                             2, 2, Row
468    ///         );
469    ///     assert_eq!(a.col(0), vec![C64::new(1f64, 1f64), C64::new(3f64, 3f64)]);
470    /// }
471    /// ```
472    fn col(&self, index: usize) -> Vec<C64> {
473        assert!(index < self.col);
474        let mut container: Vec<C64> = vec![Complex::zero(); self.row];
475        for i in 0..self.row {
476            container[i] = self[(i, index)];
477        }
478        container
479    }
480
481    /// Extract Row
482    ///
483    /// # Examples
484    /// ```rust
485    /// #[macro_use]
486    /// extern crate peroxide;
487    /// use peroxide::fuga::*;
488    /// use peroxide::complex::matrix::*;
489    ///
490    /// fn main() {
491    ///     let a = cmatrix(vec![C64::new(1f64, 1f64),
492    ///                             C64::new(2f64, 2f64),
493    ///                             C64::new(3f64, 3f64),
494    ///                             C64::new(4f64, 4f64)],
495    ///                             2, 2, Row
496    ///         );
497    ///     assert_eq!(a.row(0), vec![C64::new(1f64, 1f64), C64::new(2f64, 2f64)]);
498    /// }
499    /// ```
500    fn row(&self, index: usize) -> Vec<C64> {
501        assert!(index < self.row);
502        let mut container: Vec<C64> = vec![Complex::zero(); self.col];
503        for i in 0..self.col {
504            container[i] = self[(index, i)];
505        }
506        container
507    }
508
509    /// Extract diagonal components
510    ///
511    /// # Examples
512    /// ```rust
513    /// #[macro_use]
514    /// extern crate peroxide;
515    /// use peroxide::fuga::*;
516    /// use peroxide::complex::matrix::*;
517    ///
518    /// fn main() {
519    ///     let a = cmatrix(vec![C64::new(1f64, 1f64),
520    ///                                 C64::new(2f64, 2f64),
521    ///                                 C64::new(3f64, 3f64),
522    ///                                 C64::new(4f64, 4f64)],
523    ///                             2, 2, Row
524    ///          );
525    ///     assert_eq!(a.diag(), vec![C64::new(1f64, 1f64) ,C64::new(4f64, 4f64)]);
526    /// }
527    /// ```
528    fn diag(&self) -> Vec<C64> {
529        let mut container = vec![Complex::zero(); self.row];
530        let r = self.row;
531        let c = self.col;
532        assert_eq!(r, c);
533
534        let c2 = c + 1;
535        for i in 0..r {
536            container[i] = self.data[i * c2];
537        }
538        container
539    }
540
541    /// Transpose
542    ///
543    /// # Examples
544    /// ```rust
545    /// use peroxide::fuga::*;
546    /// use peroxide::complex::matrix::*;
547    ///
548    /// let a = cmatrix(vec![C64::new(1f64, 1f64),
549    ///                             C64::new(2f64, 2f64),
550    ///                             C64::new(3f64, 3f64),
551    ///                             C64::new(4f64, 4f64)],
552    ///                             2, 2, Row
553    ///     );
554    /// let a_t = cmatrix(vec![C64::new(1f64, 1f64),
555    ///                               C64::new(2f64, 2f64),
556    ///                               C64::new(3f64, 3f64),
557    ///                               C64::new(4f64, 4f64)],
558    ///                             2, 2, Col
559    ///     );
560    ///
561    /// assert_eq!(a.transpose(), a_t);
562    /// ```
563    fn transpose(&self) -> Self {
564        match self.shape {
565            Shape::Row => cmatrix(self.data.clone(), self.col, self.row, Shape::Col),
566            Shape::Col => cmatrix(self.data.clone(), self.col, self.row, Shape::Row),
567        }
568    }
569
570    /// Substitute Col
571    #[inline]
572    fn subs_col(&mut self, idx: usize, v: &[C64]) {
573        for i in 0..self.row {
574            self[(i, idx)] = v[i];
575        }
576    }
577
578    /// Substitute Row
579    #[inline]
580    fn subs_row(&mut self, idx: usize, v: &[C64]) {
581        for j in 0..self.col {
582            self[(idx, j)] = v[j];
583        }
584    }
585
586    /// From index operations
587    fn from_index<F, G>(f: F, size: (usize, usize)) -> ComplexMatrix
588    where
589        F: Fn(usize, usize) -> G + Copy,
590        G: Into<C64>,
591    {
592        let row = size.0;
593        let col = size.1;
594
595        let mut mat = cmatrix(vec![Complex::zero(); row * col], row, col, Shape::Row);
596
597        for i in 0..row {
598            for j in 0..col {
599                mat[(i, j)] = f(i, j).into();
600            }
601        }
602        mat
603    }
604
605    /// Matrix to `Vec<Vec<C64>>`
606    ///
607    /// To send `Matrix` to `inline-python`
608    fn to_vec(&self) -> Vec<Vec<C64>> {
609        let mut result = vec![vec![Complex::zero(); self.col]; self.row];
610        for i in 0..self.row {
611            result[i] = self.row(i);
612        }
613        result
614    }
615
616    fn to_diag(&self) -> ComplexMatrix {
617        assert_eq!(self.row, self.col, "Should be square matrix");
618        let mut result = cmatrix(
619            vec![Complex::zero(); self.row * self.col],
620            self.row,
621            self.col,
622            Shape::Row,
623        );
624        let diag = self.diag();
625        for i in 0..self.row {
626            result[(i, i)] = diag[i];
627        }
628        result
629    }
630
631    /// Submatrix
632    ///
633    /// # Description
634    /// Return below elements of complex matrix to a new complex matrix
635    ///
636    /// $$
637    /// \begin{pmatrix}
638    /// \\ddots & & & & \\\\
639    ///   & start & \\cdots & end.1 & \\\\
640    ///   & \\vdots & \\ddots & \\vdots & \\\\
641    ///   & end.0 & \\cdots & end & \\\\
642    ///   & & & & \\ddots
643    /// \end{pmatrix}
644    /// $$
645    ///
646    /// # Examples
647    /// ```rust
648    /// #[macro_use]
649    /// extern crate peroxide;
650    /// use peroxide::fuga::*;
651    /// use peroxide::complex::matrix::*;
652    ///
653    /// fn main() {
654    ///     let a = ml_cmatrix("1.0+1.0i 2.0+2.0i 3.0+3.0i;
655    ///                                4.0+4.0i 5.0+5.0i 6.0+6.0i;
656    ///                                7.0+7.0i 8.0+8.0i 9.0+9.0i");
657    ///     let b = cmatrix(vec![C64::new(5f64, 5f64),
658    ///                                 C64::new(6f64, 6f64),
659    ///                                 C64::new(8f64, 8f64),
660    ///                                 C64::new(9f64, 9f64)],
661    ///                             2, 2, Row
662    ///     );
663    ///     let c = a.submat((1, 1), (2, 2));
664    ///     assert_eq!(b, c);
665    /// }
666    /// ```
667    fn submat(&self, start: (usize, usize), end: (usize, usize)) -> ComplexMatrix {
668        let row = end.0 - start.0 + 1;
669        let col = end.1 - start.1 + 1;
670        let mut result = cmatrix(vec![Complex::zero(); row * col], row, col, self.shape);
671        for i in 0..row {
672            for j in 0..col {
673                result[(i, j)] = self[(start.0 + i, start.1 + j)];
674            }
675        }
676        result
677    }
678
679    /// Substitute complex matrix to specific position
680    ///
681    /// # Description
682    /// Substitute below elements of complex matrix
683    ///
684    /// $$
685    /// \begin{pmatrix}
686    /// \\ddots & & & & \\\\
687    ///   & start & \\cdots & end.1 & \\\\
688    ///   & \\vdots & \\ddots & \\vdots & \\\\
689    ///   & end.0 & \\cdots & end & \\\\
690    ///   & & & & \\ddots
691    /// \end{pmatrix}
692    /// $$
693    ///
694    /// # Examples
695    /// ```
696    /// extern crate peroxide;
697    /// use peroxide::fuga::*;
698    /// use peroxide::complex::matrix::*;
699    ///
700    /// fn main() {
701    ///     let mut a = ml_cmatrix("1.0+1.0i 2.0+2.0i 3.0+3.0i;
702    ///                                4.0+4.0i 5.0+5.0i 6.0+6.0i;
703    ///                                7.0+7.0i 8.0+8.0i 9.0+9.0i");
704    ///     let b = cmatrix(vec![C64::new(1f64, 1f64),
705    ///                                 C64::new(2f64, 2f64),
706    ///                                 C64::new(3f64, 3f64),
707    ///                                 C64::new(4f64, 4f64)],
708    ///                             2, 2, Row);
709    ///     let c = ml_cmatrix("1.0+1.0i 2.0+2.0i 3.0+3.0i;
710    ///                                4.0+4.0i 1.0+1.0i 2.0+2.0i;
711    ///                                7.0+7.0i 3.0+3.0i 4.0+4.0i");
712    ///     a.subs_mat((1,1), (2,2), &b);
713    ///     assert_eq!(a, c);       
714    /// }
715    /// ```
716    fn subs_mat(&mut self, start: (usize, usize), end: (usize, usize), m: &ComplexMatrix) {
717        let row = end.0 - start.0 + 1;
718        let col = end.1 - start.1 + 1;
719        for i in 0..row {
720            for j in 0..col {
721                self[(start.0 + i, start.1 + j)] = m[(i, j)];
722            }
723        }
724    }
725}
726
727// =============================================================================
728// Mathematics for Matrix
729// =============================================================================
730impl Vector for ComplexMatrix {
731    type Scalar = C64;
732
733    fn add_vec(&self, other: &Self) -> Self {
734        assert_eq!(self.row, other.row);
735        assert_eq!(self.col, other.col);
736
737        let mut result = cmatrix(self.data.clone(), self.row, self.col, self.shape);
738        for i in 0..self.row {
739            for j in 0..self.col {
740                result[(i, j)] += other[(i, j)];
741            }
742        }
743        result
744    }
745
746    fn sub_vec(&self, other: &Self) -> Self {
747        assert_eq!(self.row, other.row);
748        assert_eq!(self.col, other.col);
749
750        let mut result = cmatrix(self.data.clone(), self.row, self.col, self.shape);
751        for i in 0..self.row {
752            for j in 0..self.col {
753                result[(i, j)] -= other[(i, j)];
754            }
755        }
756        result
757    }
758
759    fn mul_scalar(&self, other: Self::Scalar) -> Self {
760        let scalar = other;
761        self.fmap(|x| x * scalar)
762    }
763}
764
765impl Normed for ComplexMatrix {
766    type UnsignedScalar = f64;
767
768    fn norm(&self, kind: Norm) -> Self::UnsignedScalar {
769        match kind {
770            Norm::F => {
771                let mut s = Complex::zero();
772                for i in 0..self.data.len() {
773                    s += self.data[i].powi(2);
774                }
775                s.sqrt().re
776            }
777            Norm::Lpq(p, q) => {
778                let mut s = Complex::zero();
779                for j in 0..self.col {
780                    let mut s_row = Complex::zero();
781                    for i in 0..self.row {
782                        s_row += self[(i, j)].powi(p as i32);
783                    }
784                    s += s_row.powf(q / p);
785                }
786                s.powf(1f64 / q).re
787            }
788            Norm::L1 => {
789                let mut m = Complex::zero();
790                match self.shape {
791                    Shape::Row => self.change_shape().norm(Norm::L1),
792                    Shape::Col => {
793                        for c in 0..self.col {
794                            let s: C64 = self.col(c).iter().sum();
795                            if s.re > m.re {
796                                m = s;
797                            }
798                        }
799                        m.re
800                    }
801                }
802            }
803            Norm::LInf => {
804                let mut m = Complex::zero();
805                match self.shape {
806                    Shape::Col => self.change_shape().norm(Norm::LInf),
807                    Shape::Row => {
808                        for r in 0..self.row {
809                            let s: C64 = self.row(r).iter().sum();
810                            if s.re > m.re {
811                                m = s;
812                            }
813                        }
814                        m.re
815                    }
816                }
817            }
818            Norm::L2 => {
819                unimplemented!()
820            }
821            Norm::Lp(_) => unimplemented!(),
822        }
823    }
824
825    fn normalize(&self, _kind: Norm) -> Self
826    where
827        Self: Sized,
828    {
829        unimplemented!()
830    }
831}
832
833/// Frobenius inner product
834impl InnerProduct for ComplexMatrix {
835    fn dot(&self, rhs: &Self) -> C64 {
836        if self.shape == rhs.shape {
837            self.data.dot(&rhs.data)
838        } else {
839            self.data.dot(&rhs.change_shape().data)
840        }
841    }
842}
843
844/// TODO: Transpose
845
846/// Matrix as Linear operator for Vector
847#[allow(non_snake_case)]
848impl LinearOp<Vec<C64>, Vec<C64>> for ComplexMatrix {
849    fn apply(&self, other: &Vec<C64>) -> Vec<C64> {
850        assert_eq!(self.col, other.len());
851        let mut c = vec![Complex::zero(); self.row];
852        cgemv(Complex::one(), self, other, Complex::zero(), &mut c);
853        c
854    }
855}
856
857/// R like cbind - concatenate two comlex matrix by column direction
858pub fn complex_cbind(m1: ComplexMatrix, m2: ComplexMatrix) -> Result<ComplexMatrix> {
859    let mut temp = m1;
860    if temp.shape != Shape::Col {
861        temp = temp.change_shape();
862    }
863
864    let mut temp2 = m2;
865    if temp2.shape != Shape::Col {
866        temp2 = temp2.change_shape();
867    }
868
869    let mut v = temp.data;
870    let mut c = temp.col;
871    let r = temp.row;
872
873    if r != temp2.row {
874        bail!(ConcatenateError::DifferentLength);
875    }
876    v.extend_from_slice(&temp2.data[..]);
877    c += temp2.col;
878    Ok(cmatrix(v, r, c, Shape::Col))
879}
880
881/// R like rbind - concatenate two complex matrix by row direction
882pub fn complex_rbind(m1: ComplexMatrix, m2: ComplexMatrix) -> Result<ComplexMatrix> {
883    let mut temp = m1;
884    if temp.shape != Shape::Row {
885        temp = temp.change_shape();
886    }
887
888    let mut temp2 = m2;
889    if temp2.shape != Shape::Row {
890        temp2 = temp2.change_shape();
891    }
892
893    let mut v = temp.data;
894    let c = temp.col;
895    let mut r = temp.row;
896
897    if c != temp2.col {
898        bail!(ConcatenateError::DifferentLength);
899    }
900    v.extend_from_slice(&temp2.data[..]);
901    r += temp2.row;
902    Ok(cmatrix(v, r, c, Shape::Row))
903}
904
905impl MatrixProduct for ComplexMatrix {
906    fn kronecker(&self, other: &Self) -> Self {
907        let r1 = self.row;
908        let c1 = self.col;
909
910        let mut result = self[(0, 0)] * other;
911
912        for j in 1..c1 {
913            let n = self[(0, j)] * other;
914            result = complex_cbind(result, n).unwrap();
915        }
916
917        for i in 1..r1 {
918            let mut m = self[(i, 0)] * other;
919            for j in 1..c1 {
920                let n = self[(i, j)] * other;
921                m = complex_cbind(m, n).unwrap();
922            }
923            result = complex_rbind(result, m).unwrap();
924        }
925        result
926    }
927
928    fn hadamard(&self, other: &Self) -> Self {
929        assert_eq!(self.row, other.row);
930        assert_eq!(self.col, other.col);
931
932        let r = self.row;
933        let c = self.col;
934
935        let mut m = cmatrix(vec![Complex::zero(); r * c], r, c, self.shape);
936        for i in 0..r {
937            for j in 0..c {
938                m[(i, j)] = self[(i, j)] * other[(i, j)]
939            }
940        }
941        m
942    }
943}
944
945// =============================================================================
946// Common Properties of Matrix & Vec<f64>
947// =============================================================================
948/// `Complex Matrix` to `Vec<C64>`
949impl Into<Vec<C64>> for ComplexMatrix {
950    fn into(self) -> Vec<C64> {
951        self.data
952    }
953}
954
955/// `&ComplexMatrix` to `&Vec<C64>`
956impl<'a> Into<&'a Vec<C64>> for &'a ComplexMatrix {
957    fn into(self) -> &'a Vec<C64> {
958        &self.data
959    }
960}
961
962/// `Vec<C64>` to `ComplexMatrix`
963impl Into<ComplexMatrix> for Vec<C64> {
964    fn into(self) -> ComplexMatrix {
965        let l = self.len();
966        cmatrix(self, l, 1, Shape::Col)
967    }
968}
969
970/// `&Vec<C64>` to `ComplexMatrix`
971impl Into<ComplexMatrix> for &Vec<C64> {
972    fn into(self) -> ComplexMatrix {
973        let l = self.len();
974        cmatrix(self.clone(), l, 1, Shape::Col)
975    }
976}
977
978// =============================================================================
979// Standard Operation for Complex Matrix (ADD)
980// =============================================================================
981
982/// Element-wise addition of Complex Matrix
983///
984/// # Caution
985/// > You should remember ownership.
986/// > If you use ComplexMatrix `a,b` then you can't use them after.
987impl Add<ComplexMatrix> for ComplexMatrix {
988    type Output = Self;
989
990    fn add(self, other: Self) -> Self {
991        assert_eq!(&self.row, &other.row);
992        assert_eq!(&self.col, &other.col);
993
994        let mut result = cmatrix(self.data.clone(), self.row, self.col, self.shape);
995        for i in 0..self.row {
996            for j in 0..self.col {
997                result[(i, j)] += other[(i, j)];
998            }
999        }
1000        result
1001    }
1002}
1003
1004impl<'a, 'b> Add<&'b ComplexMatrix> for &'a ComplexMatrix {
1005    type Output = ComplexMatrix;
1006
1007    fn add(self, rhs: &'b ComplexMatrix) -> Self::Output {
1008        self.add_vec(rhs)
1009    }
1010}
1011
1012/// Element-wise addition between Complex Matrix & C64
1013///
1014/// # Examples
1015/// ```rust
1016/// #[macro_use]
1017/// extern crate peroxide;
1018/// use peroxide::fuga::*;
1019/// use peroxide::complex::matrix::*;
1020///
1021/// fn main() {
1022///     let mut a = ml_cmatrix("1.0+1.0i 2.0+2.0i;
1023///                                    4.0+4.0i 5.0+5.0i");
1024///     let a_exp = ml_cmatrix("2.0+2.0i 3.0+3.0i;
1025///                                    5.0+5.0i 6.0+6.0i");
1026///     assert_eq!(a + C64::new(1_f64, 1_f64), a_exp);
1027/// }
1028/// ```
1029impl<T> Add<T> for ComplexMatrix
1030where
1031    T: Into<C64> + Copy,
1032{
1033    type Output = Self;
1034    fn add(self, other: T) -> Self {
1035        self.fmap(|x| x + other.into())
1036    }
1037}
1038
1039/// Element-wise addition between &ComplexMatrix & C64
1040impl<'a, T> Add<T> for &'a ComplexMatrix
1041where
1042    T: Into<C64> + Copy,
1043{
1044    type Output = ComplexMatrix;
1045
1046    fn add(self, other: T) -> Self::Output {
1047        self.fmap(|x| x + other.into())
1048    }
1049}
1050
1051// Element-wise addition between C64 & ComplexMatrix
1052///
1053/// # Examples
1054///
1055/// ```rust
1056/// #[macro_use]
1057/// extern crate peroxide;
1058/// use peroxide::fuga::*;
1059/// use peroxide::complex::matrix::*;
1060///
1061/// fn main() {
1062///     let mut a = ml_cmatrix("1.0+1.0i 2.0+2.0i;
1063///                                    4.0+4.0i 5.0+5.0i");
1064///     let a_exp = ml_cmatrix("2.0+2.0i 3.0+3.0i;
1065///                                    5.0+5.0i 6.0+6.0i");
1066///     assert_eq!(C64::new(1_f64, 1_f64) + a, a_exp);
1067/// }
1068/// ```
1069impl Add<ComplexMatrix> for C64 {
1070    type Output = ComplexMatrix;
1071
1072    fn add(self, other: ComplexMatrix) -> Self::Output {
1073        other.add(self)
1074    }
1075}
1076
1077/// Element-wise addition between C64 & &ComplexMatrix
1078impl<'a> Add<&'a ComplexMatrix> for C64 {
1079    type Output = ComplexMatrix;
1080
1081    fn add(self, other: &'a ComplexMatrix) -> Self::Output {
1082        other.add(self)
1083    }
1084}
1085
1086// =============================================================================
1087// Standard Operation for Matrix (Neg)
1088// =============================================================================
1089/// Negation of Complex Matrix
1090///
1091/// # Examples
1092/// ```rust
1093/// extern crate peroxide;
1094/// use peroxide::fuga::*;
1095/// use peroxide::complex::matrix::*;
1096///
1097/// let a = cmatrix(vec![C64::new(1f64, 1f64),
1098///                             C64::new(2f64, 2f64),
1099///                             C64::new(3f64, 3f64),
1100///                             C64::new(4f64, 4f64)],
1101///                             2, 2, Row);
1102/// let a_neg = cmatrix(vec![C64::new(-1f64, -1f64),
1103///                                 C64::new(-2f64, -2f64),
1104///                                 C64::new(-3f64, -3f64),
1105///                                 C64::new(-4f64, -4f64)],
1106///                             2, 2, Row);
1107/// assert_eq!(-a, a_neg);
1108/// ```
1109impl Neg for ComplexMatrix {
1110    type Output = Self;
1111
1112    fn neg(self) -> Self {
1113        cmatrix(
1114            self.data.into_iter().map(|x: C64| -x).collect::<Vec<C64>>(),
1115            self.row,
1116            self.col,
1117            self.shape,
1118        )
1119    }
1120}
1121
1122/// Negation of &'a Complex Matrix
1123impl<'a> Neg for &'a ComplexMatrix {
1124    type Output = ComplexMatrix;
1125
1126    fn neg(self) -> Self::Output {
1127        cmatrix(
1128            self.data
1129                .clone()
1130                .into_iter()
1131                .map(|x: C64| -x)
1132                .collect::<Vec<C64>>(),
1133            self.row,
1134            self.col,
1135            self.shape,
1136        )
1137    }
1138}
1139
1140// =============================================================================
1141// Standard Operation for Matrix (Sub)
1142// =============================================================================
1143/// Subtraction between Complex Matrix
1144///
1145/// # Examples
1146/// ```rust
1147/// #[macro_use]
1148/// extern crate peroxide;
1149/// use peroxide::fuga::*;
1150/// use peroxide::complex::matrix::*;
1151///
1152/// fn main() {
1153///     let a = ml_cmatrix("10.0+10.0i 20.0+20.0i;
1154///                                40.0+40.0i 50.0+50.0i");
1155///     let b = ml_cmatrix("1.0+1.0i 2.0+2.0i;
1156///                                4.0+4.0i 5.0+5.0i");
1157///     let diff = ml_cmatrix("9.0+9.0i 18.0+18.0i;
1158///                                   36.0+36.0i 45.0+45.0i");
1159///     assert_eq!(a-b, diff);
1160/// }
1161/// ```
1162impl Sub<ComplexMatrix> for ComplexMatrix {
1163    type Output = Self;
1164
1165    fn sub(self, other: Self) -> Self::Output {
1166        assert_eq!(&self.row, &other.row);
1167        assert_eq!(&self.col, &other.col);
1168        let mut result = cmatrix(self.data.clone(), self.row, self.col, self.shape);
1169        for i in 0..self.row {
1170            for j in 0..self.col {
1171                result[(i, j)] -= other[(i, j)];
1172            }
1173        }
1174        result
1175    }
1176}
1177
1178impl<'a, 'b> Sub<&'b ComplexMatrix> for &'a ComplexMatrix {
1179    type Output = ComplexMatrix;
1180
1181    fn sub(self, rhs: &'b ComplexMatrix) -> Self::Output {
1182        self.sub_vec(rhs)
1183    }
1184}
1185
1186/// Subtraction between Complex Matrix & C64
1187impl<T> Sub<T> for ComplexMatrix
1188where
1189    T: Into<C64> + Copy,
1190{
1191    type Output = Self;
1192
1193    fn sub(self, other: T) -> Self::Output {
1194        self.fmap(|x| x - other.into())
1195    }
1196}
1197
1198/// Subtraction between &Complex Matrix & C64
1199impl<'a, T> Sub<T> for &'a ComplexMatrix
1200where
1201    T: Into<C64> + Copy,
1202{
1203    type Output = ComplexMatrix;
1204
1205    fn sub(self, other: T) -> Self::Output {
1206        self.fmap(|x| x - other.into())
1207    }
1208}
1209
1210/// Subtraction Complex Matrix with C64
1211///
1212/// # Examples
1213/// ```rust
1214/// #[macro_use]
1215/// extern crate peroxide;
1216/// use peroxide::fuga::*;
1217/// use peroxide::complex::matrix::*;
1218///
1219/// fn main() {
1220///     let mut a = ml_cmatrix("1.0+1.0i 2.0+2.0i;
1221///                                    4.0+4.0i 5.0+5.0i");
1222///     let a_exp = ml_cmatrix("0.0+0.0i 1.0+1.0i;
1223///                                    3.0+3.0i 4.0+4.0i");
1224///     assert_eq!(a - C64::new(1_f64, 1_f64), a_exp);
1225/// }
1226/// ```
1227impl Sub<ComplexMatrix> for C64 {
1228    type Output = ComplexMatrix;
1229
1230    fn sub(self, other: ComplexMatrix) -> Self::Output {
1231        -other.sub(self)
1232    }
1233}
1234
1235impl<'a> Sub<&'a ComplexMatrix> for f64 {
1236    type Output = ComplexMatrix;
1237
1238    fn sub(self, other: &'a ComplexMatrix) -> Self::Output {
1239        -other.sub(self)
1240    }
1241}
1242
1243// =============================================================================
1244// Multiplication for Complex Matrix
1245// =============================================================================
1246/// Element-wise multiplication between Complex Matrix vs C64
1247impl Mul<C64> for ComplexMatrix {
1248    type Output = Self;
1249
1250    fn mul(self, other: C64) -> Self::Output {
1251        self.fmap(|x| x * other)
1252    }
1253}
1254
1255impl Mul<ComplexMatrix> for C64 {
1256    type Output = ComplexMatrix;
1257
1258    fn mul(self, other: ComplexMatrix) -> Self::Output {
1259        other.mul(self)
1260    }
1261}
1262
1263impl<'a> Mul<&'a ComplexMatrix> for C64 {
1264    type Output = ComplexMatrix;
1265
1266    fn mul(self, other: &'a ComplexMatrix) -> Self::Output {
1267        other.mul_scalar(self)
1268    }
1269}
1270
1271/// Matrix Multiplication
1272///
1273/// # Examples
1274/// ```rust
1275/// #[macro_use]
1276/// extern crate peroxide;
1277/// use peroxide::fuga::*;
1278/// use peroxide::complex::matrix::*;
1279///
1280/// fn main() {
1281///     let mut a = ml_cmatrix("1.0+1.0i 2.0+2.0i;
1282///                                    4.0+4.0i 5.0+5.0i");
1283///     let mut b = ml_cmatrix("2.0+2.0i 2.0+2.0i;
1284///                                    5.0+5.0i 5.0+5.0i");
1285///     let prod = ml_cmatrix("0.0+24.0i 0.0+24.0i;
1286///                                    0.0+66.0i 0.0+66.0i");
1287///     assert_eq!(a * b, prod);
1288/// }
1289/// ```
1290impl Mul<ComplexMatrix> for ComplexMatrix {
1291    type Output = Self;
1292
1293    fn mul(self, other: Self) -> Self::Output {
1294        cmatmul(&self, &other)
1295    }
1296}
1297
1298impl<'a, 'b> Mul<&'b ComplexMatrix> for &'a ComplexMatrix {
1299    type Output = ComplexMatrix;
1300
1301    fn mul(self, other: &'b ComplexMatrix) -> Self::Output {
1302        cmatmul(self, other)
1303    }
1304}
1305
1306#[allow(non_snake_case)]
1307impl Mul<Vec<C64>> for ComplexMatrix {
1308    type Output = Vec<C64>;
1309
1310    fn mul(self, other: Vec<C64>) -> Self::Output {
1311        self.apply(&other)
1312    }
1313}
1314
1315#[allow(non_snake_case)]
1316impl<'a, 'b> Mul<&'b Vec<C64>> for &'a ComplexMatrix {
1317    type Output = Vec<C64>;
1318
1319    fn mul(self, other: &'b Vec<C64>) -> Self::Output {
1320        self.apply(other)
1321    }
1322}
1323
1324/// Matrix multiplication for `Vec<C64>` vs `ComplexMatrix`
1325impl Mul<ComplexMatrix> for Vec<C64> {
1326    type Output = Vec<C64>;
1327
1328    fn mul(self, other: ComplexMatrix) -> Self::Output {
1329        assert_eq!(self.len(), other.row);
1330        let mut c = vec![Complex::zero(); other.col];
1331        complex_gevm(Complex::one(), &self, &other, Complex::zero(), &mut c);
1332        c
1333    }
1334}
1335
1336impl<'a, 'b> Mul<&'b ComplexMatrix> for &'a Vec<C64> {
1337    type Output = Vec<C64>;
1338
1339    fn mul(self, other: &'b ComplexMatrix) -> Self::Output {
1340        assert_eq!(self.len(), other.row);
1341        let mut c = vec![Complex::zero(); other.col];
1342        complex_gevm(Complex::one(), self, other, Complex::zero(), &mut c);
1343        c
1344    }
1345}
1346
1347// =============================================================================
1348// Standard Operation for Matrix (DIV)
1349// =============================================================================
1350/// Element-wise division between Complex Matrix vs C64
1351impl Div<C64> for ComplexMatrix {
1352    type Output = Self;
1353
1354    fn div(self, other: C64) -> Self::Output {
1355        self.fmap(|x| x / other)
1356    }
1357}
1358
1359impl<'a> Div<C64> for &'a ComplexMatrix {
1360    type Output = ComplexMatrix;
1361
1362    fn div(self, other: C64) -> Self::Output {
1363        self.fmap(|x| x / other)
1364    }
1365}
1366
1367/// Index for Complex Matrix
1368///
1369/// `(usize, usize) -> C64`
1370///
1371/// # Examples
1372/// ```rust
1373/// extern crate peroxide;
1374/// use peroxide::fuga::*;
1375/// use peroxide::complex::matrix::*;
1376///
1377/// let a = cmatrix(vec![C64::new(1f64, 1f64),
1378///                             C64::new(2f64, 2f64),
1379///                             C64::new(3f64, 3f64),
1380///                             C64::new(4f64, 4f64)],
1381///                             2, 2, Row
1382///     );
1383/// assert_eq!(a[(0,1)], C64::new(2f64, 2f64));
1384/// ```
1385impl Index<(usize, usize)> for ComplexMatrix {
1386    type Output = C64;
1387
1388    fn index(&self, pair: (usize, usize)) -> &C64 {
1389        let p = self.ptr();
1390        let i = pair.0;
1391        let j = pair.1;
1392        assert!(i < self.row && j < self.col, "Index out of range");
1393        match self.shape {
1394            Shape::Row => unsafe { &*p.add(i * self.col + j) },
1395            Shape::Col => unsafe { &*p.add(i + j * self.row) },
1396        }
1397    }
1398}
1399
1400/// IndexMut for Complex Matrix (Assign)
1401///
1402/// `(usize, usize) -> C64`
1403///
1404/// # Examples
1405/// ```rust
1406/// extern crate peroxide;
1407/// use peroxide::fuga::*;
1408/// use peroxide::complex::matrix::*;
1409///
1410/// let mut a = cmatrix(vec![C64::new(1f64, 1f64),
1411///                             C64::new(2f64, 2f64),
1412///                             C64::new(3f64, 3f64),
1413///                             C64::new(4f64, 4f64)],
1414///                             2, 2, Row
1415///     );
1416/// assert_eq!(a[(0,1)], C64::new(2f64, 2f64));
1417/// ```
1418impl IndexMut<(usize, usize)> for ComplexMatrix {
1419    fn index_mut(&mut self, pair: (usize, usize)) -> &mut C64 {
1420        let i = pair.0;
1421        let j = pair.1;
1422        let r = self.row;
1423        let c = self.col;
1424        assert!(i < self.row && j < self.col, "Index out of range");
1425        let p = self.mut_ptr();
1426        match self.shape {
1427            Shape::Row => {
1428                let idx = i * c + j;
1429                unsafe { &mut *p.add(idx) }
1430            }
1431            Shape::Col => {
1432                let idx = i + j * r;
1433                unsafe { &mut *p.add(idx) }
1434            }
1435        }
1436    }
1437}
1438
1439// =============================================================================
1440// Functional Programming Tools (Hand-written)
1441// =============================================================================
1442
1443impl FPMatrix for ComplexMatrix {
1444    type Scalar = C64;
1445
1446    fn take_row(&self, n: usize) -> Self {
1447        if n >= self.row {
1448            return self.clone();
1449        }
1450        match self.shape {
1451            Shape::Row => {
1452                let new_data = self
1453                    .data
1454                    .clone()
1455                    .into_iter()
1456                    .take(n * self.col)
1457                    .collect::<Vec<C64>>();
1458                cmatrix(new_data, n, self.col, Shape::Row)
1459            }
1460            Shape::Col => {
1461                let mut temp_data: Vec<C64> = Vec::new();
1462                for i in 0..n {
1463                    temp_data.extend(self.row(i));
1464                }
1465                cmatrix(temp_data, n, self.col, Shape::Row)
1466            }
1467        }
1468    }
1469
1470    fn take_col(&self, n: usize) -> Self {
1471        if n >= self.col {
1472            return self.clone();
1473        }
1474        match self.shape {
1475            Shape::Col => {
1476                let new_data = self
1477                    .data
1478                    .clone()
1479                    .into_iter()
1480                    .take(n * self.row)
1481                    .collect::<Vec<C64>>();
1482                cmatrix(new_data, self.row, n, Shape::Col)
1483            }
1484            Shape::Row => {
1485                let mut temp_data: Vec<C64> = Vec::new();
1486                for i in 0..n {
1487                    temp_data.extend(self.col(i));
1488                }
1489                cmatrix(temp_data, self.row, n, Shape::Col)
1490            }
1491        }
1492    }
1493
1494    fn skip_row(&self, n: usize) -> Self {
1495        assert!(n < self.row, "Skip range is larger than row of matrix");
1496
1497        let mut temp_data: Vec<C64> = Vec::new();
1498        for i in n..self.row {
1499            temp_data.extend(self.row(i));
1500        }
1501        cmatrix(temp_data, self.row - n, self.col, Shape::Row)
1502    }
1503
1504    fn skip_col(&self, n: usize) -> Self {
1505        assert!(n < self.col, "Skip range is larger than col of matrix");
1506
1507        let mut temp_data: Vec<C64> = Vec::new();
1508        for i in n..self.col {
1509            temp_data.extend(self.col(i));
1510        }
1511        cmatrix(temp_data, self.row, self.col - n, Shape::Col)
1512    }
1513
1514    fn fmap<F>(&self, f: F) -> Self
1515    where
1516        F: Fn(C64) -> C64,
1517    {
1518        let result = self.data.iter().map(|x| f(*x)).collect::<Vec<C64>>();
1519        cmatrix(result, self.row, self.col, self.shape)
1520    }
1521
1522    /// Column map
1523    ///
1524    /// # Example
1525    /// ```rust
1526    /// use peroxide::fuga::*;
1527    /// use peroxide::complex::matrix::*;
1528    /// use peroxide::traits::fp::FPMatrix;
1529    ///
1530    /// fn main() {
1531    ///     let x = cmatrix(vec![C64::new(1f64, 1f64),
1532    ///                                 C64::new(2f64, 2f64),
1533    ///                                 C64::new(3f64, 3f64),
1534    ///                                 C64::new(4f64, 4f64)],
1535    ///                             2, 2, Row
1536    ///     );
1537    ///     let y = x.col_map(|r| r.fmap(|t| t + r[0]));
1538    ///
1539    ///     let y_col_map = cmatrix(vec![C64::new(2f64, 2f64),
1540    ///                                         C64::new(4f64, 4f64),
1541    ///                                         C64::new(4f64, 4f64),
1542    ///                                         C64::new(6f64, 6f64)],
1543    ///                             2, 2, Col
1544    ///     );
1545    ///
1546    ///     assert_eq!(y, y_col_map);
1547    /// }
1548    /// ```
1549    fn col_map<F>(&self, f: F) -> ComplexMatrix
1550    where
1551        F: Fn(Vec<C64>) -> Vec<C64>,
1552    {
1553        let mut result = cmatrix(
1554            vec![Complex::zero(); self.row * self.col],
1555            self.row,
1556            self.col,
1557            Shape::Col,
1558        );
1559
1560        for i in 0..self.col {
1561            result.subs_col(i, &f(self.col(i)));
1562        }
1563
1564        result
1565    }
1566
1567    /// Row map
1568    ///
1569    /// # Example
1570    /// ```rust
1571    /// use peroxide::fuga::*;
1572    /// use peroxide::complex::matrix::*;
1573    /// use peroxide::traits::fp::FPMatrix;
1574    ///
1575    /// fn main() {
1576    ///     let x = cmatrix(vec![C64::new(1f64, 1f64),
1577    ///                                 C64::new(2f64, 2f64),
1578    ///                                 C64::new(3f64, 3f64),
1579    ///                                 C64::new(4f64, 4f64)],
1580    ///                             2, 2, Row
1581    ///     );
1582    ///     let y = x.row_map(|r| r.fmap(|t| t + r[0]));
1583    ///
1584    ///     let y_row_map = cmatrix(vec![C64::new(2f64, 2f64),
1585    ///                                         C64::new(3f64, 3f64),
1586    ///                                         C64::new(6f64, 6f64),
1587    ///                                         C64::new(7f64, 7f64)],
1588    ///                             2, 2, Row
1589    ///     );
1590    ///
1591    ///     assert_eq!(y, y_row_map);
1592    /// }
1593    /// ```
1594    fn row_map<F>(&self, f: F) -> ComplexMatrix
1595    where
1596        F: Fn(Vec<C64>) -> Vec<C64>,
1597    {
1598        let mut result = cmatrix(
1599            vec![Complex::zero(); self.row * self.col],
1600            self.row,
1601            self.col,
1602            Shape::Row,
1603        );
1604
1605        for i in 0..self.row {
1606            result.subs_row(i, &f(self.row(i)));
1607        }
1608
1609        result
1610    }
1611
1612    fn col_mut_map<F>(&mut self, f: F)
1613    where
1614        F: Fn(Vec<C64>) -> Vec<C64>,
1615    {
1616        for i in 0..self.col {
1617            unsafe {
1618                let mut p = self.col_mut(i);
1619                let fv = f(self.col(i));
1620                for j in 0..p.len() {
1621                    *p[j] = fv[j];
1622                }
1623            }
1624        }
1625    }
1626
1627    fn row_mut_map<F>(&mut self, f: F)
1628    where
1629        F: Fn(Vec<C64>) -> Vec<C64>,
1630    {
1631        for i in 0..self.col {
1632            unsafe {
1633                let mut p = self.row_mut(i);
1634                let fv = f(self.row(i));
1635                for j in 0..p.len() {
1636                    *p[j] = fv[j];
1637                }
1638            }
1639        }
1640    }
1641
1642    fn reduce<F, T>(&self, init: T, f: F) -> C64
1643    where
1644        F: Fn(C64, C64) -> C64,
1645        T: Into<C64>,
1646    {
1647        self.data.iter().fold(init.into(), |x, y| f(x, *y))
1648    }
1649
1650    fn zip_with<F>(&self, f: F, other: &ComplexMatrix) -> Self
1651    where
1652        F: Fn(C64, C64) -> C64,
1653    {
1654        assert_eq!(self.data.len(), other.data.len());
1655        let mut a = other.clone();
1656        if self.shape != other.shape {
1657            a = a.change_shape();
1658        }
1659        let result = self
1660            .data
1661            .iter()
1662            .zip(a.data.iter())
1663            .map(|(x, y)| f(*x, *y))
1664            .collect::<Vec<C64>>();
1665        cmatrix(result, self.row, self.col, self.shape)
1666    }
1667
1668    fn col_reduce<F>(&self, f: F) -> Vec<C64>
1669    where
1670        F: Fn(Vec<C64>) -> C64,
1671    {
1672        let mut v = vec![Complex::zero(); self.col];
1673        for i in 0..self.col {
1674            v[i] = f(self.col(i));
1675        }
1676        v
1677    }
1678
1679    fn row_reduce<F>(&self, f: F) -> Vec<C64>
1680    where
1681        F: Fn(Vec<C64>) -> C64,
1682    {
1683        let mut v = vec![Complex::zero(); self.row];
1684        for i in 0..self.row {
1685            v[i] = f(self.row(i));
1686        }
1687        v
1688    }
1689}
1690
1691pub fn cdiag(n: usize) -> ComplexMatrix {
1692    let mut v: Vec<C64> = vec![Complex::zero(); n * n];
1693    for i in 0..n {
1694        let idx = i * (n + 1);
1695        v[idx] = Complex::one();
1696    }
1697    cmatrix(v, n, n, Shape::Row)
1698}
1699
1700impl PQLU<ComplexMatrix> {
1701    /// Extract PQLU
1702    ///
1703    /// # Usage
1704    /// ```rust
1705    /// extern crate peroxide;
1706    /// use peroxide::fuga::*;
1707    ///
1708    /// let a = cmatrix(vec![C64::new(1f64, 1f64),
1709    ///                                 C64::new(2f64, 2f64),
1710    ///                                 C64::new(3f64, 3f64),
1711    ///                                 C64::new(4f64, 4f64)],
1712    ///                             2, 2, Row
1713    ///     );
1714    /// let pqlu = a.lu();
1715    /// let (p, q, l, u) = pqlu.extract();
1716    /// // p, q are permutations
1717    /// // l, u are matrices
1718    /// println!("{}", l); // lower triangular
1719    /// println!("{}", u); // upper triangular
1720    /// ```
1721    pub fn extract(&self) -> (Vec<usize>, Vec<usize>, ComplexMatrix, ComplexMatrix) {
1722        (
1723            self.p.clone(),
1724            self.q.clone(),
1725            self.l.clone(),
1726            self.u.clone(),
1727        )
1728    }
1729
1730    pub fn det(&self) -> C64 {
1731        // sgn of perms
1732        let mut sgn_p = 1f64;
1733        let mut sgn_q = 1f64;
1734        for (i, &j) in self.p.iter().enumerate() {
1735            if i != j {
1736                sgn_p *= -1f64;
1737            }
1738        }
1739        for (i, &j) in self.q.iter().enumerate() {
1740            if i != j {
1741                sgn_q *= -1f64;
1742            }
1743        }
1744
1745        self.u.diag().reduce(Complex::one(), |x, y| x * y) * sgn_p * sgn_q
1746    }
1747
1748    pub fn inv(&self) -> ComplexMatrix {
1749        let (p, q, l, u) = self.extract();
1750        let mut m = complex_inv_u(u) * complex_inv_l(l);
1751        // Q = Q1 Q2 Q3 ..
1752        for (idx1, idx2) in q.into_iter().enumerate().rev() {
1753            unsafe {
1754                m.swap(idx1, idx2, Shape::Row);
1755            }
1756        }
1757        // P = Pn-1 .. P3 P2 P1
1758        for (idx1, idx2) in p.into_iter().enumerate().rev() {
1759            unsafe {
1760                m.swap(idx1, idx2, Shape::Col);
1761            }
1762        }
1763        m
1764    }
1765}
1766
1767/// MATLAB like eye - Identity matrix
1768pub fn ceye(n: usize) -> ComplexMatrix {
1769    let mut m = cmatrix(vec![Complex::zero(); n * n], n, n, Shape::Row);
1770    for i in 0..n {
1771        m[(i, i)] = Complex::one();
1772    }
1773    m
1774}
1775
1776// =============================================================================
1777// Linear Algebra
1778// =============================================================================
1779
1780impl LinearAlgebra<ComplexMatrix> for ComplexMatrix {
1781    /// Backward Substitution for Upper Triangular
1782    fn back_subs(&self, b: &[C64]) -> Vec<C64> {
1783        let n = self.col;
1784        let mut y = vec![Complex::zero(); n];
1785        y[n - 1] = b[n - 1] / self[(n - 1, n - 1)];
1786        for i in (0..n - 1).rev() {
1787            let mut s = Complex::zero();
1788            for j in i + 1..n {
1789                s += self[(i, j)] * y[j];
1790            }
1791            y[i] = 1f64 / self[(i, i)] * (b[i] - s);
1792        }
1793        y
1794    }
1795
1796    /// Forward substitution for Lower Triangular
1797    fn forward_subs(&self, b: &[C64]) -> Vec<C64> {
1798        let n = self.col;
1799        let mut y = vec![Complex::zero(); n];
1800        y[0] = b[0] / self[(0, 0)];
1801        for i in 1..n {
1802            let mut s = Complex::zero();
1803            for j in 0..i {
1804                s += self[(i, j)] * y[j];
1805            }
1806            y[i] = 1f64 / self[(i, i)] * (b[i] - s);
1807        }
1808        y
1809    }
1810
1811    /// LU Decomposition Implements (Complete Pivot)
1812    ///
1813    /// # Description
1814    /// It use complete pivoting LU decomposition.
1815    /// You can get two permutations, and LU matrices.
1816    ///
1817    /// # Caution
1818    /// It returns `Option<PQLU>` - You should unwrap to obtain real value.
1819    /// `PQLU` has four field - `p`, `q`, `l`, `u`.
1820    /// `p`, `q` are permutations.
1821    /// `l`, `u` are matrices.
1822    ///
1823    /// # Examples
1824    /// ```
1825    /// #[macro_use]
1826    /// use peroxide::fuga::*;
1827    /// use peroxide::complex::matrix::*;
1828    ///
1829    /// fn main() {
1830    ///     let a = cmatrix(vec![
1831    ///             C64::new(1f64, 1f64),
1832    ///             C64::new(2f64, 2f64),
1833    ///             C64::new(3f64, 3f64),
1834    ///             C64::new(4f64, 4f64)
1835    ///         ],
1836    ///         2, 2, Row
1837    ///     );
1838    ///
1839    ///     let l_exp = cmatrix(vec![
1840    ///             C64::new(1f64, 0f64),
1841    ///             C64::new(0f64, 0f64),
1842    ///             C64::new(0.5f64, -0.0f64),
1843    ///             C64::new(1f64, 0f64)
1844    ///         ],
1845    ///         2, 2, Row
1846    ///     );
1847    ///
1848    ///     let u_exp = cmatrix(vec![
1849    ///             C64::new(4f64, 4f64),
1850    ///             C64::new(3f64, 3f64),
1851    ///             C64::new(0f64, 0f64),
1852    ///             C64::new(-0.5f64, -0.5f64)
1853    ///         ],
1854    ///         2, 2, Row
1855    ///     );
1856    ///     let pqlu = a.lu();
1857    ///     let (p,q,l,u) = (pqlu.p, pqlu.q, pqlu.l, pqlu.u);
1858    ///     assert_eq!(p, vec![1]); // swap 0 & 1 (Row)
1859    ///     assert_eq!(q, vec![1]); // swap 0 & 1 (Col)
1860    ///     assert_eq!(l, l_exp);
1861    ///     assert_eq!(u, u_exp);
1862    /// }
1863    /// ```
1864    fn lu(&self) -> PQLU<ComplexMatrix> {
1865        assert_eq!(self.col, self.row);
1866        let n = self.row;
1867        let len: usize = n * n;
1868
1869        let mut l = ceye(n);
1870        let mut u = cmatrix(vec![Complex::zero(); len], n, n, self.shape);
1871
1872        let mut temp = self.clone();
1873        let (p, q) = gecp(&mut temp);
1874        for i in 0..n {
1875            for j in 0..i {
1876                // Inverse multiplier
1877                l[(i, j)] = -temp[(i, j)];
1878            }
1879            for j in i..n {
1880                u[(i, j)] = temp[(i, j)];
1881            }
1882        }
1883        // Pivoting L
1884        for i in 0..n - 1 {
1885            unsafe {
1886                let l_i = l.col_mut(i);
1887                for j in i + 1..l.col - 1 {
1888                    let dst = p[j];
1889                    std::ptr::swap(l_i[j], l_i[dst]);
1890                }
1891            }
1892        }
1893        PQLU { p, q, l, u }
1894    }
1895
1896    fn waz(&self, _d_form: Form) -> Option<WAZD<ComplexMatrix>> {
1897        unimplemented!()
1898    }
1899
1900    fn qr(&self) -> QR<ComplexMatrix> {
1901        unimplemented!()
1902    }
1903
1904    fn svd(&self) -> SVD<ComplexMatrix> {
1905        unimplemented!()
1906    }
1907
1908    #[cfg(feature = "O3")]
1909    fn cholesky(&self, uplo: UPLO) -> ComplexMatrix {
1910        unimplemented!()
1911    }
1912
1913    fn rref(&self) -> ComplexMatrix {
1914        unimplemented!()
1915    }
1916
1917    /// Determinant
1918    ///
1919    /// # Examples
1920    /// ```
1921    /// #[macro_use]
1922    /// use peroxide::fuga::*;
1923    /// use peroxide::complex::matrix::*;
1924    ///
1925    /// fn main() {
1926    ///     let a = cmatrix(vec![
1927    ///             C64::new(1f64, 1f64),
1928    ///             C64::new(2f64, 2f64),
1929    ///             C64::new(3f64, 3f64),
1930    ///             C64::new(4f64, 4f64)
1931    ///         ],
1932    ///         2, 2, Row
1933    ///     );
1934    ///     assert_eq!(a.det().norm(), 4f64);
1935    /// }
1936    /// ```
1937    fn det(&self) -> C64 {
1938        assert_eq!(self.row, self.col);
1939        self.lu().det()
1940    }
1941
1942    /// Block Partition
1943    ///
1944    /// # Examples
1945    /// ```rust
1946    /// #[macro_use]
1947    /// extern crate peroxide;
1948    /// use peroxide::fuga::*;
1949    /// use peroxide::complex::matrix::*;
1950    ///
1951    /// fn main() {
1952    ///     let a = cmatrix(vec![
1953    ///             C64::new(1f64, 1f64),
1954    ///             C64::new(2f64, 2f64),
1955    ///             C64::new(3f64, 3f64),
1956    ///             C64::new(4f64, 4f64)
1957    ///         ],
1958    ///         2, 2, Row
1959    ///     );
1960    ///     let (m1, m2, m3, m4) = a.block();
1961    ///     assert_eq!(m1, ml_cmatrix("1.0+1.0i"));
1962    ///     assert_eq!(m2, ml_cmatrix("2.0+2.0i"));
1963    ///     assert_eq!(m3, ml_cmatrix("3.0+3.0i"));
1964    ///     assert_eq!(m4, ml_cmatrix("4.0+4.0i"));
1965    /// }
1966    /// ```
1967    fn block(&self) -> (Self, Self, Self, Self) {
1968        let r = self.row;
1969        let c = self.col;
1970        let l_r = self.row / 2;
1971        let l_c = self.col / 2;
1972        let r_l = r - l_r;
1973        let c_l = c - l_c;
1974
1975        let mut m1 = cmatrix(vec![Complex::zero(); l_r * l_c], l_r, l_c, self.shape);
1976        let mut m2 = cmatrix(vec![Complex::zero(); l_r * c_l], l_r, c_l, self.shape);
1977        let mut m3 = cmatrix(vec![Complex::zero(); r_l * l_c], r_l, l_c, self.shape);
1978        let mut m4 = cmatrix(vec![Complex::zero(); r_l * c_l], r_l, c_l, self.shape);
1979
1980        for idx_row in 0..r {
1981            for idx_col in 0..c {
1982                match (idx_row, idx_col) {
1983                    (i, j) if (i < l_r) && (j < l_c) => {
1984                        m1[(i, j)] = self[(i, j)];
1985                    }
1986                    (i, j) if (i < l_r) && (j >= l_c) => {
1987                        m2[(i, j - l_c)] = self[(i, j)];
1988                    }
1989                    (i, j) if (i >= l_r) && (j < l_c) => {
1990                        m3[(i - l_r, j)] = self[(i, j)];
1991                    }
1992                    (i, j) if (i >= l_r) && (j >= l_c) => {
1993                        m4[(i - l_r, j - l_c)] = self[(i, j)];
1994                    }
1995                    _ => (),
1996                }
1997            }
1998        }
1999        (m1, m2, m3, m4)
2000    }
2001
2002    /// Inverse of Matrix
2003    ///
2004    /// # Caution
2005    ///
2006    /// `inv` function returns `Option<Matrix>`
2007    /// Thus, you should use pattern matching or `unwrap` to obtain inverse.
2008    ///
2009    /// # Examples
2010    /// ```
2011    /// #[macro_use]
2012    /// extern crate peroxide;
2013    /// use peroxide::fuga::*;
2014    /// use peroxide::complex::matrix::*;
2015    ///
2016    /// fn main() {
2017    ///     // Non-singular
2018    ///     let a = cmatrix(vec![
2019    ///             C64::new(1f64, 1f64),
2020    ///             C64::new(2f64, 2f64),
2021    ///             C64::new(3f64, 3f64),
2022    ///             C64::new(4f64, 4f64)
2023    ///         ],
2024    ///         2, 2, Row
2025    ///     );
2026    ///
2027    ///     let a_inv_exp = cmatrix(vec![
2028    ///             C64::new(-1.0f64, 1f64),
2029    ///             C64::new(0.5f64, -0.5f64),
2030    ///             C64::new(0.75f64, -0.75f64),
2031    ///             C64::new(-0.25f64, 0.25f64)
2032    ///         ],
2033    ///         2, 2, Row
2034    ///     );
2035    ///     assert_eq!(a.inv(), a_inv_exp);
2036    /// }
2037    /// ```
2038    fn inv(&self) -> Self {
2039        self.lu().inv()
2040    }
2041
2042    fn pseudo_inv(&self) -> ComplexMatrix {
2043        unimplemented!()
2044    }
2045
2046    /// Solve with Vector
2047    ///
2048    /// # Solve options
2049    ///
2050    /// * LU: Gaussian elimination with Complete pivoting LU (GECP)
2051    /// * WAZ: Solve with WAZ decomposition
2052    fn solve(&self, b: &[C64], sk: SolveKind) -> Vec<C64> {
2053        match sk {
2054            SolveKind::LU => {
2055                let lu = self.lu();
2056                let (p, q, l, u) = lu.extract();
2057                let mut v = b.to_vec();
2058                v.swap_with_perm(&p.into_iter().enumerate().collect());
2059                let z = l.forward_subs(&v);
2060                let mut y = u.back_subs(&z);
2061                y.swap_with_perm(&q.into_iter().enumerate().rev().collect());
2062                y
2063            }
2064            SolveKind::WAZ => {
2065                unimplemented!()
2066            }
2067        }
2068    }
2069
2070    fn solve_mat(&self, m: &ComplexMatrix, sk: SolveKind) -> ComplexMatrix {
2071        match sk {
2072            SolveKind::LU => {
2073                let lu = self.lu();
2074                let (p, q, l, u) = lu.extract();
2075                let mut x = cmatrix(
2076                    vec![Complex::zero(); self.col * m.col],
2077                    self.col,
2078                    m.col,
2079                    Shape::Col,
2080                );
2081                for i in 0..m.col {
2082                    let mut v = m.col(i).clone();
2083                    for (r, &s) in p.iter().enumerate() {
2084                        v.swap(r, s);
2085                    }
2086                    let z = l.forward_subs(&v);
2087                    let mut y = u.back_subs(&z);
2088                    for (r, &s) in q.iter().enumerate() {
2089                        y.swap(r, s);
2090                    }
2091                    unsafe {
2092                        let mut c = x.col_mut(i);
2093                        copy_vec_ptr(&mut c, &y);
2094                    }
2095                }
2096                x
2097            }
2098            SolveKind::WAZ => {
2099                unimplemented!()
2100            }
2101        }
2102    }
2103
2104    fn is_symmetric(&self) -> bool {
2105        if self.row != self.col {
2106            return false;
2107        }
2108
2109        for i in 0..self.row {
2110            for j in i..self.col {
2111                if (!nearly_eq(self[(i, j)].re, self[(j, i)].re))
2112                    && (!nearly_eq(self[(i, j)].im, self[(j, i)].im))
2113                {
2114                    return false;
2115                }
2116            }
2117        }
2118        true
2119    }
2120}
2121
2122#[allow(non_snake_case)]
2123pub fn csolve(A: &ComplexMatrix, b: &ComplexMatrix, sk: SolveKind) -> ComplexMatrix {
2124    A.solve_mat(b, sk)
2125}
2126
2127impl MutMatrix for ComplexMatrix {
2128    type Scalar = C64;
2129
2130    unsafe fn col_mut(&mut self, idx: usize) -> Vec<*mut C64> {
2131        assert!(idx < self.col, "Index out of range");
2132        match self.shape {
2133            Shape::Col => {
2134                let mut v: Vec<*mut C64> = vec![&mut Complex::zero(); self.row];
2135                let start_idx = idx * self.row;
2136                let p = self.mut_ptr();
2137                for (i, j) in (start_idx..start_idx + v.len()).enumerate() {
2138                    v[i] = p.add(j);
2139                }
2140                v
2141            }
2142            Shape::Row => {
2143                let mut v: Vec<*mut C64> = vec![&mut Complex::zero(); self.row];
2144                let p = self.mut_ptr();
2145                for i in 0..v.len() {
2146                    v[i] = p.add(idx + i * self.col);
2147                }
2148                v
2149            }
2150        }
2151    }
2152
2153    unsafe fn row_mut(&mut self, idx: usize) -> Vec<*mut C64> {
2154        assert!(idx < self.row, "Index out of range");
2155        match self.shape {
2156            Shape::Row => {
2157                let mut v: Vec<*mut C64> = vec![&mut Complex::zero(); self.col];
2158                let start_idx = idx * self.col;
2159                let p = self.mut_ptr();
2160                for (i, j) in (start_idx..start_idx + v.len()).enumerate() {
2161                    v[i] = p.add(j);
2162                }
2163                v
2164            }
2165            Shape::Col => {
2166                let mut v: Vec<*mut C64> = vec![&mut Complex::zero(); self.col];
2167                let p = self.mut_ptr();
2168                for i in 0..v.len() {
2169                    v[i] = p.add(idx + i * self.row);
2170                }
2171                v
2172            }
2173        }
2174    }
2175
2176    unsafe fn swap(&mut self, idx1: usize, idx2: usize, shape: Shape) {
2177        match shape {
2178            Shape::Col => swap_vec_ptr(&mut self.col_mut(idx1), &mut self.col_mut(idx2)),
2179            Shape::Row => swap_vec_ptr(&mut self.row_mut(idx1), &mut self.row_mut(idx2)),
2180        }
2181    }
2182
2183    unsafe fn swap_with_perm(&mut self, p: &Vec<(usize, usize)>, shape: Shape) {
2184        for (i, j) in p.iter() {
2185            self.swap(*i, *j, shape)
2186        }
2187    }
2188}
2189
2190impl ExpLogOps for ComplexMatrix {
2191    type Float = C64;
2192
2193    fn exp(&self) -> Self {
2194        self.fmap(|x| x.exp())
2195    }
2196    fn ln(&self) -> Self {
2197        self.fmap(|x| x.ln())
2198    }
2199    fn log(&self, base: Self::Float) -> Self {
2200        self.fmap(|x| x.ln() / base.ln()) // Using `Log: change of base` formula
2201    }
2202    fn log2(&self) -> Self {
2203        self.fmap(|x| x.ln() / 2.0.ln()) // Using `Log: change of base` formula
2204    }
2205    fn log10(&self) -> Self {
2206        self.fmap(|x| x.ln() / 10.0.ln()) // Using `Log: change of base` formula
2207    }
2208}
2209
2210impl PowOps for ComplexMatrix {
2211    type Float = C64;
2212
2213    fn powi(&self, n: i32) -> Self {
2214        self.fmap(|x| x.powi(n))
2215    }
2216
2217    fn powf(&self, f: Self::Float) -> Self {
2218        self.fmap(|x| x.powc(f))
2219    }
2220
2221    fn pow(&self, _f: Self) -> Self {
2222        unimplemented!()
2223    }
2224
2225    fn sqrt(&self) -> Self {
2226        self.fmap(|x| x.sqrt())
2227    }
2228}
2229
2230impl TrigOps for ComplexMatrix {
2231    fn sin_cos(&self) -> (Self, Self) {
2232        let (sin, cos) = self.data.iter().map(|x| (x.sin(), x.cos())).unzip();
2233        (
2234            cmatrix(sin, self.row, self.col, self.shape),
2235            cmatrix(cos, self.row, self.col, self.shape),
2236        )
2237    }
2238
2239    fn sin(&self) -> Self {
2240        self.fmap(|x| x.sin())
2241    }
2242
2243    fn cos(&self) -> Self {
2244        self.fmap(|x| x.cos())
2245    }
2246
2247    fn tan(&self) -> Self {
2248        self.fmap(|x| x.tan())
2249    }
2250
2251    fn sinh(&self) -> Self {
2252        self.fmap(|x| x.sinh())
2253    }
2254
2255    fn cosh(&self) -> Self {
2256        self.fmap(|x| x.cosh())
2257    }
2258
2259    fn tanh(&self) -> Self {
2260        self.fmap(|x| x.tanh())
2261    }
2262
2263    fn asin(&self) -> Self {
2264        self.fmap(|x| x.asin())
2265    }
2266
2267    fn acos(&self) -> Self {
2268        self.fmap(|x| x.acos())
2269    }
2270
2271    fn atan(&self) -> Self {
2272        self.fmap(|x| x.atan())
2273    }
2274
2275    fn asinh(&self) -> Self {
2276        self.fmap(|x| x.asinh())
2277    }
2278
2279    fn acosh(&self) -> Self {
2280        self.fmap(|x| x.acosh())
2281    }
2282
2283    fn atanh(&self) -> Self {
2284        self.fmap(|x| x.atanh())
2285    }
2286}
2287
2288// =============================================================================
2289// Back-end Utils
2290// =============================================================================
2291/// Combine separated Complex Matrix to one Complex Matrix
2292///
2293/// # Examples
2294/// ```rust
2295/// use peroxide::fuga::*;
2296/// use peroxide::complex::matrix::*;
2297/// use peroxide::traits::fp::FPMatrix;
2298///
2299/// fn main() {
2300///     let x1 = cmatrix(vec![C64::new(1f64, 1f64)], 1, 1, Row);
2301///     let x2 = cmatrix(vec![C64::new(2f64, 2f64)], 1, 1, Row);
2302///     let x3 = cmatrix(vec![C64::new(3f64, 3f64)], 1, 1, Row);
2303///     let x4 = cmatrix(vec![C64::new(4f64, 4f64)], 1, 1, Row);
2304///
2305///     let y = complex_combine(x1, x2, x3, x4);
2306///
2307///     let y_exp = cmatrix(vec![C64::new(1f64, 1f64),
2308///                                     C64::new(2f64, 2f64),
2309///                                     C64::new(3f64, 3f64),
2310///                                     C64::new(4f64, 4f64)],
2311///                             2, 2, Row
2312///     );
2313///
2314///     assert_eq!(y, y_exp);
2315/// }
2316/// ```
2317pub fn complex_combine(
2318    m1: ComplexMatrix,
2319    m2: ComplexMatrix,
2320    m3: ComplexMatrix,
2321    m4: ComplexMatrix,
2322) -> ComplexMatrix {
2323    let l_r = m1.row;
2324    let l_c = m1.col;
2325    let c_l = m2.col;
2326    let r_l = m3.row;
2327
2328    let r = l_r + r_l;
2329    let c = l_c + c_l;
2330
2331    let mut m = cmatrix(vec![Complex::zero(); r * c], r, c, m1.shape);
2332
2333    for idx_row in 0..r {
2334        for idx_col in 0..c {
2335            match (idx_row, idx_col) {
2336                (i, j) if (i < l_r) && (j < l_c) => {
2337                    m[(i, j)] = m1[(i, j)];
2338                }
2339                (i, j) if (i < l_r) && (j >= l_c) => {
2340                    m[(i, j)] = m2[(i, j - l_c)];
2341                }
2342                (i, j) if (i >= l_r) && (j < l_c) => {
2343                    m[(i, j)] = m3[(i - l_r, j)];
2344                }
2345                (i, j) if (i >= l_r) && (j >= l_c) => {
2346                    m[(i, j)] = m4[(i - l_r, j - l_c)];
2347                }
2348                _ => (),
2349            }
2350        }
2351    }
2352    m
2353}
2354
2355/// Inverse of Lower matrix
2356///
2357/// # Examples
2358///  ```rust
2359/// #[macro_use]
2360/// extern crate peroxide;
2361/// use peroxide::fuga::*;
2362/// use peroxide::complex::matrix::*;
2363///
2364/// fn main() {
2365///     let a = ml_cmatrix("2.0+2.0i 0.0+0.0i;
2366///                                2.0+2.0i 1.0+1.0i");
2367///     let b = cmatrix(vec![C64::new(2f64, 2f64),
2368///                                 C64::new(0f64, 0f64),
2369///                                 C64::new(-2f64, -2f64),
2370///                                 C64::new(1f64, 1f64)],
2371///                             2, 2, Row
2372///     );
2373///     assert_eq!(complex_inv_l(a), b);
2374/// }
2375/// ```
2376pub fn complex_inv_l(l: ComplexMatrix) -> ComplexMatrix {
2377    let mut m = l.clone();
2378
2379    match l.row {
2380        1 => l,
2381        2 => {
2382            m[(1, 0)] = -m[(1, 0)];
2383            m
2384        }
2385        _ => {
2386            let (l1, l2, l3, l4) = l.block();
2387
2388            let m1 = complex_inv_l(l1);
2389            let m2 = l2;
2390            let m4 = complex_inv_l(l4);
2391            let m3 = -(&(&m4 * &l3) * &m1);
2392
2393            complex_combine(m1, m2, m3, m4)
2394        }
2395    }
2396}
2397
2398/// Inverse of upper triangular matrix
2399///
2400/// # Examples
2401///  ```rust
2402/// #[macro_use]
2403/// extern crate peroxide;
2404/// use peroxide::fuga::*;
2405/// use peroxide::complex::matrix::*;
2406///
2407/// fn main() {
2408///     let a = ml_cmatrix("2.0+2.0i 2.0+2.0i;
2409///                                0.0+0.0i 1.0+1.0i");
2410///     let b = cmatrix(vec![C64::new(0.25f64, -0.25f64),
2411///                                 C64::new(-0.5f64, 0.5f64),
2412///                                 C64::new(0.0f64, 0.0f64),
2413///                                 C64::new(0.5f64, -0.5f64)],
2414///                             2, 2, Row
2415///     );
2416///     assert_eq!(complex_inv_u(a), b);
2417/// }
2418/// ```
2419pub fn complex_inv_u(u: ComplexMatrix) -> ComplexMatrix {
2420    let mut w = u.clone();
2421
2422    match u.row {
2423        1 => {
2424            w[(0, 0)] = 1f64 / w[(0, 0)];
2425            w
2426        }
2427        2 => {
2428            let a = w[(0, 0)];
2429            let b = w[(0, 1)];
2430            let c = w[(1, 1)];
2431            let d = a * c;
2432
2433            w[(0, 0)] = 1f64 / a;
2434            w[(0, 1)] = -b / d;
2435            w[(1, 1)] = 1f64 / c;
2436            w
2437        }
2438        _ => {
2439            let (u1, u2, u3, u4) = u.block();
2440            let m1 = complex_inv_u(u1);
2441            let m3 = u3;
2442            let m4 = complex_inv_u(u4);
2443            let m2 = -(m1.clone() * u2 * m4.clone());
2444
2445            complex_combine(m1, m2, m3, m4)
2446        }
2447    }
2448}
2449
2450/// Matrix multiply back-ends
2451pub fn cmatmul(a: &ComplexMatrix, b: &ComplexMatrix) -> ComplexMatrix {
2452    assert_eq!(a.col, b.row);
2453    let mut c = cmatrix(vec![Complex::zero(); a.row * b.col], a.row, b.col, a.shape);
2454    cgemm(Complex::one(), a, b, Complex::zero(), &mut c);
2455    c
2456}
2457
2458/// GEMM wrapper for Matrixmultiply
2459///
2460/// # Examples
2461/// ```rust
2462/// #[macro_use]
2463/// extern crate peroxide;
2464/// use peroxide::fuga::*;
2465///
2466/// use peroxide::complex::matrix::*;
2467///
2468/// fn main() {
2469///     let a = ml_cmatrix("1.0+1.0i 2.0+2.0i;
2470///                                0.0+0.0i 1.0+1.0i");
2471///     let b = ml_cmatrix("1.0+1.0i 0.0+0.0i;
2472///                                2.0+2.0i 1.0+1.0i");
2473///     let mut c1 = ml_cmatrix("1.0+1.0i 1.0+1.0i;
2474///                                    1.0+1.0i 1.0+1.0i");
2475///     let mul_val = ml_cmatrix("-10.0+10.0i -4.0+4.0i;
2476///                                      -4.0+4.0i -2.0+2.0i");
2477///
2478///     cgemm(C64::new(1.0, 1.0), &a, &b, C64::new(0.0, 0.0), &mut c1);
2479///     assert_eq!(c1, mul_val);
2480/// }
2481pub fn cgemm(alpha: C64, a: &ComplexMatrix, b: &ComplexMatrix, beta: C64, c: &mut ComplexMatrix) {
2482    let m = a.row;
2483    let k = a.col;
2484    let n = b.col;
2485    let (rsa, csa) = match a.shape {
2486        Shape::Row => (a.col as isize, 1isize),
2487        Shape::Col => (1isize, a.row as isize),
2488    };
2489    let (rsb, csb) = match b.shape {
2490        Shape::Row => (b.col as isize, 1isize),
2491        Shape::Col => (1isize, b.row as isize),
2492    };
2493    let (rsc, csc) = match c.shape {
2494        Shape::Row => (c.col as isize, 1isize),
2495        Shape::Col => (1isize, c.row as isize),
2496    };
2497
2498    unsafe {
2499        matrixmultiply::zgemm(
2500            // Requires crate feature "cgemm"
2501            CGemmOption::Standard,
2502            CGemmOption::Standard,
2503            m,
2504            k,
2505            n,
2506            [alpha.re, alpha.im],
2507            a.ptr() as *const _,
2508            rsa,
2509            csa,
2510            b.ptr() as *const _,
2511            rsb,
2512            csb,
2513            [beta.re, beta.im],
2514            c.mut_ptr() as *mut _,
2515            rsc,
2516            csc,
2517        )
2518    }
2519}
2520
2521/// General Matrix-Vector multiplication
2522pub fn cgemv(alpha: C64, a: &ComplexMatrix, b: &Vec<C64>, beta: C64, c: &mut Vec<C64>) {
2523    let m = a.row;
2524    let k = a.col;
2525    let n = 1usize;
2526    let (rsa, csa) = match a.shape {
2527        Shape::Row => (a.col as isize, 1isize),
2528        Shape::Col => (1isize, a.row as isize),
2529    };
2530    let (rsb, csb) = (1isize, 1isize);
2531    let (rsc, csc) = (1isize, 1isize);
2532
2533    unsafe {
2534        matrixmultiply::zgemm(
2535            // Requires crate feature "cgemm"
2536            CGemmOption::Standard,
2537            CGemmOption::Standard,
2538            m,
2539            k,
2540            n,
2541            [alpha.re, alpha.im],
2542            a.ptr() as *const _,
2543            rsa,
2544            csa,
2545            b.as_ptr() as *const _,
2546            rsb,
2547            csb,
2548            [beta.re, beta.im],
2549            c.as_mut_ptr() as *mut _,
2550            rsc,
2551            csc,
2552        )
2553    }
2554}
2555
2556/// General Vector-Matrix multiplication
2557pub fn complex_gevm(alpha: C64, a: &Vec<C64>, b: &ComplexMatrix, beta: C64, c: &mut Vec<C64>) {
2558    let m = 1usize;
2559    let k = a.len();
2560    let n = b.col;
2561    let (rsa, csa) = (1isize, 1isize);
2562    let (rsb, csb) = match b.shape {
2563        Shape::Row => (b.col as isize, 1isize),
2564        Shape::Col => (1isize, b.row as isize),
2565    };
2566    let (rsc, csc) = (1isize, 1isize);
2567
2568    unsafe {
2569        matrixmultiply::zgemm(
2570            // Requires crate feature "cgemm"
2571            CGemmOption::Standard,
2572            CGemmOption::Standard,
2573            m,
2574            k,
2575            n,
2576            [alpha.re, alpha.im],
2577            a.as_ptr() as *const _,
2578            rsa,
2579            csa,
2580            b.ptr() as *const _,
2581            rsb,
2582            csb,
2583            [beta.re, beta.im],
2584            c.as_mut_ptr() as *mut _,
2585            rsc,
2586            csc,
2587        )
2588    }
2589}
2590
2591/// LU via Gaussian Elimination with Partial Pivoting
2592#[allow(dead_code)]
2593fn gepp(m: &mut ComplexMatrix) -> Vec<usize> {
2594    let mut r = vec![0usize; m.col - 1];
2595    for k in 0..(m.col - 1) {
2596        // Find the pivot row
2597        let r_k = m
2598            .col(k)
2599            .into_iter()
2600            .skip(k)
2601            .enumerate()
2602            .max_by(|x1, x2| x1.1.norm().partial_cmp(&x2.1.norm()).unwrap())
2603            .unwrap()
2604            .0
2605            + k;
2606        r[k] = r_k;
2607
2608        // Interchange the rows r_k and k
2609        for j in k..m.col {
2610            unsafe {
2611                std::ptr::swap(&mut m[(k, j)], &mut m[(r_k, j)]);
2612                println!("Swap! k:{}, r_k:{}", k, r_k);
2613            }
2614        }
2615        // Form the multipliers
2616        for i in k + 1..m.col {
2617            m[(i, k)] = -m[(i, k)] / m[(k, k)];
2618        }
2619        // Update the entries
2620        for i in k + 1..m.col {
2621            for j in k + 1..m.col {
2622                let local_m = m[(i, k)] * m[(k, j)];
2623                m[(i, j)] += local_m;
2624            }
2625        }
2626    }
2627    r
2628}
2629
2630/// LU via Gauss Elimination with Complete Pivoting
2631fn gecp(m: &mut ComplexMatrix) -> (Vec<usize>, Vec<usize>) {
2632    let n = m.col;
2633    let mut r = vec![0usize; n - 1];
2634    let mut s = vec![0usize; n - 1];
2635    for k in 0..n - 1 {
2636        // Find pivot
2637        let (r_k, s_k) = match m.shape {
2638            Shape::Col => {
2639                let mut row_ics = 0usize;
2640                let mut col_ics = 0usize;
2641                let mut max_val = 0f64;
2642                for i in k..n {
2643                    let c = m
2644                        .col(i)
2645                        .into_iter()
2646                        .skip(k)
2647                        .enumerate()
2648                        .max_by(|x1, x2| x1.1.norm().partial_cmp(&x2.1.norm()).unwrap())
2649                        .unwrap();
2650                    let c_ics = c.0 + k;
2651                    let c_val = c.1.norm();
2652                    if c_val > max_val {
2653                        row_ics = c_ics;
2654                        col_ics = i;
2655                        max_val = c_val;
2656                    }
2657                }
2658                (row_ics, col_ics)
2659            }
2660            Shape::Row => {
2661                let mut row_ics = 0usize;
2662                let mut col_ics = 0usize;
2663                let mut max_val = 0f64;
2664                for i in k..n {
2665                    let c = m
2666                        .row(i)
2667                        .into_iter()
2668                        .skip(k)
2669                        .enumerate()
2670                        .max_by(|x1, x2| x1.1.norm().partial_cmp(&x2.1.norm()).unwrap())
2671                        .unwrap();
2672                    let c_ics = c.0 + k;
2673                    let c_val = c.1.norm();
2674                    if c_val > max_val {
2675                        col_ics = c_ics;
2676                        row_ics = i;
2677                        max_val = c_val;
2678                    }
2679                }
2680                (row_ics, col_ics)
2681            }
2682        };
2683        r[k] = r_k;
2684        s[k] = s_k;
2685
2686        // Interchange rows
2687        for j in k..n {
2688            unsafe {
2689                std::ptr::swap(&mut m[(k, j)], &mut m[(r_k, j)]);
2690            }
2691        }
2692
2693        // Interchange cols
2694        for i in 0..n {
2695            unsafe {
2696                std::ptr::swap(&mut m[(i, k)], &mut m[(i, s_k)]);
2697            }
2698        }
2699
2700        // Form the multipliers
2701        for i in k + 1..n {
2702            m[(i, k)] = -m[(i, k)] / m[(k, k)];
2703            for j in k + 1..n {
2704                let local_m = m[(i, k)] * m[(k, j)];
2705                m[(i, j)] += local_m;
2706            }
2707        }
2708    }
2709    (r, s)
2710}