use crate::structure::ad::AD;
use crate::structure::matrix::{Matrix, Shape};
use crate::structure::sparse::SPMatrix;
use crate::traits::{
fp::FPVector,
math::{LinearOp, Vector},
matrix::MatrixTrait,
};
use std::ops::{Add, Deref, Div, Mul, Sub};
#[derive(Debug)]
pub struct Redox<T: Vector> {
data: Box<T>,
}
impl<T: Vector> Deref for Redox<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.data
}
}
pub trait RedoxCommon {
type ToRedox;
fn from_vec(vec: Self::ToRedox) -> Self;
fn red(self) -> Self::ToRedox;
}
impl RedoxCommon for Redox<Vec<f64>> {
type ToRedox = Vec<f64>;
fn from_vec(vec: Self::ToRedox) -> Self {
Self {
data: Box::new(vec),
}
}
fn red(self) -> Self::ToRedox {
(*self).to_vec()
}
}
impl RedoxCommon for Redox<Vec<AD>> {
type ToRedox = Vec<AD>;
fn from_vec(vec: Self::ToRedox) -> Self {
Self {
data: Box::new(vec),
}
}
fn red(self) -> Self::ToRedox {
(*self).to_vec()
}
}
pub trait Oxide: Vector {
fn ox(self) -> Redox<Self>
where
Self: Sized;
}
impl<T: Vector> Add<Redox<T>> for Redox<T> {
type Output = Self;
fn add(self, rhs: Redox<T>) -> Self::Output {
Redox {
data: Box::new(self.add_vec(&rhs.data)),
}
}
}
impl<T: Vector + FPVector> Sub<Redox<T>> for Redox<T>
where
<T as FPVector>::Scalar: Sub<Output = <T as FPVector>::Scalar>,
{
type Output = Self;
fn sub(self, rhs: Redox<T>) -> Self::Output {
Redox {
data: Box::new(self.zip_with(|x, y| x - y, &rhs.data)),
}
}
}
impl<T: Vector + FPVector> Mul<Redox<T>> for Redox<T>
where
<T as FPVector>::Scalar: Mul<Output = <T as FPVector>::Scalar>,
{
type Output = Self;
fn mul(self, rhs: Redox<T>) -> Self::Output {
Redox {
data: Box::new(self.zip_with(|x, y| x * y, &rhs.data)),
}
}
}
impl<T: Vector + FPVector> Div<Redox<T>> for Redox<T>
where
<T as FPVector>::Scalar: Div<Output = <T as FPVector>::Scalar>,
{
type Output = Self;
fn div(self, rhs: Redox<T>) -> Self::Output {
Redox {
data: Box::new(self.zip_with(|x, y| x / y, &rhs.data)),
}
}
}
impl<T: Vector + FPVector> Add<f64> for Redox<T>
where
<T as FPVector>::Scalar: Add<f64, Output = <T as FPVector>::Scalar>,
{
type Output = Self;
fn add(self, rhs: f64) -> Self::Output {
Redox {
data: Box::new(self.fmap(|x| x + rhs)),
}
}
}
impl<T: Vector + FPVector> Sub<f64> for Redox<T>
where
<T as FPVector>::Scalar: Sub<f64, Output = <T as FPVector>::Scalar>,
{
type Output = Self;
fn sub(self, rhs: f64) -> Self::Output {
Redox {
data: Box::new(self.fmap(|x| x - rhs)),
}
}
}
impl<T: Vector + FPVector> Mul<f64> for Redox<T>
where
<T as FPVector>::Scalar: Mul<f64, Output = <T as FPVector>::Scalar>,
{
type Output = Self;
fn mul(self, rhs: f64) -> Self::Output {
Redox {
data: Box::new(self.fmap(|x| x * rhs)),
}
}
}
impl<T: Vector + FPVector> Div<f64> for Redox<T>
where
<T as FPVector>::Scalar: Div<f64, Output = <T as FPVector>::Scalar>,
{
type Output = Self;
fn div(self, rhs: f64) -> Self::Output {
Redox {
data: Box::new(self.fmap(|x| x / rhs)),
}
}
}
impl Mul<Redox<Vec<f64>>> for Matrix {
type Output = Redox<Vec<f64>>;
fn mul(self, rhs: Redox<Vec<f64>>) -> Self::Output {
Redox {
data: Box::new(self.apply(&*rhs)),
}
}
}
impl Mul<Redox<Vec<f64>>> for &Matrix {
type Output = Redox<Vec<f64>>;
fn mul(self, rhs: Redox<Vec<f64>>) -> Self::Output {
Redox {
data: Box::new(self.apply(&*rhs)),
}
}
}
impl Mul<Redox<Vec<f64>>> for SPMatrix {
type Output = Redox<Vec<f64>>;
fn mul(self, rhs: Redox<Vec<f64>>) -> Self::Output {
Redox {
data: Box::new(self.apply(&rhs.data)),
}
}
}
impl Mul<Redox<Vec<f64>>> for &SPMatrix {
type Output = Redox<Vec<f64>>;
fn mul(self, rhs: Redox<Vec<f64>>) -> Self::Output {
Redox {
data: Box::new(self.apply(&rhs.data)),
}
}
}
pub trait MatrixPtr {
unsafe fn row_ptr(&self, idx: usize) -> Vec<*const f64>;
unsafe fn col_ptr(&self, idx: usize) -> Vec<*const f64>;
}
impl MatrixPtr for Matrix {
unsafe fn row_ptr(&self, idx: usize) -> Vec<*const f64> {
assert!(idx < self.col, "Index out of range");
match self.shape {
Shape::Row => {
let mut v: Vec<*const f64> = vec![&0f64; self.col];
let start_idx = idx * self.col;
let p = self.ptr();
for (i, j) in (start_idx..start_idx + v.len()).enumerate() {
v[i] = p.add(j);
}
v
}
Shape::Col => {
let mut v: Vec<*const f64> = vec![&0f64; self.col];
let p = self.ptr();
for (i, elem) in v.iter_mut().enumerate() {
*elem = p.add(idx + i * self.row);
}
v
}
}
}
unsafe fn col_ptr(&self, idx: usize) -> Vec<*const f64> {
assert!(idx < self.col, "Index out of range");
match self.shape {
Shape::Col => {
let mut v: Vec<*const f64> = vec![&0f64; self.row];
let start_idx = idx * self.row;
let p = self.ptr();
for (i, j) in (start_idx..start_idx + v.len()).enumerate() {
v[i] = p.add(j);
}
v
}
Shape::Row => {
let mut v: Vec<*const f64> = vec![&0f64; self.row];
let p = self.ptr();
for (i, elem) in v.iter_mut().enumerate() {
*elem = p.add(idx + i * self.col);
}
v
}
}
}
}