1use crate::fuga::ConvToMat;
90use crate::traits::math::{InnerProduct, Norm, Normed, Vector};
91use crate::util::non_macro::eye;
92use anyhow::{bail, Result};
93
94pub trait ODEProblem {
114 fn rhs(&self, t: f64, y: &[f64], dy: &mut [f64]) -> Result<()>;
115}
116
117pub trait ODEIntegrator {
121 fn step<P: ODEProblem>(&self, problem: &P, t: f64, y: &mut [f64], dt: f64) -> Result<f64>;
122}
123
124#[derive(Debug, Clone)]
154pub enum ODEError {
155 ConstraintViolation(f64, Vec<f64>, Vec<f64>), ReachedMaxStepIter,
157}
158
159impl std::fmt::Display for ODEError {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 match self {
162 ODEError::ConstraintViolation(t, y, dy) => write!(
163 f,
164 "Constraint violation at t = {}, y = {:?}, dy = {:?}",
165 t, y, dy
166 ),
167 ODEError::ReachedMaxStepIter => write!(f, "Reached maximum number of steps per step"),
168 }
169 }
170}
171
172pub trait ODESolver {
176 fn solve<P: ODEProblem>(
177 &self,
178 problem: &P,
179 t_span: (f64, f64),
180 dt: f64,
181 initial_conditions: &[f64],
182 ) -> Result<(Vec<f64>, Vec<Vec<f64>>)>;
183}
184
185pub struct BasicODESolver<I: ODEIntegrator> {
217 integrator: I,
218}
219
220impl<I: ODEIntegrator> BasicODESolver<I> {
221 pub fn new(integrator: I) -> Self {
222 Self { integrator }
223 }
224}
225
226impl<I: ODEIntegrator> ODESolver for BasicODESolver<I> {
227 fn solve<P: ODEProblem>(
228 &self,
229 problem: &P,
230 t_span: (f64, f64),
231 dt: f64,
232 initial_conditions: &[f64],
233 ) -> Result<(Vec<f64>, Vec<Vec<f64>>)> {
234 let mut t = t_span.0;
235 let mut dt = dt;
236 let mut y = initial_conditions.to_vec();
237 let mut t_vec = vec![t];
238 let mut y_vec = vec![y.clone()];
239
240 while t < t_span.1 {
241 let dt_step = self.integrator.step(problem, t, &mut y, dt)?;
242 t += dt;
243 t_vec.push(t);
244 y_vec.push(y.clone());
245 dt = dt_step;
246 }
247
248 Ok((t_vec, y_vec))
249 }
250}
251
252pub trait ButcherTableau {
269 const C: &'static [f64];
270 const A: &'static [&'static [f64]];
271 const BU: &'static [f64];
272 const BE: &'static [f64];
273
274 fn tol(&self) -> f64 {
275 unimplemented!()
276 }
277
278 fn safety_factor(&self) -> f64 {
279 unimplemented!()
280 }
281
282 fn max_step_size(&self) -> f64 {
283 unimplemented!()
284 }
285
286 fn min_step_size(&self) -> f64 {
287 unimplemented!()
288 }
289
290 fn max_step_iter(&self) -> usize {
291 unimplemented!()
292 }
293}
294
295impl<BU: ButcherTableau> ODEIntegrator for BU {
296 fn step<P: ODEProblem>(&self, problem: &P, t: f64, y: &mut [f64], dt: f64) -> Result<f64> {
297 let n = y.len();
298 let mut iter_count = 0usize;
299 let mut dt = dt;
300 let n_k = Self::C.len();
301
302 loop {
303 let mut k_vec = vec![vec![0.0; n]; n_k];
304 let mut y_temp = y.to_vec();
305
306 for stage in 0..n_k {
307 for i in 0..n {
308 let mut s = 0.0;
309 for j in 0..stage {
310 s += Self::A[stage][j] * k_vec[j][i];
311 }
312 y_temp[i] = y[i] + dt * s;
313 }
314 problem.rhs(t + dt * Self::C[stage], &y_temp, &mut k_vec[stage])?;
315 }
316
317 if !Self::BE.is_empty() {
318 let mut error = 0f64;
319 for i in 0..n {
320 let mut s = 0.0;
321 for j in 0..n_k {
322 s += (Self::BU[j] - Self::BE[j]) * k_vec[j][i];
323 }
324 error = error.max(dt * s.abs())
325 }
326
327 let factor = (self.tol() * dt / error).powf(0.2);
328 let new_dt = self.safety_factor() * dt * factor;
329 let new_dt = new_dt.clamp(self.min_step_size(), self.max_step_size());
330
331 if error < self.tol() {
332 for i in 0..n {
333 let mut s = 0.0;
334 for j in 0..n_k {
335 s += Self::BU[j] * k_vec[j][i];
336 }
337 y[i] += dt * s;
338 }
339 return Ok(new_dt);
340 } else {
341 iter_count += 1;
342 if iter_count >= self.max_step_iter() {
343 bail!(ODEError::ReachedMaxStepIter);
344 }
345 dt = new_dt;
346 }
347 } else {
348 for i in 0..n {
349 let mut s = 0.0;
350 for j in 0..n_k {
351 s += Self::BU[j] * k_vec[j][i];
352 }
353 y[i] += dt * s;
354 }
355 return Ok(dt);
356 }
357 }
358 }
359}
360
361#[derive(Debug, Clone, Copy, Default)]
369#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
370#[cfg_attr(
371 feature = "rkyv",
372 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
373)]
374pub struct RALS3;
375
376impl ButcherTableau for RALS3 {
377 const C: &'static [f64] = &[0.0, 0.5, 0.75];
378 const A: &'static [&'static [f64]] = &[&[], &[0.5], &[0.0, 0.75]];
379 const BU: &'static [f64] = &[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0];
380 const BE: &'static [f64] = &[];
381}
382
383#[derive(Debug, Clone, Copy, Default)]
388#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
389#[cfg_attr(
390 feature = "rkyv",
391 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
392)]
393pub struct RK4;
394
395impl ButcherTableau for RK4 {
396 const C: &'static [f64] = &[0.0, 0.5, 0.5, 1.0];
397 const A: &'static [&'static [f64]] = &[&[], &[0.5], &[0.0, 0.5], &[0.0, 0.0, 1.0]];
398 const BU: &'static [f64] = &[1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0];
399 const BE: &'static [f64] = &[];
400}
401
402#[derive(Debug, Clone, Copy, Default)]
406#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
407#[cfg_attr(
408 feature = "rkyv",
409 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
410)]
411pub struct RALS4;
412
413impl ButcherTableau for RALS4 {
414 const C: &'static [f64] = &[0.0, 0.4, 0.45573725, 1.0];
415 const A: &'static [&'static [f64]] = &[
416 &[],
417 &[0.4],
418 &[0.29697761, 0.158575964],
419 &[0.21810040, -3.050965616, 3.83286476],
420 ];
421 const BU: &'static [f64] = &[0.17476028, -0.55148066, 1.20553560, 0.17118478];
422 const BE: &'static [f64] = &[];
423}
424
425#[derive(Debug, Clone, Copy, Default)]
429#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
430#[cfg_attr(
431 feature = "rkyv",
432 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
433)]
434pub struct RK5;
435
436impl ButcherTableau for RK5 {
437 const C: &'static [f64] = &[0.0, 0.2, 0.3, 0.8, 8.0 / 9.0, 1.0, 1.0];
438 const A: &'static [&'static [f64]] = &[
439 &[],
440 &[0.2],
441 &[0.075, 0.225],
442 &[44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0],
443 &[
444 19372.0 / 6561.0,
445 -25360.0 / 2187.0,
446 64448.0 / 6561.0,
447 -212.0 / 729.0,
448 ],
449 &[
450 9017.0 / 3168.0,
451 -355.0 / 33.0,
452 46732.0 / 5247.0,
453 49.0 / 176.0,
454 -5103.0 / 18656.0,
455 ],
456 &[
457 35.0 / 384.0,
458 0.0,
459 500.0 / 1113.0,
460 125.0 / 192.0,
461 -2187.0 / 6784.0,
462 11.0 / 84.0,
463 ],
464 ];
465 const BU: &'static [f64] = &[
466 5179.0 / 57600.0,
467 0.0,
468 7571.0 / 16695.0,
469 393.0 / 640.0,
470 -92097.0 / 339200.0,
471 187.0 / 2100.0,
472 1.0 / 40.0,
473 ];
474 const BE: &'static [f64] = &[];
475}
476
477#[derive(Debug, Clone, Copy)]
492#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
493#[cfg_attr(
494 feature = "rkyv",
495 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
496)]
497pub struct BS23 {
498 pub tol: f64,
499 pub safety_factor: f64,
500 pub min_step_size: f64,
501 pub max_step_size: f64,
502 pub max_step_iter: usize,
503}
504
505impl Default for BS23 {
506 fn default() -> Self {
507 Self {
508 tol: 1e-3,
509 safety_factor: 0.9,
510 min_step_size: 1e-6,
511 max_step_size: 1e-1,
512 max_step_iter: 100,
513 }
514 }
515}
516
517impl BS23 {
518 pub fn new(
519 tol: f64,
520 safety_factor: f64,
521 min_step_size: f64,
522 max_step_size: f64,
523 max_step_iter: usize,
524 ) -> Self {
525 Self {
526 tol,
527 safety_factor,
528 min_step_size,
529 max_step_size,
530 max_step_iter,
531 }
532 }
533}
534
535impl ButcherTableau for BS23 {
536 const C: &'static [f64] = &[0.0, 0.5, 0.75, 1.0];
537 const A: &'static [&'static [f64]] = &[
538 &[],
539 &[0.5],
540 &[0.0, 0.75],
541 &[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0],
542 ];
543 const BU: &'static [f64] = &[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0, 0.0];
544 const BE: &'static [f64] = &[7.0 / 24.0, 0.25, 1.0 / 3.0, 0.125];
545
546 fn tol(&self) -> f64 {
547 self.tol
548 }
549 fn safety_factor(&self) -> f64 {
550 self.safety_factor
551 }
552 fn min_step_size(&self) -> f64 {
553 self.min_step_size
554 }
555 fn max_step_size(&self) -> f64 {
556 self.max_step_size
557 }
558 fn max_step_iter(&self) -> usize {
559 self.max_step_iter
560 }
561}
562
563#[derive(Debug, Clone, Copy)]
577#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
578#[cfg_attr(
579 feature = "rkyv",
580 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
581)]
582pub struct RKF45 {
583 pub tol: f64,
584 pub safety_factor: f64,
585 pub min_step_size: f64,
586 pub max_step_size: f64,
587 pub max_step_iter: usize,
588}
589
590impl Default for RKF45 {
591 fn default() -> Self {
592 Self {
593 tol: 1e-6,
594 safety_factor: 0.9,
595 min_step_size: 1e-6,
596 max_step_size: 1e-1,
597 max_step_iter: 100,
598 }
599 }
600}
601
602impl RKF45 {
603 pub fn new(
604 tol: f64,
605 safety_factor: f64,
606 min_step_size: f64,
607 max_step_size: f64,
608 max_step_iter: usize,
609 ) -> Self {
610 Self {
611 tol,
612 safety_factor,
613 min_step_size,
614 max_step_size,
615 max_step_iter,
616 }
617 }
618}
619
620impl ButcherTableau for RKF45 {
621 const C: &'static [f64] = &[0.0, 1.0 / 4.0, 3.0 / 8.0, 12.0 / 13.0, 1.0, 1.0 / 2.0];
622 const A: &'static [&'static [f64]] = &[
623 &[],
624 &[0.25],
625 &[3.0 / 32.0, 9.0 / 32.0],
626 &[1932.0 / 2197.0, -7200.0 / 2197.0, 7296.0 / 2197.0],
627 &[439.0 / 216.0, -8.0, 3680.0 / 513.0, -845.0 / 4104.0],
628 &[
629 -8.0 / 27.0,
630 2.0,
631 -3544.0 / 2565.0,
632 1859.0 / 4104.0,
633 -11.0 / 40.0,
634 ],
635 ];
636 const BU: &'static [f64] = &[
637 16.0 / 135.0,
638 0.0,
639 6656.0 / 12825.0,
640 28561.0 / 56430.0,
641 -9.0 / 50.0,
642 2.0 / 55.0,
643 ];
644 const BE: &'static [f64] = &[
645 25.0 / 216.0,
646 0.0,
647 1408.0 / 2565.0,
648 2197.0 / 4104.0,
649 -1.0 / 5.0,
650 0.0,
651 ];
652
653 fn tol(&self) -> f64 {
654 self.tol
655 }
656 fn safety_factor(&self) -> f64 {
657 self.safety_factor
658 }
659 fn min_step_size(&self) -> f64 {
660 self.min_step_size
661 }
662 fn max_step_size(&self) -> f64 {
663 self.max_step_size
664 }
665 fn max_step_iter(&self) -> usize {
666 self.max_step_iter
667 }
668}
669
670#[derive(Debug, Clone, Copy)]
683#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
684#[cfg_attr(
685 feature = "rkyv",
686 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
687)]
688pub struct DP45 {
689 pub tol: f64,
690 pub safety_factor: f64,
691 pub min_step_size: f64,
692 pub max_step_size: f64,
693 pub max_step_iter: usize,
694}
695
696impl Default for DP45 {
697 fn default() -> Self {
698 Self {
699 tol: 1e-6,
700 safety_factor: 0.9,
701 min_step_size: 1e-6,
702 max_step_size: 1e-1,
703 max_step_iter: 100,
704 }
705 }
706}
707
708impl DP45 {
709 pub fn new(
710 tol: f64,
711 safety_factor: f64,
712 min_step_size: f64,
713 max_step_size: f64,
714 max_step_iter: usize,
715 ) -> Self {
716 Self {
717 tol,
718 safety_factor,
719 min_step_size,
720 max_step_size,
721 max_step_iter,
722 }
723 }
724}
725
726impl ButcherTableau for DP45 {
727 const C: &'static [f64] = &[0.0, 0.2, 0.3, 0.8, 8.0 / 9.0, 1.0, 1.0];
728 const A: &'static [&'static [f64]] = &[
729 &[],
730 &[0.2],
731 &[0.075, 0.225],
732 &[44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0],
733 &[
734 19372.0 / 6561.0,
735 -25360.0 / 2187.0,
736 64448.0 / 6561.0,
737 -212.0 / 729.0,
738 ],
739 &[
740 9017.0 / 3168.0,
741 -355.0 / 33.0,
742 46732.0 / 5247.0,
743 49.0 / 176.0,
744 -5103.0 / 18656.0,
745 ],
746 &[
747 35.0 / 384.0,
748 0.0,
749 500.0 / 1113.0,
750 125.0 / 192.0,
751 -2187.0 / 6784.0,
752 11.0 / 84.0,
753 ],
754 ];
755 const BU: &'static [f64] = &[
756 35.0 / 384.0,
757 0.0,
758 500.0 / 1113.0,
759 125.0 / 192.0,
760 -2187.0 / 6784.0,
761 11.0 / 84.0,
762 0.0,
763 ];
764 const BE: &'static [f64] = &[
765 5179.0 / 57600.0,
766 0.0,
767 7571.0 / 16695.0,
768 393.0 / 640.0,
769 -92097.0 / 339200.0,
770 187.0 / 2100.0,
771 1.0 / 40.0,
772 ];
773
774 fn tol(&self) -> f64 {
775 self.tol
776 }
777 fn safety_factor(&self) -> f64 {
778 self.safety_factor
779 }
780 fn min_step_size(&self) -> f64 {
781 self.min_step_size
782 }
783 fn max_step_size(&self) -> f64 {
784 self.max_step_size
785 }
786 fn max_step_iter(&self) -> usize {
787 self.max_step_iter
788 }
789}
790
791#[derive(Debug, Clone, Copy)]
808#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
809#[cfg_attr(
810 feature = "rkyv",
811 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
812)]
813pub struct TSIT45 {
814 pub tol: f64,
815 pub safety_factor: f64,
816 pub min_step_size: f64,
817 pub max_step_size: f64,
818 pub max_step_iter: usize,
819}
820
821impl Default for TSIT45 {
822 fn default() -> Self {
823 Self {
824 tol: 1e-6,
825 safety_factor: 0.9,
826 min_step_size: 1e-6,
827 max_step_size: 1e-1,
828 max_step_iter: 100,
829 }
830 }
831}
832
833impl TSIT45 {
834 pub fn new(
835 tol: f64,
836 safety_factor: f64,
837 min_step_size: f64,
838 max_step_size: f64,
839 max_step_iter: usize,
840 ) -> Self {
841 Self {
842 tol,
843 safety_factor,
844 min_step_size,
845 max_step_size,
846 max_step_iter,
847 }
848 }
849}
850
851impl ButcherTableau for TSIT45 {
852 const C: &'static [f64] = &[0.0, 0.161, 0.327, 0.9, 0.9800255409045097, 1.0, 1.0];
853 const A: &'static [&'static [f64]] = &[
854 &[],
855 &[Self::C[1]],
856 &[Self::C[2] - 0.335480655492357, 0.335480655492357],
857 &[
858 Self::C[3] - (-6.359448489975075 + 4.362295432869581),
859 -6.359448489975075,
860 4.362295432869581,
861 ],
862 &[
863 Self::C[4] - (-11.74888356406283 + 7.495539342889836 - 0.09249506636175525),
864 -11.74888356406283,
865 7.495539342889836,
866 -0.09249506636175525,
867 ],
868 &[
869 Self::C[5]
870 - (-12.92096931784711 + 8.159367898576159
871 - 0.0715849732814010
872 - 0.02826905039406838),
873 -12.92096931784711,
874 8.159367898576159,
875 -0.0715849732814010,
876 -0.02826905039406838,
877 ],
878 &[
879 Self::BU[0],
880 Self::BU[1],
881 Self::BU[2],
882 Self::BU[3],
883 Self::BU[4],
884 Self::BU[5],
885 ],
886 ];
887 const BU: &'static [f64] = &[
888 0.09646076681806523,
889 0.01,
890 0.4798896504144996,
891 1.379008574103742,
892 -3.290069515436081,
893 2.324710524099774,
894 0.0,
895 ];
896 const BE: &'static [f64] = &[
897 0.001780011052226,
898 0.000816434459657,
899 -0.007880878010262,
900 0.144711007173263,
901 -0.582357165452555,
902 0.458082105929187,
903 1.0 / 66.0,
904 ];
905
906 fn tol(&self) -> f64 {
907 self.tol
908 }
909 fn safety_factor(&self) -> f64 {
910 self.safety_factor
911 }
912 fn min_step_size(&self) -> f64 {
913 self.min_step_size
914 }
915 fn max_step_size(&self) -> f64 {
916 self.max_step_size
917 }
918 fn max_step_iter(&self) -> usize {
919 self.max_step_iter
920 }
921}
922
923#[derive(Debug, Clone, Copy)]
942#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
943#[cfg_attr(
944 feature = "rkyv",
945 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
946)]
947pub struct RKF78 {
948 pub tol: f64,
949 pub safety_factor: f64,
950 pub min_step_size: f64,
951 pub max_step_size: f64,
952 pub max_step_iter: usize,
953}
954
955impl Default for RKF78 {
956 fn default() -> Self {
957 Self {
958 tol: 1e-7, safety_factor: 0.9,
960 min_step_size: 1e-10, max_step_size: 1e-1,
962 max_step_iter: 100,
963 }
964 }
965}
966
967impl RKF78 {
968 pub fn new(
969 tol: f64,
970 safety_factor: f64,
971 min_step_size: f64,
972 max_step_size: f64,
973 max_step_iter: usize,
974 ) -> Self {
975 Self {
976 tol,
977 safety_factor,
978 min_step_size,
979 max_step_size,
980 max_step_iter,
981 }
982 }
983}
984
985impl ButcherTableau for RKF78 {
986 const C: &'static [f64] = &[
987 0.0,
988 2.0 / 27.0,
989 1.0 / 9.0,
990 1.0 / 6.0,
991 5.0 / 12.0,
992 1.0 / 2.0,
993 5.0 / 6.0,
994 1.0 / 6.0,
995 2.0 / 3.0,
996 1.0 / 3.0,
997 1.0,
998 0.0, 1.0, ];
1001
1002 const A: &'static [&'static [f64]] = &[
1003 &[],
1005 &[2.0 / 27.0],
1007 &[1.0 / 36.0, 3.0 / 36.0],
1009 &[1.0 / 24.0, 0.0, 3.0 / 24.0],
1011 &[20.0 / 48.0, 0.0, -75.0 / 48.0, 75.0 / 48.0],
1013 &[1.0 / 20.0, 0.0, 0.0, 5.0 / 20.0, 4.0 / 20.0],
1015 &[
1017 -25.0 / 108.0,
1018 0.0,
1019 0.0,
1020 125.0 / 108.0,
1021 -260.0 / 108.0,
1022 250.0 / 108.0,
1023 ],
1024 &[
1026 31.0 / 300.0,
1027 0.0,
1028 0.0,
1029 0.0,
1030 61.0 / 225.0,
1031 -2.0 / 9.0,
1032 13.0 / 900.0,
1033 ],
1034 &[
1036 2.0,
1037 0.0,
1038 0.0,
1039 -53.0 / 6.0,
1040 704.0 / 45.0,
1041 -107.0 / 9.0,
1042 67.0 / 90.0,
1043 3.0,
1044 ],
1045 &[
1047 -91.0 / 108.0,
1048 0.0,
1049 0.0,
1050 23.0 / 108.0,
1051 -976.0 / 135.0,
1052 311.0 / 54.0,
1053 -19.0 / 60.0,
1054 17.0 / 6.0,
1055 -1.0 / 12.0,
1056 ],
1057 &[
1059 2383.0 / 4100.0,
1060 0.0,
1061 0.0,
1062 -341.0 / 164.0,
1063 4496.0 / 1025.0,
1064 -301.0 / 82.0,
1065 2133.0 / 4100.0,
1066 45.0 / 82.0,
1067 45.0 / 164.0,
1068 18.0 / 41.0,
1069 ],
1070 &[
1072 3.0 / 205.0,
1073 0.0,
1074 0.0,
1075 0.0,
1076 0.0,
1077 -6.0 / 41.0,
1078 -3.0 / 205.0,
1079 -3.0 / 41.0,
1080 3.0 / 41.0,
1081 6.0 / 41.0,
1082 0.0,
1083 ],
1084 &[
1086 -1777.0 / 4100.0,
1087 0.0,
1088 0.0,
1089 -341.0 / 164.0,
1090 4496.0 / 1025.0,
1091 -289.0 / 82.0,
1092 2193.0 / 4100.0,
1093 51.0 / 82.0,
1094 33.0 / 164.0,
1095 12.0 / 41.0,
1096 0.0,
1097 1.0,
1098 ],
1099 ];
1100
1101 const BU: &'static [f64] = &[
1105 41.0 / 420.0, 0.0,
1107 0.0,
1108 0.0,
1109 0.0,
1110 34.0 / 105.0,
1111 9.0 / 35.0,
1112 9.0 / 35.0,
1113 9.0 / 280.0,
1114 9.0 / 280.0,
1115 41.0 / 420.0, -41.0 / 840.0, -41.0 / 840.0, ];
1119
1120 const BE: &'static [f64] = &[
1123 41.0 / 840.0,
1124 0.0,
1125 0.0,
1126 0.0,
1127 0.0,
1128 34.0 / 105.0,
1129 9.0 / 35.0,
1130 9.0 / 35.0,
1131 9.0 / 280.0,
1132 9.0 / 280.0,
1133 41.0 / 840.0,
1134 0.0,
1135 0.0,
1136 ];
1137
1138 fn tol(&self) -> f64 {
1139 self.tol
1140 }
1141 fn safety_factor(&self) -> f64 {
1142 self.safety_factor
1143 }
1144 fn min_step_size(&self) -> f64 {
1145 self.min_step_size
1146 }
1147 fn max_step_size(&self) -> f64 {
1148 self.max_step_size
1149 }
1150 fn max_step_iter(&self) -> usize {
1151 self.max_step_iter
1152 }
1153}
1154
1155const SQRT3: f64 = 1.7320508075688772;
1161const C1: f64 = 0.5 - SQRT3 / 6.0;
1162const C2: f64 = 0.5 + SQRT3 / 6.0;
1163const A11: f64 = 0.25;
1164const A12: f64 = 0.25 - SQRT3 / 6.0;
1165const A21: f64 = 0.25 + SQRT3 / 6.0;
1166const A22: f64 = 0.25;
1167const B1: f64 = 0.5;
1168const B2: f64 = 0.5;
1169
1170#[derive(Debug, Clone, Copy)]
1175#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
1176#[cfg_attr(
1177 feature = "rkyv",
1178 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
1179)]
1180pub enum ImplicitSolver {
1181 FixedPoint,
1182 Broyden,
1183 }
1185
1186#[derive(Debug, Clone, Copy)]
1198#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
1199#[cfg_attr(
1200 feature = "rkyv",
1201 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
1202)]
1203pub struct GL4 {
1204 pub solver: ImplicitSolver,
1205 pub tol: f64,
1206 pub max_step_iter: usize,
1207}
1208
1209impl Default for GL4 {
1210 fn default() -> Self {
1211 GL4 {
1212 solver: ImplicitSolver::FixedPoint,
1213 tol: 1e-8,
1214 max_step_iter: 100,
1215 }
1216 }
1217}
1218
1219impl GL4 {
1220 pub fn new(solver: ImplicitSolver, tol: f64, max_step_iter: usize) -> Self {
1221 GL4 {
1222 solver,
1223 tol,
1224 max_step_iter,
1225 }
1226 }
1227}
1228
1229impl ODEIntegrator for GL4 {
1230 #[allow(non_snake_case)]
1231 #[inline]
1232 fn step<P: ODEProblem>(&self, problem: &P, t: f64, y: &mut [f64], dt: f64) -> Result<f64> {
1233 let n = y.len();
1234 let mut k1 = vec![0.0; n];
1238 let mut k2 = vec![0.0; n];
1239
1240 problem.rhs(t, y, &mut k1)?;
1242 k2.copy_from_slice(&k1);
1243
1244 match self.solver {
1245 ImplicitSolver::FixedPoint => {
1246 let mut y1 = vec![0.0; n];
1248 let mut y2 = vec![0.0; n];
1249
1250 for _ in 0..self.max_step_iter {
1251 let k1_old = k1.clone();
1252 let k2_old = k2.clone();
1253
1254 for i in 0..n {
1255 y1[i] = y[i] + dt * (A11 * k1[i] + A12 * k2[i]);
1256 y2[i] = y[i] + dt * (A11 * k1[i] + A12 * k2[i]);
1257 }
1258
1259 problem.rhs(t + C1 * dt, &y1, &mut k1)?;
1261 problem.rhs(t + C2 * dt, &y2, &mut k2)?;
1262
1263 let mut max_diff = 0f64;
1265 for i in 0..n {
1266 max_diff = max_diff.max((k1[i] - k1_old[i]).abs());
1267 max_diff = max_diff.max((k2[i] - k2_old[i]).abs());
1268 }
1269
1270 if max_diff < self.tol {
1271 break;
1272 }
1273 }
1274 }
1275 ImplicitSolver::Broyden => {
1276 let m = 2 * n;
1277 let mut U = vec![0.0; m];
1278 U[..n].copy_from_slice(&k1);
1279 U[n..].copy_from_slice(&k2);
1280
1281 let mut F_vec = vec![0.0; m];
1283 compute_F(problem, t, y, dt, &U, &mut F_vec)?;
1284
1285 let mut J_inv = eye(m);
1287
1288 for _ in 0..self.max_step_iter {
1290 let delta = (&J_inv * &F_vec).mul_scalar(-1.0);
1292
1293 U.iter_mut()
1295 .zip(delta.iter())
1296 .for_each(|(u, d)| *u += *d);
1297
1298 let mut F_new = vec![0.0; m];
1299 compute_F(problem, t, y, dt, &U, &mut F_new)?;
1300
1301 if F_new.norm(Norm::LInf) < self.tol {
1303 break;
1304 }
1305
1306 let delta_F = F_new.sub_vec(&F_vec);
1308
1309 let J_inv_delta_F = &J_inv * &delta_F;
1311
1312 let denom = delta.dot(&J_inv_delta_F);
1313 if denom.abs() < 1e-12 {
1314 break;
1315 }
1316
1317 let delta_minus_J_inv_delta_F = delta.sub_vec(&J_inv_delta_F).to_col();
1320 let delta_T_J_inv = &delta.to_row() * &J_inv;
1321 let update = (delta_minus_J_inv_delta_F * delta_T_J_inv) / denom;
1322 J_inv = J_inv + update;
1323 F_vec = F_new;
1324 }
1325
1326 k1.copy_from_slice(&U[..n]);
1327 k2.copy_from_slice(&U[n..]);
1328 }
1329 }
1330
1331 for i in 0..n {
1332 y[i] += dt * (B1 * k1[i] + B2 * k2[i]);
1333 }
1334
1335 Ok(dt)
1336 }
1337}
1338
1339#[allow(non_snake_case)]
1379fn compute_F<P: ODEProblem>(
1380 problem: &P,
1381 t: f64,
1382 y: &[f64],
1383 dt: f64,
1384 U: &[f64], F: &mut [f64],
1386) -> Result<()> {
1387 let n = y.len();
1388 let (k1_slice, k2_slice) = U.split_at(n);
1389
1390 let mut y1 = vec![0.0; n];
1391 let mut y2 = vec![0.0; n];
1392
1393 for i in 0..n {
1394 y1[i] = y[i] + dt * (A11 * k1_slice[i] + A12 * k2_slice[i]);
1395 y2[i] = y[i] + dt * (A21 * k1_slice[i] + A22 * k2_slice[i]);
1396 }
1397
1398 let (f1, f2) = F.split_at_mut(n);
1400 problem.rhs(t + C1 * dt, &y1, f1)?;
1401 problem.rhs(t + C2 * dt, &y2, f2)?;
1402
1403 for i in 0..n {
1405 f1[i] = k1_slice[i] - f1[i];
1406 f2[i] = k2_slice[i] - f2[i];
1407 }
1408 Ok(())
1409}