1use self::AD::{AD0, AD1, AD2};
109use crate::statistics::ops::C;
110use crate::traits::{fp::FPVector, math::Vector, stable::StableFn, sugar::VecOps};
111use peroxide_num::{ExpLogOps, PowOps, TrigOps};
112use std::iter::{DoubleEndedIterator, ExactSizeIterator, FromIterator};
113use std::ops::{Add, Div, Index, IndexMut, Mul, Neg, Sub};
114
115#[derive(Debug, Copy, Clone, PartialEq)]
116pub enum AD {
117 AD0(f64),
118 AD1(f64, f64),
119 AD2(f64, f64, f64),
120}
121
122impl PartialOrd for AD {
123 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
124 self.x().partial_cmp(&other.x())
125 }
126}
127
128impl std::fmt::Display for AD {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 let s = format!("{:?}", self);
131 write!(f, "{}", s)
132 }
133}
134
135impl AD {
136 pub fn to_order(&self, n: usize) -> Self {
137 if n == self.order() {
138 return *self;
139 }
140
141 let mut z = match n {
142 0 => AD0(0f64),
143 1 => AD1(0f64, 0f64),
144 2 => AD2(0f64, 0f64, 0f64),
145 _ => panic!("No more index exists"),
146 };
147
148 for i in 0..z.len().min(self.len()) {
149 z[i] = self[i];
150 }
151
152 z
153 }
154
155 pub fn order(&self) -> usize {
156 match self {
157 AD0(_) => 0,
158 AD1(_, _) => 1,
159 AD2(_, _, _) => 2,
160 }
161 }
162
163 pub fn len(&self) -> usize {
164 match self {
165 AD0(_) => 1,
166 AD1(_, _) => 2,
167 AD2(_, _, _) => 3,
168 }
169 }
170
171 pub fn is_empty(&self) -> bool {
172 self.len() == 0
173 }
174
175 pub fn iter(&self) -> ADIter {
176 self.into_iter()
177 }
178
179 pub fn iter_mut(&mut self) -> ADIterMut {
180 self.into_iter()
181 }
182
183 pub fn from_order(n: usize) -> Self {
184 match n {
185 0 => AD0(0f64),
186 1 => AD1(0f64, 0f64),
187 2 => AD2(0f64, 0f64, 0f64),
188 _ => panic!("Not yet implemented higher order AD"),
189 }
190 }
191
192 pub fn empty(&self) -> Self {
193 match self {
194 AD0(_) => AD0(0f64),
195 AD1(_, _) => AD1(0f64, 0f64),
196 AD2(_, _, _) => AD2(0f64, 0f64, 0f64),
197 }
198 }
199
200 pub fn set_x(&mut self, x: f64) {
201 match self {
202 AD0(t) => {
203 *t = x;
204 }
205 AD1(t, _) => {
206 *t = x;
207 }
208 AD2(t, _, _) => {
209 *t = x;
210 }
211 }
212 }
213
214 pub fn set_dx(&mut self, dx: f64) {
215 match self {
216 AD0(_) => panic!("Can't set dx for AD0"),
217 AD1(_, dt) => {
218 *dt = dx;
219 }
220 AD2(_, dt, _) => {
221 *dt = dx;
222 }
223 }
224 }
225
226 pub fn set_ddx(&mut self, ddx: f64) {
227 match self {
228 AD0(_) => panic!("Can't set ddx for AD0"),
229 AD1(_, _) => panic!("Can't set ddx for AD1"),
230 AD2(_, _, ddt) => {
231 *ddt = ddx;
232 }
233 }
234 }
235
236 pub fn x(&self) -> f64 {
237 match self {
238 AD0(x) => *x,
239 AD1(x, _) => *x,
240 AD2(x, _, _) => *x,
241 }
242 }
243
244 pub fn dx(&self) -> f64 {
245 match self {
246 AD0(_) => 0f64,
247 AD1(_, dx) => *dx,
248 AD2(_, dx, _) => *dx,
249 }
250 }
251
252 pub fn ddx(&self) -> f64 {
253 match self {
254 AD0(_) => 0f64,
255 AD1(_, _) => 0f64,
256 AD2(_, _, ddx) => *ddx,
257 }
258 }
259
260 pub fn x_ref(&self) -> Option<&f64> {
261 match self {
262 AD0(x) => Some(x),
263 AD1(x, _) => Some(x),
264 AD2(x, _, _) => Some(x),
265 }
266 }
267
268 pub fn dx_ref(&self) -> Option<&f64> {
269 match self {
270 AD0(_) => None,
271 AD1(_, dx) => Some(dx),
272 AD2(_, dx, _) => Some(dx),
273 }
274 }
275
276 pub fn ddx_ref(&self) -> Option<&f64> {
277 match self {
278 AD0(_) => None,
279 AD1(_, _) => None,
280 AD2(_, _, ddx) => Some(ddx),
281 }
282 }
283
284 pub fn x_mut(&mut self) -> Option<&mut f64> {
285 match self {
286 AD0(x) => Some(x),
287 AD1(x, _) => Some(x),
288 AD2(x, _, _) => Some(x),
289 }
290 }
291
292 pub fn dx_mut(&mut self) -> Option<&mut f64> {
293 match self {
294 AD0(_) => None,
295 AD1(_, dx) => Some(dx),
296 AD2(_, dx, _) => Some(dx),
297 }
298 }
299
300 pub fn ddx_mut(&mut self) -> Option<&mut f64> {
301 match self {
302 AD0(_) => None,
303 AD1(_, _) => None,
304 AD2(_, _, ddx) => Some(ddx),
305 }
306 }
307
308 #[allow(dead_code)]
309 unsafe fn x_ptr(&self) -> Option<*const f64> {
310 match self {
311 AD0(x) => Some(x),
312 AD1(x, _) => Some(x),
313 AD2(x, _, _) => Some(x),
314 }
315 }
316
317 #[allow(dead_code)]
318 unsafe fn dx_ptr(&self) -> Option<*const f64> {
319 match self {
320 AD0(_) => None,
321 AD1(_, dx) => Some(dx),
322 AD2(_, dx, _) => Some(dx),
323 }
324 }
325
326 #[allow(dead_code)]
327 unsafe fn ddx_ptr(&self) -> Option<*const f64> {
328 match self {
329 AD0(_) => None,
330 AD1(_, _) => None,
331 AD2(_, _, ddx) => Some(ddx),
332 }
333 }
334
335 unsafe fn x_mut_ptr(&mut self) -> Option<*mut f64> {
336 match self {
337 AD0(x) => Some(&mut *x),
338 AD1(x, _) => Some(&mut *x),
339 AD2(x, _, _) => Some(&mut *x),
340 }
341 }
342
343 unsafe fn dx_mut_ptr(&mut self) -> Option<*mut f64> {
344 match self {
345 AD0(_) => None,
346 AD1(_, dx) => Some(&mut *dx),
347 AD2(_, dx, _) => Some(&mut *dx),
348 }
349 }
350
351 unsafe fn ddx_mut_ptr(&mut self) -> Option<*mut f64> {
352 match self {
353 AD0(_) => None,
354 AD1(_, _) => None,
355 AD2(_, _, ddx) => Some(&mut *ddx),
356 }
357 }
358}
359
360impl Index<usize> for AD {
361 type Output = f64;
362
363 fn index(&self, index: usize) -> &Self::Output {
364 match index {
365 0 => self.x_ref().unwrap(),
366 1 => self.dx_ref().expect("AD0 has no dx"),
367 2 => self.ddx_ref().expect("AD0, AD1 have no ddx"),
368 _ => panic!("No more index exists"),
369 }
370 }
371}
372
373impl IndexMut<usize> for AD {
374 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
375 match index {
376 0 => self.x_mut().unwrap(),
377 1 => self.dx_mut().expect("AD0 has no dx"),
378 2 => self.ddx_mut().expect("AD0, AD1 have no ddx"),
379 _ => panic!("No more index exists"),
380 }
381 }
382}
383
384#[derive(Debug)]
385pub struct ADIntoIter {
386 ad: AD,
387 index: usize,
388 r_index: usize,
389}
390
391#[derive(Debug)]
392pub struct ADIter<'a> {
393 ad: &'a AD,
394 index: usize,
395 r_index: usize,
396}
397
398#[derive(Debug)]
399pub struct ADIterMut<'a> {
400 ad: &'a mut AD,
401 index: usize,
402 r_index: usize,
403}
404
405impl IntoIterator for AD {
406 type Item = f64;
407 type IntoIter = ADIntoIter;
408
409 fn into_iter(self) -> Self::IntoIter {
410 ADIntoIter {
411 ad: self,
412 index: 0,
413 r_index: 0,
414 }
415 }
416}
417
418impl<'a> IntoIterator for &'a AD {
419 type Item = &'a f64;
420 type IntoIter = ADIter<'a>;
421
422 fn into_iter(self) -> Self::IntoIter {
423 ADIter {
424 ad: self,
425 index: 0,
426 r_index: 0,
427 }
428 }
429}
430
431impl<'a> IntoIterator for &'a mut AD {
432 type Item = &'a mut f64;
433 type IntoIter = ADIterMut<'a>;
434
435 fn into_iter(self) -> Self::IntoIter {
436 ADIterMut {
437 ad: self,
438 index: 0,
439 r_index: 0,
440 }
441 }
442}
443
444impl Iterator for ADIntoIter {
445 type Item = f64;
446
447 fn next(&mut self) -> Option<Self::Item> {
448 let l = self.ad.len();
449 if self.index + self.r_index < l {
450 let result = match self.index {
451 0 => Some(self.ad.x()),
452 1 => match self.ad {
453 AD0(_) => None,
454 AD1(_, dx) => Some(dx),
455 AD2(_, dx, _) => Some(dx),
456 },
457 2 => match self.ad {
458 AD0(_) => None,
459 AD1(_, _) => None,
460 AD2(_, _, ddx) => Some(ddx),
461 },
462 _ => None,
463 };
464 self.index += 1;
465 result
466 } else {
467 None
468 }
469 }
470
471 fn size_hint(&self) -> (usize, Option<usize>) {
472 let lower = self.ad.len() - (self.index + self.r_index);
473 let upper = self.ad.len() - (self.index + self.r_index);
474 (lower, Some(upper))
475 }
476}
477
478impl<'a> Iterator for ADIter<'a> {
479 type Item = &'a f64;
480
481 fn next(&mut self) -> Option<Self::Item> {
482 let l = self.ad.len();
483 if self.index + self.r_index < l {
484 let result = match self.index {
485 0 => self.ad.x_ref(),
486 1 => self.ad.dx_ref(),
487 2 => self.ad.ddx_ref(),
488 _ => None,
489 };
490 self.index += 1;
491 result
492 } else {
493 None
494 }
495 }
496
497 fn size_hint(&self) -> (usize, Option<usize>) {
498 let lower = self.ad.len() - (self.index + self.r_index);
499 let upper = self.ad.len() - (self.index + self.r_index);
500 (lower, Some(upper))
501 }
502}
503
504impl<'a> Iterator for ADIterMut<'a> {
505 type Item = &'a mut f64;
506
507 fn next(&mut self) -> Option<Self::Item> {
508 let l = self.ad.len();
509 if self.index + self.r_index < l {
510 unsafe {
511 let result = match self.index {
512 0 => self.ad.x_mut_ptr(),
513 1 => self.ad.dx_mut_ptr(),
514 2 => self.ad.ddx_mut_ptr(),
515 _ => None,
516 };
517 self.index += 1;
518 match result {
519 None => None,
520 Some(ad) => Some(&mut *ad),
521 }
522 }
523 } else {
524 None
525 }
526 }
527
528 fn size_hint(&self) -> (usize, Option<usize>) {
529 let lower = self.ad.len() - (self.index + self.r_index);
530 let upper = self.ad.len() - (self.index + self.r_index);
531 (lower, Some(upper))
532 }
533}
534
535impl FromIterator<f64> for AD {
536 fn from_iter<T: IntoIterator<Item = f64>>(iter: T) -> Self {
537 let into_iter = iter.into_iter();
538 let s = into_iter.size_hint().0 - 1;
539 let mut z = match s {
540 0 => AD0(0f64),
541 1 => AD1(0f64, 0f64),
542 2 => AD2(0f64, 0f64, 0f64),
543 _ => panic!("Higher than order 3 is not allowed"),
544 };
545 for (i, elem) in into_iter.enumerate() {
546 z[i] = elem;
547 }
548 z
549 }
550}
551
552impl<'a> FromIterator<&'a f64> for AD {
553 fn from_iter<T: IntoIterator<Item = &'a f64>>(iter: T) -> Self {
554 let into_iter = iter.into_iter();
555 let s = into_iter.size_hint().0 - 1;
556 let mut z = match s {
557 0 => AD0(0f64),
558 1 => AD1(0f64, 0f64),
559 2 => AD2(0f64, 0f64, 0f64),
560 _ => panic!("Higher than order 3 is not allowed"),
561 };
562 for (i, &elem) in into_iter.enumerate() {
563 z[i] = elem;
564 }
565 z
566 }
567}
568
569impl DoubleEndedIterator for ADIntoIter {
570 fn next_back(&mut self) -> Option<Self::Item> {
571 if self.index + self.r_index == self.ad.len() {
572 return None;
573 }
574 let order = self.ad.order();
575 let result = self.ad[order - self.r_index];
576 self.r_index += 1;
577 Some(result)
578 }
579}
580
581impl<'a> DoubleEndedIterator for ADIter<'a> {
582 fn next_back(&mut self) -> Option<Self::Item> {
583 if self.index + self.r_index == self.ad.len() {
584 return None;
585 }
586 let order = self.ad.order();
587 let result = &self.ad[order - self.r_index];
588 self.r_index += 1;
589 Some(result)
590 }
591}
592
593impl ExactSizeIterator for ADIntoIter {
594 fn len(&self) -> usize {
595 self.ad.len() - (self.index + self.r_index)
596 }
597}
598
599impl<'a> ExactSizeIterator for ADIter<'a> {
600 fn len(&self) -> usize {
601 self.ad.len() - (self.index + self.r_index)
602 }
603}
604
605impl Neg for AD {
606 type Output = Self;
607
608 fn neg(self) -> Self::Output {
609 self.into_iter().map(|x| -x).collect()
610 }
611}
612
613impl Add<AD> for AD {
614 type Output = Self;
615
616 fn add(self, rhs: AD) -> Self::Output {
617 let ord = self.order().max(rhs.order());
618 let (a, b) = (self.to_order(ord), rhs.to_order(ord));
619
620 a.into_iter().zip(b).map(|(x, y)| x + y).collect()
621 }
622}
623
624impl Sub<AD> for AD {
625 type Output = Self;
626
627 fn sub(self, rhs: AD) -> Self::Output {
628 let ord = self.order().max(rhs.order());
629 let (a, b) = (self.to_order(ord), rhs.to_order(ord));
630
631 a.into_iter().zip(b).map(|(x, y)| x - y).collect()
632 }
633}
634
635impl Mul<AD> for AD {
636 type Output = Self;
637
638 fn mul(self, rhs: AD) -> Self::Output {
639 let ord = self.order().max(rhs.order());
640 let (a, b) = (self.to_order(ord), rhs.to_order(ord));
641
642 let mut z = a;
643 for t in 0..z.len() {
644 z[t] = a
645 .into_iter()
646 .take(t + 1)
647 .zip(b.into_iter().take(t + 1).rev())
648 .enumerate()
649 .fold(0f64, |s, (k, (x1, y1))| s + (C(t, k) as f64) * x1 * y1)
650 }
651 z
652 }
653}
654
655impl Div<AD> for AD {
656 type Output = Self;
657
658 fn div(self, rhs: AD) -> Self::Output {
659 let ord = self.order().max(rhs.order());
660 let (a, b) = (self.to_order(ord), rhs.to_order(ord));
661
662 let mut z = a;
663 z[0] = a[0] / b[0];
664 let y0 = 1f64 / b[0];
665 for i in 1..z.len() {
666 let mut s = 0f64;
667 for (j, (&y1, &z1)) in b
668 .iter()
669 .skip(1)
670 .take(i)
671 .zip(z.iter().take(i).rev())
672 .enumerate()
673 {
674 s += (C(i, j + 1) as f64) * y1 * z1;
675 }
676 z[i] = y0 * (a[i] - s);
677 }
678 z
679 }
680}
681
682impl ExpLogOps for AD {
683 type Float = f64;
684
685 fn exp(&self) -> Self {
686 let mut z = self.empty();
687 z[0] = self[0].exp();
688 for i in 1..z.len() {
689 z[i] = z
690 .iter()
691 .take(i)
692 .zip(self.iter().skip(1).take(i).rev())
693 .enumerate()
694 .fold(0f64, |x, (k, (&z1, &x1))| {
695 x + (C(i - 1, k) as f64) * x1 * z1
696 });
697 }
698 z
699 }
700
701 fn ln(&self) -> Self {
702 let mut z = self.empty();
703 z[0] = self[0].ln();
704 let x0 = 1f64 / self[0];
705 for i in 1..z.len() {
706 let mut s = 0f64;
707 for (k, (&z1, &x1)) in z
708 .iter()
709 .skip(1)
710 .take(i - 1)
711 .zip(self.iter().skip(1).take(i - 1).rev())
712 .enumerate()
713 {
714 s += (C(i - 1, k + 1) as f64) * z1 * x1;
715 }
716 z[i] = x0 * (self[i] - s);
717 }
718 z
719 }
720
721 fn log(&self, base: f64) -> Self {
722 self.ln().iter().map(|x| x / base.ln()).collect()
723 }
724
725 fn log2(&self) -> Self {
726 self.log(2f64)
727 }
728
729 fn log10(&self) -> Self {
730 self.log(10f64)
731 }
732}
733
734impl PowOps for AD {
735 type Float = f64;
736
737 fn powi(&self, n: i32) -> Self {
738 let mut z = *self;
739 for _i in 1..n {
740 z = z * *self;
741 }
742 z
743 }
744
745 fn powf(&self, f: f64) -> Self {
746 let ln_x = self.ln();
747 let mut z = self.empty();
748 z[0] = self.x().powf(f);
749 for i in 1..z.len() {
750 let mut s = 0f64;
751 for (j, (&z1, &ln_x1)) in z
752 .iter()
753 .skip(1)
754 .take(i - 1)
755 .zip(ln_x.iter().skip(1).take(i - 1).rev())
756 .enumerate()
757 {
758 s += (C(i - 1, j + 1) as f64) * z1 * ln_x1;
759 }
760 z[i] = f * (z[0] * ln_x[i] + s);
761 }
762 z
763 }
764
765 fn pow(&self, y: Self) -> Self {
766 let ln_x = self.ln();
767 let p = y * ln_x;
768 let mut z = self.empty();
769 z[0] = self.x().powf(y.x());
770 for n in 1..z.len() {
771 let mut s = 0f64;
772 for (k, (&z1, &p1)) in z
773 .iter()
774 .skip(1)
775 .take(n - 1)
776 .zip(p.iter().skip(1).take(n - 1).rev())
777 .enumerate()
778 {
779 s += (C(n - 1, k + 1) as f64) * z1 * p1;
780 }
781 z[n] = z[0] * p[n] + s;
782 }
783 z
784 }
785
786 fn sqrt(&self) -> Self {
787 self.powf(0.5f64)
788 }
789}
790
791impl TrigOps for AD {
792 fn sin_cos(&self) -> (Self, Self) {
793 let mut u = self.empty();
794 let mut v = self.empty();
795 u[0] = self[0].sin();
796 v[0] = self[0].cos();
797 for i in 1..u.len() {
798 u[i] = v
799 .iter()
800 .take(i)
801 .zip(self.iter().skip(1).take(i).rev())
802 .enumerate()
803 .fold(0f64, |x, (k, (&v1, &x1))| {
804 x + (C(i - 1, k) as f64) * x1 * v1
805 });
806 v[i] = u
807 .iter()
808 .take(i)
809 .zip(self.iter().skip(1).take(i).rev())
810 .enumerate()
811 .fold(0f64, |x, (k, (&u1, &x1))| {
812 x + (C(i - 1, k) as f64) * x1 * u1
813 });
814 }
815 (u, v)
816 }
817
818 fn tan(&self) -> Self {
819 let (s, c) = self.sin_cos();
820 s / c
821 }
822
823 fn sinh(&self) -> Self {
824 let mut u = self.empty();
825 let mut v = self.empty();
826 u[0] = self[0].sinh();
827 v[0] = self[0].cosh();
828 for i in 1..u.len() {
829 u[i] = v
830 .iter()
831 .take(i)
832 .zip(self.iter().skip(1).take(i).rev())
833 .enumerate()
834 .fold(0f64, |x, (k, (&v1, &x1))| {
835 x + (C(i - 1, k) as f64) * x1 * v1
836 });
837 v[i] = u
838 .iter()
839 .take(i)
840 .zip(self.iter().skip(1).take(i).rev())
841 .enumerate()
842 .fold(0f64, |x, (k, (&u1, &x1))| {
843 x + (C(i - 1, k) as f64) * x1 * u1
844 });
845 }
846 u
847 }
848
849 fn cosh(&self) -> Self {
850 let mut u = self.empty();
851 let mut v = self.empty();
852 u[0] = self[0].sinh();
853 v[0] = self[0].cosh();
854 for i in 1..u.len() {
855 u[i] = v
856 .iter()
857 .take(i)
858 .zip(self.iter().skip(1).take(i).rev())
859 .enumerate()
860 .fold(0f64, |x, (k, (&v1, &x1))| {
861 x + (C(i - 1, k) as f64) * x1 * v1
862 });
863 v[i] = u
864 .iter()
865 .take(i)
866 .zip(self.iter().skip(1).take(i).rev())
867 .enumerate()
868 .fold(0f64, |x, (k, (&u1, &x1))| {
869 x + (C(i - 1, k) as f64) * x1 * u1
870 });
871 }
872 v
873 }
874
875 fn tanh(&self) -> Self {
876 let mut u = self.empty();
877 let mut v = self.empty();
878 u[0] = self[0].sinh();
879 v[0] = self[0].cosh();
880 for i in 1..u.len() {
881 u[i] = v
882 .iter()
883 .take(i)
884 .zip(self.iter().skip(1).take(i).rev())
885 .enumerate()
886 .fold(0f64, |x, (k, (&v1, &x1))| {
887 x + (C(i - 1, k) as f64) * x1 * v1
888 });
889 v[i] = u
890 .iter()
891 .take(i)
892 .zip(self.iter().skip(1).take(i).rev())
893 .enumerate()
894 .fold(0f64, |x, (k, (&u1, &x1))| {
895 x + (C(i - 1, k) as f64) * x1 * u1
896 });
897 }
898 u / v
899 }
900
901 fn asin(&self) -> Self {
902 let dx = 1f64 / (1f64 - self.powi(2)).sqrt();
903 let mut z = self.empty();
904 z[0] = self[0].asin();
905 for n in 1..z.len() {
906 z[n] = dx
907 .iter()
908 .take(n)
909 .zip(self.iter().skip(1).take(n).rev())
910 .enumerate()
911 .fold(0f64, |s, (k, (&q1, &x1))| {
912 s + (C(n - 1, k) as f64) * x1 * q1
913 });
914 }
915 z
916 }
917
918 fn acos(&self) -> Self {
919 let dx = (-1f64) / (1f64 - self.powi(2)).sqrt();
920 let mut z = self.empty();
921 z[0] = self[0].acos();
922 for n in 1..z.len() {
923 z[n] = dx
924 .iter()
925 .take(n)
926 .zip(self.iter().skip(1).take(n).rev())
927 .enumerate()
928 .fold(0f64, |s, (k, (&q1, &x1))| {
929 s + (C(n - 1, k) as f64) * x1 * q1
930 });
931 }
932 z
933 }
934
935 fn atan(&self) -> Self {
936 let dx = 1f64 / (1f64 + self.powi(2));
937 let mut z = self.empty();
938 z[0] = self[0].atan();
939 for n in 1..z.len() {
940 z[n] = dx
941 .iter()
942 .take(n)
943 .zip(self.iter().skip(1).take(n).rev())
944 .enumerate()
945 .fold(0f64, |s, (k, (&q1, &x1))| {
946 s + (C(n - 1, k) as f64) * x1 * q1
947 });
948 }
949 z
950 }
951
952 fn asinh(&self) -> Self {
953 let dx = 1f64 / (1f64 + self.powi(2)).sqrt();
954 let mut z = self.empty();
955 z[0] = self[0].asinh();
956 for n in 1..z.len() {
957 z[n] = dx
958 .iter()
959 .take(n)
960 .zip(self.iter().skip(1).take(n).rev())
961 .enumerate()
962 .fold(0f64, |s, (k, (&q1, &x1))| {
963 s + (C(n - 1, k) as f64) * x1 * q1
964 });
965 }
966 z
967 }
968
969 fn acosh(&self) -> Self {
970 let dx = 1f64 / (self.powi(2) - 1f64).sqrt();
971 let mut z = self.empty();
972 z[0] = self[0].acosh();
973 for n in 1..z.len() {
974 z[n] = dx
975 .iter()
976 .take(n)
977 .zip(self.iter().skip(1).take(n).rev())
978 .enumerate()
979 .fold(0f64, |s, (k, (&q1, &x1))| {
980 s + (C(n - 1, k) as f64) * x1 * q1
981 });
982 }
983 z
984 }
985
986 fn atanh(&self) -> Self {
987 let dx = 1f64 / (1f64 - self.powi(2));
988 let mut z = self.empty();
989 z[0] = self[0].atanh();
990 for n in 1..z.len() {
991 z[n] = dx
992 .iter()
993 .take(n)
994 .zip(self.iter().skip(1).take(n).rev())
995 .enumerate()
996 .fold(0f64, |s, (k, (&q1, &x1))| {
997 s + (C(n - 1, k) as f64) * x1 * q1
998 });
999 }
1000 z
1001 }
1002}
1003
1004impl From<f64> for AD {
1005 fn from(other: f64) -> Self {
1006 AD0(other)
1007 }
1008}
1009
1010impl From<AD> for f64 {
1011 fn from(other: AD) -> Self {
1012 other.x()
1013 }
1014}
1015
1016impl Add<f64> for AD {
1017 type Output = Self;
1018
1019 fn add(self, rhs: f64) -> Self::Output {
1020 let mut z = self;
1021 z[0] += rhs;
1022 z
1023 }
1024}
1025
1026impl Sub<f64> for AD {
1027 type Output = Self;
1028
1029 fn sub(self, rhs: f64) -> Self::Output {
1030 let mut z = self;
1031 z[0] -= rhs;
1032 z
1033 }
1034}
1035
1036impl Mul<f64> for AD {
1037 type Output = Self;
1038
1039 fn mul(self, rhs: f64) -> Self::Output {
1040 self.iter().map(|&x| x * rhs).collect()
1041 }
1042}
1043
1044impl Div<f64> for AD {
1045 type Output = Self;
1046
1047 fn div(self, rhs: f64) -> Self::Output {
1048 self.iter().map(|&x| x / rhs).collect()
1049 }
1050}
1051
1052impl Add<AD> for f64 {
1053 type Output = AD;
1054
1055 fn add(self, rhs: AD) -> Self::Output {
1056 let mut z = rhs;
1057 z[0] += self;
1058 z
1059 }
1060}
1061
1062impl Sub<AD> for f64 {
1063 type Output = AD;
1064
1065 fn sub(self, rhs: AD) -> Self::Output {
1066 let mut z = rhs.empty();
1067 z[0] = self;
1068 z - rhs
1069 }
1070}
1071
1072impl Mul<AD> for f64 {
1073 type Output = AD;
1074
1075 fn mul(self, rhs: AD) -> Self::Output {
1076 rhs.iter().map(|&x| x * self).collect()
1077 }
1078}
1079
1080impl Div<AD> for f64 {
1081 type Output = AD;
1082
1083 fn div(self, rhs: AD) -> Self::Output {
1084 let ad0 = AD::from(self);
1085 ad0 / rhs
1086 }
1087}
1088
1089pub struct ADFn<F> {
1254 f: Box<F>,
1255 grad_level: usize,
1256}
1257
1258impl<F: Clone> ADFn<F> {
1259 pub fn new(f: F) -> Self {
1260 Self {
1261 f: Box::new(f),
1262 grad_level: 0usize,
1263 }
1264 }
1265
1266 pub fn grad(&self) -> Self {
1268 assert!(self.grad_level < 2, "Higher order AD is not allowed");
1269 ADFn {
1270 f: self.f.clone(),
1271 grad_level: self.grad_level + 1,
1272 }
1273 }
1274}
1275
1276impl<F: Fn(AD) -> AD> StableFn<f64> for ADFn<F> {
1277 type Output = f64;
1278 fn call_stable(&self, target: f64) -> Self::Output {
1279 match self.grad_level {
1280 0 => (self.f)(AD::from(target)).into(),
1281 1 => (self.f)(AD1(target, 1f64)).dx(),
1282 2 => (self.f)(AD2(target, 1f64, 0f64)).ddx(),
1283 _ => panic!("Higher order AD is not allowed"),
1284 }
1285 }
1286}
1287
1288impl<F: Fn(AD) -> AD> StableFn<AD> for ADFn<F> {
1289 type Output = AD;
1290 fn call_stable(&self, target: AD) -> Self::Output {
1291 (self.f)(target)
1292 }
1293}
1294
1295impl<F: Fn(Vec<AD>) -> Vec<AD>> StableFn<Vec<f64>> for ADFn<F> {
1296 type Output = Vec<f64>;
1297 fn call_stable(&self, target: Vec<f64>) -> Self::Output {
1298 ((self.f)(target.iter().map(|&t| AD::from(t)).collect()))
1299 .iter()
1300 .map(|&t| t.x())
1301 .collect()
1302 }
1303}
1304
1305impl<F: Fn(Vec<AD>) -> Vec<AD>> StableFn<Vec<AD>> for ADFn<F> {
1306 type Output = Vec<AD>;
1307 fn call_stable(&self, target: Vec<AD>) -> Self::Output {
1308 (self.f)(target)
1309 }
1310}
1311
1312impl<'a, F: Fn(&Vec<AD>) -> Vec<AD>> StableFn<&'a Vec<f64>> for ADFn<F> {
1313 type Output = Vec<f64>;
1314 fn call_stable(&self, target: &'a Vec<f64>) -> Self::Output {
1315 ((self.f)(&target.iter().map(|&t| AD::from(t)).collect()))
1316 .iter()
1317 .map(|&t| t.x())
1318 .collect()
1319 }
1320}
1321
1322impl<'a, F: Fn(&Vec<AD>) -> Vec<AD>> StableFn<&'a Vec<AD>> for ADFn<F> {
1323 type Output = Vec<AD>;
1324 fn call_stable(&self, target: &'a Vec<AD>) -> Self::Output {
1325 (self.f)(target)
1326 }
1327}
1328
1329pub trait ADVec {
1338 fn to_ad_vec(&self) -> Vec<AD>;
1339 fn to_f64_vec(&self) -> Vec<f64>;
1340}
1341
1342impl ADVec for Vec<f64> {
1343 fn to_ad_vec(&self) -> Vec<AD> {
1344 self.iter().map(|&t| AD::from(t)).collect()
1345 }
1346
1347 fn to_f64_vec(&self) -> Vec<f64> {
1348 self.clone()
1349 }
1350}
1351
1352impl ADVec for Vec<AD> {
1353 fn to_ad_vec(&self) -> Vec<AD> {
1354 self.clone()
1355 }
1356
1357 fn to_f64_vec(&self) -> Vec<f64> {
1358 self.iter().map(|t| t.x()).collect()
1359 }
1360}
1361
1362impl FPVector for Vec<AD> {
1363 type Scalar = AD;
1364
1365 fn fmap<F>(&self, f: F) -> Self
1366 where
1367 F: Fn(Self::Scalar) -> Self::Scalar,
1368 {
1369 self.iter().map(|&x| f(x)).collect()
1370 }
1371
1372 fn reduce<F, T>(&self, init: T, f: F) -> Self::Scalar
1373 where
1374 F: Fn(Self::Scalar, Self::Scalar) -> Self::Scalar,
1375 T: Into<Self::Scalar>,
1376 {
1377 self.iter().fold(init.into(), |x, &y| f(x, y))
1378 }
1379
1380 fn zip_with<F>(&self, f: F, other: &Self) -> Self
1381 where
1382 F: Fn(Self::Scalar, Self::Scalar) -> Self::Scalar,
1383 {
1384 self.iter()
1385 .zip(other.iter())
1386 .map(|(&x, &y)| f(x, y))
1387 .collect()
1388 }
1389
1390 fn filter<F>(&self, f: F) -> Self
1391 where
1392 F: Fn(Self::Scalar) -> bool,
1393 {
1394 self.iter().filter(|&x| f(*x)).cloned().collect()
1395 }
1396
1397 fn take(&self, n: usize) -> Self {
1398 self.iter().take(n).cloned().collect()
1399 }
1400
1401 fn skip(&self, n: usize) -> Self {
1402 self.iter().skip(n).cloned().collect()
1403 }
1404
1405 fn sum(&self) -> Self::Scalar {
1406 let s = self[0];
1407 self.reduce(s, |x, y| x + y)
1408 }
1409
1410 fn prod(&self) -> Self::Scalar {
1411 let s = self[0];
1412 self.reduce(s, |x, y| x * y)
1413 }
1414}
1415
1416impl Vector for Vec<AD> {
1417 type Scalar = AD;
1418
1419 fn add_vec(&self, rhs: &Self) -> Self {
1420 self.add_v(rhs)
1421 }
1422
1423 fn sub_vec(&self, rhs: &Self) -> Self {
1424 self.sub_v(rhs)
1425 }
1426
1427 fn mul_scalar(&self, rhs: Self::Scalar) -> Self {
1428 self.mul_s(rhs)
1429 }
1430}
1431
1432impl VecOps for Vec<AD> {}