peroxide/numerical/
root.rs

1//! # Root Finding Methods
2//!
3//! This module provides a collection of root finding algorithms for solving nonlinear equations.
4//! It defines traits for representing root finding problems and root finding methods, and provides implementations of several common algorithms.
5//!
6//! ## Traits
7//!
8//! - `RootFindingProblem<const I: usize, const O: usize, T>`: Defines the interface for a root finding problem.
9//!   It requires implementing the `function` method to evaluate the function at a given point, and the `initial_guess` method to provide an initial guess for the root.
10//!   Optionally, the `derivative` and `hessian` methods can be implemented to provide the derivative and Hessian of the function, respectively.
11//!
12//!   - `I`: Input dimension
13//!   - `O`: Output dimension
14//!   - `T`: State type (e.g. `f64`, `(f64, f64)`, or etc.)
15//!
16//! - `RootFinder<const I: usize, const O: usize, T>`: Defines the interface for a root finding method.
17//!   It requires implementing the `find` method, which takes a `RootFindingProblem` and returns the root of the function.
18//!   The `max_iter` and `tol` methods provide the maximum number of iterations and the tolerance for the root finding algorithm.
19//!
20//! ## Root Finding Methods
21//!
22//! - `BisectionMethod`: Implements the bisection method for finding roots of continuous functions.
23//!   It requires an initial interval that brackets the root.
24//!
25//!  - Type Parameters: `I=1, O=1, T=(f64, f64)`
26//!
27//! - `NewtonMethod`: Implements Newton's method for finding roots of differentiable functions.
28//!   It requires an initial guess for the root and the derivative of the function.
29//!
30//!   - Type Parameters: `I=1, O=1, T=f64`
31//!
32//! - `SecantMethod`: Implements the secant method for finding roots of differentiable functions.
33//!   It requires two initial guesses for the root.
34//!
35//!   - Type Parameters: `I=1, O=1, T=f64`
36//!
37//! - `FalsePositionMethod`: Implements the false position method for finding roots of continuous functions.
38//!   It requires an initial interval that brackets the root.
39//!
40//!   - Type Parameters: `I=1, O=1, T=(f64, f64)`
41//!
42//! - `BroydenMethod`: Implements Broyden's method for finding roots of systems of nonlinear equations.
43//!   It requires an two initial guesses for the first step. (not an interval, just two points)
44//!
45//!   - Type Parameters: `I>=1, O>=1, T=([f64; I], [f64; I])`
46//!
47//! ## Convenient type aliases
48//!
49//! - `Pt<const N: usize>`: Represents a point in N-dimensional space. (`[f64; N]`)
50//! - `Intv<const N: usize>`: Represents an interval in I-dimensional space. (`([f64; N], [f64; N])`)
51//! - `Jaco<const R: usize, const C: usize>`: Represents the Jacobian matrix of a function. (`[[f64; C]; R]`)
52//! - `Hess<const R: usize, const C: usize>`: Represents the Hessian matrix of a function. (`[[[f64; C]; C]; R]`)
53//!
54//! ## High-level macros
55//!
56//! Peroxide also provides high-level macros for root finding.
57//! Assume `f: fn(f64) -> f64`.
58//!
59//! - `bisection!(f, (a,b), max_iter, tol)`
60//! - `newton!(f, x0, max_iter, tol)`: (**Caution**: newton macro requires `#[ad_function]` attribute)
61//! - `secant!(f, (x0, x1), max_iter, tol)`
62//! - `false_position!(f, (a,b), max_iter, tol)`
63//!
64//! ```rust
65//! #[macro_use]
66//! extern crate peroxide;
67//! use peroxide::fuga::*;
68//! use anyhow::Result;
69//!
70//! fn main() -> Result<()> {
71//!     let root_bisect = bisection!(f, (0.0, 2.0), 100, 1e-6)?;
72//!     let root_newton = newton!(f, 0.0, 100, 1e-6)?;
73//!     let root_false_pos = false_position!(f, (0.0, 2.0), 100, 1e-6)?;
74//!     let root_secant = secant!(f, (0.0, 2.0), 100, 1e-6)?;
75//!
76//!     println!("root_bisect: {}", root_bisect);
77//!     println!("root_newton: {}", root_newton);
78//!     println!("root_false_pos: {}", root_false_pos);
79//!     println!("root_secant: {}", root_secant);
80//!
81//!     assert!(f(root_bisect).abs() < 1e-6);
82//!     assert!(f(root_newton).abs() < 1e-6);
83//!     assert!(f(root_false_pos).abs() < 1e-6);
84//!     assert!(f(root_secant).abs() < 1e-6);
85//!
86//!     Ok(())
87//! }
88//!
89//! #[ad_function]
90//! fn f(x: f64) -> f64 {
91//!     (x - 1f64).powi(3)
92//! }
93//! ```
94//!
95//! ## Examples
96//!
97//! ### Finding the root of a cubic function
98//!
99//! ```rust
100//! use peroxide::fuga::*;
101//! use anyhow::Result;
102//!
103//! fn main() -> Result<()> {
104//!     let problem = Cubic;
105//!
106//!     let bisect = BisectionMethod { max_iter: 100, tol: 1e-6 };
107//!     let newton = NewtonMethod { max_iter: 100, tol: 1e-6 };
108//!     let false_pos = FalsePositionMethod { max_iter: 100, tol: 1e-6 };
109//!
110//!     let root_bisect = bisect.find(&problem)?;
111//!     let root_newton = newton.find(&problem)?;
112//!     let root_false_pos = false_pos.find(&problem)?;
113//!
114//!     let result_bisect = problem.eval(root_bisect)?[0];
115//!     let result_newton = problem.eval(root_newton)?[0];
116//!     let result_false_pos = problem.eval(root_false_pos)?[0];
117//!
118//!     assert!(result_bisect.abs() < 1e-6);
119//!     assert!(result_newton.abs() < 1e-6);
120//!     assert!(result_false_pos.abs() < 1e-6);
121//!
122//!     Ok(())
123//! }
124//!
125//! struct Cubic;
126//!
127//! impl Cubic {
128//!     fn eval(&self, x: [f64; 1]) -> Result<[f64; 1]> {
129//!         Ok([(x[0] - 1f64).powi(3)])
130//!     }
131//! }
132//!
133//! impl RootFindingProblem<1, 1, (f64, f64)> for Cubic {
134//!     fn function(&self, x: [f64; 1]) -> Result<[f64; 1]> {
135//!         self.eval(x)
136//!     }
137//!
138//!     fn initial_guess(&self) -> (f64, f64) {
139//!         (0.0, 2.0)
140//!     }
141//! }
142//!
143//! impl RootFindingProblem<1, 1, f64> for Cubic {
144//!     fn function(&self, x: [f64; 1]) -> Result<[f64; 1]> {
145//!         self.eval(x)
146//!     }
147//!
148//!     fn initial_guess(&self) -> f64 {
149//!         0.0
150//!     }
151//!
152//!     fn derivative(&self, x: [f64; 1]) -> Result<Jaco<1, 1>> {
153//!         Ok([[3.0 * (x[0] - 1f64).powi(2)]])
154//!     }
155//! }
156//! ```
157//!
158//! This example demonstrates how to find the root of a cubic function `(x - 1)^3` using various root finding methods.
159//! The `Cubic` struct implements the `RootFindingProblem` trait for both `(f64, f64)` and `f64` initial guess types, allowing the use of different root finding methods.
160//!
161//! ### Finding the root of the cosine function (error handling example)
162//!
163//! ```rust
164//! use peroxide::fuga::*;
165//! use anyhow::Result;
166//!
167//! fn main() -> Result<()> {
168//!     let problem = Cosine;
169//!     let newton = NewtonMethod { max_iter: 100, tol: 1e-6 };
170//!
171//!     let root_newton = match newton.find(&problem) {
172//!         Ok(x) => x,
173//!         Err(e) => {
174//!             println!("{:?}", e);
175//!             match e.downcast::<RootError<1>>() {
176//!                 Ok(RootError::ZeroDerivative(x)) => x,
177//!                 Ok(e) => panic!("ok but {:?}", e),
178//!                 Err(e) => panic!("err {:?}", e),
179//!             }
180//!         }
181//!     };
182//!
183//!     assert_eq!(root_newton[0], 0.0);
184//!
185//!     Ok(())
186//! }
187//!
188//! struct Cosine;
189//!
190//! impl Cosine {
191//!     fn eval(&self, x: [f64; 1]) -> Result<[f64; 1]> {
192//!         Ok([x[0].cos()])
193//!     }
194//! }
195//!
196//! impl RootFindingProblem<1, 1, f64> for Cosine {
197//!     fn function(&self, x: [f64; 1]) -> Result<[f64; 1]> {
198//!         self.eval(x)
199//!     }
200//!
201//!     fn initial_guess(&self) -> f64 {
202//!         0.0 // should fail in newton (derivative is 0)
203//!     }
204//!
205//!     fn derivative(&self, x: [f64; 1]) -> Result<Jaco<1, 1>> {
206//!         Ok([[-x[0].sin()]])
207//!     }
208//! }
209//! ```
210//!
211//! This example shows how to find the root of the cosine function using Newton's method.
212//! The `Cosine` struct implements the `RootFindingProblem` trait for the `f64` initial guess type.
213//! The initial guess is set to `0.0`, which is a point where the derivative of the cosine function is 0.
214//! This leads to the `NewtonMethod` returning a `RootError::ZeroDerivative` error, which is handled in the example.
215use anyhow::{bail, Result};
216
217use crate::traits::math::{LinearOp, Norm, Normed};
218use crate::traits::sugar::{ConvToMat, VecOps};
219use crate::util::non_macro::zeros;
220
221// ┌─────────────────────────────────────────────────────────┐
222//  High level macro
223// └─────────────────────────────────────────────────────────┘
224/// High level macro for bisection
225///
226/// # Arguments
227///
228/// - `f`: `Fn(f64) -> f64` (allow closure)
229/// - `(a, b)`: `(f64, f64)`
230/// - `max_iter`: `usize`
231/// - `tol`: `f64`
232#[macro_export]
233macro_rules! bisection {
234    ($f:expr, ($a:expr, $b:expr), $max_iter:expr, $tol:expr) => {{
235        struct BisectionProblem<F: Fn(f64) -> f64> {
236            f: F,
237        };
238
239        impl<F: Fn(f64) -> f64> RootFindingProblem<1, 1, (f64, f64)> for BisectionProblem<F> {
240            fn initial_guess(&self) -> (f64, f64) {
241                ($a, $b)
242            }
243
244            fn function(&self, x: [f64; 1]) -> Result<[f64; 1]> {
245                Ok([(self.f)(x[0])])
246            }
247        }
248
249        let problem = BisectionProblem { f: $f };
250        let bisection = BisectionMethod {
251            max_iter: $max_iter,
252            tol: $tol,
253        };
254        match bisection.find(&problem) {
255            Ok(root) => Ok(root[0]),
256            Err(e) => Err(e),
257        }
258    }};
259}
260
261/// High level macro for newton (using Automatic differentiation)
262///
263/// # Requirements
264///
265/// - This macro requires the function with `ad_function`
266///
267///   ```rust
268///   use peroxide::fuga::*;
269///
270///   #[ad_function]
271///   fn f(x: f64) -> f64 {
272///       (x - 1f64).powi(3)
273///   }
274///   ```
275///
276/// # Arguments
277///
278/// - `f`: `fn(f64) -> f64` (not allow closure)
279/// - `x`: `f64`
280/// - `max_iter`: `usize`
281/// - `tol`: `f64`
282#[macro_export]
283macro_rules! newton {
284    ($f:ident, $x:expr, $max_iter:expr, $tol:expr) => {{
285        use paste::paste;
286        struct NewtonProblem;
287
288        impl RootFindingProblem<1, 1, f64> for NewtonProblem {
289            fn initial_guess(&self) -> f64 {
290                $x
291            }
292
293            fn function(&self, x: [f64; 1]) -> Result<[f64; 1]> {
294                Ok([$f(x[0])])
295            }
296
297            fn derivative(&self, x: [f64; 1]) -> Result<Jaco<1, 1>> {
298                paste! {
299                    let x_ad = AD1(x[0], 1f64);
300                    Ok([[[<$f _ad>](x_ad).dx()]])
301                }
302            }
303        }
304
305        let problem = NewtonProblem;
306        let newton = NewtonMethod {
307            max_iter: $max_iter,
308            tol: $tol,
309        };
310        match newton.find(&problem) {
311            Ok(root) => Ok(root[0]),
312            Err(e) => Err(e),
313        }
314    }};
315}
316
317/// High level macro for false position
318///
319/// # Arguments
320///
321/// - `f`: `Fn(f64) -> f64` (allow closure)
322/// - `(a, b)`: `(f64, f64)`
323/// - `max_iter`: `usize`
324/// - `tol`: `f64`
325#[macro_export]
326macro_rules! false_position {
327    ($f:expr, ($a:expr, $b:expr), $max_iter:expr, $tol:expr) => {{
328        struct FalsePositionProblem<F: Fn(f64) -> f64> {
329            f: F,
330        };
331
332        impl<F: Fn(f64) -> f64> RootFindingProblem<1, 1, (f64, f64)> for FalsePositionProblem<F> {
333            fn initial_guess(&self) -> (f64, f64) {
334                ($a, $b)
335            }
336
337            fn function(&self, x: [f64; 1]) -> Result<[f64; 1]> {
338                Ok([(self.f)(x[0])])
339            }
340        }
341
342        let problem = FalsePositionProblem { f: $f };
343        let false_position = FalsePositionMethod {
344            max_iter: $max_iter,
345            tol: $tol,
346        };
347        match false_position.find(&problem) {
348            Ok(root) => Ok(root[0]),
349            Err(e) => Err(e),
350        }
351    }};
352}
353
354/// High level macro for secant
355///
356/// # Arguments
357///
358/// - `f`: `Fn(f64) -> f64` (allow closure)
359/// - `(a, b)`: `(f64, f64)`
360/// - `max_iter`: `usize`
361/// - `tol`: `f64`
362#[macro_export]
363macro_rules! secant {
364    ($f:expr, ($a:expr, $b:expr), $max_iter:expr, $tol:expr) => {{
365        struct SecantProblem<F: Fn(f64) -> f64> {
366            f: F,
367        };
368
369        impl<F: Fn(f64) -> f64> RootFindingProblem<1, 1, (f64, f64)> for SecantProblem<F> {
370            fn initial_guess(&self) -> (f64, f64) {
371                ($a, $b)
372            }
373
374            fn function(&self, x: [f64; 1]) -> Result<[f64; 1]> {
375                Ok([(self.f)(x[0])])
376            }
377        }
378
379        let problem = SecantProblem { f: $f };
380        let secant = SecantMethod {
381            max_iter: $max_iter,
382            tol: $tol,
383        };
384        match secant.find(&problem) {
385            Ok(root) => Ok(root[0]),
386            Err(e) => Err(e),
387        }
388    }};
389}
390
391// ┌─────────────────────────────────────────────────────────┐
392//  Type aliases
393// └─────────────────────────────────────────────────────────┘
394/// Point alias (`[f64; N]`)
395pub type Pt<const N: usize> = [f64; N];
396/// Interval alias (`([f64; N], [f64; N])`)
397pub type Intv<const N: usize> = (Pt<N>, Pt<N>);
398/// Jacobian alias (`[[f64; C]; R]`)
399pub type Jaco<const R: usize, const C: usize> = [[f64; C]; R];
400/// Hessian alias (`[[[f64; C]; C]; R]`)
401pub type Hess<const R: usize, const C: usize> = [[[f64; C]; C]; R];
402
403// ┌─────────────────────────────────────────────────────────┐
404//  Traits
405// └─────────────────────────────────────────────────────────┘
406/// Trait to define a root finding problem
407///
408/// # Type Parameters
409///
410/// - `I`: Input type (e.g. `f64`, `[f64; N]`, or etc.)
411/// - `O`: Output type (e.g. `f64`, `[f64; N]`, or etc.)
412/// - `T`: State type (e.g. `f64`, `(f64, f64)`, or etc.)
413///
414/// # Methods
415///
416/// - `function`: Function
417/// - `initial_guess`: Initial guess
418/// - `derivative`: Derivative (optional)
419/// - `hessian`: Hessian (optional)
420pub trait RootFindingProblem<const I: usize, const O: usize, T> {
421    fn function(&self, x: Pt<I>) -> Result<Pt<O>>;
422    fn initial_guess(&self) -> T;
423    #[allow(unused_variables)]
424    fn derivative(&self, x: Pt<I>) -> Result<Jaco<O, I>> {
425        unimplemented!()
426    }
427    #[allow(unused_variables)]
428    fn hessian(&self, x: Pt<I>) -> Result<Hess<O, I>> {
429        unimplemented!()
430    }
431}
432
433/// Trait to define a root finder
434///
435/// # Type Parameters
436///
437/// - `I`: Input type (e.g. `f64`, `[f64; N]`, or etc.)
438/// - `O`: Output type (e.g. `f64`, `[f64; N]`, or etc.)
439/// - `T`: State type (e.g. `f64`, `(f64, f64)`, or etc.)
440///
441/// # Methods
442///
443/// - `max_iter`: Maximum number of iterations
444/// - `tol`: Absolute tolerance
445/// - `find`: Find root
446///
447/// # Available root finders
448///
449/// - `BisectionMethod`: `I=1, O=1, T=(f64, f64)`
450/// - `FalsePositionMethod`: `I=1, O=1, T=(f64, f64)`
451/// - `NewtonMethod`: `I=1, O=1, T=f64`
452/// - `SecantMethod`: `I=1, O=1, T=(f64, f64)`
453pub trait RootFinder<const I: usize, const O: usize, T> {
454    fn max_iter(&self) -> usize;
455    fn tol(&self) -> f64;
456    fn find<P: RootFindingProblem<I, O, T>>(&self, problem: &P) -> Result<Pt<I>>;
457}
458
459#[derive(Debug, Copy, Clone)]
460pub enum RootError<const I: usize> {
461    NotConverge(Pt<I>),
462    NoRoot,
463    ZeroDerivative(Pt<I>),
464    ZeroSecant(Pt<I>, Pt<I>),
465}
466
467impl<const I: usize> std::fmt::Display for RootError<I> {
468    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469        match self {
470            RootError::NoRoot => write!(f, "There is no root in the interval"),
471            RootError::NotConverge(a) => write!(f, "Not yet converge. Our guess is {:?}", a),
472            RootError::ZeroDerivative(a) => write!(f, "Zero derivative in {:?}", a),
473            RootError::ZeroSecant(a, b) => write!(f, "Zero secant in ({:?}, {:?})", a, b),
474        }
475    }
476}
477
478/// Macro for single function
479///
480/// # Description
481///
482/// For I=1, O=1, it is bother to write below code.
483///
484/// ```ignore
485/// let fx = problem.function([x])?[0];
486/// ```
487///
488/// This macro solve this problem as follows.
489///
490/// ```ignore
491/// let fx = single_function!(problem, x);
492/// ```
493#[macro_export]
494macro_rules! single_function {
495    ($problem:expr, $x:expr) => {{
496        $problem.function([$x])?[0]
497    }};
498}
499
500/// Macro for single derivative
501///
502/// # Description
503///
504/// For I=1, O=1, it is bother to write below code.
505///
506/// ```ignore
507/// let fx = problem.derivative([x])?[0][0];
508/// ```
509///
510/// This macro solve this problem as follows.
511///
512/// ```ignore
513/// let fx = single_derivative!(problem, x);
514/// ```
515#[macro_export]
516macro_rules! single_derivative {
517    ($problem:expr, $x:expr) => {{
518        $problem.derivative([$x])?[0][0]
519    }};
520}
521
522// ┌─────────────────────────────────────────────────────────┐
523//  Bisection method
524// └─────────────────────────────────────────────────────────┘
525/// Bisection method
526///
527/// # Type for `RootFinder`
528///
529/// - `I`: 1
530/// - `O`: 1
531/// - `T`: `(f64, f64)`
532///
533/// # Arguments
534///
535/// - `max_iter`: Maximum number of iterations
536/// - `tol`: tol
537///
538/// # Caution
539///
540/// - The function should be continuous
541/// - The function should have a root in the initial interval
542pub struct BisectionMethod {
543    pub max_iter: usize,
544    pub tol: f64,
545}
546
547impl RootFinder<1, 1, (f64, f64)> for BisectionMethod {
548    fn max_iter(&self) -> usize {
549        self.max_iter
550    }
551
552    fn tol(&self) -> f64 {
553        self.tol
554    }
555
556    fn find<P: RootFindingProblem<1, 1, (f64, f64)>>(&self, problem: &P) -> Result<[f64; 1]> {
557        let state = problem.initial_guess();
558        let (mut a, mut b) = state;
559        let mut fa = single_function!(problem, a);
560        let mut fb = single_function!(problem, b);
561
562        if fa.abs() < self.tol {
563            return Ok([a]);
564        } else if fb.abs() < self.tol {
565            return Ok([b]);
566        } else if fa * fb > 0.0 {
567            bail!(RootError::<1>::NoRoot);
568        }
569
570        for _ in 0..self.max_iter {
571            let c = (a + b) / 2.0;
572            let fc = single_function!(problem, c);
573
574            if fc.abs() < self.tol {
575                return Ok([c]);
576            } else if fa * fc < 0.0 {
577                b = c;
578                fb = fc;
579            } else if fb * fc < 0.0 {
580                a = c;
581                fa = fc;
582            } else {
583                bail!(RootError::<1>::NoRoot);
584            }
585        }
586        let c = (a + b) / 2.0;
587        bail!(RootError::NotConverge([c]));
588    }
589}
590
591// ┌─────────────────────────────────────────────────────────┐
592//  Newton method
593// └─────────────────────────────────────────────────────────┘
594/// Newton method
595///
596/// # Type for `RootFinder`
597///
598/// - `I`: 1
599/// - `O`: 1
600/// - `T`: `f64`
601///
602/// # Arguments
603///
604/// - `max_iter`: Maximum number of iterations
605/// - `tol`: Absolute tolerance
606///
607/// # Caution
608///
609/// - The function should be differentiable
610/// - This method highly depends on the initial guess
611/// - This method is not guaranteed to converge
612pub struct NewtonMethod {
613    pub max_iter: usize,
614    pub tol: f64,
615}
616
617impl RootFinder<1, 1, f64> for NewtonMethod {
618    fn max_iter(&self) -> usize {
619        self.max_iter
620    }
621    fn tol(&self) -> f64 {
622        self.tol
623    }
624    fn find<P: RootFindingProblem<1, 1, f64>>(&self, problem: &P) -> Result<[f64; 1]> {
625        let mut x = problem.initial_guess();
626
627        for _ in 0..self.max_iter {
628            let f = single_function!(problem, x);
629            let df = single_derivative!(problem, x);
630
631            if f.abs() < self.tol {
632                return Ok([x]);
633            } else if df == 0.0 {
634                bail!(RootError::ZeroDerivative([x]));
635            } else {
636                x -= f / df;
637            }
638        }
639        bail!(RootError::NotConverge([x]));
640    }
641}
642
643// ┌─────────────────────────────────────────────────────────┐
644//  Secant method
645// └─────────────────────────────────────────────────────────┘
646/// Secant method
647///
648/// # Type for `RootFinder`
649///
650/// - `I`: 1
651/// - `O`: 1
652/// - `T`: `(f64, f64)`
653///
654/// # Arguments
655///
656/// - `max_iter`: Maximum number of iterations
657/// - `tol`: Absolute tolerance
658///
659/// # Caution
660///
661/// - The function should be differentiable
662/// - This method is not guaranteed to converge
663pub struct SecantMethod {
664    pub max_iter: usize,
665    pub tol: f64,
666}
667
668impl RootFinder<1, 1, (f64, f64)> for SecantMethod {
669    fn max_iter(&self) -> usize {
670        self.max_iter
671    }
672    fn tol(&self) -> f64 {
673        self.tol
674    }
675    fn find<P: RootFindingProblem<1, 1, (f64, f64)>>(&self, problem: &P) -> Result<[f64; 1]> {
676        let state = problem.initial_guess();
677        let (mut x0, mut x1) = state;
678        let mut f0 = single_function!(problem, x0);
679
680        if f0.abs() < self.tol {
681            return Ok([x0]);
682        }
683
684        for _ in 0..self.max_iter {
685            let f1 = single_function!(problem, x1);
686
687            if f1.abs() < self.tol {
688                return Ok([x1]);
689            }
690
691            if f0 == f1 {
692                bail!(RootError::ZeroSecant([x0], [x1]));
693            }
694
695            let f0_old = f0;
696            f0 = f1;
697            (x0, x1) = (x1, x1 - f1 * (x1 - x0) / (f1 - f0_old))
698        }
699        bail!(RootError::NotConverge([x1]));
700    }
701}
702
703// ┌─────────────────────────────────────────────────────────┐
704//  False position method
705// └─────────────────────────────────────────────────────────┘
706/// False position method
707///
708/// # Type for `RootFinder`
709///
710/// - `I`: 1
711/// - `O`: 1
712/// - `T`: `(f64, f64)`
713///
714/// # Arguments
715///
716/// - `max_iter`: Maximum number of iterations
717/// - `tol`: Absolute tolerance
718///
719/// # Caution
720///
721/// - The function should be continuous
722pub struct FalsePositionMethod {
723    pub max_iter: usize,
724    pub tol: f64,
725}
726
727impl RootFinder<1, 1, (f64, f64)> for FalsePositionMethod {
728    fn max_iter(&self) -> usize {
729        self.max_iter
730    }
731    fn tol(&self) -> f64 {
732        self.tol
733    }
734    fn find<P: RootFindingProblem<1, 1, (f64, f64)>>(&self, problem: &P) -> Result<[f64; 1]> {
735        let state = problem.initial_guess();
736        let (mut a, mut b) = state;
737        let mut fa = single_function!(problem, a);
738        let mut fb = single_function!(problem, b);
739
740        if fa.abs() < self.tol {
741            return Ok([a]);
742        } else if fb.abs() < self.tol {
743            return Ok([b]);
744        } else if fa * fb > 0.0 {
745            bail!(RootError::<1>::NoRoot);
746        }
747
748        for _ in 0..self.max_iter {
749            let c = (a * fb - b * fa) / (fb - fa);
750            let fc = single_function!(problem, c);
751
752            if fc.abs() < self.tol {
753                return Ok([c]);
754            } else if fa * fc < 0.0 {
755                b = c;
756                fb = fc;
757            } else if fb * fc < 0.0 {
758                a = c;
759                fa = fc;
760            } else {
761                bail!(RootError::<1>::NoRoot);
762            }
763        }
764        let c = (a * fb - b * fa) / (fb - fa);
765        bail!(RootError::NotConverge([c]));
766    }
767}
768
769// ┌─────────────────────────────────────────────────────────┐
770//  Broyden method
771// └─────────────────────────────────────────────────────────┘
772/// Broyden method
773///
774/// # Type for `RootFinder`
775///
776/// - `I`: free
777/// - `O`: free
778/// - `T`: `Intv<I>` (=`([f64; I], [f64; I])`)
779///
780/// # Arguments
781///
782/// - `max_iter`: Maximum number of iterations
783/// - `tol`: Absolute tolerance
784///
785/// # Caution
786///
787/// - The function should be differentiable
788///
789/// # Example
790///
791/// ```rust
792/// use peroxide::fuga::*;
793/// use peroxide::numerical::root::{Pt, Intv};
794///
795/// fn main() -> Result<(), Box<dyn std::error::Error>> {
796///     let problem = CircleTangentLine;
797///     let broyden = BroydenMethod { max_iter: 100, tol: 1e-6, rtol: 1e-6 };
798///
799///     let root = broyden.find(&problem)?;
800///     let result = problem.function(root)?;
801///
802///     let norm = result.to_vec().norm(Norm::L2);
803///     assert!(norm < 1e-6);
804///
805///     Ok(())
806/// }
807///
808/// struct CircleTangentLine;
809///
810/// impl RootFindingProblem<2, 2, Intv<2>> for CircleTangentLine {
811///     fn function(&self, x: Pt<2>) -> anyhow::Result<Pt<2>> {
812///         Ok([
813///             x[0] * x[0] + x[1] * x[1] - 1.0,
814///             x[0] + x[1] - 2f64.sqrt()
815///         ])
816///     }
817///
818///     fn initial_guess(&self) -> Intv<2> {
819///         ([0.0, 0.1], [-0.1, 0.2])
820///     }
821/// }
822/// ```
823
824pub struct BroydenMethod {
825    pub max_iter: usize,
826    pub tol: f64,
827    pub rtol: f64,
828}
829
830#[allow(unused_variables, non_snake_case)]
831impl<const I: usize, const O: usize> RootFinder<I, O, Intv<I>> for BroydenMethod {
832    fn max_iter(&self) -> usize {
833        self.max_iter
834    }
835    fn tol(&self) -> f64 {
836        self.tol
837    }
838    fn find<P: RootFindingProblem<I, O, Intv<I>>>(&self, problem: &P) -> Result<Pt<I>> {
839        // Init state
840        let state = problem.initial_guess();
841        let (mut x0, mut x1) = state;
842        let mut fx0 = problem.function(x0)?.to_vec();
843
844        // Initialize negative inverse jacobian as identity
845        // H = -J^{-1}
846        let mut H = zeros(I, O);
847        for i in 0..O.min(I) {
848            H[(i, i)] = 1.0;
849        }
850
851        for _ in 0..self.max_iter {
852            let fx1 = problem.function(x1)?.to_vec();
853            if fx1.norm(Norm::L2) < self.tol {
854                return Ok(x1);
855            }
856            let dx = x1
857                .iter()
858                .zip(x0.iter())
859                .map(|(x1, x0)| x1 - x0)
860                .collect::<Vec<_>>();
861            if dx.norm(Norm::L2) < self.rtol {
862                return Ok(x1);
863            }
864            let df = fx1
865                .iter()
866                .zip(fx0.iter())
867                .map(|(fx1, fx0)| fx1 - fx0)
868                .collect::<Vec<_>>();
869
870            let denom = dx.add_v(&H.apply(&df));
871            let right = &dx.to_row() * &H;
872            let num = right.apply(&df)[0];
873
874            let left = denom.div_s(num);
875
876            H = H - left.to_col() * right;
877
878            x0 = x1;
879            fx0 = fx1.clone();
880            let dx_new = H.apply(&fx1);
881            for i in 0..I {
882                x1[i] += dx_new[i];
883            }
884        }
885        bail!(RootError::NotConverge(x1));
886    }
887}