peroxide/numerical/
ode.rs

1//! # Ordinary Differential Equation (ODE) Solvers
2//!
3//! This module provides traits and structs for solving ordinary differential equations (ODEs).
4//!
5//! ## Overview
6//!
7//! - `ODEProblem`: Trait for defining an ODE problem.
8//! - `ODEIntegrator`: Trait for ODE integrators.
9//! - `ODESolver`: Trait for ODE solvers.
10//! - `ODEError`: Enum for ODE errors.
11//!   - `ReachedMaxStepIter`: Reached maximum number of steps per step. (internal error)
12//!   - `ConstraintViolation(f64, Vec<f64>, Vec<f64>)`: Constraint violation. (user-defined error)
13//!   - ODE uses `anyhow` for error handling. So, you can customize your errors.
14//!
15//! ## Available integrators
16//!
17//! - **Explicit**
18//!   - Ralston's 3rd order (RALS3)
19//!   - Runge-Kutta 4th order (RK4)
20//!   - Ralston's 4th order (RALS4)
21//!   - Runge-Kutta 5th order (RK5)
22//! - **Embedded**
23//!   - Bogacki-Shampine 2/3rd order (BS23)
24//!   - Runge-Kutta-Fehlberg 4/5th order (RKF45)
25//!   - Dormand-Prince 4/5th order (DP45)
26//!   - Tsitouras 4/5th order (TSIT45)
27//!   - Runge-Kutta-Fehlberg 7/8th order (RKF78)
28//! - **Implicit**
29//!   - Gauss-Legendre 4th order (GL4)
30//!
31//! ## Available solvers
32//!
33//! - `BasicODESolver`: A basic ODE solver using a specified integrator.
34//!
35//! You can implement your own ODE solver by implementing the `ODESolver` trait.
36//!
37//! ## Example
38//!
39//! ```rust
40//! use peroxide::fuga::*;
41//!
42//! fn main() -> Result<(), Box<dyn Error>> {
43//!     // Same as : let rkf = RKF45::new(1e-4, 0.9, 1e-6, 1e-1, 100);
44//!     let rkf = RKF45 {
45//!         tol: 1e-6,
46//!         safety_factor: 0.9,
47//!         min_step_size: 1e-6,
48//!         max_step_size: 1e-1,
49//!         max_step_iter: 100,
50//!     };
51//!     let basic_ode_solver = BasicODESolver::new(rkf);
52//!     let initial_conditions = vec![1f64];
53//!     let (t_vec, y_vec) = basic_ode_solver.solve(
54//!         &Test,
55//!         (0f64, 10f64),
56//!         0.01,
57//!         &initial_conditions,
58//!     )?;
59//!     let y_vec: Vec<f64> = y_vec.into_iter().flatten().collect();
60//!     println!("{}", y_vec.len());
61//!
62//! #   #[cfg(feature = "plot")]
63//! #   {
64//!     let mut plt = Plot2D::new();
65//!     plt
66//!         .set_domain(t_vec)
67//!         .insert_image(y_vec)
68//!         .set_xlabel(r"$t$")
69//!         .set_ylabel(r"$y$")
70//!         .set_style(PlotStyle::Nature)
71//!         .tight_layout()
72//!         .set_dpi(600)
73//!         .set_path("example_data/rkf45_test.png")
74//!         .savefig()?;
75//! #   }
76//!     Ok(())
77//! }
78//!
79//! // Extremely customizable struct
80//! struct Test;
81//!
82//! impl ODEProblem for Test {
83//!     fn rhs(&self, t: f64, y: &[f64], dy: &mut [f64]) -> anyhow::Result<()> {
84//!         Ok(dy[0] = (5f64 * t.powi(2) - y[0]) / (t + y[0]).exp())
85//!     }
86//! }
87//! ```
88
89use crate::fuga::ConvToMat;
90use crate::traits::math::{InnerProduct, Norm, Normed, Vector};
91use crate::util::non_macro::eye;
92use anyhow::{bail, Result};
93
94/// Trait for defining an ODE problem.
95///
96/// Implement this trait to define your own ODE problem.
97///
98/// # Example
99///
100/// ```
101/// use peroxide::fuga::*;
102///
103/// struct MyODEProblem;
104///
105/// impl ODEProblem for MyODEProblem {
106///     fn rhs(&self, t: f64, y: &[f64], dy: &mut [f64]) -> anyhow::Result<()> {
107///         dy[0] = -0.5 * y[0];
108///         dy[1] = y[0] - y[1];
109///         Ok(())
110///     }
111/// }
112/// ```
113pub trait ODEProblem {
114    fn rhs(&self, t: f64, y: &[f64], dy: &mut [f64]) -> Result<()>;
115}
116
117/// Trait for ODE integrators.
118///
119/// Implement this trait to define your own ODE integrator.
120pub trait ODEIntegrator {
121    fn step<P: ODEProblem>(&self, problem: &P, t: f64, y: &mut [f64], dt: f64) -> Result<f64>;
122}
123
124/// Enum for ODE errors.
125///
126/// # Variants
127///
128/// - `ReachedMaxStepIter`: Reached maximum number of steps per step. (internal error for integrator)
129/// - `ConstraintViolation`: Constraint violation. (user-defined error)
130///
131/// If you define constraints in your problem, you can use this error to report constraint violations.
132///
133/// # Example
134///
135/// ```no_run
136/// use peroxide::fuga::*;
137///
138/// struct ConstrainedProblem {
139///     y_constraint: f64
140/// }
141///
142/// impl ODEProblem for ConstrainedProblem {
143///     fn rhs(&self, t: f64, y: &[f64], dy: &mut [f64]) -> anyhow::Result<()> {
144///         if y[0] < self.y_constraint {
145///             anyhow::bail!(ODEError::ConstraintViolation(t, y.to_vec(), dy.to_vec()));
146///         } else {
147///             // some function
148///             Ok(())
149///         }
150///     }
151/// }
152/// ```
153#[derive(Debug, Clone)]
154pub enum ODEError {
155    ConstraintViolation(f64, Vec<f64>, Vec<f64>), // t, y, dy
156    ReachedMaxStepIter,
157}
158
159impl std::fmt::Display for ODEError {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        match self {
162            ODEError::ConstraintViolation(t, y, dy) => write!(
163                f,
164                "Constraint violation at t = {}, y = {:?}, dy = {:?}",
165                t, y, dy
166            ),
167            ODEError::ReachedMaxStepIter => write!(f, "Reached maximum number of steps per step"),
168        }
169    }
170}
171
172/// Trait for ODE solvers.
173///
174/// Implement this trait to define your own ODE solver.
175pub trait ODESolver {
176    fn solve<P: ODEProblem>(
177        &self,
178        problem: &P,
179        t_span: (f64, f64),
180        dt: f64,
181        initial_conditions: &[f64],
182    ) -> Result<(Vec<f64>, Vec<Vec<f64>>)>;
183}
184
185/// A basic ODE solver using a specified integrator.
186///
187/// # Example
188///
189/// ```
190/// use peroxide::fuga::*;
191///
192/// fn main() -> Result<(), Box<dyn Error>> {
193///     let initial_conditions = vec![1f64];
194///     let rkf = RKF45::new(1e-4, 0.9, 1e-6, 1e-1, 100);
195///     let basic_ode_solver = BasicODESolver::new(rkf);
196///     let (t_vec, y_vec) = basic_ode_solver.solve(
197///         &Test,
198///         (0f64, 10f64),
199///         0.01,
200///         &initial_conditions,
201///     )?;
202///     let y_vec: Vec<f64> = y_vec.into_iter().flatten().collect();
203///
204///     Ok(())
205/// }
206///
207/// struct Test;
208///
209/// impl ODEProblem for Test {
210///     fn rhs(&self, t: f64, y: &[f64], dy: &mut [f64]) -> anyhow::Result<()> {
211///         dy[0] = (5f64 * t.powi(2) - y[0]) / (t + y[0]).exp();
212///         Ok(())
213///     }
214/// }
215/// ```
216pub struct BasicODESolver<I: ODEIntegrator> {
217    integrator: I,
218}
219
220impl<I: ODEIntegrator> BasicODESolver<I> {
221    pub fn new(integrator: I) -> Self {
222        Self { integrator }
223    }
224}
225
226impl<I: ODEIntegrator> ODESolver for BasicODESolver<I> {
227    fn solve<P: ODEProblem>(
228        &self,
229        problem: &P,
230        t_span: (f64, f64),
231        dt: f64,
232        initial_conditions: &[f64],
233    ) -> Result<(Vec<f64>, Vec<Vec<f64>>)> {
234        let mut t = t_span.0;
235        let mut dt = dt;
236        let mut y = initial_conditions.to_vec();
237        let mut t_vec = vec![t];
238        let mut y_vec = vec![y.clone()];
239
240        while t < t_span.1 {
241            let dt_step = self.integrator.step(problem, t, &mut y, dt)?;
242            t += dt;
243            t_vec.push(t);
244            y_vec.push(y.clone());
245            dt = dt_step;
246        }
247
248        Ok((t_vec, y_vec))
249    }
250}
251
252// ┌─────────────────────────────────────────────────────────┐
253//  Butcher Tableau
254// └─────────────────────────────────────────────────────────┘
255/// Trait for Butcher tableau
256///
257/// ```text
258/// C | A
259/// - - -
260///   | BU (Coefficient for update)
261///   | BE (Coefficient for estimate error)
262/// ```
263///
264/// # References
265///
266/// - J. R. Dormand and P. J. Prince, _A family of embedded Runge-Kutta formulae_, J. Comp. Appl. Math., 6(1), 19-26, 1980.
267/// - Wikipedia: [List of Runge-Kutta methods](https://en.wikipedia.org/wiki/List_of_Runge%E2%80%93Kutta_methods)
268pub trait ButcherTableau {
269    const C: &'static [f64];
270    const A: &'static [&'static [f64]];
271    const BU: &'static [f64];
272    const BE: &'static [f64];
273
274    fn tol(&self) -> f64 {
275        unimplemented!()
276    }
277
278    fn safety_factor(&self) -> f64 {
279        unimplemented!()
280    }
281
282    fn max_step_size(&self) -> f64 {
283        unimplemented!()
284    }
285
286    fn min_step_size(&self) -> f64 {
287        unimplemented!()
288    }
289
290    fn max_step_iter(&self) -> usize {
291        unimplemented!()
292    }
293
294    /// Order of the lower-order solution for adaptive step size control.
295    /// The exponent `1/(order+1)` is used in the step size formula.
296    fn order(&self) -> usize {
297        4
298    }
299}
300
301impl<BU: ButcherTableau> ODEIntegrator for BU {
302    fn step<P: ODEProblem>(&self, problem: &P, t: f64, y: &mut [f64], dt: f64) -> Result<f64> {
303        let n = y.len();
304        let mut iter_count = 0usize;
305        let mut dt = dt;
306        let n_k = Self::C.len();
307
308        loop {
309            let mut k_vec = vec![vec![0.0; n]; n_k];
310            let mut y_temp = y.to_vec();
311
312            for stage in 0..n_k {
313                for i in 0..n {
314                    let mut s = 0.0;
315                    for j in 0..stage {
316                        s += Self::A[stage][j] * k_vec[j][i];
317                    }
318                    y_temp[i] = y[i] + dt * s;
319                }
320                problem.rhs(t + dt * Self::C[stage], &y_temp, &mut k_vec[stage])?;
321            }
322
323            if !Self::BE.is_empty() {
324                let mut error = 0f64;
325                for i in 0..n {
326                    let mut s = 0.0;
327                    for j in 0..n_k {
328                        s += (Self::BU[j] - Self::BE[j]) * k_vec[j][i];
329                    }
330                    error = error.max(dt * s.abs())
331                }
332
333                let factor = (self.tol() / error).powf(1.0 / (self.order() as f64 + 1.0));
334                let new_dt = self.safety_factor() * dt * factor;
335                let new_dt = new_dt.clamp(self.min_step_size(), self.max_step_size());
336
337                if error < self.tol() {
338                    for i in 0..n {
339                        let mut s = 0.0;
340                        for j in 0..n_k {
341                            s += Self::BU[j] * k_vec[j][i];
342                        }
343                        y[i] += dt * s;
344                    }
345                    return Ok(new_dt);
346                } else {
347                    iter_count += 1;
348                    if iter_count >= self.max_step_iter() {
349                        bail!(ODEError::ReachedMaxStepIter);
350                    }
351                    dt = new_dt;
352                }
353            } else {
354                for i in 0..n {
355                    let mut s = 0.0;
356                    for j in 0..n_k {
357                        s += Self::BU[j] * k_vec[j][i];
358                    }
359                    y[i] += dt * s;
360                }
361                return Ok(dt);
362            }
363        }
364    }
365}
366
367// ┌─────────────────────────────────────────────────────────┐
368//  Runge-Kutta
369// └─────────────────────────────────────────────────────────┘
370/// Ralston's 3rd order integrator
371///
372/// This integrator uses the Ralston's 3rd order method to numerically integrate the ODE system.
373/// In MATLAB, it is called `ode3`.
374#[derive(Debug, Clone, Copy, Default)]
375#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
376#[cfg_attr(
377    feature = "rkyv",
378    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
379)]
380pub struct RALS3;
381
382impl ButcherTableau for RALS3 {
383    const C: &'static [f64] = &[0.0, 0.5, 0.75];
384    const A: &'static [&'static [f64]] = &[&[], &[0.5], &[0.0, 0.75]];
385    const BU: &'static [f64] = &[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0];
386    const BE: &'static [f64] = &[];
387}
388
389/// Runge-Kutta 4th order integrator.
390///
391/// This integrator uses the classical 4th order Runge-Kutta method to numerically integrate the ODE system.
392/// It calculates four intermediate values (k1, k2, k3, k4) to estimate the next step solution.
393#[derive(Debug, Clone, Copy, Default)]
394#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
395#[cfg_attr(
396    feature = "rkyv",
397    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
398)]
399pub struct RK4;
400
401impl ButcherTableau for RK4 {
402    const C: &'static [f64] = &[0.0, 0.5, 0.5, 1.0];
403    const A: &'static [&'static [f64]] = &[&[], &[0.5], &[0.0, 0.5], &[0.0, 0.0, 1.0]];
404    const BU: &'static [f64] = &[1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0];
405    const BE: &'static [f64] = &[];
406}
407
408/// Ralston's 4th order integrator.
409///
410/// This fourth order method is known as minimum truncation error RK4.
411#[derive(Debug, Clone, Copy, Default)]
412#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
413#[cfg_attr(
414    feature = "rkyv",
415    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
416)]
417pub struct RALS4;
418
419impl ButcherTableau for RALS4 {
420    const C: &'static [f64] = &[0.0, 0.4, 0.45573725, 1.0];
421    const A: &'static [&'static [f64]] = &[
422        &[],
423        &[0.4],
424        &[0.29697761, 0.158575964],
425        &[0.21810040, -3.050965616, 3.83286476],
426    ];
427    const BU: &'static [f64] = &[0.17476028, -0.55148066, 1.20553560, 0.17118478];
428    const BE: &'static [f64] = &[];
429}
430
431/// Runge-Kutta 5th order integrator
432///
433/// This integrator uses the 5th order Runge-Kutta method to numerically integrate the ODE system.
434#[derive(Debug, Clone, Copy, Default)]
435#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
436#[cfg_attr(
437    feature = "rkyv",
438    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
439)]
440pub struct RK5;
441
442impl ButcherTableau for RK5 {
443    const C: &'static [f64] = &[0.0, 0.2, 0.3, 0.8, 8.0 / 9.0, 1.0, 1.0];
444    const A: &'static [&'static [f64]] = &[
445        &[],
446        &[0.2],
447        &[0.075, 0.225],
448        &[44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0],
449        &[
450            19372.0 / 6561.0,
451            -25360.0 / 2187.0,
452            64448.0 / 6561.0,
453            -212.0 / 729.0,
454        ],
455        &[
456            9017.0 / 3168.0,
457            -355.0 / 33.0,
458            46732.0 / 5247.0,
459            49.0 / 176.0,
460            -5103.0 / 18656.0,
461        ],
462        &[
463            35.0 / 384.0,
464            0.0,
465            500.0 / 1113.0,
466            125.0 / 192.0,
467            -2187.0 / 6784.0,
468            11.0 / 84.0,
469        ],
470    ];
471    const BU: &'static [f64] = &[
472        5179.0 / 57600.0,
473        0.0,
474        7571.0 / 16695.0,
475        393.0 / 640.0,
476        -92097.0 / 339200.0,
477        187.0 / 2100.0,
478        1.0 / 40.0,
479    ];
480    const BE: &'static [f64] = &[];
481}
482
483// ┌─────────────────────────────────────────────────────────┐
484//  Embedded Runge-Kutta
485// └─────────────────────────────────────────────────────────┘
486/// Bogacki-Shampine 3(2) method
487///
488/// This method is known as `ode23` in MATLAB.
489///
490/// # Member variables
491///
492/// - `tol`: The tolerance for the estimated error.
493/// - `safety_factor`: The safety factor for the step size adjustment.
494/// - `min_step_size`: The minimum step size.
495/// - `max_step_size`: The maximum step size.
496/// - `max_step_iter`: The maximum number of iterations per step.
497#[derive(Debug, Clone, Copy)]
498#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
499#[cfg_attr(
500    feature = "rkyv",
501    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
502)]
503pub struct BS23 {
504    pub tol: f64,
505    pub safety_factor: f64,
506    pub min_step_size: f64,
507    pub max_step_size: f64,
508    pub max_step_iter: usize,
509}
510
511impl Default for BS23 {
512    fn default() -> Self {
513        Self {
514            tol: 1e-3,
515            safety_factor: 0.9,
516            min_step_size: 1e-6,
517            max_step_size: 1e-1,
518            max_step_iter: 100,
519        }
520    }
521}
522
523impl BS23 {
524    pub fn new(
525        tol: f64,
526        safety_factor: f64,
527        min_step_size: f64,
528        max_step_size: f64,
529        max_step_iter: usize,
530    ) -> Self {
531        Self {
532            tol,
533            safety_factor,
534            min_step_size,
535            max_step_size,
536            max_step_iter,
537        }
538    }
539}
540
541impl ButcherTableau for BS23 {
542    const C: &'static [f64] = &[0.0, 0.5, 0.75, 1.0];
543    const A: &'static [&'static [f64]] = &[
544        &[],
545        &[0.5],
546        &[0.0, 0.75],
547        &[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0],
548    ];
549    const BU: &'static [f64] = &[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0, 0.0];
550    const BE: &'static [f64] = &[7.0 / 24.0, 0.25, 1.0 / 3.0, 0.125];
551
552    fn tol(&self) -> f64 {
553        self.tol
554    }
555    fn safety_factor(&self) -> f64 {
556        self.safety_factor
557    }
558    fn min_step_size(&self) -> f64 {
559        self.min_step_size
560    }
561    fn max_step_size(&self) -> f64 {
562        self.max_step_size
563    }
564    fn max_step_iter(&self) -> usize {
565        self.max_step_iter
566    }
567    fn order(&self) -> usize {
568        2
569    }
570}
571
572/// Runge-Kutta-Fehlberg 4/5th order integrator.
573///
574/// This integrator uses the Runge-Kutta-Fehlberg method, which is an adaptive step size integrator.
575/// It calculates six intermediate values (k1, k2, k3, k4, k5, k6) to estimate the next step solution and the error.
576/// The step size is automatically adjusted based on the estimated error to maintain the desired tolerance.
577///
578/// # Member variables
579///
580/// - `tol`: The tolerance for the estimated error.
581/// - `safety_factor`: The safety factor for the step size adjustment.
582/// - `min_step_size`: The minimum step size.
583/// - `max_step_size`: The maximum step size.
584/// - `max_step_iter`: The maximum number of iterations per step.
585#[derive(Debug, Clone, Copy)]
586#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
587#[cfg_attr(
588    feature = "rkyv",
589    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
590)]
591pub struct RKF45 {
592    pub tol: f64,
593    pub safety_factor: f64,
594    pub min_step_size: f64,
595    pub max_step_size: f64,
596    pub max_step_iter: usize,
597}
598
599impl Default for RKF45 {
600    fn default() -> Self {
601        Self {
602            tol: 1e-6,
603            safety_factor: 0.9,
604            min_step_size: 1e-6,
605            max_step_size: 1e-1,
606            max_step_iter: 100,
607        }
608    }
609}
610
611impl RKF45 {
612    pub fn new(
613        tol: f64,
614        safety_factor: f64,
615        min_step_size: f64,
616        max_step_size: f64,
617        max_step_iter: usize,
618    ) -> Self {
619        Self {
620            tol,
621            safety_factor,
622            min_step_size,
623            max_step_size,
624            max_step_iter,
625        }
626    }
627}
628
629impl ButcherTableau for RKF45 {
630    const C: &'static [f64] = &[0.0, 1.0 / 4.0, 3.0 / 8.0, 12.0 / 13.0, 1.0, 1.0 / 2.0];
631    const A: &'static [&'static [f64]] = &[
632        &[],
633        &[0.25],
634        &[3.0 / 32.0, 9.0 / 32.0],
635        &[1932.0 / 2197.0, -7200.0 / 2197.0, 7296.0 / 2197.0],
636        &[439.0 / 216.0, -8.0, 3680.0 / 513.0, -845.0 / 4104.0],
637        &[
638            -8.0 / 27.0,
639            2.0,
640            -3544.0 / 2565.0,
641            1859.0 / 4104.0,
642            -11.0 / 40.0,
643        ],
644    ];
645    const BU: &'static [f64] = &[
646        16.0 / 135.0,
647        0.0,
648        6656.0 / 12825.0,
649        28561.0 / 56430.0,
650        -9.0 / 50.0,
651        2.0 / 55.0,
652    ];
653    const BE: &'static [f64] = &[
654        25.0 / 216.0,
655        0.0,
656        1408.0 / 2565.0,
657        2197.0 / 4104.0,
658        -1.0 / 5.0,
659        0.0,
660    ];
661
662    fn tol(&self) -> f64 {
663        self.tol
664    }
665    fn safety_factor(&self) -> f64 {
666        self.safety_factor
667    }
668    fn min_step_size(&self) -> f64 {
669        self.min_step_size
670    }
671    fn max_step_size(&self) -> f64 {
672        self.max_step_size
673    }
674    fn max_step_iter(&self) -> usize {
675        self.max_step_iter
676    }
677}
678
679/// Dormand-Prince 5(4) method
680///
681/// This is an adaptive step size integrator based on a 5th order Runge-Kutta method with
682/// 4th order embedded error estimation.
683///
684/// # Member variables
685///
686/// - `tol`: The tolerance for the estimated error.
687/// - `safety_factor`: The safety factor for the step size adjustment.
688/// - `min_step_size`: The minimum step size.
689/// - `max_step_size`: The maximum step size.
690/// - `max_step_iter`: The maximum number of iterations per step.
691#[derive(Debug, Clone, Copy)]
692#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
693#[cfg_attr(
694    feature = "rkyv",
695    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
696)]
697pub struct DP45 {
698    pub tol: f64,
699    pub safety_factor: f64,
700    pub min_step_size: f64,
701    pub max_step_size: f64,
702    pub max_step_iter: usize,
703}
704
705impl Default for DP45 {
706    fn default() -> Self {
707        Self {
708            tol: 1e-6,
709            safety_factor: 0.9,
710            min_step_size: 1e-6,
711            max_step_size: 1e-1,
712            max_step_iter: 100,
713        }
714    }
715}
716
717impl DP45 {
718    pub fn new(
719        tol: f64,
720        safety_factor: f64,
721        min_step_size: f64,
722        max_step_size: f64,
723        max_step_iter: usize,
724    ) -> Self {
725        Self {
726            tol,
727            safety_factor,
728            min_step_size,
729            max_step_size,
730            max_step_iter,
731        }
732    }
733}
734
735impl ButcherTableau for DP45 {
736    const C: &'static [f64] = &[0.0, 0.2, 0.3, 0.8, 8.0 / 9.0, 1.0, 1.0];
737    const A: &'static [&'static [f64]] = &[
738        &[],
739        &[0.2],
740        &[0.075, 0.225],
741        &[44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0],
742        &[
743            19372.0 / 6561.0,
744            -25360.0 / 2187.0,
745            64448.0 / 6561.0,
746            -212.0 / 729.0,
747        ],
748        &[
749            9017.0 / 3168.0,
750            -355.0 / 33.0,
751            46732.0 / 5247.0,
752            49.0 / 176.0,
753            -5103.0 / 18656.0,
754        ],
755        &[
756            35.0 / 384.0,
757            0.0,
758            500.0 / 1113.0,
759            125.0 / 192.0,
760            -2187.0 / 6784.0,
761            11.0 / 84.0,
762        ],
763    ];
764    const BU: &'static [f64] = &[
765        35.0 / 384.0,
766        0.0,
767        500.0 / 1113.0,
768        125.0 / 192.0,
769        -2187.0 / 6784.0,
770        11.0 / 84.0,
771        0.0,
772    ];
773    const BE: &'static [f64] = &[
774        5179.0 / 57600.0,
775        0.0,
776        7571.0 / 16695.0,
777        393.0 / 640.0,
778        -92097.0 / 339200.0,
779        187.0 / 2100.0,
780        1.0 / 40.0,
781    ];
782
783    fn tol(&self) -> f64 {
784        self.tol
785    }
786    fn safety_factor(&self) -> f64 {
787        self.safety_factor
788    }
789    fn min_step_size(&self) -> f64 {
790        self.min_step_size
791    }
792    fn max_step_size(&self) -> f64 {
793        self.max_step_size
794    }
795    fn max_step_iter(&self) -> usize {
796        self.max_step_iter
797    }
798}
799
800/// Tsitouras 5(4) method
801///
802/// This is an adaptive step size integrator based on a 5th order Runge-Kutta method with
803/// 4th order embedded error estimation, using the coefficients from Tsitouras (2011).
804///
805/// # Member variables
806///
807/// - `tol`: The tolerance for the estimated error.
808/// - `safety_factor`: The safety factor for the step size adjustment.
809/// - `min_step_size`: The minimum step size.
810/// - `max_step_size`: The maximum step size.
811/// - `max_step_iter`: The maximum number of iterations per step.
812///
813/// # References
814///
815/// - Ch. Tsitouras, Comput. Math. Appl. 62 (2011) 770-780.
816#[derive(Debug, Clone, Copy)]
817#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
818#[cfg_attr(
819    feature = "rkyv",
820    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
821)]
822pub struct TSIT45 {
823    pub tol: f64,
824    pub safety_factor: f64,
825    pub min_step_size: f64,
826    pub max_step_size: f64,
827    pub max_step_iter: usize,
828}
829
830impl Default for TSIT45 {
831    fn default() -> Self {
832        Self {
833            tol: 1e-6,
834            safety_factor: 0.9,
835            min_step_size: 1e-6,
836            max_step_size: 1e-1,
837            max_step_iter: 100,
838        }
839    }
840}
841
842impl TSIT45 {
843    pub fn new(
844        tol: f64,
845        safety_factor: f64,
846        min_step_size: f64,
847        max_step_size: f64,
848        max_step_iter: usize,
849    ) -> Self {
850        Self {
851            tol,
852            safety_factor,
853            min_step_size,
854            max_step_size,
855            max_step_iter,
856        }
857    }
858}
859
860impl ButcherTableau for TSIT45 {
861    const C: &'static [f64] = &[0.0, 0.161, 0.327, 0.9, 0.9800255409045097, 1.0, 1.0];
862    const A: &'static [&'static [f64]] = &[
863        &[],
864        &[Self::C[1]],
865        &[Self::C[2] - 0.335480655492357, 0.335480655492357],
866        &[
867            Self::C[3] - (-6.359448489975075 + 4.362295432869581),
868            -6.359448489975075,
869            4.362295432869581,
870        ],
871        &[
872            Self::C[4] - (-11.74888356406283 + 7.495539342889836 - 0.09249506636175525),
873            -11.74888356406283,
874            7.495539342889836,
875            -0.09249506636175525,
876        ],
877        &[
878            Self::C[5]
879                - (-12.92096931784711 + 8.159367898576159
880                    - 0.0715849732814010
881                    - 0.02826905039406838),
882            -12.92096931784711,
883            8.159367898576159,
884            -0.0715849732814010,
885            -0.02826905039406838,
886        ],
887        &[
888            Self::BU[0],
889            Self::BU[1],
890            Self::BU[2],
891            Self::BU[3],
892            Self::BU[4],
893            Self::BU[5],
894        ],
895    ];
896    const BU: &'static [f64] = &[
897        0.09646076681806523,
898        0.01,
899        0.4798896504144996,
900        1.379008574103742,
901        -3.290069515436081,
902        2.324710524099774,
903        0.0,
904    ];
905    const BE: &'static [f64] = &[
906        0.001780011052226,
907        0.000816434459657,
908        -0.007880878010262,
909        0.144711007173263,
910        -0.582357165452555,
911        0.458082105929187,
912        1.0 / 66.0,
913    ];
914
915    fn tol(&self) -> f64 {
916        self.tol
917    }
918    fn safety_factor(&self) -> f64 {
919        self.safety_factor
920    }
921    fn min_step_size(&self) -> f64 {
922        self.min_step_size
923    }
924    fn max_step_size(&self) -> f64 {
925        self.max_step_size
926    }
927    fn max_step_iter(&self) -> usize {
928        self.max_step_iter
929    }
930}
931
932/// Runge-Kutta-Fehlberg 7/8th order integrator.
933///
934/// This integrator uses the Runge-Kutta-Fehlberg 7(8) method, an adaptive step size integrator.
935/// It evaluates f(x,y) thirteen times per step, using embedded 7th and 8th order
936/// Runge-Kutta estimates to estimate the solution and the error.
937/// The 7th order solution is propagated, and the difference between the 8th and 7th
938/// order solutions is used for error estimation and step size control.
939///
940/// # Member variables
941///
942/// - `tol`: The tolerance for the estimated error.
943/// - `safety_factor`: The safety factor for the step size adjustment.
944/// - `min_step_size`: The minimum step size.
945/// - `max_step_size`: The maximum step size.
946/// - `max_step_iter`: The maximum number of iterations per step.
947///
948/// # References
949/// - Meysam Mahooti (2025). [Runge-Kutta-Fehlberg (RKF78)](https://www.mathworks.com/matlabcentral/fileexchange/61130-runge-kutta-fehlberg-rkf78), MATLAB Central File Exchange.
950#[derive(Debug, Clone, Copy)]
951#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
952#[cfg_attr(
953    feature = "rkyv",
954    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
955)]
956pub struct RKF78 {
957    pub tol: f64,
958    pub safety_factor: f64,
959    pub min_step_size: f64,
960    pub max_step_size: f64,
961    pub max_step_iter: usize,
962}
963
964impl Default for RKF78 {
965    fn default() -> Self {
966        Self {
967            tol: 1e-7, // Higher precision default for a higher-order method
968            safety_factor: 0.9,
969            min_step_size: 1e-10, // Smaller min step for higher order
970            max_step_size: 1e-1,
971            max_step_iter: 100,
972        }
973    }
974}
975
976impl RKF78 {
977    pub fn new(
978        tol: f64,
979        safety_factor: f64,
980        min_step_size: f64,
981        max_step_size: f64,
982        max_step_iter: usize,
983    ) -> Self {
984        Self {
985            tol,
986            safety_factor,
987            min_step_size,
988            max_step_size,
989            max_step_iter,
990        }
991    }
992}
993
994impl ButcherTableau for RKF78 {
995    const C: &'static [f64] = &[
996        0.0,
997        2.0 / 27.0,
998        1.0 / 9.0,
999        1.0 / 6.0,
1000        5.0 / 12.0,
1001        1.0 / 2.0,
1002        5.0 / 6.0,
1003        1.0 / 6.0,
1004        2.0 / 3.0,
1005        1.0 / 3.0,
1006        1.0,
1007        0.0, // k12 is evaluated at x[i]
1008        1.0, // k13 is evaluated at x[i]+h
1009    ];
1010
1011    const A: &'static [&'static [f64]] = &[
1012        // k1
1013        &[],
1014        // k2
1015        &[2.0 / 27.0],
1016        // k3
1017        &[1.0 / 36.0, 3.0 / 36.0],
1018        // k4
1019        &[1.0 / 24.0, 0.0, 3.0 / 24.0],
1020        // k5
1021        &[20.0 / 48.0, 0.0, -75.0 / 48.0, 75.0 / 48.0],
1022        // k6
1023        &[1.0 / 20.0, 0.0, 0.0, 5.0 / 20.0, 4.0 / 20.0],
1024        // k7
1025        &[
1026            -25.0 / 108.0,
1027            0.0,
1028            0.0,
1029            125.0 / 108.0,
1030            -260.0 / 108.0,
1031            250.0 / 108.0,
1032        ],
1033        // k8
1034        &[
1035            31.0 / 300.0,
1036            0.0,
1037            0.0,
1038            0.0,
1039            61.0 / 225.0,
1040            -2.0 / 9.0,
1041            13.0 / 900.0,
1042        ],
1043        // k9
1044        &[
1045            2.0,
1046            0.0,
1047            0.0,
1048            -53.0 / 6.0,
1049            704.0 / 45.0,
1050            -107.0 / 9.0,
1051            67.0 / 90.0,
1052            3.0,
1053        ],
1054        // k10
1055        &[
1056            -91.0 / 108.0,
1057            0.0,
1058            0.0,
1059            23.0 / 108.0,
1060            -976.0 / 135.0,
1061            311.0 / 54.0,
1062            -19.0 / 60.0,
1063            17.0 / 6.0,
1064            -1.0 / 12.0,
1065        ],
1066        // k11
1067        &[
1068            2383.0 / 4100.0,
1069            0.0,
1070            0.0,
1071            -341.0 / 164.0,
1072            4496.0 / 1025.0,
1073            -301.0 / 82.0,
1074            2133.0 / 4100.0,
1075            45.0 / 82.0,
1076            45.0 / 164.0,
1077            18.0 / 41.0,
1078        ],
1079        // k12
1080        &[
1081            3.0 / 205.0,
1082            0.0,
1083            0.0,
1084            0.0,
1085            0.0,
1086            -6.0 / 41.0,
1087            -3.0 / 205.0,
1088            -3.0 / 41.0,
1089            3.0 / 41.0,
1090            6.0 / 41.0,
1091            0.0,
1092        ],
1093        // k13
1094        &[
1095            -1777.0 / 4100.0,
1096            0.0,
1097            0.0,
1098            -341.0 / 164.0,
1099            4496.0 / 1025.0,
1100            -289.0 / 82.0,
1101            2193.0 / 4100.0,
1102            51.0 / 82.0,
1103            33.0 / 164.0,
1104            12.0 / 41.0,
1105            0.0,
1106            1.0,
1107        ],
1108    ];
1109
1110    // Coefficients for the 8th order solution (propagated via local extrapolation)
1111    const BU: &'static [f64] = &[
1112        0.0,
1113        0.0,
1114        0.0,
1115        0.0,
1116        0.0,
1117        34.0 / 105.0,
1118        9.0 / 35.0,
1119        9.0 / 35.0,
1120        9.0 / 280.0,
1121        9.0 / 280.0,
1122        0.0,
1123        41.0 / 840.0,
1124        41.0 / 840.0,
1125    ];
1126
1127    // Synthetic coefficients for error estimation
1128    // BU - BE yields the Fehlberg error formula: 41/840 * (k1 + k11 - k12 - k13)
1129    const BE: &'static [f64] = &[
1130        41.0 / 840.0,
1131        0.0,
1132        0.0,
1133        0.0,
1134        0.0,
1135        34.0 / 105.0,
1136        9.0 / 35.0,
1137        9.0 / 35.0,
1138        9.0 / 280.0,
1139        9.0 / 280.0,
1140        41.0 / 840.0,
1141        0.0,
1142        0.0,
1143    ];
1144
1145    fn tol(&self) -> f64 {
1146        self.tol
1147    }
1148    fn safety_factor(&self) -> f64 {
1149        self.safety_factor
1150    }
1151    fn min_step_size(&self) -> f64 {
1152        self.min_step_size
1153    }
1154    fn max_step_size(&self) -> f64 {
1155        self.max_step_size
1156    }
1157    fn max_step_iter(&self) -> usize {
1158        self.max_step_iter
1159    }
1160    fn order(&self) -> usize {
1161        7
1162    }
1163}
1164
1165// ┌─────────────────────────────────────────────────────────┐
1166//  Gauss-Legendre 4th order
1167// └─────────────────────────────────────────────────────────┘
1168
1169// Correct coefficients for 4th-order Gauss-Legendre method
1170const SQRT3: f64 = 1.7320508075688772;
1171const C1: f64 = 0.5 - SQRT3 / 6.0;
1172const C2: f64 = 0.5 + SQRT3 / 6.0;
1173const A11: f64 = 0.25;
1174const A12: f64 = 0.25 - SQRT3 / 6.0;
1175const A21: f64 = 0.25 + SQRT3 / 6.0;
1176const A22: f64 = 0.25;
1177const B1: f64 = 0.5;
1178const B2: f64 = 0.5;
1179
1180/// Enum for implicit solvers.
1181///
1182/// This enum defines the available implicit solvers for the Gauss-Legendre 4th order integrator.
1183/// Currently, there are two options: fixed-point iteration and Broyden's method.
1184#[derive(Debug, Clone, Copy)]
1185#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
1186#[cfg_attr(
1187    feature = "rkyv",
1188    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
1189)]
1190pub enum ImplicitSolver {
1191    FixedPoint,
1192    Broyden,
1193    //TrustRegion(f64, f64),
1194}
1195
1196/// Gauss-Legendre 4th order integrator.
1197///
1198/// This integrator uses the 4th order Gauss-Legendre Runge-Kutta method, which is an implicit integrator.
1199/// It requires solving a system of nonlinear equations at each step, which is done using the specified implicit solver (e.g., fixed-point iteration).
1200/// The Gauss-Legendre method has better stability properties compared to explicit methods, especially for stiff ODEs.
1201///
1202/// # Member variables
1203///
1204/// - `solver`: The implicit solver to use.
1205/// - `tol`: The tolerance for the implicit solver.
1206/// - `max_step_iter`: The maximum number of iterations for the implicit solver per step.
1207#[derive(Debug, Clone, Copy)]
1208#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
1209#[cfg_attr(
1210    feature = "rkyv",
1211    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
1212)]
1213pub struct GL4 {
1214    pub solver: ImplicitSolver,
1215    pub tol: f64,
1216    pub max_step_iter: usize,
1217}
1218
1219impl Default for GL4 {
1220    fn default() -> Self {
1221        GL4 {
1222            solver: ImplicitSolver::FixedPoint,
1223            tol: 1e-8,
1224            max_step_iter: 100,
1225        }
1226    }
1227}
1228
1229impl GL4 {
1230    pub fn new(solver: ImplicitSolver, tol: f64, max_step_iter: usize) -> Self {
1231        GL4 {
1232            solver,
1233            tol,
1234            max_step_iter,
1235        }
1236    }
1237}
1238
1239impl ODEIntegrator for GL4 {
1240    #[allow(non_snake_case)]
1241    #[inline]
1242    fn step<P: ODEProblem>(&self, problem: &P, t: f64, y: &mut [f64], dt: f64) -> Result<f64> {
1243        let n = y.len();
1244        //let sqrt3 = 3.0_f64.sqrt();
1245        //let c = 0.5 * (3.0 - sqrt3) / 6.0;
1246        //let d = 0.5 * (3.0 + sqrt3) / 6.0;
1247        let mut k1 = vec![0.0; n];
1248        let mut k2 = vec![0.0; n];
1249
1250        // Initial guess for k1, k2.
1251        problem.rhs(t, y, &mut k1)?;
1252        k2.copy_from_slice(&k1);
1253
1254        match self.solver {
1255            ImplicitSolver::FixedPoint => {
1256                // Fixed-point iteration
1257                let mut y1 = vec![0.0; n];
1258                let mut y2 = vec![0.0; n];
1259
1260                for _ in 0..self.max_step_iter {
1261                    let k1_old = k1.clone();
1262                    let k2_old = k2.clone();
1263
1264                    for i in 0..n {
1265                        y1[i] = y[i] + dt * (A11 * k1[i] + A12 * k2[i]);
1266                        y2[i] = y[i] + dt * (A11 * k1[i] + A12 * k2[i]);
1267                    }
1268
1269                    // Compute new k1 and k2
1270                    problem.rhs(t + C1 * dt, &y1, &mut k1)?;
1271                    problem.rhs(t + C2 * dt, &y2, &mut k2)?;
1272
1273                    // Check for convergence
1274                    let mut max_diff = 0f64;
1275                    for i in 0..n {
1276                        max_diff = max_diff.max((k1[i] - k1_old[i]).abs());
1277                        max_diff = max_diff.max((k2[i] - k2_old[i]).abs());
1278                    }
1279
1280                    if max_diff < self.tol {
1281                        break;
1282                    }
1283                }
1284            }
1285            ImplicitSolver::Broyden => {
1286                let m = 2 * n;
1287                let mut U = vec![0.0; m];
1288                U[..n].copy_from_slice(&k1);
1289                U[n..].copy_from_slice(&k2);
1290
1291                // F_vec = F(U)
1292                let mut F_vec = vec![0.0; m];
1293                compute_F(problem, t, y, dt, &U, &mut F_vec)?;
1294
1295                // Initialize inverse Jacobian matrix
1296                let mut J_inv = eye(m);
1297
1298                // Repeat Broyden's method
1299                for _ in 0..self.max_step_iter {
1300                    // delta = - J_inv * F_vec
1301                    let delta = (&J_inv * &F_vec).mul_scalar(-1.0);
1302
1303                    // U <- U + delta
1304                    U.iter_mut().zip(delta.iter()).for_each(|(u, d)| *u += *d);
1305
1306                    let mut F_new = vec![0.0; m];
1307                    compute_F(problem, t, y, dt, &U, &mut F_new)?;
1308
1309                    // If infinity norm of F_new is less than tol, break
1310                    if F_new.norm(Norm::LInf) < self.tol {
1311                        break;
1312                    }
1313
1314                    // Residual: delta_F = F_new - F_vec
1315                    let delta_F = F_new.sub_vec(&F_vec);
1316
1317                    // J_inv * delta_F
1318                    let J_inv_delta_F = &J_inv * &delta_F;
1319
1320                    let denom = delta.dot(&J_inv_delta_F);
1321                    if denom.abs() < 1e-12 {
1322                        break;
1323                    }
1324
1325                    // Broyden's "good" update for the inverse Jacobian
1326                    // J_inv <- J_inv + ((delta - J_inv * delta_F) * delta^T * J_inv) / denom
1327                    let delta_minus_J_inv_delta_F = delta.sub_vec(&J_inv_delta_F).to_col();
1328                    let delta_T_J_inv = &delta.to_row() * &J_inv;
1329                    let update = (delta_minus_J_inv_delta_F * delta_T_J_inv) / denom;
1330                    J_inv = J_inv + update;
1331                    F_vec = F_new;
1332                }
1333
1334                k1.copy_from_slice(&U[..n]);
1335                k2.copy_from_slice(&U[n..]);
1336            }
1337        }
1338
1339        for i in 0..n {
1340            y[i] += dt * (B1 * k1[i] + B2 * k2[i]);
1341        }
1342
1343        Ok(dt)
1344    }
1345}
1346
1347//// Helper function to compute the function F(U) for the implicit solver.
1348//// y1 = y + dt*(c*k1 + d*k2 - sqrt3/2*(k2-k1))
1349//// y2 = y + dt*(c*k1 + d*k2 + sqrt3/2*(k2-k1))
1350//#[allow(non_snake_case)]
1351//fn compute_F<P: ODEProblem>(
1352//    problem: &P,
1353//    t: f64,
1354//    y: &[f64],
1355//    dt: f64,
1356//    c: f64,
1357//    d: f64,
1358//    sqrt3: f64,
1359//    U: &[f64],
1360//    F: &mut [f64],
1361//) -> Result<()> {
1362//    let n = y.len();
1363//    let mut y1 = vec![0.0; n];
1364//    let mut y2 = vec![0.0; n];
1365//
1366//    for i in 0..n {
1367//        let k1 = U[i];
1368//        let k2 = U[n + i];
1369//        y1[i] = y[i] + dt * (c * k1 + d * k2 - sqrt3 * (k2 - k1) / 2.0);
1370//        y2[i] = y[i] + dt * (c * k1 + d * k2 + sqrt3 * (k2 - k1) / 2.0);
1371//    }
1372//
1373//    let mut f1 = vec![0.0; n];
1374//    let mut f2 = vec![0.0; n];
1375//    problem.rhs(t + c * dt, &y1, &mut f1)?;
1376//    problem.rhs(t + d * dt, &y2, &mut f2)?;
1377//
1378//    // F = [ k1 - f1, k2 - f2 ]
1379//    for i in 0..n {
1380//        F[i] = U[i] - f1[i];
1381//        F[n + i] = U[n + i] - f2[i];
1382//    }
1383//    Ok(())
1384//}
1385/// Helper function to compute the residual F(U) = U - f(y + dt*A*U)
1386#[allow(non_snake_case)]
1387fn compute_F<P: ODEProblem>(
1388    problem: &P,
1389    t: f64,
1390    y: &[f64],
1391    dt: f64,
1392    U: &[f64], // U is a concatenated vector [k1, k2]
1393    F: &mut [f64],
1394) -> Result<()> {
1395    let n = y.len();
1396    let (k1_slice, k2_slice) = U.split_at(n);
1397
1398    let mut y1 = vec![0.0; n];
1399    let mut y2 = vec![0.0; n];
1400
1401    for i in 0..n {
1402        y1[i] = y[i] + dt * (A11 * k1_slice[i] + A12 * k2_slice[i]);
1403        y2[i] = y[i] + dt * (A21 * k1_slice[i] + A22 * k2_slice[i]);
1404    }
1405
1406    // F is an output parameter, its parts f1 and f2 are stored temporarily
1407    let (f1, f2) = F.split_at_mut(n);
1408    problem.rhs(t + C1 * dt, &y1, f1)?;
1409    problem.rhs(t + C2 * dt, &y2, f2)?;
1410
1411    // Compute final residual F = [k1 - f1, k2 - f2]
1412    for i in 0..n {
1413        f1[i] = k1_slice[i] - f1[i];
1414        f2[i] = k2_slice[i] - f2[i];
1415    }
1416    Ok(())
1417}