peroxide/complex/matrix.rs
1use std::{
2 cmp::{max, min},
3 fmt,
4 ops::{Add, Div, Index, IndexMut, Mul, Neg, Sub},
5};
6
7use anyhow::{bail, Result};
8use matrixmultiply::CGemmOption;
9use num_complex::Complex;
10use peroxide_num::{ExpLogOps, PowOps, TrigOps};
11use rand_distr::num_traits::{One, Zero};
12
13use crate::{
14 complex::C64,
15 structure::matrix::Shape,
16 traits::fp::{FPMatrix, FPVector},
17 traits::general::Algorithm,
18 traits::math::{InnerProduct, LinearOp, MatrixProduct, Norm, Normed, Vector},
19 traits::matrix::{Form, LinearAlgebra, MatrixTrait, SolveKind, PQLU, QR, SVD, UPLO, WAZD},
20 traits::mutable::MutMatrix,
21 util::low_level::{copy_vec_ptr, swap_vec_ptr},
22 util::non_macro::ConcatenateError,
23 util::useful::{nearly_eq, tab},
24};
25
26/// R-like complex matrix structure
27///
28/// # Examples
29///
30/// ```rust
31/// use peroxide::fuga::*;
32/// use peroxide::complex::matrix::ComplexMatrix;
33///
34/// let v1 = ComplexMatrix {
35/// data: vec![
36/// C64::new(1f64, 1f64),
37/// C64::new(2f64, 2f64),
38/// C64::new(3f64, 3f64),
39/// C64::new(4f64, 4f64),
40/// ],
41/// row: 2,
42/// col: 2,
43/// shape: Row,
44/// }; // [[1+1i,2+2i],[3+3i,4+4i]]
45/// ```
46#[derive(Debug, Clone, Default)]
47pub struct ComplexMatrix {
48 pub data: Vec<C64>,
49 pub row: usize,
50 pub col: usize,
51 pub shape: Shape,
52}
53
54// =============================================================================
55// Various complex matrix constructor
56// =============================================================================
57
58/// R-like complex matrix constructor
59///
60/// # Examples
61/// ```rust
62/// #[macro_use]
63/// extern crate peroxide;
64/// use peroxide::fuga::*;
65/// use peroxide::complex::matrix::cmatrix;
66///
67/// fn main() {
68/// let a = cmatrix(vec![C64::new(1f64, 1f64),
69/// C64::new(2f64, 2f64),
70/// C64::new(3f64, 3f64),
71/// C64::new(4f64, 4f64)],
72/// 2, 2, Row
73/// );
74/// a.col.print(); // Print matrix column
75/// }
76/// ```
77pub fn cmatrix<T>(v: Vec<T>, r: usize, c: usize, shape: Shape) -> ComplexMatrix
78where
79 T: Into<C64>,
80{
81 ComplexMatrix {
82 data: v.into_iter().map(|t| t.into()).collect::<Vec<C64>>(),
83 row: r,
84 col: c,
85 shape,
86 }
87}
88
89/// R-like complex matrix constructor (Explicit ver.)
90pub fn r_cmatrix<T>(v: Vec<T>, r: usize, c: usize, shape: Shape) -> ComplexMatrix
91where
92 T: Into<C64>,
93{
94 cmatrix(v, r, c, shape)
95}
96
97/// Python-like complex matrix constructor
98///
99/// # Examples
100/// ```rust
101/// #[macro_use]
102/// extern crate peroxide;
103/// use peroxide::fuga::*;
104/// use peroxide::complex::matrix::*;
105///
106/// fn main() {
107/// let a = py_cmatrix(vec![vec![C64::new(1f64, 1f64),
108/// C64::new(2f64, 2f64)],
109/// vec![C64::new(3f64, 3f64),
110/// C64::new(4f64, 4f64)]
111/// ]);
112/// let b = cmatrix(vec![C64::new(1f64, 1f64),
113/// C64::new(2f64, 2f64),
114/// C64::new(3f64, 3f64),
115/// C64::new(4f64, 4f64)],
116/// 2, 2, Row
117/// );
118/// assert_eq!(a, b);
119/// }
120/// ```
121pub fn py_cmatrix<T>(v: Vec<Vec<T>>) -> ComplexMatrix
122where
123 T: Into<C64> + Copy,
124{
125 let r = v.len();
126 let c = v[0].len();
127 let data: Vec<T> = v.into_iter().flatten().collect();
128 cmatrix(data, r, c, Shape::Row)
129}
130
131/// Matlab-like matrix constructor
132///
133/// Note that the entries to the `ml_cmatrix`
134/// needs to be in the `a+bi` format
135/// without any spaces between the real and imaginary
136/// parts of the Complex number.
137///
138/// # Examples
139/// ```rust
140/// #[macro_use]
141/// extern crate peroxide;
142/// use peroxide::fuga::*;
143/// use peroxide::complex::matrix::*;
144///
145/// fn main() {
146/// let a = ml_cmatrix("1.0+1.0i 2.0+2.0i;
147/// 3.0+3.0i 4.0+4.0i");
148/// let b = cmatrix(vec![C64::new(1f64, 1f64),
149/// C64::new(2f64, 2f64),
150/// C64::new(3f64, 3f64),
151/// C64::new(4f64, 4f64)],
152/// 2, 2, Row
153/// );
154/// assert_eq!(a, b);
155/// }
156/// ```
157pub fn ml_cmatrix(s: &str) -> ComplexMatrix {
158 let str_row = s.split(";").collect::<Vec<&str>>();
159 let r = str_row.len();
160 let str_data = str_row
161 .iter()
162 .map(|x| x.trim().split(" ").collect::<Vec<&str>>())
163 .collect::<Vec<Vec<&str>>>();
164 let c = str_data[0].len();
165 let data = str_data
166 .iter()
167 .flat_map(|t| {
168 t.iter()
169 .map(|x| x.parse::<C64>().unwrap())
170 .collect::<Vec<C64>>()
171 })
172 .collect::<Vec<C64>>();
173
174 cmatrix(data, r, c, Shape::Row)
175}
176
177/// Pretty Print
178impl fmt::Display for ComplexMatrix {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
180 write!(f, "{}", self.spread())
181 }
182}
183
184/// PartialEq implements
185impl PartialEq for ComplexMatrix {
186 fn eq(&self, other: &ComplexMatrix) -> bool {
187 if self.shape == other.shape {
188 self.data
189 .clone()
190 .into_iter()
191 .zip(other.data.clone())
192 .all(|(x, y)| nearly_eq(x.re, y.re) && nearly_eq(x.im, y.im))
193 && self.row == other.row
194 } else {
195 self.eq(&other.change_shape())
196 }
197 }
198}
199
200impl MatrixTrait for ComplexMatrix {
201 type Scalar = C64;
202
203 /// Raw pointer for `self.data`
204 fn ptr(&self) -> *const C64 {
205 &self.data[0] as *const C64
206 }
207
208 /// Raw mutable pointer for `self.data`
209 fn mut_ptr(&mut self) -> *mut C64 {
210 &mut self.data[0] as *mut C64
211 }
212
213 /// Slice of `self.data`
214 ///
215 /// # Examples
216 /// ```rust
217 /// use peroxide::fuga::*;
218 /// use peroxide::complex::matrix::*;
219 ///
220 /// let a = cmatrix(vec![C64::new(1f64, 1f64),
221 /// C64::new(2f64, 2f64),
222 /// C64::new(3f64, 3f64),
223 /// C64::new(4f64, 4f64)],
224 /// 2, 2, Row
225 /// );
226 /// let b = a.as_slice();
227 /// assert_eq!(b, &[C64::new(1f64, 1f64),
228 /// C64::new(2f64, 2f64),
229 /// C64::new(3f64, 3f64),
230 /// C64::new(4f64, 4f64)]);
231 /// ```
232 fn as_slice(&self) -> &[C64] {
233 &self.data[..]
234 }
235
236 /// Mutable slice of `self.data`
237 ///
238 /// # Examples
239 /// ```rust
240 /// use peroxide::fuga::*;
241 /// use peroxide::complex::matrix::*;
242 ///
243 /// let mut a = cmatrix(vec![C64::new(1f64, 1f64),
244 /// C64::new(2f64, 2f64),
245 /// C64::new(3f64, 3f64),
246 /// C64::new(4f64, 4f64)],
247 /// 2, 2, Row
248 /// );
249 /// let mut b = a.as_mut_slice();
250 /// b[1] = C64::new(5f64, 5f64);
251 /// assert_eq!(b, &[C64::new(1f64, 1f64),
252 /// C64::new(5f64, 5f64),
253 /// C64::new(3f64, 3f64),
254 /// C64::new(4f64, 4f64)]);
255 /// assert_eq!(a, cmatrix(vec![C64::new(1f64, 1f64),
256 /// C64::new(5f64, 5f64),
257 /// C64::new(3f64, 3f64),
258 /// C64::new(4f64, 4f64)],
259 /// 2, 2, Row));
260 /// ```
261 fn as_mut_slice(&mut self) -> &mut [C64] {
262 &mut self.data[..]
263 }
264
265 /// Change Bindings
266 ///
267 /// `Row` -> `Col` or `Col` -> `Row`
268 ///
269 /// # Examples
270 /// ```rust
271 /// use peroxide::fuga::*;
272 /// use peroxide::complex::matrix::*;
273 ///
274 /// let mut a = cmatrix(vec![C64::new(1f64, 1f64),
275 /// C64::new(2f64, 2f64),
276 /// C64::new(3f64, 3f64),
277 /// C64::new(4f64, 4f64)],
278 /// 2, 2, Row
279 /// );
280 /// assert_eq!(a.shape, Row);
281 /// let b = a.change_shape();
282 /// assert_eq!(b.shape, Col);
283 /// ```
284 fn change_shape(&self) -> Self {
285 let r = self.row;
286 let c = self.col;
287 assert_eq!(r * c, self.data.len());
288 let l = r * c - 1;
289 let mut data: Vec<C64> = self.data.clone();
290 let ref_data = &self.data;
291
292 match self.shape {
293 Shape::Row => {
294 for i in 0..l {
295 let s = (i * c) % l;
296 data[i] = ref_data[s];
297 }
298 data[l] = ref_data[l];
299 cmatrix(data, r, c, Shape::Col)
300 }
301 Shape::Col => {
302 for i in 0..l {
303 let s = (i * r) % l;
304 data[i] = ref_data[s];
305 }
306 data[l] = ref_data[l];
307 cmatrix(data, r, c, Shape::Row)
308 }
309 }
310 }
311
312 /// Change Bindings Mutably
313 ///
314 /// `Row` -> `Col` or `Col` -> `Row`
315 ///
316 /// # Examples
317 /// ```rust
318 /// use peroxide::fuga::*;
319 /// use peroxide::complex::matrix::*;
320 ///
321 /// let mut a = cmatrix(vec![
322 /// C64::new(1f64, 1f64),
323 /// C64::new(2f64, 2f64),
324 /// C64::new(3f64, 3f64),
325 /// C64::new(4f64, 4f64)
326 /// ],
327 /// 2, 2, Row
328 /// );
329 /// assert_eq!(a.shape, Row);
330 /// a.change_shape_mut();
331 /// assert_eq!(a.shape, Col);
332 /// ```
333 fn change_shape_mut(&mut self) {
334 let r = self.row;
335 let c = self.col;
336 assert_eq!(r * c, self.data.len());
337 let l = r * c - 1;
338 let ref_data = self.data.clone();
339
340 match self.shape {
341 Shape::Row => {
342 for i in 0..l {
343 let s = (i * c) % l;
344 self.data[i] = ref_data[s];
345 }
346 self.data[l] = ref_data[l];
347 self.shape = Shape::Col;
348 }
349 Shape::Col => {
350 for i in 0..l {
351 let s = (i * r) % l;
352 self.data[i] = ref_data[s];
353 }
354 self.data[l] = ref_data[l];
355 self.shape = Shape::Row;
356 }
357 }
358 }
359
360 /// Spread data(1D vector) to 2D formatted String
361 ///
362 /// # Examples
363 /// ```rust
364 /// use peroxide::fuga::*;
365 /// use peroxide::complex::matrix::*;
366 ///
367 /// let a = cmatrix(vec![C64::new(1f64, 1f64),
368 /// C64::new(2f64, 2f64),
369 /// C64::new(3f64, 3f64),
370 /// C64::new(4f64, 4f64)],
371 /// 2, 2, Row
372 /// );
373 /// println!("{}", a.spread()); // same as println!("{}", a);
374 /// // Result:
375 /// // c[0] c[1]
376 /// // r[0] 1+1i 3+3i
377 /// // r[1] 2+2i 4+4i
378 /// ```
379 fn spread(&self) -> String {
380 assert_eq!(self.row * self.col, self.data.len());
381 let r = self.row;
382 let c = self.col;
383 let mut key_row = 20usize;
384 let mut key_col = 20usize;
385
386 if r > 100 || c > 100 || (r > 20 && c > 20) {
387 let part = if r <= 10 {
388 key_row = r;
389 key_col = 100;
390 self.take_col(100)
391 } else if c <= 10 {
392 key_row = 100;
393 key_col = c;
394 self.take_row(100)
395 } else {
396 self.take_row(20).take_col(20)
397 };
398 return format!(
399 "Result is too large to print - {}x{}\n only print {}x{} parts:\n{}",
400 self.row.to_string(),
401 self.col.to_string(),
402 key_row.to_string(),
403 key_col.to_string(),
404 part.spread()
405 );
406 }
407
408 // Find maximum length of data
409 let sample = self.data.clone();
410 let mut space: usize = sample
411 .into_iter()
412 .map(
413 |x| min(format!("{:.4}", x).len(), x.to_string().len()), // Choose minimum of approx vs normal
414 )
415 .fold(0, |x, y| max(x, y))
416 + 1;
417
418 if space < 5 {
419 space = 5;
420 }
421
422 let mut result = String::new();
423
424 result.push_str(&tab("", 5));
425 for i in 0..c {
426 result.push_str(&tab(&format!("c[{}]", i), space)); // Header
427 }
428 result.push('\n');
429
430 for i in 0..r {
431 result.push_str(&tab(&format!("r[{}]", i), 5));
432 for j in 0..c {
433 let st1 = format!("{:.4}", self[(i, j)]); // Round at fourth position
434 let st2 = self[(i, j)].to_string(); // Normal string
435 let mut st = st2.clone();
436
437 // Select more small thing
438 if st1.len() < st2.len() {
439 st = st1;
440 }
441
442 result.push_str(&tab(&st, space));
443 }
444 if i == (r - 1) {
445 break;
446 }
447 result.push('\n');
448 }
449
450 return result;
451 }
452
453 /// Extract Column
454 ///
455 /// # Examples
456 /// ```rust
457 /// #[macro_use]
458 /// extern crate peroxide;
459 /// use peroxide::fuga::*;
460 /// use peroxide::complex::matrix::*;
461 ///
462 /// fn main() {
463 /// let a = cmatrix(vec![C64::new(1f64, 1f64),
464 /// C64::new(2f64, 2f64),
465 /// C64::new(3f64, 3f64),
466 /// C64::new(4f64, 4f64)],
467 /// 2, 2, Row
468 /// );
469 /// assert_eq!(a.col(0), vec![C64::new(1f64, 1f64), C64::new(3f64, 3f64)]);
470 /// }
471 /// ```
472 fn col(&self, index: usize) -> Vec<C64> {
473 assert!(index < self.col);
474 let mut container: Vec<C64> = vec![Complex::zero(); self.row];
475 for i in 0..self.row {
476 container[i] = self[(i, index)];
477 }
478 container
479 }
480
481 /// Extract Row
482 ///
483 /// # Examples
484 /// ```rust
485 /// #[macro_use]
486 /// extern crate peroxide;
487 /// use peroxide::fuga::*;
488 /// use peroxide::complex::matrix::*;
489 ///
490 /// fn main() {
491 /// let a = cmatrix(vec![C64::new(1f64, 1f64),
492 /// C64::new(2f64, 2f64),
493 /// C64::new(3f64, 3f64),
494 /// C64::new(4f64, 4f64)],
495 /// 2, 2, Row
496 /// );
497 /// assert_eq!(a.row(0), vec![C64::new(1f64, 1f64), C64::new(2f64, 2f64)]);
498 /// }
499 /// ```
500 fn row(&self, index: usize) -> Vec<C64> {
501 assert!(index < self.row);
502 let mut container: Vec<C64> = vec![Complex::zero(); self.col];
503 for i in 0..self.col {
504 container[i] = self[(index, i)];
505 }
506 container
507 }
508
509 /// Extract diagonal components
510 ///
511 /// # Examples
512 /// ```rust
513 /// #[macro_use]
514 /// extern crate peroxide;
515 /// use peroxide::fuga::*;
516 /// use peroxide::complex::matrix::*;
517 ///
518 /// fn main() {
519 /// let a = cmatrix(vec![C64::new(1f64, 1f64),
520 /// C64::new(2f64, 2f64),
521 /// C64::new(3f64, 3f64),
522 /// C64::new(4f64, 4f64)],
523 /// 2, 2, Row
524 /// );
525 /// assert_eq!(a.diag(), vec![C64::new(1f64, 1f64) ,C64::new(4f64, 4f64)]);
526 /// }
527 /// ```
528 fn diag(&self) -> Vec<C64> {
529 let mut container = vec![Complex::zero(); self.row];
530 let r = self.row;
531 let c = self.col;
532 assert_eq!(r, c);
533
534 let c2 = c + 1;
535 for i in 0..r {
536 container[i] = self.data[i * c2];
537 }
538 container
539 }
540
541 /// Transpose
542 ///
543 /// # Examples
544 /// ```rust
545 /// use peroxide::fuga::*;
546 /// use peroxide::complex::matrix::*;
547 ///
548 /// let a = cmatrix(vec![C64::new(1f64, 1f64),
549 /// C64::new(2f64, 2f64),
550 /// C64::new(3f64, 3f64),
551 /// C64::new(4f64, 4f64)],
552 /// 2, 2, Row
553 /// );
554 /// let a_t = cmatrix(vec![C64::new(1f64, 1f64),
555 /// C64::new(2f64, 2f64),
556 /// C64::new(3f64, 3f64),
557 /// C64::new(4f64, 4f64)],
558 /// 2, 2, Col
559 /// );
560 ///
561 /// assert_eq!(a.transpose(), a_t);
562 /// ```
563 fn transpose(&self) -> Self {
564 match self.shape {
565 Shape::Row => cmatrix(self.data.clone(), self.col, self.row, Shape::Col),
566 Shape::Col => cmatrix(self.data.clone(), self.col, self.row, Shape::Row),
567 }
568 }
569
570 /// Substitute Col
571 #[inline]
572 fn subs_col(&mut self, idx: usize, v: &[C64]) {
573 for i in 0..self.row {
574 self[(i, idx)] = v[i];
575 }
576 }
577
578 /// Substitute Row
579 #[inline]
580 fn subs_row(&mut self, idx: usize, v: &[C64]) {
581 for j in 0..self.col {
582 self[(idx, j)] = v[j];
583 }
584 }
585
586 /// From index operations
587 fn from_index<F, G>(f: F, size: (usize, usize)) -> ComplexMatrix
588 where
589 F: Fn(usize, usize) -> G + Copy,
590 G: Into<C64>,
591 {
592 let row = size.0;
593 let col = size.1;
594
595 let mut mat = cmatrix(vec![Complex::zero(); row * col], row, col, Shape::Row);
596
597 for i in 0..row {
598 for j in 0..col {
599 mat[(i, j)] = f(i, j).into();
600 }
601 }
602 mat
603 }
604
605 /// Matrix to `Vec<Vec<C64>>`
606 ///
607 /// To send `Matrix` to `inline-python`
608 fn to_vec(&self) -> Vec<Vec<C64>> {
609 let mut result = vec![vec![Complex::zero(); self.col]; self.row];
610 for i in 0..self.row {
611 result[i] = self.row(i);
612 }
613 result
614 }
615
616 fn to_diag(&self) -> ComplexMatrix {
617 assert_eq!(self.row, self.col, "Should be square matrix");
618 let mut result = cmatrix(
619 vec![Complex::zero(); self.row * self.col],
620 self.row,
621 self.col,
622 Shape::Row,
623 );
624 let diag = self.diag();
625 for i in 0..self.row {
626 result[(i, i)] = diag[i];
627 }
628 result
629 }
630
631 /// Submatrix
632 ///
633 /// # Description
634 /// Return below elements of complex matrix to a new complex matrix
635 ///
636 /// $$
637 /// \begin{pmatrix}
638 /// \\ddots & & & & \\\\
639 /// & start & \\cdots & end.1 & \\\\
640 /// & \\vdots & \\ddots & \\vdots & \\\\
641 /// & end.0 & \\cdots & end & \\\\
642 /// & & & & \\ddots
643 /// \end{pmatrix}
644 /// $$
645 ///
646 /// # Examples
647 /// ```rust
648 /// #[macro_use]
649 /// extern crate peroxide;
650 /// use peroxide::fuga::*;
651 /// use peroxide::complex::matrix::*;
652 ///
653 /// fn main() {
654 /// let a = ml_cmatrix("1.0+1.0i 2.0+2.0i 3.0+3.0i;
655 /// 4.0+4.0i 5.0+5.0i 6.0+6.0i;
656 /// 7.0+7.0i 8.0+8.0i 9.0+9.0i");
657 /// let b = cmatrix(vec![C64::new(5f64, 5f64),
658 /// C64::new(6f64, 6f64),
659 /// C64::new(8f64, 8f64),
660 /// C64::new(9f64, 9f64)],
661 /// 2, 2, Row
662 /// );
663 /// let c = a.submat((1, 1), (2, 2));
664 /// assert_eq!(b, c);
665 /// }
666 /// ```
667 fn submat(&self, start: (usize, usize), end: (usize, usize)) -> ComplexMatrix {
668 let row = end.0 - start.0 + 1;
669 let col = end.1 - start.1 + 1;
670 let mut result = cmatrix(vec![Complex::zero(); row * col], row, col, self.shape);
671 for i in 0..row {
672 for j in 0..col {
673 result[(i, j)] = self[(start.0 + i, start.1 + j)];
674 }
675 }
676 result
677 }
678
679 /// Substitute complex matrix to specific position
680 ///
681 /// # Description
682 /// Substitute below elements of complex matrix
683 ///
684 /// $$
685 /// \begin{pmatrix}
686 /// \\ddots & & & & \\\\
687 /// & start & \\cdots & end.1 & \\\\
688 /// & \\vdots & \\ddots & \\vdots & \\\\
689 /// & end.0 & \\cdots & end & \\\\
690 /// & & & & \\ddots
691 /// \end{pmatrix}
692 /// $$
693 ///
694 /// # Examples
695 /// ```
696 /// extern crate peroxide;
697 /// use peroxide::fuga::*;
698 /// use peroxide::complex::matrix::*;
699 ///
700 /// fn main() {
701 /// let mut a = ml_cmatrix("1.0+1.0i 2.0+2.0i 3.0+3.0i;
702 /// 4.0+4.0i 5.0+5.0i 6.0+6.0i;
703 /// 7.0+7.0i 8.0+8.0i 9.0+9.0i");
704 /// let b = cmatrix(vec![C64::new(1f64, 1f64),
705 /// C64::new(2f64, 2f64),
706 /// C64::new(3f64, 3f64),
707 /// C64::new(4f64, 4f64)],
708 /// 2, 2, Row);
709 /// let c = ml_cmatrix("1.0+1.0i 2.0+2.0i 3.0+3.0i;
710 /// 4.0+4.0i 1.0+1.0i 2.0+2.0i;
711 /// 7.0+7.0i 3.0+3.0i 4.0+4.0i");
712 /// a.subs_mat((1,1), (2,2), &b);
713 /// assert_eq!(a, c);
714 /// }
715 /// ```
716 fn subs_mat(&mut self, start: (usize, usize), end: (usize, usize), m: &ComplexMatrix) {
717 let row = end.0 - start.0 + 1;
718 let col = end.1 - start.1 + 1;
719 for i in 0..row {
720 for j in 0..col {
721 self[(start.0 + i, start.1 + j)] = m[(i, j)];
722 }
723 }
724 }
725}
726
727// =============================================================================
728// Mathematics for Matrix
729// =============================================================================
730impl Vector for ComplexMatrix {
731 type Scalar = C64;
732
733 fn add_vec(&self, other: &Self) -> Self {
734 assert_eq!(self.row, other.row);
735 assert_eq!(self.col, other.col);
736
737 let mut result = cmatrix(self.data.clone(), self.row, self.col, self.shape);
738 for i in 0..self.row {
739 for j in 0..self.col {
740 result[(i, j)] += other[(i, j)];
741 }
742 }
743 result
744 }
745
746 fn sub_vec(&self, other: &Self) -> Self {
747 assert_eq!(self.row, other.row);
748 assert_eq!(self.col, other.col);
749
750 let mut result = cmatrix(self.data.clone(), self.row, self.col, self.shape);
751 for i in 0..self.row {
752 for j in 0..self.col {
753 result[(i, j)] -= other[(i, j)];
754 }
755 }
756 result
757 }
758
759 fn mul_scalar(&self, other: Self::Scalar) -> Self {
760 let scalar = other;
761 self.fmap(|x| x * scalar)
762 }
763}
764
765impl Normed for ComplexMatrix {
766 type UnsignedScalar = f64;
767
768 fn norm(&self, kind: Norm) -> Self::UnsignedScalar {
769 match kind {
770 Norm::F => {
771 let mut s = Complex::zero();
772 for i in 0..self.data.len() {
773 s += self.data[i].powi(2);
774 }
775 s.sqrt().re
776 }
777 Norm::Lpq(p, q) => {
778 let mut s = Complex::zero();
779 for j in 0..self.col {
780 let mut s_row = Complex::zero();
781 for i in 0..self.row {
782 s_row += self[(i, j)].powi(p as i32);
783 }
784 s += s_row.powf(q / p);
785 }
786 s.powf(1f64 / q).re
787 }
788 Norm::L1 => {
789 let mut m = Complex::zero();
790 match self.shape {
791 Shape::Row => self.change_shape().norm(Norm::L1),
792 Shape::Col => {
793 for c in 0..self.col {
794 let s: C64 = self.col(c).iter().sum();
795 if s.re > m.re {
796 m = s;
797 }
798 }
799 m.re
800 }
801 }
802 }
803 Norm::LInf => {
804 let mut m = Complex::zero();
805 match self.shape {
806 Shape::Col => self.change_shape().norm(Norm::LInf),
807 Shape::Row => {
808 for r in 0..self.row {
809 let s: C64 = self.row(r).iter().sum();
810 if s.re > m.re {
811 m = s;
812 }
813 }
814 m.re
815 }
816 }
817 }
818 Norm::L2 => {
819 unimplemented!()
820 }
821 Norm::Lp(_) => unimplemented!(),
822 }
823 }
824
825 fn normalize(&self, _kind: Norm) -> Self
826 where
827 Self: Sized,
828 {
829 unimplemented!()
830 }
831}
832
833/// Frobenius inner product
834impl InnerProduct for ComplexMatrix {
835 fn dot(&self, rhs: &Self) -> C64 {
836 if self.shape == rhs.shape {
837 self.data.dot(&rhs.data)
838 } else {
839 self.data.dot(&rhs.change_shape().data)
840 }
841 }
842}
843
844/// TODO: Transpose
845
846/// Matrix as Linear operator for Vector
847#[allow(non_snake_case)]
848impl LinearOp<Vec<C64>, Vec<C64>> for ComplexMatrix {
849 fn apply(&self, other: &Vec<C64>) -> Vec<C64> {
850 assert_eq!(self.col, other.len());
851 let mut c = vec![Complex::zero(); self.row];
852 cgemv(Complex::one(), self, other, Complex::zero(), &mut c);
853 c
854 }
855}
856
857/// R like cbind - concatenate two comlex matrix by column direction
858pub fn complex_cbind(m1: ComplexMatrix, m2: ComplexMatrix) -> Result<ComplexMatrix> {
859 let mut temp = m1;
860 if temp.shape != Shape::Col {
861 temp = temp.change_shape();
862 }
863
864 let mut temp2 = m2;
865 if temp2.shape != Shape::Col {
866 temp2 = temp2.change_shape();
867 }
868
869 let mut v = temp.data;
870 let mut c = temp.col;
871 let r = temp.row;
872
873 if r != temp2.row {
874 bail!(ConcatenateError::DifferentLength);
875 }
876 v.extend_from_slice(&temp2.data[..]);
877 c += temp2.col;
878 Ok(cmatrix(v, r, c, Shape::Col))
879}
880
881/// R like rbind - concatenate two complex matrix by row direction
882pub fn complex_rbind(m1: ComplexMatrix, m2: ComplexMatrix) -> Result<ComplexMatrix> {
883 let mut temp = m1;
884 if temp.shape != Shape::Row {
885 temp = temp.change_shape();
886 }
887
888 let mut temp2 = m2;
889 if temp2.shape != Shape::Row {
890 temp2 = temp2.change_shape();
891 }
892
893 let mut v = temp.data;
894 let c = temp.col;
895 let mut r = temp.row;
896
897 if c != temp2.col {
898 bail!(ConcatenateError::DifferentLength);
899 }
900 v.extend_from_slice(&temp2.data[..]);
901 r += temp2.row;
902 Ok(cmatrix(v, r, c, Shape::Row))
903}
904
905impl MatrixProduct for ComplexMatrix {
906 fn kronecker(&self, other: &Self) -> Self {
907 let r1 = self.row;
908 let c1 = self.col;
909
910 let mut result = self[(0, 0)] * other;
911
912 for j in 1..c1 {
913 let n = self[(0, j)] * other;
914 result = complex_cbind(result, n).unwrap();
915 }
916
917 for i in 1..r1 {
918 let mut m = self[(i, 0)] * other;
919 for j in 1..c1 {
920 let n = self[(i, j)] * other;
921 m = complex_cbind(m, n).unwrap();
922 }
923 result = complex_rbind(result, m).unwrap();
924 }
925 result
926 }
927
928 fn hadamard(&self, other: &Self) -> Self {
929 assert_eq!(self.row, other.row);
930 assert_eq!(self.col, other.col);
931
932 let r = self.row;
933 let c = self.col;
934
935 let mut m = cmatrix(vec![Complex::zero(); r * c], r, c, self.shape);
936 for i in 0..r {
937 for j in 0..c {
938 m[(i, j)] = self[(i, j)] * other[(i, j)]
939 }
940 }
941 m
942 }
943}
944
945// =============================================================================
946// Common Properties of Matrix & Vec<f64>
947// =============================================================================
948/// `Complex Matrix` to `Vec<C64>`
949impl Into<Vec<C64>> for ComplexMatrix {
950 fn into(self) -> Vec<C64> {
951 self.data
952 }
953}
954
955/// `&ComplexMatrix` to `&Vec<C64>`
956impl<'a> Into<&'a Vec<C64>> for &'a ComplexMatrix {
957 fn into(self) -> &'a Vec<C64> {
958 &self.data
959 }
960}
961
962/// `Vec<C64>` to `ComplexMatrix`
963impl Into<ComplexMatrix> for Vec<C64> {
964 fn into(self) -> ComplexMatrix {
965 let l = self.len();
966 cmatrix(self, l, 1, Shape::Col)
967 }
968}
969
970/// `&Vec<C64>` to `ComplexMatrix`
971impl Into<ComplexMatrix> for &Vec<C64> {
972 fn into(self) -> ComplexMatrix {
973 let l = self.len();
974 cmatrix(self.clone(), l, 1, Shape::Col)
975 }
976}
977
978// =============================================================================
979// Standard Operation for Complex Matrix (ADD)
980// =============================================================================
981
982/// Element-wise addition of Complex Matrix
983///
984/// # Caution
985/// > You should remember ownership.
986/// > If you use ComplexMatrix `a,b` then you can't use them after.
987impl Add<ComplexMatrix> for ComplexMatrix {
988 type Output = Self;
989
990 fn add(self, other: Self) -> Self {
991 assert_eq!(&self.row, &other.row);
992 assert_eq!(&self.col, &other.col);
993
994 let mut result = cmatrix(self.data.clone(), self.row, self.col, self.shape);
995 for i in 0..self.row {
996 for j in 0..self.col {
997 result[(i, j)] += other[(i, j)];
998 }
999 }
1000 result
1001 }
1002}
1003
1004impl<'a, 'b> Add<&'b ComplexMatrix> for &'a ComplexMatrix {
1005 type Output = ComplexMatrix;
1006
1007 fn add(self, rhs: &'b ComplexMatrix) -> Self::Output {
1008 self.add_vec(rhs)
1009 }
1010}
1011
1012/// Element-wise addition between Complex Matrix & C64
1013///
1014/// # Examples
1015/// ```rust
1016/// #[macro_use]
1017/// extern crate peroxide;
1018/// use peroxide::fuga::*;
1019/// use peroxide::complex::matrix::*;
1020///
1021/// fn main() {
1022/// let mut a = ml_cmatrix("1.0+1.0i 2.0+2.0i;
1023/// 4.0+4.0i 5.0+5.0i");
1024/// let a_exp = ml_cmatrix("2.0+2.0i 3.0+3.0i;
1025/// 5.0+5.0i 6.0+6.0i");
1026/// assert_eq!(a + C64::new(1_f64, 1_f64), a_exp);
1027/// }
1028/// ```
1029impl<T> Add<T> for ComplexMatrix
1030where
1031 T: Into<C64> + Copy,
1032{
1033 type Output = Self;
1034 fn add(self, other: T) -> Self {
1035 self.fmap(|x| x + other.into())
1036 }
1037}
1038
1039/// Element-wise addition between &ComplexMatrix & C64
1040impl<'a, T> Add<T> for &'a ComplexMatrix
1041where
1042 T: Into<C64> + Copy,
1043{
1044 type Output = ComplexMatrix;
1045
1046 fn add(self, other: T) -> Self::Output {
1047 self.fmap(|x| x + other.into())
1048 }
1049}
1050
1051// Element-wise addition between C64 & ComplexMatrix
1052///
1053/// # Examples
1054///
1055/// ```rust
1056/// #[macro_use]
1057/// extern crate peroxide;
1058/// use peroxide::fuga::*;
1059/// use peroxide::complex::matrix::*;
1060///
1061/// fn main() {
1062/// let mut a = ml_cmatrix("1.0+1.0i 2.0+2.0i;
1063/// 4.0+4.0i 5.0+5.0i");
1064/// let a_exp = ml_cmatrix("2.0+2.0i 3.0+3.0i;
1065/// 5.0+5.0i 6.0+6.0i");
1066/// assert_eq!(C64::new(1_f64, 1_f64) + a, a_exp);
1067/// }
1068/// ```
1069impl Add<ComplexMatrix> for C64 {
1070 type Output = ComplexMatrix;
1071
1072 fn add(self, other: ComplexMatrix) -> Self::Output {
1073 other.add(self)
1074 }
1075}
1076
1077/// Element-wise addition between C64 & &ComplexMatrix
1078impl<'a> Add<&'a ComplexMatrix> for C64 {
1079 type Output = ComplexMatrix;
1080
1081 fn add(self, other: &'a ComplexMatrix) -> Self::Output {
1082 other.add(self)
1083 }
1084}
1085
1086// =============================================================================
1087// Standard Operation for Matrix (Neg)
1088// =============================================================================
1089/// Negation of Complex Matrix
1090///
1091/// # Examples
1092/// ```rust
1093/// extern crate peroxide;
1094/// use peroxide::fuga::*;
1095/// use peroxide::complex::matrix::*;
1096///
1097/// let a = cmatrix(vec![C64::new(1f64, 1f64),
1098/// C64::new(2f64, 2f64),
1099/// C64::new(3f64, 3f64),
1100/// C64::new(4f64, 4f64)],
1101/// 2, 2, Row);
1102/// let a_neg = cmatrix(vec![C64::new(-1f64, -1f64),
1103/// C64::new(-2f64, -2f64),
1104/// C64::new(-3f64, -3f64),
1105/// C64::new(-4f64, -4f64)],
1106/// 2, 2, Row);
1107/// assert_eq!(-a, a_neg);
1108/// ```
1109impl Neg for ComplexMatrix {
1110 type Output = Self;
1111
1112 fn neg(self) -> Self {
1113 cmatrix(
1114 self.data.into_iter().map(|x: C64| -x).collect::<Vec<C64>>(),
1115 self.row,
1116 self.col,
1117 self.shape,
1118 )
1119 }
1120}
1121
1122/// Negation of &'a Complex Matrix
1123impl<'a> Neg for &'a ComplexMatrix {
1124 type Output = ComplexMatrix;
1125
1126 fn neg(self) -> Self::Output {
1127 cmatrix(
1128 self.data
1129 .clone()
1130 .into_iter()
1131 .map(|x: C64| -x)
1132 .collect::<Vec<C64>>(),
1133 self.row,
1134 self.col,
1135 self.shape,
1136 )
1137 }
1138}
1139
1140// =============================================================================
1141// Standard Operation for Matrix (Sub)
1142// =============================================================================
1143/// Subtraction between Complex Matrix
1144///
1145/// # Examples
1146/// ```rust
1147/// #[macro_use]
1148/// extern crate peroxide;
1149/// use peroxide::fuga::*;
1150/// use peroxide::complex::matrix::*;
1151///
1152/// fn main() {
1153/// let a = ml_cmatrix("10.0+10.0i 20.0+20.0i;
1154/// 40.0+40.0i 50.0+50.0i");
1155/// let b = ml_cmatrix("1.0+1.0i 2.0+2.0i;
1156/// 4.0+4.0i 5.0+5.0i");
1157/// let diff = ml_cmatrix("9.0+9.0i 18.0+18.0i;
1158/// 36.0+36.0i 45.0+45.0i");
1159/// assert_eq!(a-b, diff);
1160/// }
1161/// ```
1162impl Sub<ComplexMatrix> for ComplexMatrix {
1163 type Output = Self;
1164
1165 fn sub(self, other: Self) -> Self::Output {
1166 assert_eq!(&self.row, &other.row);
1167 assert_eq!(&self.col, &other.col);
1168 let mut result = cmatrix(self.data.clone(), self.row, self.col, self.shape);
1169 for i in 0..self.row {
1170 for j in 0..self.col {
1171 result[(i, j)] -= other[(i, j)];
1172 }
1173 }
1174 result
1175 }
1176}
1177
1178impl<'a, 'b> Sub<&'b ComplexMatrix> for &'a ComplexMatrix {
1179 type Output = ComplexMatrix;
1180
1181 fn sub(self, rhs: &'b ComplexMatrix) -> Self::Output {
1182 self.sub_vec(rhs)
1183 }
1184}
1185
1186/// Subtraction between Complex Matrix & C64
1187impl<T> Sub<T> for ComplexMatrix
1188where
1189 T: Into<C64> + Copy,
1190{
1191 type Output = Self;
1192
1193 fn sub(self, other: T) -> Self::Output {
1194 self.fmap(|x| x - other.into())
1195 }
1196}
1197
1198/// Subtraction between &Complex Matrix & C64
1199impl<'a, T> Sub<T> for &'a ComplexMatrix
1200where
1201 T: Into<C64> + Copy,
1202{
1203 type Output = ComplexMatrix;
1204
1205 fn sub(self, other: T) -> Self::Output {
1206 self.fmap(|x| x - other.into())
1207 }
1208}
1209
1210/// Subtraction Complex Matrix with C64
1211///
1212/// # Examples
1213/// ```rust
1214/// #[macro_use]
1215/// extern crate peroxide;
1216/// use peroxide::fuga::*;
1217/// use peroxide::complex::matrix::*;
1218///
1219/// fn main() {
1220/// let mut a = ml_cmatrix("1.0+1.0i 2.0+2.0i;
1221/// 4.0+4.0i 5.0+5.0i");
1222/// let a_exp = ml_cmatrix("0.0+0.0i 1.0+1.0i;
1223/// 3.0+3.0i 4.0+4.0i");
1224/// assert_eq!(a - C64::new(1_f64, 1_f64), a_exp);
1225/// }
1226/// ```
1227impl Sub<ComplexMatrix> for C64 {
1228 type Output = ComplexMatrix;
1229
1230 fn sub(self, other: ComplexMatrix) -> Self::Output {
1231 -other.sub(self)
1232 }
1233}
1234
1235impl<'a> Sub<&'a ComplexMatrix> for f64 {
1236 type Output = ComplexMatrix;
1237
1238 fn sub(self, other: &'a ComplexMatrix) -> Self::Output {
1239 -other.sub(self)
1240 }
1241}
1242
1243// =============================================================================
1244// Multiplication for Complex Matrix
1245// =============================================================================
1246/// Element-wise multiplication between Complex Matrix vs C64
1247impl Mul<C64> for ComplexMatrix {
1248 type Output = Self;
1249
1250 fn mul(self, other: C64) -> Self::Output {
1251 self.fmap(|x| x * other)
1252 }
1253}
1254
1255impl Mul<ComplexMatrix> for C64 {
1256 type Output = ComplexMatrix;
1257
1258 fn mul(self, other: ComplexMatrix) -> Self::Output {
1259 other.mul(self)
1260 }
1261}
1262
1263impl<'a> Mul<&'a ComplexMatrix> for C64 {
1264 type Output = ComplexMatrix;
1265
1266 fn mul(self, other: &'a ComplexMatrix) -> Self::Output {
1267 other.mul_scalar(self)
1268 }
1269}
1270
1271/// Matrix Multiplication
1272///
1273/// # Examples
1274/// ```rust
1275/// #[macro_use]
1276/// extern crate peroxide;
1277/// use peroxide::fuga::*;
1278/// use peroxide::complex::matrix::*;
1279///
1280/// fn main() {
1281/// let mut a = ml_cmatrix("1.0+1.0i 2.0+2.0i;
1282/// 4.0+4.0i 5.0+5.0i");
1283/// let mut b = ml_cmatrix("2.0+2.0i 2.0+2.0i;
1284/// 5.0+5.0i 5.0+5.0i");
1285/// let prod = ml_cmatrix("0.0+24.0i 0.0+24.0i;
1286/// 0.0+66.0i 0.0+66.0i");
1287/// assert_eq!(a * b, prod);
1288/// }
1289/// ```
1290impl Mul<ComplexMatrix> for ComplexMatrix {
1291 type Output = Self;
1292
1293 fn mul(self, other: Self) -> Self::Output {
1294 cmatmul(&self, &other)
1295 }
1296}
1297
1298impl<'a, 'b> Mul<&'b ComplexMatrix> for &'a ComplexMatrix {
1299 type Output = ComplexMatrix;
1300
1301 fn mul(self, other: &'b ComplexMatrix) -> Self::Output {
1302 cmatmul(self, other)
1303 }
1304}
1305
1306#[allow(non_snake_case)]
1307impl Mul<Vec<C64>> for ComplexMatrix {
1308 type Output = Vec<C64>;
1309
1310 fn mul(self, other: Vec<C64>) -> Self::Output {
1311 self.apply(&other)
1312 }
1313}
1314
1315#[allow(non_snake_case)]
1316impl<'a, 'b> Mul<&'b Vec<C64>> for &'a ComplexMatrix {
1317 type Output = Vec<C64>;
1318
1319 fn mul(self, other: &'b Vec<C64>) -> Self::Output {
1320 self.apply(other)
1321 }
1322}
1323
1324/// Matrix multiplication for `Vec<C64>` vs `ComplexMatrix`
1325impl Mul<ComplexMatrix> for Vec<C64> {
1326 type Output = Vec<C64>;
1327
1328 fn mul(self, other: ComplexMatrix) -> Self::Output {
1329 assert_eq!(self.len(), other.row);
1330 let mut c = vec![Complex::zero(); other.col];
1331 complex_gevm(Complex::one(), &self, &other, Complex::zero(), &mut c);
1332 c
1333 }
1334}
1335
1336impl<'a, 'b> Mul<&'b ComplexMatrix> for &'a Vec<C64> {
1337 type Output = Vec<C64>;
1338
1339 fn mul(self, other: &'b ComplexMatrix) -> Self::Output {
1340 assert_eq!(self.len(), other.row);
1341 let mut c = vec![Complex::zero(); other.col];
1342 complex_gevm(Complex::one(), self, other, Complex::zero(), &mut c);
1343 c
1344 }
1345}
1346
1347// =============================================================================
1348// Standard Operation for Matrix (DIV)
1349// =============================================================================
1350/// Element-wise division between Complex Matrix vs C64
1351impl Div<C64> for ComplexMatrix {
1352 type Output = Self;
1353
1354 fn div(self, other: C64) -> Self::Output {
1355 self.fmap(|x| x / other)
1356 }
1357}
1358
1359impl<'a> Div<C64> for &'a ComplexMatrix {
1360 type Output = ComplexMatrix;
1361
1362 fn div(self, other: C64) -> Self::Output {
1363 self.fmap(|x| x / other)
1364 }
1365}
1366
1367/// Index for Complex Matrix
1368///
1369/// `(usize, usize) -> C64`
1370///
1371/// # Examples
1372/// ```rust
1373/// extern crate peroxide;
1374/// use peroxide::fuga::*;
1375/// use peroxide::complex::matrix::*;
1376///
1377/// let a = cmatrix(vec![C64::new(1f64, 1f64),
1378/// C64::new(2f64, 2f64),
1379/// C64::new(3f64, 3f64),
1380/// C64::new(4f64, 4f64)],
1381/// 2, 2, Row
1382/// );
1383/// assert_eq!(a[(0,1)], C64::new(2f64, 2f64));
1384/// ```
1385impl Index<(usize, usize)> for ComplexMatrix {
1386 type Output = C64;
1387
1388 fn index(&self, pair: (usize, usize)) -> &C64 {
1389 let p = self.ptr();
1390 let i = pair.0;
1391 let j = pair.1;
1392 assert!(i < self.row && j < self.col, "Index out of range");
1393 match self.shape {
1394 Shape::Row => unsafe { &*p.add(i * self.col + j) },
1395 Shape::Col => unsafe { &*p.add(i + j * self.row) },
1396 }
1397 }
1398}
1399
1400/// IndexMut for Complex Matrix (Assign)
1401///
1402/// `(usize, usize) -> C64`
1403///
1404/// # Examples
1405/// ```rust
1406/// extern crate peroxide;
1407/// use peroxide::fuga::*;
1408/// use peroxide::complex::matrix::*;
1409///
1410/// let mut a = cmatrix(vec![C64::new(1f64, 1f64),
1411/// C64::new(2f64, 2f64),
1412/// C64::new(3f64, 3f64),
1413/// C64::new(4f64, 4f64)],
1414/// 2, 2, Row
1415/// );
1416/// assert_eq!(a[(0,1)], C64::new(2f64, 2f64));
1417/// ```
1418impl IndexMut<(usize, usize)> for ComplexMatrix {
1419 fn index_mut(&mut self, pair: (usize, usize)) -> &mut C64 {
1420 let i = pair.0;
1421 let j = pair.1;
1422 let r = self.row;
1423 let c = self.col;
1424 assert!(i < self.row && j < self.col, "Index out of range");
1425 let p = self.mut_ptr();
1426 match self.shape {
1427 Shape::Row => {
1428 let idx = i * c + j;
1429 unsafe { &mut *p.add(idx) }
1430 }
1431 Shape::Col => {
1432 let idx = i + j * r;
1433 unsafe { &mut *p.add(idx) }
1434 }
1435 }
1436 }
1437}
1438
1439// =============================================================================
1440// Functional Programming Tools (Hand-written)
1441// =============================================================================
1442
1443impl FPMatrix for ComplexMatrix {
1444 type Scalar = C64;
1445
1446 fn take_row(&self, n: usize) -> Self {
1447 if n >= self.row {
1448 return self.clone();
1449 }
1450 match self.shape {
1451 Shape::Row => {
1452 let new_data = self
1453 .data
1454 .clone()
1455 .into_iter()
1456 .take(n * self.col)
1457 .collect::<Vec<C64>>();
1458 cmatrix(new_data, n, self.col, Shape::Row)
1459 }
1460 Shape::Col => {
1461 let mut temp_data: Vec<C64> = Vec::new();
1462 for i in 0..n {
1463 temp_data.extend(self.row(i));
1464 }
1465 cmatrix(temp_data, n, self.col, Shape::Row)
1466 }
1467 }
1468 }
1469
1470 fn take_col(&self, n: usize) -> Self {
1471 if n >= self.col {
1472 return self.clone();
1473 }
1474 match self.shape {
1475 Shape::Col => {
1476 let new_data = self
1477 .data
1478 .clone()
1479 .into_iter()
1480 .take(n * self.row)
1481 .collect::<Vec<C64>>();
1482 cmatrix(new_data, self.row, n, Shape::Col)
1483 }
1484 Shape::Row => {
1485 let mut temp_data: Vec<C64> = Vec::new();
1486 for i in 0..n {
1487 temp_data.extend(self.col(i));
1488 }
1489 cmatrix(temp_data, self.row, n, Shape::Col)
1490 }
1491 }
1492 }
1493
1494 fn skip_row(&self, n: usize) -> Self {
1495 assert!(n < self.row, "Skip range is larger than row of matrix");
1496
1497 let mut temp_data: Vec<C64> = Vec::new();
1498 for i in n..self.row {
1499 temp_data.extend(self.row(i));
1500 }
1501 cmatrix(temp_data, self.row - n, self.col, Shape::Row)
1502 }
1503
1504 fn skip_col(&self, n: usize) -> Self {
1505 assert!(n < self.col, "Skip range is larger than col of matrix");
1506
1507 let mut temp_data: Vec<C64> = Vec::new();
1508 for i in n..self.col {
1509 temp_data.extend(self.col(i));
1510 }
1511 cmatrix(temp_data, self.row, self.col - n, Shape::Col)
1512 }
1513
1514 fn fmap<F>(&self, f: F) -> Self
1515 where
1516 F: Fn(C64) -> C64,
1517 {
1518 let result = self.data.iter().map(|x| f(*x)).collect::<Vec<C64>>();
1519 cmatrix(result, self.row, self.col, self.shape)
1520 }
1521
1522 /// Column map
1523 ///
1524 /// # Example
1525 /// ```rust
1526 /// use peroxide::fuga::*;
1527 /// use peroxide::complex::matrix::*;
1528 /// use peroxide::traits::fp::FPMatrix;
1529 ///
1530 /// fn main() {
1531 /// let x = cmatrix(vec![C64::new(1f64, 1f64),
1532 /// C64::new(2f64, 2f64),
1533 /// C64::new(3f64, 3f64),
1534 /// C64::new(4f64, 4f64)],
1535 /// 2, 2, Row
1536 /// );
1537 /// let y = x.col_map(|r| r.fmap(|t| t + r[0]));
1538 ///
1539 /// let y_col_map = cmatrix(vec![C64::new(2f64, 2f64),
1540 /// C64::new(4f64, 4f64),
1541 /// C64::new(4f64, 4f64),
1542 /// C64::new(6f64, 6f64)],
1543 /// 2, 2, Col
1544 /// );
1545 ///
1546 /// assert_eq!(y, y_col_map);
1547 /// }
1548 /// ```
1549 fn col_map<F>(&self, f: F) -> ComplexMatrix
1550 where
1551 F: Fn(Vec<C64>) -> Vec<C64>,
1552 {
1553 let mut result = cmatrix(
1554 vec![Complex::zero(); self.row * self.col],
1555 self.row,
1556 self.col,
1557 Shape::Col,
1558 );
1559
1560 for i in 0..self.col {
1561 result.subs_col(i, &f(self.col(i)));
1562 }
1563
1564 result
1565 }
1566
1567 /// Row map
1568 ///
1569 /// # Example
1570 /// ```rust
1571 /// use peroxide::fuga::*;
1572 /// use peroxide::complex::matrix::*;
1573 /// use peroxide::traits::fp::FPMatrix;
1574 ///
1575 /// fn main() {
1576 /// let x = cmatrix(vec![C64::new(1f64, 1f64),
1577 /// C64::new(2f64, 2f64),
1578 /// C64::new(3f64, 3f64),
1579 /// C64::new(4f64, 4f64)],
1580 /// 2, 2, Row
1581 /// );
1582 /// let y = x.row_map(|r| r.fmap(|t| t + r[0]));
1583 ///
1584 /// let y_row_map = cmatrix(vec![C64::new(2f64, 2f64),
1585 /// C64::new(3f64, 3f64),
1586 /// C64::new(6f64, 6f64),
1587 /// C64::new(7f64, 7f64)],
1588 /// 2, 2, Row
1589 /// );
1590 ///
1591 /// assert_eq!(y, y_row_map);
1592 /// }
1593 /// ```
1594 fn row_map<F>(&self, f: F) -> ComplexMatrix
1595 where
1596 F: Fn(Vec<C64>) -> Vec<C64>,
1597 {
1598 let mut result = cmatrix(
1599 vec![Complex::zero(); self.row * self.col],
1600 self.row,
1601 self.col,
1602 Shape::Row,
1603 );
1604
1605 for i in 0..self.row {
1606 result.subs_row(i, &f(self.row(i)));
1607 }
1608
1609 result
1610 }
1611
1612 fn col_mut_map<F>(&mut self, f: F)
1613 where
1614 F: Fn(Vec<C64>) -> Vec<C64>,
1615 {
1616 for i in 0..self.col {
1617 unsafe {
1618 let mut p = self.col_mut(i);
1619 let fv = f(self.col(i));
1620 for j in 0..p.len() {
1621 *p[j] = fv[j];
1622 }
1623 }
1624 }
1625 }
1626
1627 fn row_mut_map<F>(&mut self, f: F)
1628 where
1629 F: Fn(Vec<C64>) -> Vec<C64>,
1630 {
1631 for i in 0..self.col {
1632 unsafe {
1633 let mut p = self.row_mut(i);
1634 let fv = f(self.row(i));
1635 for j in 0..p.len() {
1636 *p[j] = fv[j];
1637 }
1638 }
1639 }
1640 }
1641
1642 fn reduce<F, T>(&self, init: T, f: F) -> C64
1643 where
1644 F: Fn(C64, C64) -> C64,
1645 T: Into<C64>,
1646 {
1647 self.data.iter().fold(init.into(), |x, y| f(x, *y))
1648 }
1649
1650 fn zip_with<F>(&self, f: F, other: &ComplexMatrix) -> Self
1651 where
1652 F: Fn(C64, C64) -> C64,
1653 {
1654 assert_eq!(self.data.len(), other.data.len());
1655 let mut a = other.clone();
1656 if self.shape != other.shape {
1657 a = a.change_shape();
1658 }
1659 let result = self
1660 .data
1661 .iter()
1662 .zip(a.data.iter())
1663 .map(|(x, y)| f(*x, *y))
1664 .collect::<Vec<C64>>();
1665 cmatrix(result, self.row, self.col, self.shape)
1666 }
1667
1668 fn col_reduce<F>(&self, f: F) -> Vec<C64>
1669 where
1670 F: Fn(Vec<C64>) -> C64,
1671 {
1672 let mut v = vec![Complex::zero(); self.col];
1673 for i in 0..self.col {
1674 v[i] = f(self.col(i));
1675 }
1676 v
1677 }
1678
1679 fn row_reduce<F>(&self, f: F) -> Vec<C64>
1680 where
1681 F: Fn(Vec<C64>) -> C64,
1682 {
1683 let mut v = vec![Complex::zero(); self.row];
1684 for i in 0..self.row {
1685 v[i] = f(self.row(i));
1686 }
1687 v
1688 }
1689}
1690
1691pub fn cdiag(n: usize) -> ComplexMatrix {
1692 let mut v: Vec<C64> = vec![Complex::zero(); n * n];
1693 for i in 0..n {
1694 let idx = i * (n + 1);
1695 v[idx] = Complex::one();
1696 }
1697 cmatrix(v, n, n, Shape::Row)
1698}
1699
1700impl PQLU<ComplexMatrix> {
1701 /// Extract PQLU
1702 ///
1703 /// # Usage
1704 /// ```rust
1705 /// extern crate peroxide;
1706 /// use peroxide::fuga::*;
1707 ///
1708 /// let a = cmatrix(vec![C64::new(1f64, 1f64),
1709 /// C64::new(2f64, 2f64),
1710 /// C64::new(3f64, 3f64),
1711 /// C64::new(4f64, 4f64)],
1712 /// 2, 2, Row
1713 /// );
1714 /// let pqlu = a.lu();
1715 /// let (p, q, l, u) = pqlu.extract();
1716 /// // p, q are permutations
1717 /// // l, u are matrices
1718 /// println!("{}", l); // lower triangular
1719 /// println!("{}", u); // upper triangular
1720 /// ```
1721 pub fn extract(&self) -> (Vec<usize>, Vec<usize>, ComplexMatrix, ComplexMatrix) {
1722 (
1723 self.p.clone(),
1724 self.q.clone(),
1725 self.l.clone(),
1726 self.u.clone(),
1727 )
1728 }
1729
1730 pub fn det(&self) -> C64 {
1731 // sgn of perms
1732 let mut sgn_p = 1f64;
1733 let mut sgn_q = 1f64;
1734 for (i, &j) in self.p.iter().enumerate() {
1735 if i != j {
1736 sgn_p *= -1f64;
1737 }
1738 }
1739 for (i, &j) in self.q.iter().enumerate() {
1740 if i != j {
1741 sgn_q *= -1f64;
1742 }
1743 }
1744
1745 self.u.diag().reduce(Complex::one(), |x, y| x * y) * sgn_p * sgn_q
1746 }
1747
1748 pub fn inv(&self) -> ComplexMatrix {
1749 let (p, q, l, u) = self.extract();
1750 let mut m = complex_inv_u(u) * complex_inv_l(l);
1751 // Q = Q1 Q2 Q3 ..
1752 for (idx1, idx2) in q.into_iter().enumerate().rev() {
1753 unsafe {
1754 m.swap(idx1, idx2, Shape::Row);
1755 }
1756 }
1757 // P = Pn-1 .. P3 P2 P1
1758 for (idx1, idx2) in p.into_iter().enumerate().rev() {
1759 unsafe {
1760 m.swap(idx1, idx2, Shape::Col);
1761 }
1762 }
1763 m
1764 }
1765}
1766
1767/// MATLAB like eye - Identity matrix
1768pub fn ceye(n: usize) -> ComplexMatrix {
1769 let mut m = cmatrix(vec![Complex::zero(); n * n], n, n, Shape::Row);
1770 for i in 0..n {
1771 m[(i, i)] = Complex::one();
1772 }
1773 m
1774}
1775
1776// =============================================================================
1777// Linear Algebra
1778// =============================================================================
1779
1780impl LinearAlgebra<ComplexMatrix> for ComplexMatrix {
1781 /// Backward Substitution for Upper Triangular
1782 fn back_subs(&self, b: &[C64]) -> Vec<C64> {
1783 let n = self.col;
1784 let mut y = vec![Complex::zero(); n];
1785 y[n - 1] = b[n - 1] / self[(n - 1, n - 1)];
1786 for i in (0..n - 1).rev() {
1787 let mut s = Complex::zero();
1788 for j in i + 1..n {
1789 s += self[(i, j)] * y[j];
1790 }
1791 y[i] = 1f64 / self[(i, i)] * (b[i] - s);
1792 }
1793 y
1794 }
1795
1796 /// Forward substitution for Lower Triangular
1797 fn forward_subs(&self, b: &[C64]) -> Vec<C64> {
1798 let n = self.col;
1799 let mut y = vec![Complex::zero(); n];
1800 y[0] = b[0] / self[(0, 0)];
1801 for i in 1..n {
1802 let mut s = Complex::zero();
1803 for j in 0..i {
1804 s += self[(i, j)] * y[j];
1805 }
1806 y[i] = 1f64 / self[(i, i)] * (b[i] - s);
1807 }
1808 y
1809 }
1810
1811 /// LU Decomposition Implements (Complete Pivot)
1812 ///
1813 /// # Description
1814 /// It use complete pivoting LU decomposition.
1815 /// You can get two permutations, and LU matrices.
1816 ///
1817 /// # Caution
1818 /// It returns `Option<PQLU>` - You should unwrap to obtain real value.
1819 /// `PQLU` has four field - `p`, `q`, `l`, `u`.
1820 /// `p`, `q` are permutations.
1821 /// `l`, `u` are matrices.
1822 ///
1823 /// # Examples
1824 /// ```
1825 /// #[macro_use]
1826 /// use peroxide::fuga::*;
1827 /// use peroxide::complex::matrix::*;
1828 ///
1829 /// fn main() {
1830 /// let a = cmatrix(vec![
1831 /// C64::new(1f64, 1f64),
1832 /// C64::new(2f64, 2f64),
1833 /// C64::new(3f64, 3f64),
1834 /// C64::new(4f64, 4f64)
1835 /// ],
1836 /// 2, 2, Row
1837 /// );
1838 ///
1839 /// let l_exp = cmatrix(vec![
1840 /// C64::new(1f64, 0f64),
1841 /// C64::new(0f64, 0f64),
1842 /// C64::new(0.5f64, -0.0f64),
1843 /// C64::new(1f64, 0f64)
1844 /// ],
1845 /// 2, 2, Row
1846 /// );
1847 ///
1848 /// let u_exp = cmatrix(vec![
1849 /// C64::new(4f64, 4f64),
1850 /// C64::new(3f64, 3f64),
1851 /// C64::new(0f64, 0f64),
1852 /// C64::new(-0.5f64, -0.5f64)
1853 /// ],
1854 /// 2, 2, Row
1855 /// );
1856 /// let pqlu = a.lu();
1857 /// let (p,q,l,u) = (pqlu.p, pqlu.q, pqlu.l, pqlu.u);
1858 /// assert_eq!(p, vec![1]); // swap 0 & 1 (Row)
1859 /// assert_eq!(q, vec![1]); // swap 0 & 1 (Col)
1860 /// assert_eq!(l, l_exp);
1861 /// assert_eq!(u, u_exp);
1862 /// }
1863 /// ```
1864 fn lu(&self) -> PQLU<ComplexMatrix> {
1865 assert_eq!(self.col, self.row);
1866 let n = self.row;
1867 let len: usize = n * n;
1868
1869 let mut l = ceye(n);
1870 let mut u = cmatrix(vec![Complex::zero(); len], n, n, self.shape);
1871
1872 let mut temp = self.clone();
1873 let (p, q) = gecp(&mut temp);
1874 for i in 0..n {
1875 for j in 0..i {
1876 // Inverse multiplier
1877 l[(i, j)] = -temp[(i, j)];
1878 }
1879 for j in i..n {
1880 u[(i, j)] = temp[(i, j)];
1881 }
1882 }
1883 // Pivoting L
1884 for i in 0..n - 1 {
1885 unsafe {
1886 let l_i = l.col_mut(i);
1887 for j in i + 1..l.col - 1 {
1888 let dst = p[j];
1889 std::ptr::swap(l_i[j], l_i[dst]);
1890 }
1891 }
1892 }
1893 PQLU { p, q, l, u }
1894 }
1895
1896 fn waz(&self, _d_form: Form) -> Option<WAZD<ComplexMatrix>> {
1897 unimplemented!()
1898 }
1899
1900 fn qr(&self) -> QR<ComplexMatrix> {
1901 unimplemented!()
1902 }
1903
1904 fn svd(&self) -> SVD<ComplexMatrix> {
1905 unimplemented!()
1906 }
1907
1908 #[cfg(feature = "O3")]
1909 fn cholesky(&self, uplo: UPLO) -> ComplexMatrix {
1910 unimplemented!()
1911 }
1912
1913 fn rref(&self) -> ComplexMatrix {
1914 unimplemented!()
1915 }
1916
1917 /// Determinant
1918 ///
1919 /// # Examples
1920 /// ```
1921 /// #[macro_use]
1922 /// use peroxide::fuga::*;
1923 /// use peroxide::complex::matrix::*;
1924 ///
1925 /// fn main() {
1926 /// let a = cmatrix(vec![
1927 /// C64::new(1f64, 1f64),
1928 /// C64::new(2f64, 2f64),
1929 /// C64::new(3f64, 3f64),
1930 /// C64::new(4f64, 4f64)
1931 /// ],
1932 /// 2, 2, Row
1933 /// );
1934 /// assert_eq!(a.det().norm(), 4f64);
1935 /// }
1936 /// ```
1937 fn det(&self) -> C64 {
1938 assert_eq!(self.row, self.col);
1939 self.lu().det()
1940 }
1941
1942 /// Block Partition
1943 ///
1944 /// # Examples
1945 /// ```rust
1946 /// #[macro_use]
1947 /// extern crate peroxide;
1948 /// use peroxide::fuga::*;
1949 /// use peroxide::complex::matrix::*;
1950 ///
1951 /// fn main() {
1952 /// let a = cmatrix(vec![
1953 /// C64::new(1f64, 1f64),
1954 /// C64::new(2f64, 2f64),
1955 /// C64::new(3f64, 3f64),
1956 /// C64::new(4f64, 4f64)
1957 /// ],
1958 /// 2, 2, Row
1959 /// );
1960 /// let (m1, m2, m3, m4) = a.block();
1961 /// assert_eq!(m1, ml_cmatrix("1.0+1.0i"));
1962 /// assert_eq!(m2, ml_cmatrix("2.0+2.0i"));
1963 /// assert_eq!(m3, ml_cmatrix("3.0+3.0i"));
1964 /// assert_eq!(m4, ml_cmatrix("4.0+4.0i"));
1965 /// }
1966 /// ```
1967 fn block(&self) -> (Self, Self, Self, Self) {
1968 let r = self.row;
1969 let c = self.col;
1970 let l_r = self.row / 2;
1971 let l_c = self.col / 2;
1972 let r_l = r - l_r;
1973 let c_l = c - l_c;
1974
1975 let mut m1 = cmatrix(vec![Complex::zero(); l_r * l_c], l_r, l_c, self.shape);
1976 let mut m2 = cmatrix(vec![Complex::zero(); l_r * c_l], l_r, c_l, self.shape);
1977 let mut m3 = cmatrix(vec![Complex::zero(); r_l * l_c], r_l, l_c, self.shape);
1978 let mut m4 = cmatrix(vec![Complex::zero(); r_l * c_l], r_l, c_l, self.shape);
1979
1980 for idx_row in 0..r {
1981 for idx_col in 0..c {
1982 match (idx_row, idx_col) {
1983 (i, j) if (i < l_r) && (j < l_c) => {
1984 m1[(i, j)] = self[(i, j)];
1985 }
1986 (i, j) if (i < l_r) && (j >= l_c) => {
1987 m2[(i, j - l_c)] = self[(i, j)];
1988 }
1989 (i, j) if (i >= l_r) && (j < l_c) => {
1990 m3[(i - l_r, j)] = self[(i, j)];
1991 }
1992 (i, j) if (i >= l_r) && (j >= l_c) => {
1993 m4[(i - l_r, j - l_c)] = self[(i, j)];
1994 }
1995 _ => (),
1996 }
1997 }
1998 }
1999 (m1, m2, m3, m4)
2000 }
2001
2002 /// Inverse of Matrix
2003 ///
2004 /// # Caution
2005 ///
2006 /// `inv` function returns `Option<Matrix>`
2007 /// Thus, you should use pattern matching or `unwrap` to obtain inverse.
2008 ///
2009 /// # Examples
2010 /// ```
2011 /// #[macro_use]
2012 /// extern crate peroxide;
2013 /// use peroxide::fuga::*;
2014 /// use peroxide::complex::matrix::*;
2015 ///
2016 /// fn main() {
2017 /// // Non-singular
2018 /// let a = cmatrix(vec![
2019 /// C64::new(1f64, 1f64),
2020 /// C64::new(2f64, 2f64),
2021 /// C64::new(3f64, 3f64),
2022 /// C64::new(4f64, 4f64)
2023 /// ],
2024 /// 2, 2, Row
2025 /// );
2026 ///
2027 /// let a_inv_exp = cmatrix(vec![
2028 /// C64::new(-1.0f64, 1f64),
2029 /// C64::new(0.5f64, -0.5f64),
2030 /// C64::new(0.75f64, -0.75f64),
2031 /// C64::new(-0.25f64, 0.25f64)
2032 /// ],
2033 /// 2, 2, Row
2034 /// );
2035 /// assert_eq!(a.inv(), a_inv_exp);
2036 /// }
2037 /// ```
2038 fn inv(&self) -> Self {
2039 self.lu().inv()
2040 }
2041
2042 fn pseudo_inv(&self) -> ComplexMatrix {
2043 unimplemented!()
2044 }
2045
2046 /// Solve with Vector
2047 ///
2048 /// # Solve options
2049 ///
2050 /// * LU: Gaussian elimination with Complete pivoting LU (GECP)
2051 /// * WAZ: Solve with WAZ decomposition
2052 fn solve(&self, b: &[C64], sk: SolveKind) -> Vec<C64> {
2053 match sk {
2054 SolveKind::LU => {
2055 let lu = self.lu();
2056 let (p, q, l, u) = lu.extract();
2057 let mut v = b.to_vec();
2058 v.swap_with_perm(&p.into_iter().enumerate().collect());
2059 let z = l.forward_subs(&v);
2060 let mut y = u.back_subs(&z);
2061 y.swap_with_perm(&q.into_iter().enumerate().rev().collect());
2062 y
2063 }
2064 SolveKind::WAZ => {
2065 unimplemented!()
2066 }
2067 }
2068 }
2069
2070 fn solve_mat(&self, m: &ComplexMatrix, sk: SolveKind) -> ComplexMatrix {
2071 match sk {
2072 SolveKind::LU => {
2073 let lu = self.lu();
2074 let (p, q, l, u) = lu.extract();
2075 let mut x = cmatrix(
2076 vec![Complex::zero(); self.col * m.col],
2077 self.col,
2078 m.col,
2079 Shape::Col,
2080 );
2081 for i in 0..m.col {
2082 let mut v = m.col(i).clone();
2083 for (r, &s) in p.iter().enumerate() {
2084 v.swap(r, s);
2085 }
2086 let z = l.forward_subs(&v);
2087 let mut y = u.back_subs(&z);
2088 for (r, &s) in q.iter().enumerate() {
2089 y.swap(r, s);
2090 }
2091 unsafe {
2092 let mut c = x.col_mut(i);
2093 copy_vec_ptr(&mut c, &y);
2094 }
2095 }
2096 x
2097 }
2098 SolveKind::WAZ => {
2099 unimplemented!()
2100 }
2101 }
2102 }
2103
2104 fn is_symmetric(&self) -> bool {
2105 if self.row != self.col {
2106 return false;
2107 }
2108
2109 for i in 0..self.row {
2110 for j in i..self.col {
2111 if (!nearly_eq(self[(i, j)].re, self[(j, i)].re))
2112 && (!nearly_eq(self[(i, j)].im, self[(j, i)].im))
2113 {
2114 return false;
2115 }
2116 }
2117 }
2118 true
2119 }
2120}
2121
2122#[allow(non_snake_case)]
2123pub fn csolve(A: &ComplexMatrix, b: &ComplexMatrix, sk: SolveKind) -> ComplexMatrix {
2124 A.solve_mat(b, sk)
2125}
2126
2127impl MutMatrix for ComplexMatrix {
2128 type Scalar = C64;
2129
2130 unsafe fn col_mut(&mut self, idx: usize) -> Vec<*mut C64> {
2131 assert!(idx < self.col, "Index out of range");
2132 match self.shape {
2133 Shape::Col => {
2134 let mut v: Vec<*mut C64> = vec![&mut Complex::zero(); self.row];
2135 let start_idx = idx * self.row;
2136 let p = self.mut_ptr();
2137 for (i, j) in (start_idx..start_idx + v.len()).enumerate() {
2138 v[i] = p.add(j);
2139 }
2140 v
2141 }
2142 Shape::Row => {
2143 let mut v: Vec<*mut C64> = vec![&mut Complex::zero(); self.row];
2144 let p = self.mut_ptr();
2145 for i in 0..v.len() {
2146 v[i] = p.add(idx + i * self.col);
2147 }
2148 v
2149 }
2150 }
2151 }
2152
2153 unsafe fn row_mut(&mut self, idx: usize) -> Vec<*mut C64> {
2154 assert!(idx < self.row, "Index out of range");
2155 match self.shape {
2156 Shape::Row => {
2157 let mut v: Vec<*mut C64> = vec![&mut Complex::zero(); self.col];
2158 let start_idx = idx * self.col;
2159 let p = self.mut_ptr();
2160 for (i, j) in (start_idx..start_idx + v.len()).enumerate() {
2161 v[i] = p.add(j);
2162 }
2163 v
2164 }
2165 Shape::Col => {
2166 let mut v: Vec<*mut C64> = vec![&mut Complex::zero(); self.col];
2167 let p = self.mut_ptr();
2168 for i in 0..v.len() {
2169 v[i] = p.add(idx + i * self.row);
2170 }
2171 v
2172 }
2173 }
2174 }
2175
2176 unsafe fn swap(&mut self, idx1: usize, idx2: usize, shape: Shape) {
2177 match shape {
2178 Shape::Col => swap_vec_ptr(&mut self.col_mut(idx1), &mut self.col_mut(idx2)),
2179 Shape::Row => swap_vec_ptr(&mut self.row_mut(idx1), &mut self.row_mut(idx2)),
2180 }
2181 }
2182
2183 unsafe fn swap_with_perm(&mut self, p: &Vec<(usize, usize)>, shape: Shape) {
2184 for (i, j) in p.iter() {
2185 self.swap(*i, *j, shape)
2186 }
2187 }
2188}
2189
2190impl ExpLogOps for ComplexMatrix {
2191 type Float = C64;
2192
2193 fn exp(&self) -> Self {
2194 self.fmap(|x| x.exp())
2195 }
2196 fn ln(&self) -> Self {
2197 self.fmap(|x| x.ln())
2198 }
2199 fn log(&self, base: Self::Float) -> Self {
2200 self.fmap(|x| x.ln() / base.ln()) // Using `Log: change of base` formula
2201 }
2202 fn log2(&self) -> Self {
2203 self.fmap(|x| x.ln() / 2.0.ln()) // Using `Log: change of base` formula
2204 }
2205 fn log10(&self) -> Self {
2206 self.fmap(|x| x.ln() / 10.0.ln()) // Using `Log: change of base` formula
2207 }
2208}
2209
2210impl PowOps for ComplexMatrix {
2211 type Float = C64;
2212
2213 fn powi(&self, n: i32) -> Self {
2214 self.fmap(|x| x.powi(n))
2215 }
2216
2217 fn powf(&self, f: Self::Float) -> Self {
2218 self.fmap(|x| x.powc(f))
2219 }
2220
2221 fn pow(&self, _f: Self) -> Self {
2222 unimplemented!()
2223 }
2224
2225 fn sqrt(&self) -> Self {
2226 self.fmap(|x| x.sqrt())
2227 }
2228}
2229
2230impl TrigOps for ComplexMatrix {
2231 fn sin_cos(&self) -> (Self, Self) {
2232 let (sin, cos) = self.data.iter().map(|x| (x.sin(), x.cos())).unzip();
2233 (
2234 cmatrix(sin, self.row, self.col, self.shape),
2235 cmatrix(cos, self.row, self.col, self.shape),
2236 )
2237 }
2238
2239 fn sin(&self) -> Self {
2240 self.fmap(|x| x.sin())
2241 }
2242
2243 fn cos(&self) -> Self {
2244 self.fmap(|x| x.cos())
2245 }
2246
2247 fn tan(&self) -> Self {
2248 self.fmap(|x| x.tan())
2249 }
2250
2251 fn sinh(&self) -> Self {
2252 self.fmap(|x| x.sinh())
2253 }
2254
2255 fn cosh(&self) -> Self {
2256 self.fmap(|x| x.cosh())
2257 }
2258
2259 fn tanh(&self) -> Self {
2260 self.fmap(|x| x.tanh())
2261 }
2262
2263 fn asin(&self) -> Self {
2264 self.fmap(|x| x.asin())
2265 }
2266
2267 fn acos(&self) -> Self {
2268 self.fmap(|x| x.acos())
2269 }
2270
2271 fn atan(&self) -> Self {
2272 self.fmap(|x| x.atan())
2273 }
2274
2275 fn asinh(&self) -> Self {
2276 self.fmap(|x| x.asinh())
2277 }
2278
2279 fn acosh(&self) -> Self {
2280 self.fmap(|x| x.acosh())
2281 }
2282
2283 fn atanh(&self) -> Self {
2284 self.fmap(|x| x.atanh())
2285 }
2286}
2287
2288// =============================================================================
2289// Back-end Utils
2290// =============================================================================
2291/// Combine separated Complex Matrix to one Complex Matrix
2292///
2293/// # Examples
2294/// ```rust
2295/// use peroxide::fuga::*;
2296/// use peroxide::complex::matrix::*;
2297/// use peroxide::traits::fp::FPMatrix;
2298///
2299/// fn main() {
2300/// let x1 = cmatrix(vec![C64::new(1f64, 1f64)], 1, 1, Row);
2301/// let x2 = cmatrix(vec![C64::new(2f64, 2f64)], 1, 1, Row);
2302/// let x3 = cmatrix(vec![C64::new(3f64, 3f64)], 1, 1, Row);
2303/// let x4 = cmatrix(vec![C64::new(4f64, 4f64)], 1, 1, Row);
2304///
2305/// let y = complex_combine(x1, x2, x3, x4);
2306///
2307/// let y_exp = cmatrix(vec![C64::new(1f64, 1f64),
2308/// C64::new(2f64, 2f64),
2309/// C64::new(3f64, 3f64),
2310/// C64::new(4f64, 4f64)],
2311/// 2, 2, Row
2312/// );
2313///
2314/// assert_eq!(y, y_exp);
2315/// }
2316/// ```
2317pub fn complex_combine(
2318 m1: ComplexMatrix,
2319 m2: ComplexMatrix,
2320 m3: ComplexMatrix,
2321 m4: ComplexMatrix,
2322) -> ComplexMatrix {
2323 let l_r = m1.row;
2324 let l_c = m1.col;
2325 let c_l = m2.col;
2326 let r_l = m3.row;
2327
2328 let r = l_r + r_l;
2329 let c = l_c + c_l;
2330
2331 let mut m = cmatrix(vec![Complex::zero(); r * c], r, c, m1.shape);
2332
2333 for idx_row in 0..r {
2334 for idx_col in 0..c {
2335 match (idx_row, idx_col) {
2336 (i, j) if (i < l_r) && (j < l_c) => {
2337 m[(i, j)] = m1[(i, j)];
2338 }
2339 (i, j) if (i < l_r) && (j >= l_c) => {
2340 m[(i, j)] = m2[(i, j - l_c)];
2341 }
2342 (i, j) if (i >= l_r) && (j < l_c) => {
2343 m[(i, j)] = m3[(i - l_r, j)];
2344 }
2345 (i, j) if (i >= l_r) && (j >= l_c) => {
2346 m[(i, j)] = m4[(i - l_r, j - l_c)];
2347 }
2348 _ => (),
2349 }
2350 }
2351 }
2352 m
2353}
2354
2355/// Inverse of Lower matrix
2356///
2357/// # Examples
2358/// ```rust
2359/// #[macro_use]
2360/// extern crate peroxide;
2361/// use peroxide::fuga::*;
2362/// use peroxide::complex::matrix::*;
2363///
2364/// fn main() {
2365/// let a = ml_cmatrix("2.0+2.0i 0.0+0.0i;
2366/// 2.0+2.0i 1.0+1.0i");
2367/// let b = cmatrix(vec![C64::new(2f64, 2f64),
2368/// C64::new(0f64, 0f64),
2369/// C64::new(-2f64, -2f64),
2370/// C64::new(1f64, 1f64)],
2371/// 2, 2, Row
2372/// );
2373/// assert_eq!(complex_inv_l(a), b);
2374/// }
2375/// ```
2376pub fn complex_inv_l(l: ComplexMatrix) -> ComplexMatrix {
2377 let mut m = l.clone();
2378
2379 match l.row {
2380 1 => l,
2381 2 => {
2382 m[(1, 0)] = -m[(1, 0)];
2383 m
2384 }
2385 _ => {
2386 let (l1, l2, l3, l4) = l.block();
2387
2388 let m1 = complex_inv_l(l1);
2389 let m2 = l2;
2390 let m4 = complex_inv_l(l4);
2391 let m3 = -(&(&m4 * &l3) * &m1);
2392
2393 complex_combine(m1, m2, m3, m4)
2394 }
2395 }
2396}
2397
2398/// Inverse of upper triangular matrix
2399///
2400/// # Examples
2401/// ```rust
2402/// #[macro_use]
2403/// extern crate peroxide;
2404/// use peroxide::fuga::*;
2405/// use peroxide::complex::matrix::*;
2406///
2407/// fn main() {
2408/// let a = ml_cmatrix("2.0+2.0i 2.0+2.0i;
2409/// 0.0+0.0i 1.0+1.0i");
2410/// let b = cmatrix(vec![C64::new(0.25f64, -0.25f64),
2411/// C64::new(-0.5f64, 0.5f64),
2412/// C64::new(0.0f64, 0.0f64),
2413/// C64::new(0.5f64, -0.5f64)],
2414/// 2, 2, Row
2415/// );
2416/// assert_eq!(complex_inv_u(a), b);
2417/// }
2418/// ```
2419pub fn complex_inv_u(u: ComplexMatrix) -> ComplexMatrix {
2420 let mut w = u.clone();
2421
2422 match u.row {
2423 1 => {
2424 w[(0, 0)] = 1f64 / w[(0, 0)];
2425 w
2426 }
2427 2 => {
2428 let a = w[(0, 0)];
2429 let b = w[(0, 1)];
2430 let c = w[(1, 1)];
2431 let d = a * c;
2432
2433 w[(0, 0)] = 1f64 / a;
2434 w[(0, 1)] = -b / d;
2435 w[(1, 1)] = 1f64 / c;
2436 w
2437 }
2438 _ => {
2439 let (u1, u2, u3, u4) = u.block();
2440 let m1 = complex_inv_u(u1);
2441 let m3 = u3;
2442 let m4 = complex_inv_u(u4);
2443 let m2 = -(m1.clone() * u2 * m4.clone());
2444
2445 complex_combine(m1, m2, m3, m4)
2446 }
2447 }
2448}
2449
2450/// Matrix multiply back-ends
2451pub fn cmatmul(a: &ComplexMatrix, b: &ComplexMatrix) -> ComplexMatrix {
2452 assert_eq!(a.col, b.row);
2453 let mut c = cmatrix(vec![Complex::zero(); a.row * b.col], a.row, b.col, a.shape);
2454 cgemm(Complex::one(), a, b, Complex::zero(), &mut c);
2455 c
2456}
2457
2458/// GEMM wrapper for Matrixmultiply
2459///
2460/// # Examples
2461/// ```rust
2462/// #[macro_use]
2463/// extern crate peroxide;
2464/// use peroxide::fuga::*;
2465///
2466/// use peroxide::complex::matrix::*;
2467///
2468/// fn main() {
2469/// let a = ml_cmatrix("1.0+1.0i 2.0+2.0i;
2470/// 0.0+0.0i 1.0+1.0i");
2471/// let b = ml_cmatrix("1.0+1.0i 0.0+0.0i;
2472/// 2.0+2.0i 1.0+1.0i");
2473/// let mut c1 = ml_cmatrix("1.0+1.0i 1.0+1.0i;
2474/// 1.0+1.0i 1.0+1.0i");
2475/// let mul_val = ml_cmatrix("-10.0+10.0i -4.0+4.0i;
2476/// -4.0+4.0i -2.0+2.0i");
2477///
2478/// cgemm(C64::new(1.0, 1.0), &a, &b, C64::new(0.0, 0.0), &mut c1);
2479/// assert_eq!(c1, mul_val);
2480/// }
2481pub fn cgemm(alpha: C64, a: &ComplexMatrix, b: &ComplexMatrix, beta: C64, c: &mut ComplexMatrix) {
2482 let m = a.row;
2483 let k = a.col;
2484 let n = b.col;
2485 let (rsa, csa) = match a.shape {
2486 Shape::Row => (a.col as isize, 1isize),
2487 Shape::Col => (1isize, a.row as isize),
2488 };
2489 let (rsb, csb) = match b.shape {
2490 Shape::Row => (b.col as isize, 1isize),
2491 Shape::Col => (1isize, b.row as isize),
2492 };
2493 let (rsc, csc) = match c.shape {
2494 Shape::Row => (c.col as isize, 1isize),
2495 Shape::Col => (1isize, c.row as isize),
2496 };
2497
2498 unsafe {
2499 matrixmultiply::zgemm(
2500 // Requires crate feature "cgemm"
2501 CGemmOption::Standard,
2502 CGemmOption::Standard,
2503 m,
2504 k,
2505 n,
2506 [alpha.re, alpha.im],
2507 a.ptr() as *const _,
2508 rsa,
2509 csa,
2510 b.ptr() as *const _,
2511 rsb,
2512 csb,
2513 [beta.re, beta.im],
2514 c.mut_ptr() as *mut _,
2515 rsc,
2516 csc,
2517 )
2518 }
2519}
2520
2521/// General Matrix-Vector multiplication
2522pub fn cgemv(alpha: C64, a: &ComplexMatrix, b: &Vec<C64>, beta: C64, c: &mut Vec<C64>) {
2523 let m = a.row;
2524 let k = a.col;
2525 let n = 1usize;
2526 let (rsa, csa) = match a.shape {
2527 Shape::Row => (a.col as isize, 1isize),
2528 Shape::Col => (1isize, a.row as isize),
2529 };
2530 let (rsb, csb) = (1isize, 1isize);
2531 let (rsc, csc) = (1isize, 1isize);
2532
2533 unsafe {
2534 matrixmultiply::zgemm(
2535 // Requires crate feature "cgemm"
2536 CGemmOption::Standard,
2537 CGemmOption::Standard,
2538 m,
2539 k,
2540 n,
2541 [alpha.re, alpha.im],
2542 a.ptr() as *const _,
2543 rsa,
2544 csa,
2545 b.as_ptr() as *const _,
2546 rsb,
2547 csb,
2548 [beta.re, beta.im],
2549 c.as_mut_ptr() as *mut _,
2550 rsc,
2551 csc,
2552 )
2553 }
2554}
2555
2556/// General Vector-Matrix multiplication
2557pub fn complex_gevm(alpha: C64, a: &Vec<C64>, b: &ComplexMatrix, beta: C64, c: &mut Vec<C64>) {
2558 let m = 1usize;
2559 let k = a.len();
2560 let n = b.col;
2561 let (rsa, csa) = (1isize, 1isize);
2562 let (rsb, csb) = match b.shape {
2563 Shape::Row => (b.col as isize, 1isize),
2564 Shape::Col => (1isize, b.row as isize),
2565 };
2566 let (rsc, csc) = (1isize, 1isize);
2567
2568 unsafe {
2569 matrixmultiply::zgemm(
2570 // Requires crate feature "cgemm"
2571 CGemmOption::Standard,
2572 CGemmOption::Standard,
2573 m,
2574 k,
2575 n,
2576 [alpha.re, alpha.im],
2577 a.as_ptr() as *const _,
2578 rsa,
2579 csa,
2580 b.ptr() as *const _,
2581 rsb,
2582 csb,
2583 [beta.re, beta.im],
2584 c.as_mut_ptr() as *mut _,
2585 rsc,
2586 csc,
2587 )
2588 }
2589}
2590
2591/// LU via Gaussian Elimination with Partial Pivoting
2592#[allow(dead_code)]
2593fn gepp(m: &mut ComplexMatrix) -> Vec<usize> {
2594 let mut r = vec![0usize; m.col - 1];
2595 for k in 0..(m.col - 1) {
2596 // Find the pivot row
2597 let r_k = m
2598 .col(k)
2599 .into_iter()
2600 .skip(k)
2601 .enumerate()
2602 .max_by(|x1, x2| x1.1.norm().partial_cmp(&x2.1.norm()).unwrap())
2603 .unwrap()
2604 .0
2605 + k;
2606 r[k] = r_k;
2607
2608 // Interchange the rows r_k and k
2609 for j in k..m.col {
2610 unsafe {
2611 std::ptr::swap(&mut m[(k, j)], &mut m[(r_k, j)]);
2612 println!("Swap! k:{}, r_k:{}", k, r_k);
2613 }
2614 }
2615 // Form the multipliers
2616 for i in k + 1..m.col {
2617 m[(i, k)] = -m[(i, k)] / m[(k, k)];
2618 }
2619 // Update the entries
2620 for i in k + 1..m.col {
2621 for j in k + 1..m.col {
2622 let local_m = m[(i, k)] * m[(k, j)];
2623 m[(i, j)] += local_m;
2624 }
2625 }
2626 }
2627 r
2628}
2629
2630/// LU via Gauss Elimination with Complete Pivoting
2631fn gecp(m: &mut ComplexMatrix) -> (Vec<usize>, Vec<usize>) {
2632 let n = m.col;
2633 let mut r = vec![0usize; n - 1];
2634 let mut s = vec![0usize; n - 1];
2635 for k in 0..n - 1 {
2636 // Find pivot
2637 let (r_k, s_k) = match m.shape {
2638 Shape::Col => {
2639 let mut row_ics = 0usize;
2640 let mut col_ics = 0usize;
2641 let mut max_val = 0f64;
2642 for i in k..n {
2643 let c = m
2644 .col(i)
2645 .into_iter()
2646 .skip(k)
2647 .enumerate()
2648 .max_by(|x1, x2| x1.1.norm().partial_cmp(&x2.1.norm()).unwrap())
2649 .unwrap();
2650 let c_ics = c.0 + k;
2651 let c_val = c.1.norm();
2652 if c_val > max_val {
2653 row_ics = c_ics;
2654 col_ics = i;
2655 max_val = c_val;
2656 }
2657 }
2658 (row_ics, col_ics)
2659 }
2660 Shape::Row => {
2661 let mut row_ics = 0usize;
2662 let mut col_ics = 0usize;
2663 let mut max_val = 0f64;
2664 for i in k..n {
2665 let c = m
2666 .row(i)
2667 .into_iter()
2668 .skip(k)
2669 .enumerate()
2670 .max_by(|x1, x2| x1.1.norm().partial_cmp(&x2.1.norm()).unwrap())
2671 .unwrap();
2672 let c_ics = c.0 + k;
2673 let c_val = c.1.norm();
2674 if c_val > max_val {
2675 col_ics = c_ics;
2676 row_ics = i;
2677 max_val = c_val;
2678 }
2679 }
2680 (row_ics, col_ics)
2681 }
2682 };
2683 r[k] = r_k;
2684 s[k] = s_k;
2685
2686 // Interchange rows
2687 for j in k..n {
2688 unsafe {
2689 std::ptr::swap(&mut m[(k, j)], &mut m[(r_k, j)]);
2690 }
2691 }
2692
2693 // Interchange cols
2694 for i in 0..n {
2695 unsafe {
2696 std::ptr::swap(&mut m[(i, k)], &mut m[(i, s_k)]);
2697 }
2698 }
2699
2700 // Form the multipliers
2701 for i in k + 1..n {
2702 m[(i, k)] = -m[(i, k)] / m[(k, k)];
2703 for j in k + 1..n {
2704 let local_m = m[(i, k)] * m[(k, j)];
2705 m[(i, j)] += local_m;
2706 }
2707 }
2708 }
2709 (r, s)
2710}