1use crate::fuga::ConvToMat;
90use crate::traits::math::{InnerProduct, Norm, Normed, Vector};
91use crate::util::non_macro::eye;
92use anyhow::{bail, Result};
93
94pub trait ODEProblem {
114 fn rhs(&self, t: f64, y: &[f64], dy: &mut [f64]) -> Result<()>;
115}
116
117pub trait ODEIntegrator {
121 fn step<P: ODEProblem>(&self, problem: &P, t: f64, y: &mut [f64], dt: f64) -> Result<f64>;
122}
123
124#[derive(Debug, Clone)]
154pub enum ODEError {
155 ConstraintViolation(f64, Vec<f64>, Vec<f64>), ReachedMaxStepIter,
157}
158
159impl std::fmt::Display for ODEError {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 match self {
162 ODEError::ConstraintViolation(t, y, dy) => write!(
163 f,
164 "Constraint violation at t = {}, y = {:?}, dy = {:?}",
165 t, y, dy
166 ),
167 ODEError::ReachedMaxStepIter => write!(f, "Reached maximum number of steps per step"),
168 }
169 }
170}
171
172pub trait ODESolver {
176 fn solve<P: ODEProblem>(
177 &self,
178 problem: &P,
179 t_span: (f64, f64),
180 dt: f64,
181 initial_conditions: &[f64],
182 ) -> Result<(Vec<f64>, Vec<Vec<f64>>)>;
183}
184
185pub struct BasicODESolver<I: ODEIntegrator> {
217 integrator: I,
218}
219
220impl<I: ODEIntegrator> BasicODESolver<I> {
221 pub fn new(integrator: I) -> Self {
222 Self { integrator }
223 }
224}
225
226impl<I: ODEIntegrator> ODESolver for BasicODESolver<I> {
227 fn solve<P: ODEProblem>(
228 &self,
229 problem: &P,
230 t_span: (f64, f64),
231 dt: f64,
232 initial_conditions: &[f64],
233 ) -> Result<(Vec<f64>, Vec<Vec<f64>>)> {
234 let mut t = t_span.0;
235 let mut dt = dt;
236 let mut y = initial_conditions.to_vec();
237 let mut t_vec = vec![t];
238 let mut y_vec = vec![y.clone()];
239
240 while t < t_span.1 {
241 let dt_step = self.integrator.step(problem, t, &mut y, dt)?;
242 t += dt;
243 t_vec.push(t);
244 y_vec.push(y.clone());
245 dt = dt_step;
246 }
247
248 Ok((t_vec, y_vec))
249 }
250}
251
252pub trait ButcherTableau {
269 const C: &'static [f64];
270 const A: &'static [&'static [f64]];
271 const BU: &'static [f64];
272 const BE: &'static [f64];
273
274 fn tol(&self) -> f64 {
275 unimplemented!()
276 }
277
278 fn safety_factor(&self) -> f64 {
279 unimplemented!()
280 }
281
282 fn max_step_size(&self) -> f64 {
283 unimplemented!()
284 }
285
286 fn min_step_size(&self) -> f64 {
287 unimplemented!()
288 }
289
290 fn max_step_iter(&self) -> usize {
291 unimplemented!()
292 }
293
294 fn order(&self) -> usize {
297 4
298 }
299}
300
301impl<BU: ButcherTableau> ODEIntegrator for BU {
302 fn step<P: ODEProblem>(&self, problem: &P, t: f64, y: &mut [f64], dt: f64) -> Result<f64> {
303 let n = y.len();
304 let mut iter_count = 0usize;
305 let mut dt = dt;
306 let n_k = Self::C.len();
307
308 loop {
309 let mut k_vec = vec![vec![0.0; n]; n_k];
310 let mut y_temp = y.to_vec();
311
312 for stage in 0..n_k {
313 for i in 0..n {
314 let mut s = 0.0;
315 for j in 0..stage {
316 s += Self::A[stage][j] * k_vec[j][i];
317 }
318 y_temp[i] = y[i] + dt * s;
319 }
320 problem.rhs(t + dt * Self::C[stage], &y_temp, &mut k_vec[stage])?;
321 }
322
323 if !Self::BE.is_empty() {
324 let mut error = 0f64;
325 for i in 0..n {
326 let mut s = 0.0;
327 for j in 0..n_k {
328 s += (Self::BU[j] - Self::BE[j]) * k_vec[j][i];
329 }
330 error = error.max(dt * s.abs())
331 }
332
333 let factor = (self.tol() / error).powf(1.0 / (self.order() as f64 + 1.0));
334 let new_dt = self.safety_factor() * dt * factor;
335 let new_dt = new_dt.clamp(self.min_step_size(), self.max_step_size());
336
337 if error < self.tol() {
338 for i in 0..n {
339 let mut s = 0.0;
340 for j in 0..n_k {
341 s += Self::BU[j] * k_vec[j][i];
342 }
343 y[i] += dt * s;
344 }
345 return Ok(new_dt);
346 } else {
347 iter_count += 1;
348 if iter_count >= self.max_step_iter() {
349 bail!(ODEError::ReachedMaxStepIter);
350 }
351 dt = new_dt;
352 }
353 } else {
354 for i in 0..n {
355 let mut s = 0.0;
356 for j in 0..n_k {
357 s += Self::BU[j] * k_vec[j][i];
358 }
359 y[i] += dt * s;
360 }
361 return Ok(dt);
362 }
363 }
364 }
365}
366
367#[derive(Debug, Clone, Copy, Default)]
375#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
376#[cfg_attr(
377 feature = "rkyv",
378 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
379)]
380pub struct RALS3;
381
382impl ButcherTableau for RALS3 {
383 const C: &'static [f64] = &[0.0, 0.5, 0.75];
384 const A: &'static [&'static [f64]] = &[&[], &[0.5], &[0.0, 0.75]];
385 const BU: &'static [f64] = &[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0];
386 const BE: &'static [f64] = &[];
387}
388
389#[derive(Debug, Clone, Copy, Default)]
394#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
395#[cfg_attr(
396 feature = "rkyv",
397 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
398)]
399pub struct RK4;
400
401impl ButcherTableau for RK4 {
402 const C: &'static [f64] = &[0.0, 0.5, 0.5, 1.0];
403 const A: &'static [&'static [f64]] = &[&[], &[0.5], &[0.0, 0.5], &[0.0, 0.0, 1.0]];
404 const BU: &'static [f64] = &[1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0];
405 const BE: &'static [f64] = &[];
406}
407
408#[derive(Debug, Clone, Copy, Default)]
412#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
413#[cfg_attr(
414 feature = "rkyv",
415 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
416)]
417pub struct RALS4;
418
419impl ButcherTableau for RALS4 {
420 const C: &'static [f64] = &[0.0, 0.4, 0.45573725, 1.0];
421 const A: &'static [&'static [f64]] = &[
422 &[],
423 &[0.4],
424 &[0.29697761, 0.158575964],
425 &[0.21810040, -3.050965616, 3.83286476],
426 ];
427 const BU: &'static [f64] = &[0.17476028, -0.55148066, 1.20553560, 0.17118478];
428 const BE: &'static [f64] = &[];
429}
430
431#[derive(Debug, Clone, Copy, Default)]
435#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
436#[cfg_attr(
437 feature = "rkyv",
438 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
439)]
440pub struct RK5;
441
442impl ButcherTableau for RK5 {
443 const C: &'static [f64] = &[0.0, 0.2, 0.3, 0.8, 8.0 / 9.0, 1.0, 1.0];
444 const A: &'static [&'static [f64]] = &[
445 &[],
446 &[0.2],
447 &[0.075, 0.225],
448 &[44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0],
449 &[
450 19372.0 / 6561.0,
451 -25360.0 / 2187.0,
452 64448.0 / 6561.0,
453 -212.0 / 729.0,
454 ],
455 &[
456 9017.0 / 3168.0,
457 -355.0 / 33.0,
458 46732.0 / 5247.0,
459 49.0 / 176.0,
460 -5103.0 / 18656.0,
461 ],
462 &[
463 35.0 / 384.0,
464 0.0,
465 500.0 / 1113.0,
466 125.0 / 192.0,
467 -2187.0 / 6784.0,
468 11.0 / 84.0,
469 ],
470 ];
471 const BU: &'static [f64] = &[
472 5179.0 / 57600.0,
473 0.0,
474 7571.0 / 16695.0,
475 393.0 / 640.0,
476 -92097.0 / 339200.0,
477 187.0 / 2100.0,
478 1.0 / 40.0,
479 ];
480 const BE: &'static [f64] = &[];
481}
482
483#[derive(Debug, Clone, Copy)]
498#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
499#[cfg_attr(
500 feature = "rkyv",
501 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
502)]
503pub struct BS23 {
504 pub tol: f64,
505 pub safety_factor: f64,
506 pub min_step_size: f64,
507 pub max_step_size: f64,
508 pub max_step_iter: usize,
509}
510
511impl Default for BS23 {
512 fn default() -> Self {
513 Self {
514 tol: 1e-3,
515 safety_factor: 0.9,
516 min_step_size: 1e-6,
517 max_step_size: 1e-1,
518 max_step_iter: 100,
519 }
520 }
521}
522
523impl BS23 {
524 pub fn new(
525 tol: f64,
526 safety_factor: f64,
527 min_step_size: f64,
528 max_step_size: f64,
529 max_step_iter: usize,
530 ) -> Self {
531 Self {
532 tol,
533 safety_factor,
534 min_step_size,
535 max_step_size,
536 max_step_iter,
537 }
538 }
539}
540
541impl ButcherTableau for BS23 {
542 const C: &'static [f64] = &[0.0, 0.5, 0.75, 1.0];
543 const A: &'static [&'static [f64]] = &[
544 &[],
545 &[0.5],
546 &[0.0, 0.75],
547 &[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0],
548 ];
549 const BU: &'static [f64] = &[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0, 0.0];
550 const BE: &'static [f64] = &[7.0 / 24.0, 0.25, 1.0 / 3.0, 0.125];
551
552 fn tol(&self) -> f64 {
553 self.tol
554 }
555 fn safety_factor(&self) -> f64 {
556 self.safety_factor
557 }
558 fn min_step_size(&self) -> f64 {
559 self.min_step_size
560 }
561 fn max_step_size(&self) -> f64 {
562 self.max_step_size
563 }
564 fn max_step_iter(&self) -> usize {
565 self.max_step_iter
566 }
567 fn order(&self) -> usize {
568 2
569 }
570}
571
572#[derive(Debug, Clone, Copy)]
586#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
587#[cfg_attr(
588 feature = "rkyv",
589 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
590)]
591pub struct RKF45 {
592 pub tol: f64,
593 pub safety_factor: f64,
594 pub min_step_size: f64,
595 pub max_step_size: f64,
596 pub max_step_iter: usize,
597}
598
599impl Default for RKF45 {
600 fn default() -> Self {
601 Self {
602 tol: 1e-6,
603 safety_factor: 0.9,
604 min_step_size: 1e-6,
605 max_step_size: 1e-1,
606 max_step_iter: 100,
607 }
608 }
609}
610
611impl RKF45 {
612 pub fn new(
613 tol: f64,
614 safety_factor: f64,
615 min_step_size: f64,
616 max_step_size: f64,
617 max_step_iter: usize,
618 ) -> Self {
619 Self {
620 tol,
621 safety_factor,
622 min_step_size,
623 max_step_size,
624 max_step_iter,
625 }
626 }
627}
628
629impl ButcherTableau for RKF45 {
630 const C: &'static [f64] = &[0.0, 1.0 / 4.0, 3.0 / 8.0, 12.0 / 13.0, 1.0, 1.0 / 2.0];
631 const A: &'static [&'static [f64]] = &[
632 &[],
633 &[0.25],
634 &[3.0 / 32.0, 9.0 / 32.0],
635 &[1932.0 / 2197.0, -7200.0 / 2197.0, 7296.0 / 2197.0],
636 &[439.0 / 216.0, -8.0, 3680.0 / 513.0, -845.0 / 4104.0],
637 &[
638 -8.0 / 27.0,
639 2.0,
640 -3544.0 / 2565.0,
641 1859.0 / 4104.0,
642 -11.0 / 40.0,
643 ],
644 ];
645 const BU: &'static [f64] = &[
646 16.0 / 135.0,
647 0.0,
648 6656.0 / 12825.0,
649 28561.0 / 56430.0,
650 -9.0 / 50.0,
651 2.0 / 55.0,
652 ];
653 const BE: &'static [f64] = &[
654 25.0 / 216.0,
655 0.0,
656 1408.0 / 2565.0,
657 2197.0 / 4104.0,
658 -1.0 / 5.0,
659 0.0,
660 ];
661
662 fn tol(&self) -> f64 {
663 self.tol
664 }
665 fn safety_factor(&self) -> f64 {
666 self.safety_factor
667 }
668 fn min_step_size(&self) -> f64 {
669 self.min_step_size
670 }
671 fn max_step_size(&self) -> f64 {
672 self.max_step_size
673 }
674 fn max_step_iter(&self) -> usize {
675 self.max_step_iter
676 }
677}
678
679#[derive(Debug, Clone, Copy)]
692#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
693#[cfg_attr(
694 feature = "rkyv",
695 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
696)]
697pub struct DP45 {
698 pub tol: f64,
699 pub safety_factor: f64,
700 pub min_step_size: f64,
701 pub max_step_size: f64,
702 pub max_step_iter: usize,
703}
704
705impl Default for DP45 {
706 fn default() -> Self {
707 Self {
708 tol: 1e-6,
709 safety_factor: 0.9,
710 min_step_size: 1e-6,
711 max_step_size: 1e-1,
712 max_step_iter: 100,
713 }
714 }
715}
716
717impl DP45 {
718 pub fn new(
719 tol: f64,
720 safety_factor: f64,
721 min_step_size: f64,
722 max_step_size: f64,
723 max_step_iter: usize,
724 ) -> Self {
725 Self {
726 tol,
727 safety_factor,
728 min_step_size,
729 max_step_size,
730 max_step_iter,
731 }
732 }
733}
734
735impl ButcherTableau for DP45 {
736 const C: &'static [f64] = &[0.0, 0.2, 0.3, 0.8, 8.0 / 9.0, 1.0, 1.0];
737 const A: &'static [&'static [f64]] = &[
738 &[],
739 &[0.2],
740 &[0.075, 0.225],
741 &[44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0],
742 &[
743 19372.0 / 6561.0,
744 -25360.0 / 2187.0,
745 64448.0 / 6561.0,
746 -212.0 / 729.0,
747 ],
748 &[
749 9017.0 / 3168.0,
750 -355.0 / 33.0,
751 46732.0 / 5247.0,
752 49.0 / 176.0,
753 -5103.0 / 18656.0,
754 ],
755 &[
756 35.0 / 384.0,
757 0.0,
758 500.0 / 1113.0,
759 125.0 / 192.0,
760 -2187.0 / 6784.0,
761 11.0 / 84.0,
762 ],
763 ];
764 const BU: &'static [f64] = &[
765 35.0 / 384.0,
766 0.0,
767 500.0 / 1113.0,
768 125.0 / 192.0,
769 -2187.0 / 6784.0,
770 11.0 / 84.0,
771 0.0,
772 ];
773 const BE: &'static [f64] = &[
774 5179.0 / 57600.0,
775 0.0,
776 7571.0 / 16695.0,
777 393.0 / 640.0,
778 -92097.0 / 339200.0,
779 187.0 / 2100.0,
780 1.0 / 40.0,
781 ];
782
783 fn tol(&self) -> f64 {
784 self.tol
785 }
786 fn safety_factor(&self) -> f64 {
787 self.safety_factor
788 }
789 fn min_step_size(&self) -> f64 {
790 self.min_step_size
791 }
792 fn max_step_size(&self) -> f64 {
793 self.max_step_size
794 }
795 fn max_step_iter(&self) -> usize {
796 self.max_step_iter
797 }
798}
799
800#[derive(Debug, Clone, Copy)]
817#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
818#[cfg_attr(
819 feature = "rkyv",
820 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
821)]
822pub struct TSIT45 {
823 pub tol: f64,
824 pub safety_factor: f64,
825 pub min_step_size: f64,
826 pub max_step_size: f64,
827 pub max_step_iter: usize,
828}
829
830impl Default for TSIT45 {
831 fn default() -> Self {
832 Self {
833 tol: 1e-6,
834 safety_factor: 0.9,
835 min_step_size: 1e-6,
836 max_step_size: 1e-1,
837 max_step_iter: 100,
838 }
839 }
840}
841
842impl TSIT45 {
843 pub fn new(
844 tol: f64,
845 safety_factor: f64,
846 min_step_size: f64,
847 max_step_size: f64,
848 max_step_iter: usize,
849 ) -> Self {
850 Self {
851 tol,
852 safety_factor,
853 min_step_size,
854 max_step_size,
855 max_step_iter,
856 }
857 }
858}
859
860impl ButcherTableau for TSIT45 {
861 const C: &'static [f64] = &[0.0, 0.161, 0.327, 0.9, 0.9800255409045097, 1.0, 1.0];
862 const A: &'static [&'static [f64]] = &[
863 &[],
864 &[Self::C[1]],
865 &[Self::C[2] - 0.335480655492357, 0.335480655492357],
866 &[
867 Self::C[3] - (-6.359448489975075 + 4.362295432869581),
868 -6.359448489975075,
869 4.362295432869581,
870 ],
871 &[
872 Self::C[4] - (-11.74888356406283 + 7.495539342889836 - 0.09249506636175525),
873 -11.74888356406283,
874 7.495539342889836,
875 -0.09249506636175525,
876 ],
877 &[
878 Self::C[5]
879 - (-12.92096931784711 + 8.159367898576159
880 - 0.0715849732814010
881 - 0.02826905039406838),
882 -12.92096931784711,
883 8.159367898576159,
884 -0.0715849732814010,
885 -0.02826905039406838,
886 ],
887 &[
888 Self::BU[0],
889 Self::BU[1],
890 Self::BU[2],
891 Self::BU[3],
892 Self::BU[4],
893 Self::BU[5],
894 ],
895 ];
896 const BU: &'static [f64] = &[
897 0.09646076681806523,
898 0.01,
899 0.4798896504144996,
900 1.379008574103742,
901 -3.290069515436081,
902 2.324710524099774,
903 0.0,
904 ];
905 const BE: &'static [f64] = &[
906 0.001780011052226,
907 0.000816434459657,
908 -0.007880878010262,
909 0.144711007173263,
910 -0.582357165452555,
911 0.458082105929187,
912 1.0 / 66.0,
913 ];
914
915 fn tol(&self) -> f64 {
916 self.tol
917 }
918 fn safety_factor(&self) -> f64 {
919 self.safety_factor
920 }
921 fn min_step_size(&self) -> f64 {
922 self.min_step_size
923 }
924 fn max_step_size(&self) -> f64 {
925 self.max_step_size
926 }
927 fn max_step_iter(&self) -> usize {
928 self.max_step_iter
929 }
930}
931
932#[derive(Debug, Clone, Copy)]
951#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
952#[cfg_attr(
953 feature = "rkyv",
954 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
955)]
956pub struct RKF78 {
957 pub tol: f64,
958 pub safety_factor: f64,
959 pub min_step_size: f64,
960 pub max_step_size: f64,
961 pub max_step_iter: usize,
962}
963
964impl Default for RKF78 {
965 fn default() -> Self {
966 Self {
967 tol: 1e-7, safety_factor: 0.9,
969 min_step_size: 1e-10, max_step_size: 1e-1,
971 max_step_iter: 100,
972 }
973 }
974}
975
976impl RKF78 {
977 pub fn new(
978 tol: f64,
979 safety_factor: f64,
980 min_step_size: f64,
981 max_step_size: f64,
982 max_step_iter: usize,
983 ) -> Self {
984 Self {
985 tol,
986 safety_factor,
987 min_step_size,
988 max_step_size,
989 max_step_iter,
990 }
991 }
992}
993
994impl ButcherTableau for RKF78 {
995 const C: &'static [f64] = &[
996 0.0,
997 2.0 / 27.0,
998 1.0 / 9.0,
999 1.0 / 6.0,
1000 5.0 / 12.0,
1001 1.0 / 2.0,
1002 5.0 / 6.0,
1003 1.0 / 6.0,
1004 2.0 / 3.0,
1005 1.0 / 3.0,
1006 1.0,
1007 0.0, 1.0, ];
1010
1011 const A: &'static [&'static [f64]] = &[
1012 &[],
1014 &[2.0 / 27.0],
1016 &[1.0 / 36.0, 3.0 / 36.0],
1018 &[1.0 / 24.0, 0.0, 3.0 / 24.0],
1020 &[20.0 / 48.0, 0.0, -75.0 / 48.0, 75.0 / 48.0],
1022 &[1.0 / 20.0, 0.0, 0.0, 5.0 / 20.0, 4.0 / 20.0],
1024 &[
1026 -25.0 / 108.0,
1027 0.0,
1028 0.0,
1029 125.0 / 108.0,
1030 -260.0 / 108.0,
1031 250.0 / 108.0,
1032 ],
1033 &[
1035 31.0 / 300.0,
1036 0.0,
1037 0.0,
1038 0.0,
1039 61.0 / 225.0,
1040 -2.0 / 9.0,
1041 13.0 / 900.0,
1042 ],
1043 &[
1045 2.0,
1046 0.0,
1047 0.0,
1048 -53.0 / 6.0,
1049 704.0 / 45.0,
1050 -107.0 / 9.0,
1051 67.0 / 90.0,
1052 3.0,
1053 ],
1054 &[
1056 -91.0 / 108.0,
1057 0.0,
1058 0.0,
1059 23.0 / 108.0,
1060 -976.0 / 135.0,
1061 311.0 / 54.0,
1062 -19.0 / 60.0,
1063 17.0 / 6.0,
1064 -1.0 / 12.0,
1065 ],
1066 &[
1068 2383.0 / 4100.0,
1069 0.0,
1070 0.0,
1071 -341.0 / 164.0,
1072 4496.0 / 1025.0,
1073 -301.0 / 82.0,
1074 2133.0 / 4100.0,
1075 45.0 / 82.0,
1076 45.0 / 164.0,
1077 18.0 / 41.0,
1078 ],
1079 &[
1081 3.0 / 205.0,
1082 0.0,
1083 0.0,
1084 0.0,
1085 0.0,
1086 -6.0 / 41.0,
1087 -3.0 / 205.0,
1088 -3.0 / 41.0,
1089 3.0 / 41.0,
1090 6.0 / 41.0,
1091 0.0,
1092 ],
1093 &[
1095 -1777.0 / 4100.0,
1096 0.0,
1097 0.0,
1098 -341.0 / 164.0,
1099 4496.0 / 1025.0,
1100 -289.0 / 82.0,
1101 2193.0 / 4100.0,
1102 51.0 / 82.0,
1103 33.0 / 164.0,
1104 12.0 / 41.0,
1105 0.0,
1106 1.0,
1107 ],
1108 ];
1109
1110 const BU: &'static [f64] = &[
1112 0.0,
1113 0.0,
1114 0.0,
1115 0.0,
1116 0.0,
1117 34.0 / 105.0,
1118 9.0 / 35.0,
1119 9.0 / 35.0,
1120 9.0 / 280.0,
1121 9.0 / 280.0,
1122 0.0,
1123 41.0 / 840.0,
1124 41.0 / 840.0,
1125 ];
1126
1127 const BE: &'static [f64] = &[
1130 41.0 / 840.0,
1131 0.0,
1132 0.0,
1133 0.0,
1134 0.0,
1135 34.0 / 105.0,
1136 9.0 / 35.0,
1137 9.0 / 35.0,
1138 9.0 / 280.0,
1139 9.0 / 280.0,
1140 41.0 / 840.0,
1141 0.0,
1142 0.0,
1143 ];
1144
1145 fn tol(&self) -> f64 {
1146 self.tol
1147 }
1148 fn safety_factor(&self) -> f64 {
1149 self.safety_factor
1150 }
1151 fn min_step_size(&self) -> f64 {
1152 self.min_step_size
1153 }
1154 fn max_step_size(&self) -> f64 {
1155 self.max_step_size
1156 }
1157 fn max_step_iter(&self) -> usize {
1158 self.max_step_iter
1159 }
1160 fn order(&self) -> usize {
1161 7
1162 }
1163}
1164
1165const SQRT3: f64 = 1.7320508075688772;
1171const C1: f64 = 0.5 - SQRT3 / 6.0;
1172const C2: f64 = 0.5 + SQRT3 / 6.0;
1173const A11: f64 = 0.25;
1174const A12: f64 = 0.25 - SQRT3 / 6.0;
1175const A21: f64 = 0.25 + SQRT3 / 6.0;
1176const A22: f64 = 0.25;
1177const B1: f64 = 0.5;
1178const B2: f64 = 0.5;
1179
1180#[derive(Debug, Clone, Copy)]
1185#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
1186#[cfg_attr(
1187 feature = "rkyv",
1188 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
1189)]
1190pub enum ImplicitSolver {
1191 FixedPoint,
1192 Broyden,
1193 }
1195
1196#[derive(Debug, Clone, Copy)]
1208#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
1209#[cfg_attr(
1210 feature = "rkyv",
1211 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
1212)]
1213pub struct GL4 {
1214 pub solver: ImplicitSolver,
1215 pub tol: f64,
1216 pub max_step_iter: usize,
1217}
1218
1219impl Default for GL4 {
1220 fn default() -> Self {
1221 GL4 {
1222 solver: ImplicitSolver::FixedPoint,
1223 tol: 1e-8,
1224 max_step_iter: 100,
1225 }
1226 }
1227}
1228
1229impl GL4 {
1230 pub fn new(solver: ImplicitSolver, tol: f64, max_step_iter: usize) -> Self {
1231 GL4 {
1232 solver,
1233 tol,
1234 max_step_iter,
1235 }
1236 }
1237}
1238
1239impl ODEIntegrator for GL4 {
1240 #[allow(non_snake_case)]
1241 #[inline]
1242 fn step<P: ODEProblem>(&self, problem: &P, t: f64, y: &mut [f64], dt: f64) -> Result<f64> {
1243 let n = y.len();
1244 let mut k1 = vec![0.0; n];
1248 let mut k2 = vec![0.0; n];
1249
1250 problem.rhs(t, y, &mut k1)?;
1252 k2.copy_from_slice(&k1);
1253
1254 match self.solver {
1255 ImplicitSolver::FixedPoint => {
1256 let mut y1 = vec![0.0; n];
1258 let mut y2 = vec![0.0; n];
1259
1260 for _ in 0..self.max_step_iter {
1261 let k1_old = k1.clone();
1262 let k2_old = k2.clone();
1263
1264 for i in 0..n {
1265 y1[i] = y[i] + dt * (A11 * k1[i] + A12 * k2[i]);
1266 y2[i] = y[i] + dt * (A11 * k1[i] + A12 * k2[i]);
1267 }
1268
1269 problem.rhs(t + C1 * dt, &y1, &mut k1)?;
1271 problem.rhs(t + C2 * dt, &y2, &mut k2)?;
1272
1273 let mut max_diff = 0f64;
1275 for i in 0..n {
1276 max_diff = max_diff.max((k1[i] - k1_old[i]).abs());
1277 max_diff = max_diff.max((k2[i] - k2_old[i]).abs());
1278 }
1279
1280 if max_diff < self.tol {
1281 break;
1282 }
1283 }
1284 }
1285 ImplicitSolver::Broyden => {
1286 let m = 2 * n;
1287 let mut U = vec![0.0; m];
1288 U[..n].copy_from_slice(&k1);
1289 U[n..].copy_from_slice(&k2);
1290
1291 let mut F_vec = vec![0.0; m];
1293 compute_F(problem, t, y, dt, &U, &mut F_vec)?;
1294
1295 let mut J_inv = eye(m);
1297
1298 for _ in 0..self.max_step_iter {
1300 let delta = (&J_inv * &F_vec).mul_scalar(-1.0);
1302
1303 U.iter_mut().zip(delta.iter()).for_each(|(u, d)| *u += *d);
1305
1306 let mut F_new = vec![0.0; m];
1307 compute_F(problem, t, y, dt, &U, &mut F_new)?;
1308
1309 if F_new.norm(Norm::LInf) < self.tol {
1311 break;
1312 }
1313
1314 let delta_F = F_new.sub_vec(&F_vec);
1316
1317 let J_inv_delta_F = &J_inv * &delta_F;
1319
1320 let denom = delta.dot(&J_inv_delta_F);
1321 if denom.abs() < 1e-12 {
1322 break;
1323 }
1324
1325 let delta_minus_J_inv_delta_F = delta.sub_vec(&J_inv_delta_F).to_col();
1328 let delta_T_J_inv = &delta.to_row() * &J_inv;
1329 let update = (delta_minus_J_inv_delta_F * delta_T_J_inv) / denom;
1330 J_inv = J_inv + update;
1331 F_vec = F_new;
1332 }
1333
1334 k1.copy_from_slice(&U[..n]);
1335 k2.copy_from_slice(&U[n..]);
1336 }
1337 }
1338
1339 for i in 0..n {
1340 y[i] += dt * (B1 * k1[i] + B2 * k2[i]);
1341 }
1342
1343 Ok(dt)
1344 }
1345}
1346
1347#[allow(non_snake_case)]
1387fn compute_F<P: ODEProblem>(
1388 problem: &P,
1389 t: f64,
1390 y: &[f64],
1391 dt: f64,
1392 U: &[f64], F: &mut [f64],
1394) -> Result<()> {
1395 let n = y.len();
1396 let (k1_slice, k2_slice) = U.split_at(n);
1397
1398 let mut y1 = vec![0.0; n];
1399 let mut y2 = vec![0.0; n];
1400
1401 for i in 0..n {
1402 y1[i] = y[i] + dt * (A11 * k1_slice[i] + A12 * k2_slice[i]);
1403 y2[i] = y[i] + dt * (A21 * k1_slice[i] + A22 * k2_slice[i]);
1404 }
1405
1406 let (f1, f2) = F.split_at_mut(n);
1408 problem.rhs(t + C1 * dt, &y1, f1)?;
1409 problem.rhs(t + C2 * dt, &y2, f2)?;
1410
1411 for i in 0..n {
1413 f1[i] = k1_slice[i] - f1[i];
1414 f2[i] = k2_slice[i] - f2[i];
1415 }
1416 Ok(())
1417}