peroxide/numerical/
ode.rs

1//! # Ordinary Differential Equation (ODE) Solvers
2//!
3//! This module provides traits and structs for solving ordinary differential equations (ODEs).
4//!
5//! ## Overview
6//!
7//! - `ODEProblem`: Trait for defining an ODE problem.
8//! - `ODEIntegrator`: Trait for ODE integrators.
9//! - `ODESolver`: Trait for ODE solvers.
10//! - `ODEError`: Enum for ODE errors.
11//!   - `ReachedMaxStepIter`: Reached maximum number of steps per step. (internal error)
12//!   - `ConstraintViolation(f64, Vec<f64>, Vec<f64>)`: Constraint violation. (user-defined error)
13//!   - ODE uses `anyhow` for error handling. So, you can customize your errors.
14//!
15//! ## Available integrators
16//!
17//! - **Explicit**
18//!   - Ralston's 3rd order (RALS3)
19//!   - Runge-Kutta 4th order (RK4)
20//!   - Ralston's 4th order (RALS4)
21//!   - Runge-Kutta 5th order (RK5)
22//! - **Embedded**
23//!   - Bogacki-Shampine 2/3rd order (BS23)
24//!   - Runge-Kutta-Fehlberg 4/5th order (RKF45)
25//!   - Dormand-Prince 4/5th order (DP45)
26//!   - Tsitouras 4/5th order (TSIT45)
27//!   - Runge-Kutta-Fehlberg 7/8th order (RKF78)
28//! - **Implicit**
29//!   - Gauss-Legendre 4th order (GL4)
30//!
31//! ## Available solvers
32//!
33//! - `BasicODESolver`: A basic ODE solver using a specified integrator.
34//!
35//! You can implement your own ODE solver by implementing the `ODESolver` trait.
36//!
37//! ## Example
38//!
39//! ```rust
40//! use peroxide::fuga::*;
41//!
42//! fn main() -> Result<(), Box<dyn Error>> {
43//!     // Same as : let rkf = RKF45::new(1e-4, 0.9, 1e-6, 1e-1, 100);
44//!     let rkf = RKF45 {
45//!         tol: 1e-6,
46//!         safety_factor: 0.9,
47//!         min_step_size: 1e-6,
48//!         max_step_size: 1e-1,
49//!         max_step_iter: 100,
50//!     };
51//!     let basic_ode_solver = BasicODESolver::new(rkf);
52//!     let initial_conditions = vec![1f64];
53//!     let (t_vec, y_vec) = basic_ode_solver.solve(
54//!         &Test,
55//!         (0f64, 10f64),
56//!         0.01,
57//!         &initial_conditions,
58//!     )?;
59//!     let y_vec: Vec<f64> = y_vec.into_iter().flatten().collect();
60//!     println!("{}", y_vec.len());
61//!
62//! #   #[cfg(feature = "plot")]
63//! #   {
64//!     let mut plt = Plot2D::new();
65//!     plt
66//!         .set_domain(t_vec)
67//!         .insert_image(y_vec)
68//!         .set_xlabel(r"$t$")
69//!         .set_ylabel(r"$y$")
70//!         .set_style(PlotStyle::Nature)
71//!         .tight_layout()
72//!         .set_dpi(600)
73//!         .set_path("example_data/rkf45_test.png")
74//!         .savefig()?;
75//! #   }
76//!     Ok(())
77//! }
78//!
79//! // Extremely customizable struct
80//! struct Test;
81//!
82//! impl ODEProblem for Test {
83//!     fn rhs(&self, t: f64, y: &[f64], dy: &mut [f64]) -> anyhow::Result<()> {
84//!         Ok(dy[0] = (5f64 * t.powi(2) - y[0]) / (t + y[0]).exp())
85//!     }
86//! }
87//! ```
88
89use crate::fuga::ConvToMat;
90use crate::traits::math::{InnerProduct, Norm, Normed, Vector};
91use crate::util::non_macro::eye;
92use anyhow::{bail, Result};
93
94/// Trait for defining an ODE problem.
95///
96/// Implement this trait to define your own ODE problem.
97///
98/// # Example
99///
100/// ```
101/// use peroxide::fuga::*;
102///
103/// struct MyODEProblem;
104///
105/// impl ODEProblem for MyODEProblem {
106///     fn rhs(&self, t: f64, y: &[f64], dy: &mut [f64]) -> anyhow::Result<()> {
107///         dy[0] = -0.5 * y[0];
108///         dy[1] = y[0] - y[1];
109///         Ok(())
110///     }
111/// }
112/// ```
113pub trait ODEProblem {
114    fn rhs(&self, t: f64, y: &[f64], dy: &mut [f64]) -> Result<()>;
115}
116
117/// Trait for ODE integrators.
118///
119/// Implement this trait to define your own ODE integrator.
120pub trait ODEIntegrator {
121    fn step<P: ODEProblem>(&self, problem: &P, t: f64, y: &mut [f64], dt: f64) -> Result<f64>;
122}
123
124/// Enum for ODE errors.
125///
126/// # Variants
127///
128/// - `ReachedMaxStepIter`: Reached maximum number of steps per step. (internal error for integrator)
129/// - `ConstraintViolation`: Constraint violation. (user-defined error)
130///
131/// If you define constraints in your problem, you can use this error to report constraint violations.
132///
133/// # Example
134///
135/// ```no_run
136/// use peroxide::fuga::*;
137///
138/// struct ConstrainedProblem {
139///     y_constraint: f64
140/// }
141///
142/// impl ODEProblem for ConstrainedProblem {
143///     fn rhs(&self, t: f64, y: &[f64], dy: &mut [f64]) -> anyhow::Result<()> {
144///         if y[0] < self.y_constraint {
145///             anyhow::bail!(ODEError::ConstraintViolation(t, y.to_vec(), dy.to_vec()));
146///         } else {
147///             // some function
148///             Ok(())
149///         }
150///     }
151/// }
152/// ```
153#[derive(Debug, Clone)]
154pub enum ODEError {
155    ConstraintViolation(f64, Vec<f64>, Vec<f64>), // t, y, dy
156    ReachedMaxStepIter,
157}
158
159impl std::fmt::Display for ODEError {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        match self {
162            ODEError::ConstraintViolation(t, y, dy) => write!(
163                f,
164                "Constraint violation at t = {}, y = {:?}, dy = {:?}",
165                t, y, dy
166            ),
167            ODEError::ReachedMaxStepIter => write!(f, "Reached maximum number of steps per step"),
168        }
169    }
170}
171
172/// Trait for ODE solvers.
173///
174/// Implement this trait to define your own ODE solver.
175pub trait ODESolver {
176    fn solve<P: ODEProblem>(
177        &self,
178        problem: &P,
179        t_span: (f64, f64),
180        dt: f64,
181        initial_conditions: &[f64],
182    ) -> Result<(Vec<f64>, Vec<Vec<f64>>)>;
183}
184
185/// A basic ODE solver using a specified integrator.
186///
187/// # Example
188///
189/// ```
190/// use peroxide::fuga::*;
191///
192/// fn main() -> Result<(), Box<dyn Error>> {
193///     let initial_conditions = vec![1f64];
194///     let rkf = RKF45::new(1e-4, 0.9, 1e-6, 1e-1, 100);
195///     let basic_ode_solver = BasicODESolver::new(rkf);
196///     let (t_vec, y_vec) = basic_ode_solver.solve(
197///         &Test,
198///         (0f64, 10f64),
199///         0.01,
200///         &initial_conditions,
201///     )?;
202///     let y_vec: Vec<f64> = y_vec.into_iter().flatten().collect();
203///
204///     Ok(())
205/// }
206///
207/// struct Test;
208///
209/// impl ODEProblem for Test {
210///     fn rhs(&self, t: f64, y: &[f64], dy: &mut [f64]) -> anyhow::Result<()> {
211///         dy[0] = (5f64 * t.powi(2) - y[0]) / (t + y[0]).exp();
212///         Ok(())
213///     }
214/// }
215/// ```
216pub struct BasicODESolver<I: ODEIntegrator> {
217    integrator: I,
218}
219
220impl<I: ODEIntegrator> BasicODESolver<I> {
221    pub fn new(integrator: I) -> Self {
222        Self { integrator }
223    }
224}
225
226impl<I: ODEIntegrator> ODESolver for BasicODESolver<I> {
227    fn solve<P: ODEProblem>(
228        &self,
229        problem: &P,
230        t_span: (f64, f64),
231        dt: f64,
232        initial_conditions: &[f64],
233    ) -> Result<(Vec<f64>, Vec<Vec<f64>>)> {
234        let mut t = t_span.0;
235        let mut dt = dt;
236        let mut y = initial_conditions.to_vec();
237        let mut t_vec = vec![t];
238        let mut y_vec = vec![y.clone()];
239
240        while t < t_span.1 {
241            let dt_step = self.integrator.step(problem, t, &mut y, dt)?;
242            t += dt;
243            t_vec.push(t);
244            y_vec.push(y.clone());
245            dt = dt_step;
246        }
247
248        Ok((t_vec, y_vec))
249    }
250}
251
252// ┌─────────────────────────────────────────────────────────┐
253//  Butcher Tableau
254// └─────────────────────────────────────────────────────────┘
255/// Trait for Butcher tableau
256///
257/// ```text
258/// C | A
259/// - - -
260///   | BU (Coefficient for update)
261///   | BE (Coefficient for estimate error)
262/// ```
263///
264/// # References
265///
266/// - J. R. Dormand and P. J. Prince, _A family of embedded Runge-Kutta formulae_, J. Comp. Appl. Math., 6(1), 19-26, 1980.
267/// - Wikipedia: [List of Runge-Kutta methods](https://en.wikipedia.org/wiki/List_of_Runge%E2%80%93Kutta_methods)
268pub trait ButcherTableau {
269    const C: &'static [f64];
270    const A: &'static [&'static [f64]];
271    const BU: &'static [f64];
272    const BE: &'static [f64];
273
274    fn tol(&self) -> f64 {
275        unimplemented!()
276    }
277
278    fn safety_factor(&self) -> f64 {
279        unimplemented!()
280    }
281
282    fn max_step_size(&self) -> f64 {
283        unimplemented!()
284    }
285
286    fn min_step_size(&self) -> f64 {
287        unimplemented!()
288    }
289
290    fn max_step_iter(&self) -> usize {
291        unimplemented!()
292    }
293}
294
295impl<BU: ButcherTableau> ODEIntegrator for BU {
296    fn step<P: ODEProblem>(&self, problem: &P, t: f64, y: &mut [f64], dt: f64) -> Result<f64> {
297        let n = y.len();
298        let mut iter_count = 0usize;
299        let mut dt = dt;
300        let n_k = Self::C.len();
301
302        loop {
303            let mut k_vec = vec![vec![0.0; n]; n_k];
304            let mut y_temp = y.to_vec();
305
306            for stage in 0..n_k {
307                for i in 0..n {
308                    let mut s = 0.0;
309                    for j in 0..stage {
310                        s += Self::A[stage][j] * k_vec[j][i];
311                    }
312                    y_temp[i] = y[i] + dt * s;
313                }
314                problem.rhs(t + dt * Self::C[stage], &y_temp, &mut k_vec[stage])?;
315            }
316
317            if !Self::BE.is_empty() {
318                let mut error = 0f64;
319                for i in 0..n {
320                    let mut s = 0.0;
321                    for j in 0..n_k {
322                        s += (Self::BU[j] - Self::BE[j]) * k_vec[j][i];
323                    }
324                    error = error.max(dt * s.abs())
325                }
326
327                let factor = (self.tol() * dt / error).powf(0.2);
328                let new_dt = self.safety_factor() * dt * factor;
329                let new_dt = new_dt.clamp(self.min_step_size(), self.max_step_size());
330
331                if error < self.tol() {
332                    for i in 0..n {
333                        let mut s = 0.0;
334                        for j in 0..n_k {
335                            s += Self::BU[j] * k_vec[j][i];
336                        }
337                        y[i] += dt * s;
338                    }
339                    return Ok(new_dt);
340                } else {
341                    iter_count += 1;
342                    if iter_count >= self.max_step_iter() {
343                        bail!(ODEError::ReachedMaxStepIter);
344                    }
345                    dt = new_dt;
346                }
347            } else {
348                for i in 0..n {
349                    let mut s = 0.0;
350                    for j in 0..n_k {
351                        s += Self::BU[j] * k_vec[j][i];
352                    }
353                    y[i] += dt * s;
354                }
355                return Ok(dt);
356            }
357        }
358    }
359}
360
361// ┌─────────────────────────────────────────────────────────┐
362//  Runge-Kutta
363// └─────────────────────────────────────────────────────────┘
364/// Ralston's 3rd order integrator
365///
366/// This integrator uses the Ralston's 3rd order method to numerically integrate the ODE system.
367/// In MATLAB, it is called `ode3`.
368#[derive(Debug, Clone, Copy, Default)]
369#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
370#[cfg_attr(
371    feature = "rkyv",
372    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
373)]
374pub struct RALS3;
375
376impl ButcherTableau for RALS3 {
377    const C: &'static [f64] = &[0.0, 0.5, 0.75];
378    const A: &'static [&'static [f64]] = &[&[], &[0.5], &[0.0, 0.75]];
379    const BU: &'static [f64] = &[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0];
380    const BE: &'static [f64] = &[];
381}
382
383/// Runge-Kutta 4th order integrator.
384///
385/// This integrator uses the classical 4th order Runge-Kutta method to numerically integrate the ODE system.
386/// It calculates four intermediate values (k1, k2, k3, k4) to estimate the next step solution.
387#[derive(Debug, Clone, Copy, Default)]
388#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
389#[cfg_attr(
390    feature = "rkyv",
391    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
392)]
393pub struct RK4;
394
395impl ButcherTableau for RK4 {
396    const C: &'static [f64] = &[0.0, 0.5, 0.5, 1.0];
397    const A: &'static [&'static [f64]] = &[&[], &[0.5], &[0.0, 0.5], &[0.0, 0.0, 1.0]];
398    const BU: &'static [f64] = &[1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0];
399    const BE: &'static [f64] = &[];
400}
401
402/// Ralston's 4th order integrator.
403///
404/// This fourth order method is known as minimum truncation error RK4.
405#[derive(Debug, Clone, Copy, Default)]
406#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
407#[cfg_attr(
408    feature = "rkyv",
409    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
410)]
411pub struct RALS4;
412
413impl ButcherTableau for RALS4 {
414    const C: &'static [f64] = &[0.0, 0.4, 0.45573725, 1.0];
415    const A: &'static [&'static [f64]] = &[
416        &[],
417        &[0.4],
418        &[0.29697761, 0.158575964],
419        &[0.21810040, -3.050965616, 3.83286476],
420    ];
421    const BU: &'static [f64] = &[0.17476028, -0.55148066, 1.20553560, 0.17118478];
422    const BE: &'static [f64] = &[];
423}
424
425/// Runge-Kutta 5th order integrator
426///
427/// This integrator uses the 5th order Runge-Kutta method to numerically integrate the ODE system.
428#[derive(Debug, Clone, Copy, Default)]
429#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
430#[cfg_attr(
431    feature = "rkyv",
432    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
433)]
434pub struct RK5;
435
436impl ButcherTableau for RK5 {
437    const C: &'static [f64] = &[0.0, 0.2, 0.3, 0.8, 8.0 / 9.0, 1.0, 1.0];
438    const A: &'static [&'static [f64]] = &[
439        &[],
440        &[0.2],
441        &[0.075, 0.225],
442        &[44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0],
443        &[
444            19372.0 / 6561.0,
445            -25360.0 / 2187.0,
446            64448.0 / 6561.0,
447            -212.0 / 729.0,
448        ],
449        &[
450            9017.0 / 3168.0,
451            -355.0 / 33.0,
452            46732.0 / 5247.0,
453            49.0 / 176.0,
454            -5103.0 / 18656.0,
455        ],
456        &[
457            35.0 / 384.0,
458            0.0,
459            500.0 / 1113.0,
460            125.0 / 192.0,
461            -2187.0 / 6784.0,
462            11.0 / 84.0,
463        ],
464    ];
465    const BU: &'static [f64] = &[
466        5179.0 / 57600.0,
467        0.0,
468        7571.0 / 16695.0,
469        393.0 / 640.0,
470        -92097.0 / 339200.0,
471        187.0 / 2100.0,
472        1.0 / 40.0,
473    ];
474    const BE: &'static [f64] = &[];
475}
476
477// ┌─────────────────────────────────────────────────────────┐
478//  Embedded Runge-Kutta
479// └─────────────────────────────────────────────────────────┘
480/// Bogacki-Shampine 3(2) method
481///
482/// This method is known as `ode23` in MATLAB.
483///
484/// # Member variables
485///
486/// - `tol`: The tolerance for the estimated error.
487/// - `safety_factor`: The safety factor for the step size adjustment.
488/// - `min_step_size`: The minimum step size.
489/// - `max_step_size`: The maximum step size.
490/// - `max_step_iter`: The maximum number of iterations per step.
491#[derive(Debug, Clone, Copy)]
492#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
493#[cfg_attr(
494    feature = "rkyv",
495    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
496)]
497pub struct BS23 {
498    pub tol: f64,
499    pub safety_factor: f64,
500    pub min_step_size: f64,
501    pub max_step_size: f64,
502    pub max_step_iter: usize,
503}
504
505impl Default for BS23 {
506    fn default() -> Self {
507        Self {
508            tol: 1e-3,
509            safety_factor: 0.9,
510            min_step_size: 1e-6,
511            max_step_size: 1e-1,
512            max_step_iter: 100,
513        }
514    }
515}
516
517impl BS23 {
518    pub fn new(
519        tol: f64,
520        safety_factor: f64,
521        min_step_size: f64,
522        max_step_size: f64,
523        max_step_iter: usize,
524    ) -> Self {
525        Self {
526            tol,
527            safety_factor,
528            min_step_size,
529            max_step_size,
530            max_step_iter,
531        }
532    }
533}
534
535impl ButcherTableau for BS23 {
536    const C: &'static [f64] = &[0.0, 0.5, 0.75, 1.0];
537    const A: &'static [&'static [f64]] = &[
538        &[],
539        &[0.5],
540        &[0.0, 0.75],
541        &[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0],
542    ];
543    const BU: &'static [f64] = &[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0, 0.0];
544    const BE: &'static [f64] = &[7.0 / 24.0, 0.25, 1.0 / 3.0, 0.125];
545
546    fn tol(&self) -> f64 {
547        self.tol
548    }
549    fn safety_factor(&self) -> f64 {
550        self.safety_factor
551    }
552    fn min_step_size(&self) -> f64 {
553        self.min_step_size
554    }
555    fn max_step_size(&self) -> f64 {
556        self.max_step_size
557    }
558    fn max_step_iter(&self) -> usize {
559        self.max_step_iter
560    }
561}
562
563/// Runge-Kutta-Fehlberg 4/5th order integrator.
564///
565/// This integrator uses the Runge-Kutta-Fehlberg method, which is an adaptive step size integrator.
566/// It calculates six intermediate values (k1, k2, k3, k4, k5, k6) to estimate the next step solution and the error.
567/// The step size is automatically adjusted based on the estimated error to maintain the desired tolerance.
568///
569/// # Member variables
570///
571/// - `tol`: The tolerance for the estimated error.
572/// - `safety_factor`: The safety factor for the step size adjustment.
573/// - `min_step_size`: The minimum step size.
574/// - `max_step_size`: The maximum step size.
575/// - `max_step_iter`: The maximum number of iterations per step.
576#[derive(Debug, Clone, Copy)]
577#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
578#[cfg_attr(
579    feature = "rkyv",
580    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
581)]
582pub struct RKF45 {
583    pub tol: f64,
584    pub safety_factor: f64,
585    pub min_step_size: f64,
586    pub max_step_size: f64,
587    pub max_step_iter: usize,
588}
589
590impl Default for RKF45 {
591    fn default() -> Self {
592        Self {
593            tol: 1e-6,
594            safety_factor: 0.9,
595            min_step_size: 1e-6,
596            max_step_size: 1e-1,
597            max_step_iter: 100,
598        }
599    }
600}
601
602impl RKF45 {
603    pub fn new(
604        tol: f64,
605        safety_factor: f64,
606        min_step_size: f64,
607        max_step_size: f64,
608        max_step_iter: usize,
609    ) -> Self {
610        Self {
611            tol,
612            safety_factor,
613            min_step_size,
614            max_step_size,
615            max_step_iter,
616        }
617    }
618}
619
620impl ButcherTableau for RKF45 {
621    const C: &'static [f64] = &[0.0, 1.0 / 4.0, 3.0 / 8.0, 12.0 / 13.0, 1.0, 1.0 / 2.0];
622    const A: &'static [&'static [f64]] = &[
623        &[],
624        &[0.25],
625        &[3.0 / 32.0, 9.0 / 32.0],
626        &[1932.0 / 2197.0, -7200.0 / 2197.0, 7296.0 / 2197.0],
627        &[439.0 / 216.0, -8.0, 3680.0 / 513.0, -845.0 / 4104.0],
628        &[
629            -8.0 / 27.0,
630            2.0,
631            -3544.0 / 2565.0,
632            1859.0 / 4104.0,
633            -11.0 / 40.0,
634        ],
635    ];
636    const BU: &'static [f64] = &[
637        16.0 / 135.0,
638        0.0,
639        6656.0 / 12825.0,
640        28561.0 / 56430.0,
641        -9.0 / 50.0,
642        2.0 / 55.0,
643    ];
644    const BE: &'static [f64] = &[
645        25.0 / 216.0,
646        0.0,
647        1408.0 / 2565.0,
648        2197.0 / 4104.0,
649        -1.0 / 5.0,
650        0.0,
651    ];
652
653    fn tol(&self) -> f64 {
654        self.tol
655    }
656    fn safety_factor(&self) -> f64 {
657        self.safety_factor
658    }
659    fn min_step_size(&self) -> f64 {
660        self.min_step_size
661    }
662    fn max_step_size(&self) -> f64 {
663        self.max_step_size
664    }
665    fn max_step_iter(&self) -> usize {
666        self.max_step_iter
667    }
668}
669
670/// Dormand-Prince 5(4) method
671///
672/// This is an adaptive step size integrator based on a 5th order Runge-Kutta method with
673/// 4th order embedded error estimation.
674///
675/// # Member variables
676///
677/// - `tol`: The tolerance for the estimated error.
678/// - `safety_factor`: The safety factor for the step size adjustment.
679/// - `min_step_size`: The minimum step size.
680/// - `max_step_size`: The maximum step size.
681/// - `max_step_iter`: The maximum number of iterations per step.
682#[derive(Debug, Clone, Copy)]
683#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
684#[cfg_attr(
685    feature = "rkyv",
686    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
687)]
688pub struct DP45 {
689    pub tol: f64,
690    pub safety_factor: f64,
691    pub min_step_size: f64,
692    pub max_step_size: f64,
693    pub max_step_iter: usize,
694}
695
696impl Default for DP45 {
697    fn default() -> Self {
698        Self {
699            tol: 1e-6,
700            safety_factor: 0.9,
701            min_step_size: 1e-6,
702            max_step_size: 1e-1,
703            max_step_iter: 100,
704        }
705    }
706}
707
708impl DP45 {
709    pub fn new(
710        tol: f64,
711        safety_factor: f64,
712        min_step_size: f64,
713        max_step_size: f64,
714        max_step_iter: usize,
715    ) -> Self {
716        Self {
717            tol,
718            safety_factor,
719            min_step_size,
720            max_step_size,
721            max_step_iter,
722        }
723    }
724}
725
726impl ButcherTableau for DP45 {
727    const C: &'static [f64] = &[0.0, 0.2, 0.3, 0.8, 8.0 / 9.0, 1.0, 1.0];
728    const A: &'static [&'static [f64]] = &[
729        &[],
730        &[0.2],
731        &[0.075, 0.225],
732        &[44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0],
733        &[
734            19372.0 / 6561.0,
735            -25360.0 / 2187.0,
736            64448.0 / 6561.0,
737            -212.0 / 729.0,
738        ],
739        &[
740            9017.0 / 3168.0,
741            -355.0 / 33.0,
742            46732.0 / 5247.0,
743            49.0 / 176.0,
744            -5103.0 / 18656.0,
745        ],
746        &[
747            35.0 / 384.0,
748            0.0,
749            500.0 / 1113.0,
750            125.0 / 192.0,
751            -2187.0 / 6784.0,
752            11.0 / 84.0,
753        ],
754    ];
755    const BU: &'static [f64] = &[
756        35.0 / 384.0,
757        0.0,
758        500.0 / 1113.0,
759        125.0 / 192.0,
760        -2187.0 / 6784.0,
761        11.0 / 84.0,
762        0.0,
763    ];
764    const BE: &'static [f64] = &[
765        5179.0 / 57600.0,
766        0.0,
767        7571.0 / 16695.0,
768        393.0 / 640.0,
769        -92097.0 / 339200.0,
770        187.0 / 2100.0,
771        1.0 / 40.0,
772    ];
773
774    fn tol(&self) -> f64 {
775        self.tol
776    }
777    fn safety_factor(&self) -> f64 {
778        self.safety_factor
779    }
780    fn min_step_size(&self) -> f64 {
781        self.min_step_size
782    }
783    fn max_step_size(&self) -> f64 {
784        self.max_step_size
785    }
786    fn max_step_iter(&self) -> usize {
787        self.max_step_iter
788    }
789}
790
791/// Tsitouras 5(4) method
792///
793/// This is an adaptive step size integrator based on a 5th order Runge-Kutta method with
794/// 4th order embedded error estimation, using the coefficients from Tsitouras (2011).
795///
796/// # Member variables
797///
798/// - `tol`: The tolerance for the estimated error.
799/// - `safety_factor`: The safety factor for the step size adjustment.
800/// - `min_step_size`: The minimum step size.
801/// - `max_step_size`: The maximum step size.
802/// - `max_step_iter`: The maximum number of iterations per step.
803///
804/// # References
805///
806/// - Ch. Tsitouras, Comput. Math. Appl. 62 (2011) 770-780.
807#[derive(Debug, Clone, Copy)]
808#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
809#[cfg_attr(
810    feature = "rkyv",
811    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
812)]
813pub struct TSIT45 {
814    pub tol: f64,
815    pub safety_factor: f64,
816    pub min_step_size: f64,
817    pub max_step_size: f64,
818    pub max_step_iter: usize,
819}
820
821impl Default for TSIT45 {
822    fn default() -> Self {
823        Self {
824            tol: 1e-6,
825            safety_factor: 0.9,
826            min_step_size: 1e-6,
827            max_step_size: 1e-1,
828            max_step_iter: 100,
829        }
830    }
831}
832
833impl TSIT45 {
834    pub fn new(
835        tol: f64,
836        safety_factor: f64,
837        min_step_size: f64,
838        max_step_size: f64,
839        max_step_iter: usize,
840    ) -> Self {
841        Self {
842            tol,
843            safety_factor,
844            min_step_size,
845            max_step_size,
846            max_step_iter,
847        }
848    }
849}
850
851impl ButcherTableau for TSIT45 {
852    const C: &'static [f64] = &[0.0, 0.161, 0.327, 0.9, 0.9800255409045097, 1.0, 1.0];
853    const A: &'static [&'static [f64]] = &[
854        &[],
855        &[Self::C[1]],
856        &[Self::C[2] - 0.335480655492357, 0.335480655492357],
857        &[
858            Self::C[3] - (-6.359448489975075 + 4.362295432869581),
859            -6.359448489975075,
860            4.362295432869581,
861        ],
862        &[
863            Self::C[4] - (-11.74888356406283 + 7.495539342889836 - 0.09249506636175525),
864            -11.74888356406283,
865            7.495539342889836,
866            -0.09249506636175525,
867        ],
868        &[
869            Self::C[5]
870                - (-12.92096931784711 + 8.159367898576159
871                    - 0.0715849732814010
872                    - 0.02826905039406838),
873            -12.92096931784711,
874            8.159367898576159,
875            -0.0715849732814010,
876            -0.02826905039406838,
877        ],
878        &[
879            Self::BU[0],
880            Self::BU[1],
881            Self::BU[2],
882            Self::BU[3],
883            Self::BU[4],
884            Self::BU[5],
885        ],
886    ];
887    const BU: &'static [f64] = &[
888        0.09646076681806523,
889        0.01,
890        0.4798896504144996,
891        1.379008574103742,
892        -3.290069515436081,
893        2.324710524099774,
894        0.0,
895    ];
896    const BE: &'static [f64] = &[
897        0.001780011052226,
898        0.000816434459657,
899        -0.007880878010262,
900        0.144711007173263,
901        -0.582357165452555,
902        0.458082105929187,
903        1.0 / 66.0,
904    ];
905
906    fn tol(&self) -> f64 {
907        self.tol
908    }
909    fn safety_factor(&self) -> f64 {
910        self.safety_factor
911    }
912    fn min_step_size(&self) -> f64 {
913        self.min_step_size
914    }
915    fn max_step_size(&self) -> f64 {
916        self.max_step_size
917    }
918    fn max_step_iter(&self) -> usize {
919        self.max_step_iter
920    }
921}
922
923/// Runge-Kutta-Fehlberg 7/8th order integrator.
924///
925/// This integrator uses the Runge-Kutta-Fehlberg 7(8) method, an adaptive step size integrator.
926/// It evaluates f(x,y) thirteen times per step, using embedded 7th and 8th order
927/// Runge-Kutta estimates to estimate the solution and the error.
928/// The 7th order solution is propagated, and the difference between the 8th and 7th
929/// order solutions is used for error estimation and step size control.
930///
931/// # Member variables
932///
933/// - `tol`: The tolerance for the estimated error.
934/// - `safety_factor`: The safety factor for the step size adjustment.
935/// - `min_step_size`: The minimum step size.
936/// - `max_step_size`: The maximum step size.
937/// - `max_step_iter`: The maximum number of iterations per step.
938///
939/// # References
940/// - Meysam Mahooti (2025). [Runge-Kutta-Fehlberg (RKF78)](https://www.mathworks.com/matlabcentral/fileexchange/61130-runge-kutta-fehlberg-rkf78), MATLAB Central File Exchange.
941#[derive(Debug, Clone, Copy)]
942#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
943#[cfg_attr(
944    feature = "rkyv",
945    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
946)]
947pub struct RKF78 {
948    pub tol: f64,
949    pub safety_factor: f64,
950    pub min_step_size: f64,
951    pub max_step_size: f64,
952    pub max_step_iter: usize,
953}
954
955impl Default for RKF78 {
956    fn default() -> Self {
957        Self {
958            tol: 1e-7, // Higher precision default for a higher-order method
959            safety_factor: 0.9,
960            min_step_size: 1e-10, // Smaller min step for higher order
961            max_step_size: 1e-1,
962            max_step_iter: 100,
963        }
964    }
965}
966
967impl RKF78 {
968    pub fn new(
969        tol: f64,
970        safety_factor: f64,
971        min_step_size: f64,
972        max_step_size: f64,
973        max_step_iter: usize,
974    ) -> Self {
975        Self {
976            tol,
977            safety_factor,
978            min_step_size,
979            max_step_size,
980            max_step_iter,
981        }
982    }
983}
984
985impl ButcherTableau for RKF78 {
986    const C: &'static [f64] = &[
987        0.0,
988        2.0 / 27.0,
989        1.0 / 9.0,
990        1.0 / 6.0,
991        5.0 / 12.0,
992        1.0 / 2.0,
993        5.0 / 6.0,
994        1.0 / 6.0,
995        2.0 / 3.0,
996        1.0 / 3.0,
997        1.0,
998        0.0, // k12 is evaluated at x[i]
999        1.0, // k13 is evaluated at x[i]+h
1000    ];
1001
1002    const A: &'static [&'static [f64]] = &[
1003        // k1
1004        &[],
1005        // k2
1006        &[2.0 / 27.0],
1007        // k3
1008        &[1.0 / 36.0, 3.0 / 36.0],
1009        // k4
1010        &[1.0 / 24.0, 0.0, 3.0 / 24.0],
1011        // k5
1012        &[20.0 / 48.0, 0.0, -75.0 / 48.0, 75.0 / 48.0],
1013        // k6
1014        &[1.0 / 20.0, 0.0, 0.0, 5.0 / 20.0, 4.0 / 20.0],
1015        // k7
1016        &[
1017            -25.0 / 108.0,
1018            0.0,
1019            0.0,
1020            125.0 / 108.0,
1021            -260.0 / 108.0,
1022            250.0 / 108.0,
1023        ],
1024        // k8
1025        &[
1026            31.0 / 300.0,
1027            0.0,
1028            0.0,
1029            0.0,
1030            61.0 / 225.0,
1031            -2.0 / 9.0,
1032            13.0 / 900.0,
1033        ],
1034        // k9
1035        &[
1036            2.0,
1037            0.0,
1038            0.0,
1039            -53.0 / 6.0,
1040            704.0 / 45.0,
1041            -107.0 / 9.0,
1042            67.0 / 90.0,
1043            3.0,
1044        ],
1045        // k10
1046        &[
1047            -91.0 / 108.0,
1048            0.0,
1049            0.0,
1050            23.0 / 108.0,
1051            -976.0 / 135.0,
1052            311.0 / 54.0,
1053            -19.0 / 60.0,
1054            17.0 / 6.0,
1055            -1.0 / 12.0,
1056        ],
1057        // k11
1058        &[
1059            2383.0 / 4100.0,
1060            0.0,
1061            0.0,
1062            -341.0 / 164.0,
1063            4496.0 / 1025.0,
1064            -301.0 / 82.0,
1065            2133.0 / 4100.0,
1066            45.0 / 82.0,
1067            45.0 / 164.0,
1068            18.0 / 41.0,
1069        ],
1070        // k12
1071        &[
1072            3.0 / 205.0,
1073            0.0,
1074            0.0,
1075            0.0,
1076            0.0,
1077            -6.0 / 41.0,
1078            -3.0 / 205.0,
1079            -3.0 / 41.0,
1080            3.0 / 41.0,
1081            6.0 / 41.0,
1082            0.0,
1083        ],
1084        // k13
1085        &[
1086            -1777.0 / 4100.0,
1087            0.0,
1088            0.0,
1089            -341.0 / 164.0,
1090            4496.0 / 1025.0,
1091            -289.0 / 82.0,
1092            2193.0 / 4100.0,
1093            51.0 / 82.0,
1094            33.0 / 164.0,
1095            12.0 / 41.0,
1096            0.0,
1097            1.0,
1098        ],
1099    ];
1100
1101    // Coefficients for the 7th order solution (propagated solution)
1102    // BU_i = BE_i (8th order) - ErrorCoeff_i
1103    // ErrorCoeff_i = [-41/840, 0, ..., 0, -41/840 (for k11), 41/840 (for k12), 41/840 (for k13)]
1104    const BU: &'static [f64] = &[
1105        41.0 / 420.0, // 41/840 - (-41/840)
1106        0.0,
1107        0.0,
1108        0.0,
1109        0.0,
1110        34.0 / 105.0,
1111        9.0 / 35.0,
1112        9.0 / 35.0,
1113        9.0 / 280.0,
1114        9.0 / 280.0,
1115        41.0 / 420.0,  // 41/840 - (-41/840)
1116        -41.0 / 840.0, // 0.0 - (41/840)
1117        -41.0 / 840.0, // 0.0 - (41/840)
1118    ];
1119
1120    // Coefficients for the 8th order solution (used for error estimation)
1121    // These are from the y[i+1] formula in the MATLAB description
1122    const BE: &'static [f64] = &[
1123        41.0 / 840.0,
1124        0.0,
1125        0.0,
1126        0.0,
1127        0.0,
1128        34.0 / 105.0,
1129        9.0 / 35.0,
1130        9.0 / 35.0,
1131        9.0 / 280.0,
1132        9.0 / 280.0,
1133        41.0 / 840.0,
1134        0.0,
1135        0.0,
1136    ];
1137
1138    fn tol(&self) -> f64 {
1139        self.tol
1140    }
1141    fn safety_factor(&self) -> f64 {
1142        self.safety_factor
1143    }
1144    fn min_step_size(&self) -> f64 {
1145        self.min_step_size
1146    }
1147    fn max_step_size(&self) -> f64 {
1148        self.max_step_size
1149    }
1150    fn max_step_iter(&self) -> usize {
1151        self.max_step_iter
1152    }
1153}
1154
1155// ┌─────────────────────────────────────────────────────────┐
1156//  Gauss-Legendre 4th order
1157// └─────────────────────────────────────────────────────────┘
1158
1159// Correct coefficients for 4th-order Gauss-Legendre method
1160const SQRT3: f64 = 1.7320508075688772;
1161const C1: f64 = 0.5 - SQRT3 / 6.0;
1162const C2: f64 = 0.5 + SQRT3 / 6.0;
1163const A11: f64 = 0.25;
1164const A12: f64 = 0.25 - SQRT3 / 6.0;
1165const A21: f64 = 0.25 + SQRT3 / 6.0;
1166const A22: f64 = 0.25;
1167const B1: f64 = 0.5;
1168const B2: f64 = 0.5;
1169
1170/// Enum for implicit solvers.
1171///
1172/// This enum defines the available implicit solvers for the Gauss-Legendre 4th order integrator.
1173/// Currently, there are two options: fixed-point iteration and Broyden's method.
1174#[derive(Debug, Clone, Copy)]
1175#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
1176#[cfg_attr(
1177    feature = "rkyv",
1178    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
1179)]
1180pub enum ImplicitSolver {
1181    FixedPoint,
1182    Broyden,
1183    //TrustRegion(f64, f64),
1184}
1185
1186/// Gauss-Legendre 4th order integrator.
1187///
1188/// This integrator uses the 4th order Gauss-Legendre Runge-Kutta method, which is an implicit integrator.
1189/// It requires solving a system of nonlinear equations at each step, which is done using the specified implicit solver (e.g., fixed-point iteration).
1190/// The Gauss-Legendre method has better stability properties compared to explicit methods, especially for stiff ODEs.
1191///
1192/// # Member variables
1193///
1194/// - `solver`: The implicit solver to use.
1195/// - `tol`: The tolerance for the implicit solver.
1196/// - `max_step_iter`: The maximum number of iterations for the implicit solver per step.
1197#[derive(Debug, Clone, Copy)]
1198#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
1199#[cfg_attr(
1200    feature = "rkyv",
1201    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
1202)]
1203pub struct GL4 {
1204    pub solver: ImplicitSolver,
1205    pub tol: f64,
1206    pub max_step_iter: usize,
1207}
1208
1209impl Default for GL4 {
1210    fn default() -> Self {
1211        GL4 {
1212            solver: ImplicitSolver::FixedPoint,
1213            tol: 1e-8,
1214            max_step_iter: 100,
1215        }
1216    }
1217}
1218
1219impl GL4 {
1220    pub fn new(solver: ImplicitSolver, tol: f64, max_step_iter: usize) -> Self {
1221        GL4 {
1222            solver,
1223            tol,
1224            max_step_iter,
1225        }
1226    }
1227}
1228
1229impl ODEIntegrator for GL4 {
1230    #[allow(non_snake_case)]
1231    #[inline]
1232    fn step<P: ODEProblem>(&self, problem: &P, t: f64, y: &mut [f64], dt: f64) -> Result<f64> {
1233        let n = y.len();
1234        //let sqrt3 = 3.0_f64.sqrt();
1235        //let c = 0.5 * (3.0 - sqrt3) / 6.0;
1236        //let d = 0.5 * (3.0 + sqrt3) / 6.0;
1237        let mut k1 = vec![0.0; n];
1238        let mut k2 = vec![0.0; n];
1239
1240        // Initial guess for k1, k2.
1241        problem.rhs(t, y, &mut k1)?;
1242        k2.copy_from_slice(&k1);
1243
1244        match self.solver {
1245            ImplicitSolver::FixedPoint => {
1246                // Fixed-point iteration
1247                let mut y1 = vec![0.0; n];
1248                let mut y2 = vec![0.0; n];
1249
1250                for _ in 0..self.max_step_iter {
1251                    let k1_old = k1.clone();
1252                    let k2_old = k2.clone();
1253
1254                    for i in 0..n {
1255                        y1[i] = y[i] + dt * (A11 * k1[i] + A12 * k2[i]);
1256                        y2[i] = y[i] + dt * (A11 * k1[i] + A12 * k2[i]);
1257                    }
1258
1259                    // Compute new k1 and k2
1260                    problem.rhs(t + C1 * dt, &y1, &mut k1)?;
1261                    problem.rhs(t + C2 * dt, &y2, &mut k2)?;
1262
1263                    // Check for convergence
1264                    let mut max_diff = 0f64;
1265                    for i in 0..n {
1266                        max_diff = max_diff.max((k1[i] - k1_old[i]).abs());
1267                        max_diff = max_diff.max((k2[i] - k2_old[i]).abs());
1268                    }
1269
1270                    if max_diff < self.tol {
1271                        break;
1272                    }
1273                }
1274            }
1275            ImplicitSolver::Broyden => {
1276                let m = 2 * n;
1277                let mut U = vec![0.0; m];
1278                U[..n].copy_from_slice(&k1);
1279                U[n..].copy_from_slice(&k2);
1280
1281                // F_vec = F(U)
1282                let mut F_vec = vec![0.0; m];
1283                compute_F(problem, t, y, dt, &U, &mut F_vec)?;
1284
1285                // Initialize inverse Jacobian matrix
1286                let mut J_inv = eye(m);
1287
1288                // Repeat Broyden's method
1289                for _ in 0..self.max_step_iter {
1290                    // delta = - J_inv * F_vec
1291                    let delta = (&J_inv * &F_vec).mul_scalar(-1.0);
1292
1293                    // U <- U + delta
1294                    U.iter_mut()
1295                        .zip(delta.iter())
1296                        .for_each(|(u, d)| *u += *d);
1297
1298                    let mut F_new = vec![0.0; m];
1299                    compute_F(problem, t, y, dt, &U, &mut F_new)?;
1300
1301                    // If infinity norm of F_new is less than tol, break
1302                    if F_new.norm(Norm::LInf) < self.tol {
1303                        break;
1304                    }
1305
1306                    // Residual: delta_F = F_new - F_vec
1307                    let delta_F = F_new.sub_vec(&F_vec);
1308
1309                    // J_inv * delta_F
1310                    let J_inv_delta_F = &J_inv * &delta_F;
1311
1312                    let denom = delta.dot(&J_inv_delta_F);
1313                    if denom.abs() < 1e-12 {
1314                        break;
1315                    }
1316
1317                    // Broyden's "good" update for the inverse Jacobian
1318                    // J_inv <- J_inv + ((delta - J_inv * delta_F) * delta^T * J_inv) / denom
1319                    let delta_minus_J_inv_delta_F = delta.sub_vec(&J_inv_delta_F).to_col();
1320                    let delta_T_J_inv = &delta.to_row() * &J_inv;
1321                    let update = (delta_minus_J_inv_delta_F * delta_T_J_inv) / denom;
1322                    J_inv = J_inv + update;
1323                    F_vec = F_new;
1324                }
1325
1326                k1.copy_from_slice(&U[..n]);
1327                k2.copy_from_slice(&U[n..]);
1328            }
1329        }
1330
1331        for i in 0..n {
1332            y[i] += dt * (B1 * k1[i] + B2 * k2[i]);
1333        }
1334
1335        Ok(dt)
1336    }
1337}
1338
1339//// Helper function to compute the function F(U) for the implicit solver.
1340//// y1 = y + dt*(c*k1 + d*k2 - sqrt3/2*(k2-k1))
1341//// y2 = y + dt*(c*k1 + d*k2 + sqrt3/2*(k2-k1))
1342//#[allow(non_snake_case)]
1343//fn compute_F<P: ODEProblem>(
1344//    problem: &P,
1345//    t: f64,
1346//    y: &[f64],
1347//    dt: f64,
1348//    c: f64,
1349//    d: f64,
1350//    sqrt3: f64,
1351//    U: &[f64],
1352//    F: &mut [f64],
1353//) -> Result<()> {
1354//    let n = y.len();
1355//    let mut y1 = vec![0.0; n];
1356//    let mut y2 = vec![0.0; n];
1357//
1358//    for i in 0..n {
1359//        let k1 = U[i];
1360//        let k2 = U[n + i];
1361//        y1[i] = y[i] + dt * (c * k1 + d * k2 - sqrt3 * (k2 - k1) / 2.0);
1362//        y2[i] = y[i] + dt * (c * k1 + d * k2 + sqrt3 * (k2 - k1) / 2.0);
1363//    }
1364//
1365//    let mut f1 = vec![0.0; n];
1366//    let mut f2 = vec![0.0; n];
1367//    problem.rhs(t + c * dt, &y1, &mut f1)?;
1368//    problem.rhs(t + d * dt, &y2, &mut f2)?;
1369//
1370//    // F = [ k1 - f1, k2 - f2 ]
1371//    for i in 0..n {
1372//        F[i] = U[i] - f1[i];
1373//        F[n + i] = U[n + i] - f2[i];
1374//    }
1375//    Ok(())
1376//}
1377/// Helper function to compute the residual F(U) = U - f(y + dt*A*U)
1378#[allow(non_snake_case)]
1379fn compute_F<P: ODEProblem>(
1380    problem: &P,
1381    t: f64,
1382    y: &[f64],
1383    dt: f64,
1384    U: &[f64], // U is a concatenated vector [k1, k2]
1385    F: &mut [f64],
1386) -> Result<()> {
1387    let n = y.len();
1388    let (k1_slice, k2_slice) = U.split_at(n);
1389
1390    let mut y1 = vec![0.0; n];
1391    let mut y2 = vec![0.0; n];
1392
1393    for i in 0..n {
1394        y1[i] = y[i] + dt * (A11 * k1_slice[i] + A12 * k2_slice[i]);
1395        y2[i] = y[i] + dt * (A21 * k1_slice[i] + A22 * k2_slice[i]);
1396    }
1397    
1398    // F is an output parameter, its parts f1 and f2 are stored temporarily
1399    let (f1, f2) = F.split_at_mut(n);
1400    problem.rhs(t + C1 * dt, &y1, f1)?;
1401    problem.rhs(t + C2 * dt, &y2, f2)?;
1402
1403    // Compute final residual F = [k1 - f1, k2 - f2]
1404    for i in 0..n {
1405        f1[i] = k1_slice[i] - f1[i];
1406        f2[i] = k2_slice[i] - f2[i];
1407    }
1408    Ok(())
1409}