peroxide/structure/
ad.rs

1//! Taylor mode forward automatic differentiation with const-generic `Jet<N>` type
2//!
3//! ## Overview
4//!
5//! This module provides a const-generic `Jet<N>` struct for Taylor-mode forward AD
6//! of arbitrary order N. The struct stores **normalized** Taylor coefficients:
7//!
8//! ```text
9//! Jet { value: c_0, deriv: [c_1, c_2, ..., c_N] }
10//! ```
11//!
12//! where $c_k = f^{(k)}(a) / k!$ is the $k$-th normalized Taylor coefficient evaluated
13//! at the expansion point $a$. This normalization eliminates binomial coefficients
14//! from all arithmetic recurrences.
15//!
16//! ## Type Aliases
17//!
18//! * `Dual = Jet<1>` — first-order forward AD (value + first derivative)
19//! * `HyperDual = Jet<2>` — second-order forward AD (value + first + second derivative)
20//!
21//! ## Constructors
22//!
23//! * `Jet::var(x)` — independent variable at point x (deriv\[0\] = 1)
24//! * `Jet::constant(x)` — constant (all derivatives zero)
25//! * `Jet::new(value, deriv)` — raw constructor
26//! * `ad0(x)` — `Jet<0>` constant (backward compat)
27//! * `ad1(x, dx)` — `Jet<1>` with first derivative (backward compat)
28//! * `ad2(x, dx, ddx)` — `Jet<2>` with first and second derivatives (backward compat)
29//!
30//! ## Accessors
31//!
32//! * `.value()` / `.x()` — $f(a)$
33//! * `.dx()` — $f'(a)$
34//! * `.ddx()` — $f''(a)$
35//! * `.derivative(k)` — $f^{(k)}(a)$ (raw factorial-scaled derivative)
36//! * `.taylor_coeff(k)` — normalized Taylor coefficient $c_k$
37//!
38//! ## Implemented Operations
39//!
40//! * `Add, Sub, Mul, Div` (Jet op Jet, Jet op f64, f64 op Jet)
41//! * `Neg`
42//! * `ExpLogOps`: `exp`, `ln`, `log`, `log2`, `log10`
43//! * `PowOps`: `powi`, `powf`, `pow`, `sqrt`
44//! * `TrigOps`: `sin_cos`, `sin`, `cos`, `tan`, `sinh`, `cosh`, `tanh`,
45//!              `asin`, `acos`, `atan`, `asinh`, `acosh`, `atanh`
46//!
47//! ## Usage
48//!
49//! ```
50//! extern crate peroxide;
51//! use peroxide::fuga::*;
52//!
53//! fn main() {
54//!     // First derivative of f(x) = x^2 at x = 2
55//!     let x = Jet::<1>::var(2.0);
56//!     let y = x.powi(2);
57//!     assert_eq!(y.value(), 4.0);
58//!     assert_eq!(y.dx(), 4.0);  // f'(2) = 2*2 = 4
59//!
60//!     // Second derivative using HyperDual
61//!     let x2 = HyperDual::new(2.0, [1.0, 0.0]);
62//!     let y2 = x2.powi(2);
63//!     assert_eq!(y2.value(), 4.0);
64//!     assert_eq!(y2.dx(), 4.0);   // f'(2) = 4
65//!     assert_eq!(y2.ddx(), 2.0);  // f''(2) = 2
66//! }
67//! ```
68//!
69//! ### Higher-order derivatives
70//!
71//! ```
72//! extern crate peroxide;
73//! use peroxide::fuga::*;
74//!
75//! fn main() {
76//!     // 5th derivative of x^5 at x = 1
77//!     let x = Jet::<5>::var(1.0);
78//!     let y = x.powi(5);
79//!     assert_eq!(y.derivative(5), 120.0);  // 5! = 120
80//! }
81//! ```
82//!
83//! ### Using the `#[ad_function]` macro
84//!
85//! ```
86//! extern crate peroxide;
87//! use peroxide::fuga::*;
88//!
89//! #[ad_function]
90//! fn f(x: f64) -> f64 {
91//!     x.sin() + x.powi(2)
92//! }
93//!
94//! fn main() {
95//!     // f_grad and f_hess are generated automatically
96//!     let grad = f_grad(1.0);   // f'(1) = cos(1) + 2
97//!     let hess = f_hess(1.0);   // f''(1) = -sin(1) + 2
98//!
99//!     assert!((grad - (1.0_f64.cos() + 2.0)).abs() < 1e-10);
100//!     assert!((hess - (-1.0_f64.sin() + 2.0)).abs() < 1e-10);
101//! }
102//! ```
103//!
104//! ### Generic functions with `Real` trait
105//!
106//! ```
107//! extern crate peroxide;
108//! use peroxide::fuga::*;
109//!
110//! fn quadratic<T: Real>(x: T) -> T {
111//!     x.powi(2) + x * 3.0 + T::from_f64(1.0)
112//! }
113//!
114//! fn main() {
115//!     // Works with both f64 and AD (= Jet<2>)
116//!     let val = quadratic(2.0_f64);                // 11.0
117//!     let jet = quadratic(AD1(2.0, 1.0));
118//!     assert_eq!(val, 11.0);                       // f(2) = 4 + 6 + 1
119//!     assert_eq!(jet.value(), 11.0);               // f(2) = 11
120//!     assert_eq!(jet.dx(), 7.0);                   // f'(2) = 2*2 + 3
121//! }
122//! ```
123//!
124//! ### Jacobian computation
125//!
126//! ```
127//! extern crate peroxide;
128//! use peroxide::fuga::*;
129//!
130//! fn main() {
131//!     // Jacobian of f(x,y) = [x - y, x + 2*y] at (1, 1)
132//!     let x = vec![1.0, 1.0];
133//!     let j = jacobian(f, &x);
134//!     j.print();
135//!     //       c[0] c[1]
136//!     // r[0]     1   -1
137//!     // r[1]     1    2
138//! }
139//!
140//! fn f(xs: &Vec<AD>) -> Vec<AD> {
141//!     let x = xs[0];
142//!     let y = xs[1];
143//!     vec![x - y, x + 2.0 * y]
144//! }
145//! ```
146//!
147//! ### Backward-compatible constructors
148//!
149//! ```
150//! extern crate peroxide;
151//! use peroxide::fuga::*;
152//!
153//! fn main() {
154//!     // These work just like the old AD1/AD2 constructors
155//!     let a = AD1(2.0, 1.0);   // value=2, f'=1
156//!     let b = AD2(4.0, 4.0, 2.0);  // x^2 at x=2
157//!
158//!     assert_eq!(a.x(), 2.0);
159//!     assert_eq!(b.dx(), 4.0);
160//!     assert_eq!(b.ddx(), 2.0);
161//!
162//!     // New constructors (equivalent)
163//!     let c = Jet::<1>::var(2.0);  // Same as Dual var at x=2
164//!     assert_eq!(c.dx(), 1.0);     // dx/dx = 1 for independent variable
165//! }
166//! ```
167//!
168//! ## Accuracy: Jet\<N\> vs Finite Differences
169//!
170//! `Jet<N>` computes derivatives to **machine precision** because it propagates
171//! exact Taylor coefficients through the computation graph. In contrast, finite
172//! difference methods suffer from both truncation and cancellation errors that
173//! worsen rapidly at higher derivative orders.
174//!
175//! The plot below compares the relative error of `Jet<N>` against central finite
176//! differences ($h = 10^{-4}$) for $f(x) = \sin(x)$ at $x = 1.0$, across derivative orders 1–8:
177//!
178//! ![Derivative Accuracy](https://raw.githubusercontent.com/Axect/Peroxide/master/example_data/derivative_accuracy.png)
179//!
180//! `Jet<N>` (blue) stays at $\sim 10^{-15}$ (machine epsilon) for all orders,
181//! while finite differences (green) degrade from $\sim 10^{-9}$ at order 1 to $> 10^{0}$ at order 4.
182//!
183//! ## Taylor Series Convergence
184//!
185//! Since `Jet<N>` stores normalized Taylor coefficients $c_k = f^{(k)}(a)/k!$,
186//! you can directly reconstruct the Taylor polynomial of any function:
187//!
188//! $$T_N(x) = c_0 + c_1 (x-a) + c_2 (x-a)^2 + \cdots + c_N (x-a)^N$$
189//!
190//! The plot below shows the Taylor polynomial of $\sin(x)$ around $x = 0$ for
191//! increasing truncation orders $N = 1, 3, 5, 7, 9$:
192//!
193//! ![Taylor Convergence](https://raw.githubusercontent.com/Axect/Peroxide/master/example_data/taylor_convergence.png)
194//!
195//! As $N$ increases, the Taylor polynomial converges to the exact $\sin(x)$ curve
196//! over a wider interval.
197
198use crate::traits::{fp::FPVector, math::Vector, stable::StableFn, sugar::VecOps};
199use peroxide_num::{ExpLogOps, PowOps, TrigOps};
200use std::ops::{Add, Div, Index, IndexMut, Mul, Neg, Sub};
201
202// =============================================================================
203// Jet struct
204// =============================================================================
205
206/// Const-generic Taylor-mode forward AD type.
207///
208/// Stores the value and $N$ normalized Taylor coefficients:
209/// - `value` = $f(a) = c_0$
210/// - `deriv[k]` = $f^{(k+1)}(a) / (k+1)! = c_{k+1}$
211///
212/// So `Jet<1>` stores $(c_0, c_1) = (f(a),\, f'(a))$,
213/// and `Jet<2>` stores $(c_0, c_1, c_2) = (f(a),\, f'(a),\, f''(a)/2)$.
214#[derive(Debug, Clone, Copy, PartialEq)]
215pub struct Jet<const N: usize> {
216    value: f64,
217    deriv: [f64; N],
218}
219
220impl<const N: usize> Jet<N> {
221    /// Create a `Jet` from raw value and normalized Taylor coefficient array.
222    pub fn new(value: f64, deriv: [f64; N]) -> Self {
223        Self { value, deriv }
224    }
225
226    /// Create an independent variable jet at point `x`.
227    /// Sets `deriv[0] = 1.0` (the 1st normalized coefficient), rest zero.
228    ///
229    /// # Examples
230    /// ```
231    /// use peroxide::fuga::*;
232    ///
233    /// let x = Jet::<2>::var(3.0);
234    /// assert_eq!(x.value(), 3.0);
235    /// assert_eq!(x.dx(), 1.0);    // dx/dx = 1
236    /// assert_eq!(x.ddx(), 0.0);   // d²x/dx² = 0
237    /// ```
238    pub fn var(x: f64) -> Self {
239        let mut deriv = [0.0f64; N];
240        if N >= 1 {
241            deriv[0] = 1.0;
242        }
243        Self { value: x, deriv }
244    }
245
246    /// Create a constant jet (all derivatives zero).
247    pub fn constant(x: f64) -> Self {
248        Self {
249            value: x,
250            deriv: [0.0f64; N],
251        }
252    }
253
254    /// The function value $f(a)$.
255    #[inline]
256    pub fn value(&self) -> f64 {
257        self.value
258    }
259
260    /// Alias for `value()` — backward compatibility.
261    #[inline]
262    pub fn x(&self) -> f64 {
263        self.value
264    }
265
266    /// First derivative $f'(a)$.
267    ///
268    /// # Examples
269    /// ```
270    /// use peroxide::fuga::*;
271    ///
272    /// let x = Jet::<1>::var(2.0);
273    /// let y = x.powi(3);    // x^3
274    /// assert_eq!(y.dx(), 12.0);  // 3*x^2 = 3*4 = 12
275    /// ```
276    #[inline]
277    pub fn dx(&self) -> f64 {
278        if N >= 1 {
279            self.deriv[0]
280        } else {
281            0.0
282        }
283    }
284
285    /// Second derivative $f''(a)$.
286    ///
287    /// # Examples
288    /// ```
289    /// use peroxide::fuga::*;
290    ///
291    /// let x = Jet::<2>::var(2.0);
292    /// let y = x.powi(3);     // x^3
293    /// assert_eq!(y.ddx(), 12.0);  // 6*x = 6*2 = 12
294    /// ```
295    #[inline]
296    pub fn ddx(&self) -> f64 {
297        if N >= 2 {
298            self.deriv[1] * 2.0
299        } else {
300            0.0
301        }
302    }
303
304    /// Returns $f^{(\mathrm{order})}(a)$, the raw (factorial-scaled) derivative of given order.
305    /// - order = 0: $f(a)$
306    /// - order = 1: $f'(a)$ = `deriv[0]`
307    /// - order = k: $f^{(k)}(a)$ = `deriv[k-1]` $\times\, k!$
308    ///
309    /// Internally computes `taylor_coeff(k)` $\times\, k!$.
310    ///
311    /// # Examples
312    /// ```
313    /// use peroxide::fuga::*;
314    ///
315    /// let x = Jet::<3>::var(0.0);
316    /// let y = x.exp();
317    /// // All derivatives of exp at 0 are 1
318    /// assert!((y.derivative(0) - 1.0).abs() < 1e-15);
319    /// assert!((y.derivative(1) - 1.0).abs() < 1e-15);
320    /// assert!((y.derivative(2) - 1.0).abs() < 1e-15);
321    /// assert!((y.derivative(3) - 1.0).abs() < 1e-15);
322    /// ```
323    pub fn derivative(&self, order: usize) -> f64 {
324        if order == 0 {
325            self.value
326        } else if order <= N {
327            self.deriv[order - 1] * factorial(order) as f64
328        } else {
329            0.0
330        }
331    }
332
333    /// Returns the $k$-th normalized Taylor coefficient $c_k = f^{(k)}(a) / k!$.
334    /// - $k = 0$: $c_0 = f(a)$
335    /// - $k \ge 1$: $c_k$ = `deriv[k-1]`
336    pub fn taylor_coeff(&self, k: usize) -> f64 {
337        self.coeff(k)
338    }
339
340    /// Internal: get the $k$-th Taylor coefficient $c_k$.
341    #[inline]
342    fn coeff(&self, k: usize) -> f64 {
343        if k == 0 {
344            self.value
345        } else if k <= N {
346            self.deriv[k - 1]
347        } else {
348            0.0
349        }
350    }
351
352    /// Internal: set the $k$-th Taylor coefficient $c_k$.
353    #[inline]
354    fn set_coeff(&mut self, k: usize, v: f64) {
355        if k == 0 {
356            self.value = v;
357        } else if k <= N {
358            self.deriv[k - 1] = v;
359        }
360    }
361
362    /// Internal: create a zero jet.
363    #[inline]
364    fn zero() -> Self {
365        Self {
366            value: 0.0,
367            deriv: [0.0f64; N],
368        }
369    }
370}
371
372// =============================================================================
373// Type aliases
374// =============================================================================
375
376/// First-order forward AD: stores value and first derivative.
377pub type Dual = Jet<1>;
378
379/// Second-order forward AD: stores value, first derivative, and second derivative $/\, 2!$.
380pub type HyperDual = Jet<2>;
381
382// =============================================================================
383// Compatibility constructors
384// =============================================================================
385
386/// Create a `Jet<0>` constant (zero-order, value only).
387#[inline]
388pub fn ad0(x: f64) -> Jet<0> {
389    Jet { value: x, deriv: [] }
390}
391
392/// Create a `Jet<1>` with value and first derivative.
393/// `dx` is the raw first derivative $f'(a)$; stored as `deriv[0]` $= dx / 1! = dx$.
394///
395/// # Arguments
396/// * `x` - function value $f(a)$
397/// * `dx` - first derivative $f'(a)$
398///
399/// # Examples
400/// ```
401/// use peroxide::fuga::*;
402///
403/// let j = ad1(2.0, 1.0);  // variable x at x=2
404/// assert_eq!(j.value(), 2.0);
405/// assert_eq!(j.dx(), 1.0);
406/// ```
407#[inline]
408pub fn ad1(x: f64, dx: f64) -> Jet<1> {
409    Jet {
410        value: x,
411        deriv: [dx],
412    }
413}
414
415/// Create a `Jet<2>` with value, first derivative, and second derivative.
416/// `ddx` is the raw second derivative $f''(a)$; stored internally as `deriv[1]` $= f''(a) / 2!$.
417///
418/// # Arguments
419/// * `x` - function value $f(a)$
420/// * `dx` - first derivative $f'(a)$
421/// * `ddx` - second derivative $f''(a)$
422///
423/// # Examples
424/// ```
425/// use peroxide::fuga::*;
426///
427/// // Represent x^2 at x=2: f=4, f'=4, f''=2
428/// let j = ad2(4.0, 4.0, 2.0);
429/// assert_eq!(j.value(), 4.0);
430/// assert_eq!(j.dx(), 4.0);
431/// assert_eq!(j.ddx(), 2.0);
432/// ```
433#[inline]
434pub fn ad2(x: f64, dx: f64, ddx: f64) -> Jet<2> {
435    Jet {
436        value: x,
437        deriv: [dx, ddx / 2.0],
438    }
439}
440
441// =============================================================================
442// Helper: factorial
443// =============================================================================
444
445#[inline]
446fn factorial(n: usize) -> u64 {
447    let mut result = 1u64;
448    for i in 2..=(n as u64) {
449        result *= i;
450    }
451    result
452}
453
454// =============================================================================
455// Display
456// =============================================================================
457
458impl<const N: usize> std::fmt::Display for Jet<N> {
459    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
460        write!(f, "Jet({}", self.value)?;
461        if N > 0 {
462            write!(f, "; ")?;
463            for (i, d) in self.deriv.iter().enumerate() {
464                if i > 0 {
465                    write!(f, ", ")?;
466                }
467                write!(f, "{}", d)?;
468            }
469        }
470        write!(f, ")")
471    }
472}
473
474// =============================================================================
475// PartialOrd (compare by value only)
476// =============================================================================
477
478impl<const N: usize> PartialOrd for Jet<N> {
479    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
480        self.value.partial_cmp(&other.value)
481    }
482}
483
484// =============================================================================
485// From conversions
486// =============================================================================
487
488impl<const N: usize> From<f64> for Jet<N> {
489    fn from(v: f64) -> Self {
490        Self::constant(v)
491    }
492}
493
494impl<const N: usize> From<Jet<N>> for f64 {
495    fn from(j: Jet<N>) -> f64 {
496        j.value
497    }
498}
499
500// =============================================================================
501// Index / IndexMut (backward compat: index 0 = value, index k >= 1 = deriv[k-1])
502// =============================================================================
503
504impl<const N: usize> Index<usize> for Jet<N> {
505    type Output = f64;
506
507    fn index(&self, index: usize) -> &Self::Output {
508        if index == 0 {
509            &self.value
510        } else if index <= N {
511            &self.deriv[index - 1]
512        } else {
513            panic!("Jet<{}> index {} out of bounds (max index = {})", N, index, N)
514        }
515    }
516}
517
518impl<const N: usize> IndexMut<usize> for Jet<N> {
519    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
520        if index == 0 {
521            &mut self.value
522        } else if index <= N {
523            &mut self.deriv[index - 1]
524        } else {
525            panic!("Jet<{}> index {} out of bounds (max index = {})", N, index, N)
526        }
527    }
528}
529
530// =============================================================================
531// Neg
532// =============================================================================
533
534impl<const N: usize> Neg for Jet<N> {
535    type Output = Self;
536
537    fn neg(self) -> Self::Output {
538        let mut z = self;
539        z.value = -z.value;
540        for d in z.deriv.iter_mut() {
541            *d = -*d;
542        }
543        z
544    }
545}
546
547// =============================================================================
548// Add, Sub, Mul, Div for Jet<N> op Jet<N>
549// =============================================================================
550
551impl<const N: usize> Add<Jet<N>> for Jet<N> {
552    type Output = Self;
553
554    fn add(self, rhs: Jet<N>) -> Self::Output {
555        let mut z = self;
556        z.value += rhs.value;
557        for i in 0..N {
558            z.deriv[i] += rhs.deriv[i];
559        }
560        z
561    }
562}
563
564impl<const N: usize> Sub<Jet<N>> for Jet<N> {
565    type Output = Self;
566
567    fn sub(self, rhs: Jet<N>) -> Self::Output {
568        let mut z = self;
569        z.value -= rhs.value;
570        for i in 0..N {
571            z.deriv[i] -= rhs.deriv[i];
572        }
573        z
574    }
575}
576
577impl<const N: usize> Mul<Jet<N>> for Jet<N> {
578    type Output = Self;
579
580    /// Multiplication using normalized Taylor coefficient convolution:
581    /// $z_n = \sum_{k=0}^{n} c_k \cdot d_{n-k}$.
582    /// No binomial coefficients needed due to normalization convention.
583    fn mul(self, rhs: Jet<N>) -> Self::Output {
584        let mut z = Self::zero();
585        for n in 0..=N {
586            let mut s = 0.0f64;
587            for k in 0..=n {
588                s += self.coeff(k) * rhs.coeff(n - k);
589            }
590            z.set_coeff(n, s);
591        }
592        z
593    }
594}
595
596impl<const N: usize> Div<Jet<N>> for Jet<N> {
597    type Output = Self;
598
599    /// Division using normalized Taylor coefficient recurrence:
600    /// $z_0 = a_0 / b_0$,
601    /// $z_n = \frac{1}{b_0}\left(a_n - \sum_{k=1}^{n} b_k \, z_{n-k}\right)$
602    fn div(self, rhs: Jet<N>) -> Self::Output {
603        let b0 = rhs.coeff(0);
604        let inv_b0 = 1.0 / b0;
605        let mut z = Self::zero();
606        z.set_coeff(0, self.coeff(0) * inv_b0);
607        for n in 1..=N {
608            let mut s = 0.0f64;
609            for k in 1..=n {
610                s += rhs.coeff(k) * z.coeff(n - k);
611            }
612            z.set_coeff(n, inv_b0 * (self.coeff(n) - s));
613        }
614        z
615    }
616}
617
618// =============================================================================
619// Scalar arithmetic: Jet<N> op f64
620// =============================================================================
621
622impl<const N: usize> Add<f64> for Jet<N> {
623    type Output = Self;
624
625    fn add(self, rhs: f64) -> Self::Output {
626        let mut z = self;
627        z.value += rhs;
628        z
629    }
630}
631
632impl<const N: usize> Sub<f64> for Jet<N> {
633    type Output = Self;
634
635    fn sub(self, rhs: f64) -> Self::Output {
636        let mut z = self;
637        z.value -= rhs;
638        z
639    }
640}
641
642impl<const N: usize> Mul<f64> for Jet<N> {
643    type Output = Self;
644
645    fn mul(self, rhs: f64) -> Self::Output {
646        let mut z = self;
647        z.value *= rhs;
648        for d in z.deriv.iter_mut() {
649            *d *= rhs;
650        }
651        z
652    }
653}
654
655impl<const N: usize> Div<f64> for Jet<N> {
656    type Output = Self;
657
658    fn div(self, rhs: f64) -> Self::Output {
659        let inv = 1.0 / rhs;
660        let mut z = self;
661        z.value *= inv;
662        for d in z.deriv.iter_mut() {
663            *d *= inv;
664        }
665        z
666    }
667}
668
669// =============================================================================
670// Scalar arithmetic: f64 op Jet<N>
671// =============================================================================
672
673impl<const N: usize> Add<Jet<N>> for f64 {
674    type Output = Jet<N>;
675
676    fn add(self, rhs: Jet<N>) -> Self::Output {
677        let mut z = rhs;
678        z.value += self;
679        z
680    }
681}
682
683impl<const N: usize> Sub<Jet<N>> for f64 {
684    type Output = Jet<N>;
685
686    fn sub(self, rhs: Jet<N>) -> Self::Output {
687        let mut z = -rhs;
688        z.value += self;
689        z
690    }
691}
692
693impl<const N: usize> Mul<Jet<N>> for f64 {
694    type Output = Jet<N>;
695
696    fn mul(self, rhs: Jet<N>) -> Self::Output {
697        rhs * self
698    }
699}
700
701impl<const N: usize> Div<Jet<N>> for f64 {
702    type Output = Jet<N>;
703
704    fn div(self, rhs: Jet<N>) -> Self::Output {
705        Jet::<N>::constant(self) / rhs
706    }
707}
708
709// =============================================================================
710// ExpLogOps
711// =============================================================================
712
713impl<const N: usize> ExpLogOps for Jet<N> {
714    type Float = f64;
715
716    /// $\exp(a)$ using the normalized recurrence:
717    /// $z_0 = e^{a_0}$,
718    /// $z_n = \frac{1}{n}\sum_{k=1}^{n} k\, a_k\, z_{n-k}$
719    fn exp(&self) -> Self {
720        let mut z = Self::zero();
721        z.set_coeff(0, self.coeff(0).exp());
722        for n in 1..=N {
723            let mut s = 0.0f64;
724            for k in 1..=n {
725                s += (k as f64) * self.coeff(k) * z.coeff(n - k);
726            }
727            z.set_coeff(n, s / (n as f64));
728        }
729        z
730    }
731
732    /// $\ln(a)$ using the normalized recurrence:
733    /// $z_0 = \ln(a_0)$,
734    /// $z_n = \frac{1}{a_0}\left(a_n - \frac{1}{n}\sum_{k=1}^{n-1} k\, z_k\, a_{n-k}\right)$
735    fn ln(&self) -> Self {
736        let a0 = self.coeff(0);
737        let inv_a0 = 1.0 / a0;
738        let mut z = Self::zero();
739        z.set_coeff(0, a0.ln());
740        for n in 1..=N {
741            let mut s = 0.0f64;
742            for k in 1..n {
743                s += (k as f64) * z.coeff(k) * self.coeff(n - k);
744            }
745            z.set_coeff(n, inv_a0 * (self.coeff(n) - s / (n as f64)));
746        }
747        z
748    }
749
750    fn log(&self, base: f64) -> Self {
751        let ln_base = base.ln();
752        let z = self.ln();
753        let mut result = Self::zero();
754        result.set_coeff(0, z.coeff(0) / ln_base);
755        for k in 1..=N {
756            result.set_coeff(k, z.coeff(k) / ln_base);
757        }
758        result
759    }
760
761    fn log2(&self) -> Self {
762        self.log(2.0)
763    }
764
765    fn log10(&self) -> Self {
766        self.log(10.0)
767    }
768}
769
770// =============================================================================
771// PowOps
772// =============================================================================
773
774impl<const N: usize> PowOps for Jet<N> {
775    type Float = f64;
776
777    /// Integer power via repeated multiplication.
778    fn powi(&self, n: i32) -> Self {
779        if n == 0 {
780            return Self::constant(1.0);
781        }
782        let abs_n = n.unsigned_abs() as usize;
783        let mut result = *self;
784        for _ in 1..abs_n {
785            result = result * *self;
786        }
787        if n < 0 {
788            Self::constant(1.0) / result
789        } else {
790            result
791        }
792    }
793
794    /// Float power: exp(f * ln(self))
795    fn powf(&self, f: f64) -> Self {
796        (self.ln() * f).exp()
797    }
798
799    /// Jet power: exp(rhs * ln(self))
800    fn pow(&self, rhs: Self) -> Self {
801        (self.ln() * rhs).exp()
802    }
803
804    /// Square root using direct recurrence from $z^2 = a$:
805    /// $z_0 = \sqrt{a_0}$,
806    /// $z_n = \frac{1}{2\,z_0}\left(a_n - \sum_{k=1}^{n-1} z_k\, z_{n-k}\right)$
807    fn sqrt(&self) -> Self {
808        let a0 = self.coeff(0);
809        let z0 = a0.sqrt();
810        let inv_2z0 = 1.0 / (2.0 * z0);
811        let mut z = Self::zero();
812        z.set_coeff(0, z0);
813        for n in 1..=N {
814            let mut s = 0.0f64;
815            for k in 1..n {
816                s += z.coeff(k) * z.coeff(n - k);
817            }
818            z.set_coeff(n, inv_2z0 * (self.coeff(n) - s));
819        }
820        z
821    }
822}
823
824// =============================================================================
825// TrigOps
826// =============================================================================
827
828impl<const N: usize> TrigOps for Jet<N> {
829    /// $\sin$ and $\cos$ computed together via coupled normalized recurrence:
830    /// $s_0 = \sin(a_0)$, $c_0 = \cos(a_0)$,
831    /// $s_n = \frac{1}{n}\sum_{k=1}^{n} k\, a_k\, c_{n-k}$,
832    /// $c_n = -\frac{1}{n}\sum_{k=1}^{n} k\, a_k\, s_{n-k}$
833    fn sin_cos(&self) -> (Self, Self) {
834        let mut s = Self::zero();
835        let mut c = Self::zero();
836        s.set_coeff(0, self.coeff(0).sin());
837        c.set_coeff(0, self.coeff(0).cos());
838        for n in 1..=N {
839            let mut ss = 0.0f64;
840            let mut cs = 0.0f64;
841            for k in 1..=n {
842                let ka = (k as f64) * self.coeff(k);
843                ss += ka * c.coeff(n - k);
844                cs += ka * s.coeff(n - k);
845            }
846            s.set_coeff(n, ss / (n as f64));
847            c.set_coeff(n, -cs / (n as f64));
848        }
849        (s, c)
850    }
851
852    fn sin(&self) -> Self {
853        self.sin_cos().0
854    }
855
856    fn cos(&self) -> Self {
857        self.sin_cos().1
858    }
859
860    fn tan(&self) -> Self {
861        let (s, c) = self.sin_cos();
862        s / c
863    }
864
865    /// $\sinh$ and $\cosh$ via coupled normalized recurrence (same as $\sin/\cos$ but no negative on $\cosh$):
866    /// $s_n = \frac{1}{n}\sum_{k=1}^{n} k\, a_k\, c_{n-k}$,
867    /// $c_n = \frac{1}{n}\sum_{k=1}^{n} k\, a_k\, s_{n-k}$
868    fn sinh(&self) -> Self {
869        self.sinh_cosh().0
870    }
871
872    fn cosh(&self) -> Self {
873        self.sinh_cosh().1
874    }
875
876    fn tanh(&self) -> Self {
877        let (s, c) = self.sinh_cosh();
878        s / c
879    }
880
881    fn asin(&self) -> Self {
882        // q = 1/sqrt(1 - a^2)
883        let one = Self::constant(1.0);
884        let q = (one - self.powi(2)).sqrt();
885        let q_inv = one / q;
886        self.integrate_derivative(self.coeff(0).asin(), &q_inv)
887    }
888
889    fn acos(&self) -> Self {
890        // q = -1/sqrt(1 - a^2)
891        let one = Self::constant(1.0);
892        let q = (one - self.powi(2)).sqrt();
893        let q_inv = -(one / q);
894        self.integrate_derivative(self.coeff(0).acos(), &q_inv)
895    }
896
897    fn atan(&self) -> Self {
898        // q = 1/(1 + a^2)
899        let one = Self::constant(1.0);
900        let q = one / (one + self.powi(2));
901        self.integrate_derivative(self.coeff(0).atan(), &q)
902    }
903
904    fn asinh(&self) -> Self {
905        // q = 1/sqrt(1 + a^2)
906        let one = Self::constant(1.0);
907        let q_inv = (one + self.powi(2)).sqrt();
908        let q = one / q_inv;
909        self.integrate_derivative(self.coeff(0).asinh(), &q)
910    }
911
912    fn acosh(&self) -> Self {
913        // q = 1/sqrt(a^2 - 1)
914        let one = Self::constant(1.0);
915        let q_inv = (self.powi(2) - one).sqrt();
916        let q = one / q_inv;
917        self.integrate_derivative(self.coeff(0).acosh(), &q)
918    }
919
920    fn atanh(&self) -> Self {
921        // q = 1/(1 - a^2)
922        let one = Self::constant(1.0);
923        let q = one / (one - self.powi(2));
924        self.integrate_derivative(self.coeff(0).atanh(), &q)
925    }
926}
927
928impl<const N: usize> Jet<N> {
929    /// Compute $\sinh$ and $\cosh$ together via the coupled normalized recurrence.
930    pub fn sinh_cosh(&self) -> (Self, Self) {
931        let mut s = Self::zero();
932        let mut c = Self::zero();
933        s.set_coeff(0, self.coeff(0).sinh());
934        c.set_coeff(0, self.coeff(0).cosh());
935        for n in 1..=N {
936            let mut ss = 0.0f64;
937            let mut cs = 0.0f64;
938            for k in 1..=n {
939                let ka = (k as f64) * self.coeff(k);
940                ss += ka * c.coeff(n - k);
941                cs += ka * s.coeff(n - k);
942            }
943            s.set_coeff(n, ss / (n as f64));
944            c.set_coeff(n, cs / (n as f64));  // NO negative for cosh
945        }
946        (s, c)
947    }
948
949    /// Integrate using derivative jet: used by inverse trig functions.
950    /// Given $z'(a)$ encoded as a Jet `q`, compute $z$ coefficients by:
951    /// $z_0 = z_0$,
952    /// $z_n = \frac{1}{n}\sum_{k=1}^{n} k\, a_k\, q_{n-k}$
953    fn integrate_derivative(&self, z0: f64, q: &Self) -> Self {
954        let mut z = Self::zero();
955        z.set_coeff(0, z0);
956        for n in 1..=N {
957            let mut s = 0.0f64;
958            for k in 1..=n {
959                s += (k as f64) * self.coeff(k) * q.coeff(n - k);
960            }
961            z.set_coeff(n, s / (n as f64));
962        }
963        z
964    }
965}
966
967// =============================================================================
968// ADFn — lift functions over Jet<2> to work at f64 or Jet level
969// =============================================================================
970
971/// Generic AD function wrapper.
972///
973/// Lifts a function `F: Fn(Jet<2>) -> Jet<2>` to operate at multiple levels:
974/// - `call_stable(f64)` → `f64`: evaluate function value, first derivative, or second derivative
975/// - `call_stable(Jet<2>)` → `Jet<2>`: pass through
976///
977/// For vector functions, also lifts `F: Fn(Vec<Jet<1>>) -> Vec<Jet<1>>`.
978///
979/// # Examples
980/// ```
981/// use peroxide::fuga::*;
982///
983/// let f_ad = ADFn::new(|x: Jet<2>| x.powi(2));
984///
985/// // Value: f(3) = 9
986/// assert_eq!(f_ad.call_stable(3.0), 9.0);
987///
988/// // Gradient: f'(3) = 6  (2*3)
989/// let df = f_ad.grad();
990/// assert_eq!(df.call_stable(3.0), 6.0);
991///
992/// // Hessian: f''(3) = 2
993/// let ddf = df.grad();
994/// assert_eq!(ddf.call_stable(3.0), 2.0);
995/// ```
996pub struct ADFn<F> {
997    f: Box<F>,
998    grad_level: usize,
999}
1000
1001impl<F: Clone> ADFn<F> {
1002    /// Create a new `ADFn` wrapping function `f` at gradient level 0 (function evaluation).
1003    pub fn new(f: F) -> Self {
1004        Self {
1005            f: Box::new(f),
1006            grad_level: 0,
1007        }
1008    }
1009
1010    /// Produce the gradient version of this function (increments grad_level by 1).
1011    /// Panics if grad_level >= 2.
1012    pub fn grad(&self) -> Self {
1013        assert!(self.grad_level < 2, "Higher order AD is not allowed");
1014        ADFn {
1015            f: self.f.clone(),
1016            grad_level: self.grad_level + 1,
1017        }
1018    }
1019}
1020
1021/// Scalar version: F works with `Jet<2>`, target is `f64`.
1022impl<F: Fn(Jet<2>) -> Jet<2>> StableFn<f64> for ADFn<F> {
1023    type Output = f64;
1024
1025    fn call_stable(&self, target: f64) -> f64 {
1026        match self.grad_level {
1027            0 => (self.f)(Jet::<2>::constant(target)).value(),
1028            1 => (self.f)(Jet::<2>::new(target, [1.0, 0.0])).dx(),
1029            2 => (self.f)(Jet::<2>::new(target, [1.0, 0.0])).ddx(),
1030            _ => unreachable!("grad_level > 2 is not allowed"),
1031        }
1032    }
1033}
1034
1035/// Scalar version: F works with `Jet<2>`, target is `Jet<2>`.
1036impl<F: Fn(Jet<2>) -> Jet<2>> StableFn<Jet<2>> for ADFn<F> {
1037    type Output = Jet<2>;
1038
1039    fn call_stable(&self, target: Jet<2>) -> Jet<2> {
1040        (self.f)(target)
1041    }
1042}
1043
1044/// Vector version: F works with `Vec<Jet<1>>`, target is `Vec<f64>`.
1045impl<F: Fn(Vec<Jet<1>>) -> Vec<Jet<1>>> StableFn<Vec<f64>> for ADFn<F> {
1046    type Output = Vec<f64>;
1047
1048    fn call_stable(&self, target: Vec<f64>) -> Vec<f64> {
1049        (self.f)(target.into_iter().map(Jet::<1>::constant).collect())
1050            .into_iter()
1051            .map(|j| j.value())
1052            .collect()
1053    }
1054}
1055
1056/// Vector version: F works with `Vec<Jet<1>>`, target is `Vec<Jet<1>>`.
1057impl<F: Fn(Vec<Jet<1>>) -> Vec<Jet<1>>> StableFn<Vec<Jet<1>>> for ADFn<F> {
1058    type Output = Vec<Jet<1>>;
1059
1060    fn call_stable(&self, target: Vec<Jet<1>>) -> Vec<Jet<1>> {
1061        (self.f)(target)
1062    }
1063}
1064
1065/// Vector version: F works with `&Vec<Jet<1>>`, target is `&Vec<f64>`.
1066impl<'a, F: Fn(&Vec<Jet<1>>) -> Vec<Jet<1>>> StableFn<&'a Vec<f64>> for ADFn<F> {
1067    type Output = Vec<f64>;
1068
1069    fn call_stable(&self, target: &'a Vec<f64>) -> Vec<f64> {
1070        let jet_target: Vec<Jet<1>> = target.iter().map(|&x| Jet::<1>::constant(x)).collect();
1071        (self.f)(&jet_target)
1072            .into_iter()
1073            .map(|j| j.value())
1074            .collect()
1075    }
1076}
1077
1078/// Vector version: F works with `&Vec<Jet<1>>`, target is `&Vec<Jet<1>>`.
1079impl<'a, F: Fn(&Vec<Jet<1>>) -> Vec<Jet<1>>> StableFn<&'a Vec<Jet<1>>> for ADFn<F> {
1080    type Output = Vec<Jet<1>>;
1081
1082    fn call_stable(&self, target: &'a Vec<Jet<1>>) -> Vec<Jet<1>> {
1083        (self.f)(target)
1084    }
1085}
1086
1087// =============================================================================
1088// JetVec trait (replaces ADVec)
1089// =============================================================================
1090
1091/// Trait for converting between `Vec<f64>` and `Vec<Jet<1>>`.
1092pub trait JetVec {
1093    fn to_jet_vec(&self) -> Vec<Jet<1>>;
1094    fn to_f64_vec(&self) -> Vec<f64>;
1095}
1096
1097impl JetVec for Vec<f64> {
1098    fn to_jet_vec(&self) -> Vec<Jet<1>> {
1099        self.iter().map(|&x| Jet::<1>::constant(x)).collect()
1100    }
1101
1102    fn to_f64_vec(&self) -> Vec<f64> {
1103        self.clone()
1104    }
1105}
1106
1107impl JetVec for Vec<Jet<1>> {
1108    fn to_jet_vec(&self) -> Vec<Jet<1>> {
1109        self.clone()
1110    }
1111
1112    fn to_f64_vec(&self) -> Vec<f64> {
1113        self.iter().map(|j| j.value()).collect()
1114    }
1115}
1116
1117// =============================================================================
1118// FPVector, Vector, VecOps for Vec<Jet<1>>
1119// =============================================================================
1120
1121impl FPVector for Vec<Jet<1>> {
1122    type Scalar = Jet<1>;
1123
1124    fn fmap<F>(&self, f: F) -> Self
1125    where
1126        F: Fn(Self::Scalar) -> Self::Scalar,
1127    {
1128        self.iter().map(|&x| f(x)).collect()
1129    }
1130
1131    fn reduce<F, T>(&self, init: T, f: F) -> Self::Scalar
1132    where
1133        F: Fn(Self::Scalar, Self::Scalar) -> Self::Scalar,
1134        T: Into<Self::Scalar>,
1135    {
1136        self.iter().fold(init.into(), |acc, &x| f(acc, x))
1137    }
1138
1139    fn zip_with<F>(&self, f: F, other: &Self) -> Self
1140    where
1141        F: Fn(Self::Scalar, Self::Scalar) -> Self::Scalar,
1142    {
1143        self.iter()
1144            .zip(other.iter())
1145            .map(|(&x, &y)| f(x, y))
1146            .collect()
1147    }
1148
1149    fn filter<F>(&self, f: F) -> Self
1150    where
1151        F: Fn(Self::Scalar) -> bool,
1152    {
1153        self.iter().filter(|&&x| f(x)).cloned().collect()
1154    }
1155
1156    fn take(&self, n: usize) -> Self {
1157        self.iter().take(n).cloned().collect()
1158    }
1159
1160    fn skip(&self, n: usize) -> Self {
1161        self.iter().skip(n).cloned().collect()
1162    }
1163
1164    fn sum(&self) -> Self::Scalar {
1165        if self.is_empty() {
1166            return Jet::<1>::constant(0.0);
1167        }
1168        let s = self[0];
1169        self.reduce(s, |x, y| x + y)
1170    }
1171
1172    fn prod(&self) -> Self::Scalar {
1173        if self.is_empty() {
1174            return Jet::<1>::constant(1.0);
1175        }
1176        let s = self[0];
1177        self.reduce(s, |x, y| x * y)
1178    }
1179}
1180
1181impl Vector for Vec<Jet<1>> {
1182    type Scalar = Jet<1>;
1183
1184    fn add_vec(&self, rhs: &Self) -> Self {
1185        self.add_v(rhs)
1186    }
1187
1188    fn sub_vec(&self, rhs: &Self) -> Self {
1189        self.sub_v(rhs)
1190    }
1191
1192    fn mul_scalar(&self, rhs: Self::Scalar) -> Self {
1193        self.mul_s(rhs)
1194    }
1195}
1196
1197impl VecOps for Vec<Jet<1>> {}
1198
1199// =============================================================================
1200// Backward compatibility: AD type aliases and constructors
1201// =============================================================================
1202// Keep the old public API so existing code that uses AD, AD0, AD1, AD2, ADVec, ADFn
1203// continues to compile. AD is now an alias for Jet<2> (the highest order used in ADFn).
1204// AD0/AD1/AD2 are re-exported constructor functions.
1205
1206/// Backward compatibility alias: `AD` is now `Jet<2>`.
1207///
1208/// For new code, prefer `Dual = Jet<1>` or `HyperDual = Jet<2>` directly.
1209pub type AD = Jet<2>;
1210
1211/// Backward compatibility constructor: `AD0(x)` creates a zero-derivative `Jet<2>` constant.
1212#[inline]
1213#[allow(non_snake_case)]
1214pub fn AD0(x: f64) -> Jet<2> {
1215    Jet::<2>::constant(x)
1216}
1217
1218/// Backward compatibility constructor: `AD1(x, dx)` creates a `Jet<2>` with given first derivative.
1219#[inline]
1220#[allow(non_snake_case)]
1221pub fn AD1(x: f64, dx: f64) -> Jet<2> {
1222    Jet::<2>::new(x, [dx, 0.0])
1223}
1224
1225/// Backward compatibility constructor: `AD2(x, dx, ddx)` creates a `Jet<2>` with given derivatives.
1226#[inline]
1227#[allow(non_snake_case)]
1228pub fn AD2(x: f64, dx: f64, ddx: f64) -> Jet<2> {
1229    Jet::<2>::new(x, [dx, ddx / 2.0])
1230}
1231
1232/// Backward compatibility trait: provides `to_ad_vec` and `to_f64_vec` on vector types.
1233///
1234/// Extends `JetVec` with the `to_ad_vec` method for converting to `Vec<AD>` (= `Vec<Jet<2>>`).
1235pub trait ADVec: JetVec {
1236    fn to_ad_vec(&self) -> Vec<AD>;
1237
1238    /// Convert to a `Vec<f64>` by extracting the value of each jet.
1239    /// (Delegates to `JetVec::to_f64_vec` — provided here so the trait is self-contained.)
1240    fn to_f64_vec_compat(&self) -> Vec<f64> {
1241        self.to_f64_vec()
1242    }
1243}
1244
1245impl ADVec for Vec<f64> {
1246    fn to_ad_vec(&self) -> Vec<AD> {
1247        self.iter().map(|&x| Jet::<2>::constant(x)).collect()
1248    }
1249}
1250
1251impl ADVec for Vec<AD> {
1252    fn to_ad_vec(&self) -> Vec<AD> {
1253        self.clone()
1254    }
1255}
1256
1257impl JetVec for Vec<AD> {
1258    fn to_jet_vec(&self) -> Vec<Jet<1>> {
1259        self.iter()
1260            .map(|j| Jet::<1>::new(j.value(), [j.dx()]))
1261            .collect()
1262    }
1263
1264    fn to_f64_vec(&self) -> Vec<f64> {
1265        self.iter().map(|j| j.value()).collect()
1266    }
1267}
1268
1269// FPVector, Vector, VecOps for Vec<AD> (= Vec<Jet<2>>)
1270impl FPVector for Vec<AD> {
1271    type Scalar = AD;
1272
1273    fn fmap<F>(&self, f: F) -> Self
1274    where
1275        F: Fn(Self::Scalar) -> Self::Scalar,
1276    {
1277        self.iter().map(|&x| f(x)).collect()
1278    }
1279
1280    fn reduce<F, T>(&self, init: T, f: F) -> Self::Scalar
1281    where
1282        F: Fn(Self::Scalar, Self::Scalar) -> Self::Scalar,
1283        T: Into<Self::Scalar>,
1284    {
1285        self.iter().fold(init.into(), |acc, &x| f(acc, x))
1286    }
1287
1288    fn zip_with<F>(&self, f: F, other: &Self) -> Self
1289    where
1290        F: Fn(Self::Scalar, Self::Scalar) -> Self::Scalar,
1291    {
1292        self.iter()
1293            .zip(other.iter())
1294            .map(|(&x, &y)| f(x, y))
1295            .collect()
1296    }
1297
1298    fn filter<F>(&self, f: F) -> Self
1299    where
1300        F: Fn(Self::Scalar) -> bool,
1301    {
1302        self.iter().filter(|&&x| f(x)).cloned().collect()
1303    }
1304
1305    fn take(&self, n: usize) -> Self {
1306        self.iter().take(n).cloned().collect()
1307    }
1308
1309    fn skip(&self, n: usize) -> Self {
1310        self.iter().skip(n).cloned().collect()
1311    }
1312
1313    fn sum(&self) -> Self::Scalar {
1314        if self.is_empty() {
1315            return Jet::<2>::constant(0.0);
1316        }
1317        let s = self[0];
1318        self.reduce(s, |x, y| x + y)
1319    }
1320
1321    fn prod(&self) -> Self::Scalar {
1322        if self.is_empty() {
1323            return Jet::<2>::constant(1.0);
1324        }
1325        let s = self[0];
1326        self.reduce(s, |x, y| x * y)
1327    }
1328}
1329
1330impl Vector for Vec<AD> {
1331    type Scalar = AD;
1332
1333    fn add_vec(&self, rhs: &Self) -> Self {
1334        self.add_v(rhs)
1335    }
1336
1337    fn sub_vec(&self, rhs: &Self) -> Self {
1338        self.sub_v(rhs)
1339    }
1340
1341    fn mul_scalar(&self, rhs: Self::Scalar) -> Self {
1342        self.mul_s(rhs)
1343    }
1344}
1345
1346impl VecOps for Vec<AD> {}