peroxide/statistics/
rand.rs

1//! Random number generator
2//!
3//! ## Uniform random number generator
4//!
5//! * Peroxide uses external [`rand` crate](https://crates.io/crates/rand) to generate random number
6//!
7//!     ```rust
8//!     use rand::prelude::*;
9//!
10//!     fn main() {
11//!         let mut rng = rand::rng();
12//!
13//!         let a = rng.random_range(0f64..=1f64); // Generate random f64 number ranges from 0 to 1
14//!     }
15//!     ```
16//!
17//! * To want more detailed explanation, see [`rand` crate](https://crates.io/crates/rand)
18//!
19//! ## Piece-wise Rejection Sampling
20//!
21//!
22
23use rand::prelude::*;
24use rand_distr::uniform::SampleUniform;
25
26use crate::statistics::dist::{WeightedUniform, RNG};
27#[allow(unused_imports)]
28use crate::structure::matrix::*;
29
30/// Small random number generator from seed
31///
32/// # Examples
33/// ```
34/// use peroxide::fuga::*;
35///
36/// fn main() {
37///     let mut rng = smallrng_from_seed(42);
38///
39///     let n = Normal(0f64, 1f64);
40///     n.sample_with_rng(&mut rng, 10).print();
41/// }
42pub fn smallrng_from_seed(seed: u64) -> SmallRng {
43    SmallRng::seed_from_u64(seed)
44}
45
46/// Std random number generator from seed
47///
48/// # Examples
49/// ```
50/// use peroxide::fuga::*;
51///
52/// fn main() {
53///     let mut rng = stdrng_from_seed(42);
54///
55///     let n = Normal(0f64, 1f64);
56///     n.sample_with_rng(&mut rng, 10).print();
57/// }
58pub fn stdrng_from_seed(seed: u64) -> StdRng {
59    StdRng::seed_from_u64(seed)
60}
61
62/// Simple uniform random number generator with ThreadRng
63///
64/// # Examples
65/// ```
66/// use peroxide::fuga::*;
67///
68/// let mut rng = rand::rng();
69/// println!("{}", rand_num(&mut rng, 1, 7));       // Roll a die
70/// println!("{}", rand_num(&mut rng, 0f64, 1f64)); // Uniform [0,1)
71/// ```
72pub fn rand_num<T>(rng: &mut ThreadRng, start: T, end: T) -> T
73where
74    T: PartialOrd + SampleUniform + Copy,
75{
76    rng.random_range(start..=end)
77}
78
79// =============================================================================
80// Back end utils
81// =============================================================================
82
83/// Gaussian random number generator using Marsaglia polar form
84pub fn marsaglia_polar(rng: &mut ThreadRng, m: f64, s: f64) -> f64 {
85    let mut x1 = 0f64;
86    let mut x2 = 0f64;
87    let mut _y2 = 0f64;
88    let mut w = 0f64;
89
90    while w == 0. || w >= 1. {
91        x1 = 2.0 * rng.random_range(0f64..=1f64) - 1.0;
92        x2 = 2.0 * rng.random_range(0f64..=1f64) - 1.0;
93        w = x1 * x1 + x2 * x2;
94    }
95
96    w = (-2.0 * w.ln() / w).sqrt();
97    let y1 = x1 * w;
98    _y2 = x2 * w;
99
100    return m + y1 * s;
101}
102
103// =============================================================================
104// Ziggurat Table
105// =============================================================================
106
107/// Position of Right-most step
108const PARAM_R: f64 = 3.44428647676;
109
110/// Tabulated values for the height of the Ziggurat levels
111const YTAB: [f64; 128] = [
112    1.,
113    0.963598623011,
114    0.936280813353,
115    0.913041104253,
116    0.892278506696,
117    0.873239356919,
118    0.855496407634,
119    0.838778928349,
120    0.822902083699,
121    0.807732738234,
122    0.793171045519,
123    0.779139726505,
124    0.765577436082,
125    0.752434456248,
126    0.739669787677,
127    0.727249120285,
128    0.715143377413,
129    0.703327646455,
130    0.691780377035,
131    0.68048276891,
132    0.669418297233,
133    0.65857233912,
134    0.647931876189,
135    0.637485254896,
136    0.62722199145,
137    0.617132611532,
138    0.607208517467,
139    0.597441877296,
140    0.587825531465,
141    0.578352913803,
142    0.569017984198,
143    0.559815170911,
144    0.550739320877,
145    0.541785656682,
146    0.532949739145,
147    0.524227434628,
148    0.515614886373,
149    0.507108489253,
150    0.498704867478,
151    0.490400854812,
152    0.482193476986,
153    0.47407993601,
154    0.466057596125,
155    0.458123971214,
156    0.450276713467,
157    0.442513603171,
158    0.434832539473,
159    0.427231532022,
160    0.419708693379,
161    0.41226223212,
162    0.404890446548,
163    0.397591718955,
164    0.390364510382,
165    0.383207355816,
166    0.376118859788,
167    0.369097692334,
168    0.362142585282,
169    0.355252328834,
170    0.348425768415,
171    0.341661801776,
172    0.334959376311,
173    0.328317486588,
174    0.321735172063,
175    0.31521151497,
176    0.308745638367,
177    0.302336704338,
178    0.29598391232,
179    0.289686497571,
180    0.283443729739,
181    0.27725491156,
182    0.271119377649,
183    0.265036493387,
184    0.259005653912,
185    0.253026283183,
186    0.247097833139,
187    0.241219782932,
188    0.235391638239,
189    0.229612930649,
190    0.223883217122,
191    0.218202079518,
192    0.212569124201,
193    0.206983981709,
194    0.201446306496,
195    0.195955776745,
196    0.190512094256,
197    0.185114984406,
198    0.179764196185,
199    0.174459502324,
200    0.169200699492,
201    0.1639876086,
202    0.158820075195,
203    0.153697969964,
204    0.148621189348,
205    0.143589656295,
206    0.138603321143,
207    0.133662162669,
208    0.128766189309,
209    0.123915440582,
210    0.119109988745,
211    0.114349940703,
212    0.10963544023,
213    0.104966670533,
214    0.100343857232,
215    0.0957672718266,
216    0.0912372357329,
217    0.0867541250127,
218    0.082318375932,
219    0.0779304915295,
220    0.0735910494266,
221    0.0693007111742,
222    0.065060233529,
223    0.0608704821745,
224    0.056732448584,
225    0.05264727098,
226    0.0486162607163,
227    0.0446409359769,
228    0.0407230655415,
229    0.0368647267386,
230    0.0330683839378,
231    0.0293369977411,
232    0.0256741818288,
233    0.0220844372634,
234    0.0185735200577,
235    0.0151490552854,
236    0.0118216532614,
237    0.00860719483079,
238    0.00553245272614,
239    0.00265435214565,
240];
241
242/// Tabulated values for 2^24 times `x[i] / x[i+1]`
243/// Used to accept for `U*x[i+1] <= x[i]` without any floating point operations
244const KTAB: [u32; 128] = [
245    0, 12590644, 14272653, 14988939, 15384584, 15635009, 15807561, 15933577, 16029594, 16105155,
246    16166147, 16216399, 16258508, 16294295, 16325078, 16351831, 16375291, 16396026, 16414479,
247    16431002, 16445880, 16459343, 16471578, 16482744, 16492970, 16502368, 16511031, 16519039,
248    16526459, 16533352, 16539769, 16545755, 16551348, 16556584, 16561493, 16566101, 16570433,
249    16574511, 16578353, 16581977, 16585398, 16588629, 16591685, 16594575, 16597311, 16599901,
250    16602354, 16604679, 16606881, 16608968, 16610945, 16612818, 16614592, 16616272, 16617861,
251    16619363, 16620782, 16622121, 16623383, 16624570, 16625685, 16626730, 16627708, 16628619,
252    16629465, 16630248, 16630969, 16631628, 16632228, 16632768, 16633248, 16633671, 16634034,
253    16634340, 16634586, 16634774, 16634903, 16634972, 16634980, 16634926, 16634810, 16634628,
254    16634381, 16634066, 16633680, 16633222, 16632688, 16632075, 16631380, 16630598, 16629726,
255    16628757, 16627686, 16626507, 16625212, 16623794, 16622243, 16620548, 16618698, 16616679,
256    16614476, 16612071, 16609444, 16606571, 16603425, 16599973, 16596178, 16591995, 16587369,
257    16582237, 16576520, 16570120, 16562917, 16554758, 16545450, 16534739, 16522287, 16507638,
258    16490152, 16468907, 16442518, 16408804, 16364095, 16301683, 16207738, 16047994, 15704248,
259    15472926,
260];
261
262/// Tabulated values of `2^{-24} * x[i]`
263const WTAB: [f64; 128] = [
264    1.62318314817e-08,
265    2.16291505214e-08,
266    2.54246305087e-08,
267    2.84579525938e-08,
268    3.10340022482e-08,
269    3.33011726243e-08,
270    3.53439060345e-08,
271    3.72152672658e-08,
272    3.8950989572e-08,
273    4.05763964764e-08,
274    4.21101548915e-08,
275    4.35664624904e-08,
276    4.49563968336e-08,
277    4.62887864029e-08,
278    4.75707945735e-08,
279    4.88083237257e-08,
280    5.00063025384e-08,
281    5.11688950428e-08,
282    5.22996558616e-08,
283    5.34016475624e-08,
284    5.44775307871e-08,
285    5.55296344581e-08,
286    5.65600111659e-08,
287    5.75704813695e-08,
288    5.85626690412e-08,
289    5.95380306862e-08,
290    6.04978791776e-08,
291    6.14434034901e-08,
292    6.23756851626e-08,
293    6.32957121259e-08,
294    6.42043903937e-08,
295    6.51025540077e-08,
296    6.59909735447e-08,
297    6.68703634341e-08,
298    6.77413882848e-08,
299    6.8604668381e-08,
300    6.94607844804e-08,
301    7.03102820203e-08,
302    7.11536748229e-08,
303    7.1991448372e-08,
304    7.2824062723e-08,
305    7.36519550992e-08,
306    7.44755422158e-08,
307    7.52952223703e-08,
308    7.61113773308e-08,
309    7.69243740467e-08,
310    7.77345662086e-08,
311    7.85422956743e-08,
312    7.93478937793e-08,
313    8.01516825471e-08,
314    8.09539758128e-08,
315    8.17550802699e-08,
316    8.25552964535e-08,
317    8.33549196661e-08,
318    8.41542408569e-08,
319    8.49535474601e-08,
320    8.57531242006e-08,
321    8.65532538723e-08,
322    8.73542180955e-08,
323    8.8156298059e-08,
324    8.89597752521e-08,
325    8.97649321908e-08,
326    9.05720531451e-08,
327    9.138142487e-08,
328    9.21933373471e-08,
329    9.30080845407e-08,
330    9.38259651738e-08,
331    9.46472835298e-08,
332    9.54723502847e-08,
333    9.63014833769e-08,
334    9.71350089201e-08,
335    9.79732621669e-08,
336    9.88165885297e-08,
337    9.96653446693e-08,
338    1.00519899658e-07,
339    1.0138063623e-07,
340    1.02247952126e-07,
341    1.03122261554e-07,
342    1.04003996769e-07,
343    1.04893609795e-07,
344    1.05791574313e-07,
345    1.06698387725e-07,
346    1.07614573423e-07,
347    1.08540683296e-07,
348    1.09477300508e-07,
349    1.1042504257e-07,
350    1.11384564771e-07,
351    1.12356564007e-07,
352    1.13341783071e-07,
353    1.14341015475e-07,
354    1.15355110887e-07,
355    1.16384981291e-07,
356    1.17431607977e-07,
357    1.18496049514e-07,
358    1.19579450872e-07,
359    1.20683053909e-07,
360    1.21808209468e-07,
361    1.2295639141e-07,
362    1.24129212952e-07,
363    1.25328445797e-07,
364    1.26556042658e-07,
365    1.27814163916e-07,
366    1.29105209375e-07,
367    1.30431856341e-07,
368    1.31797105598e-07,
369    1.3320433736e-07,
370    1.34657379914e-07,
371    1.36160594606e-07,
372    1.37718982103e-07,
373    1.39338316679e-07,
374    1.41025317971e-07,
375    1.42787873535e-07,
376    1.44635331499e-07,
377    1.4657889173e-07,
378    1.48632138436e-07,
379    1.50811780719e-07,
380    1.53138707402e-07,
381    1.55639532047e-07,
382    1.58348931426e-07,
383    1.61313325908e-07,
384    1.64596952856e-07,
385    1.68292495203e-07,
386    1.72541128694e-07,
387    1.77574279496e-07,
388    1.83813550477e-07,
389    1.92166040885e-07,
390    2.05295471952e-07,
391    2.22600839893e-07,
392];
393
394/// Gaussian random numbers using the Ziggurat Method
395///
396/// The code is based on a [C implementation][1] by Jochen Voss.
397///
398/// [1]: https://www.seehuhn.de/pages/ziggurat.html
399#[allow(unused_assignments)]
400pub fn ziggurat(rng: &mut ThreadRng, sigma: f64) -> f64 {
401    let (mut u, mut i, mut sign, mut j) = (0u32, 0usize, 0u32, 0u32);
402    let mut x = 0f64;
403    let mut y = 0f64;
404
405    loop {
406        u = rand_num(rng, u32::MIN, u32::MAX);
407        i = (u & 0x0000007F) as usize; // 7 bit to choose the step
408        sign = u & 0x00000080; // 1 bit for the sign
409        j = u >> 8; // 24 bit for the x-value
410
411        x = j as f64 * WTAB[i];
412        if j < KTAB[i] {
413            break;
414        }
415
416        if i < 127 {
417            let y0 = YTAB[i];
418            let y1 = YTAB[i + 1];
419            y = y1 + (y0 - y1) * rand_num(rng, 0f64, 1f64);
420        } else {
421            x = PARAM_R - (1.0 - rand_num(rng, 0f64, 1f64).ln()) / PARAM_R;
422            y = (-PARAM_R * (x - 0.5 * PARAM_R)).exp() * rand_num(rng, 0f64, 1f64);
423        }
424
425        if y < (-0.5 * x * x).exp() {
426            break;
427        }
428    }
429
430    if sign != 0 {
431        sigma * x
432    } else {
433        -sigma * x
434    }
435}
436
437// =============================================================================
438// Rejection Sampling
439// =============================================================================
440/// Piecewise Rejection Sampling
441///
442/// # Arguments
443/// * `f` - Function to sample (unnormalized function is allowed)
444/// * `n` - Number of samples
445/// * `(a, b)` - Range of sampling
446/// * `m` - Number of pieces
447/// * `eps` - Epsilon for max pooling
448///
449/// # Examples
450/// ```
451/// use peroxide::fuga::*;
452///
453/// fn main() -> Result<(), Box<dyn Error>> {
454///     let f = |x: f64| {
455///         if (0f64..=2f64).contains(&x) {
456///             -(x - 1f64).powi(2) + 1f64
457///         } else {
458///             0f64
459///         }
460///     };
461///
462///     let samples = prs(f, 1000, (-1f64, 3f64), 200, 1e-4)?;
463///     samples.mean().print(); // near 1
464///
465///     Ok(())
466/// }
467pub fn prs<F>(f: F, n: usize, (a, b): (f64, f64), m: usize, eps: f64) -> anyhow::Result<Vec<f64>>
468where
469    F: Fn(f64) -> f64 + Copy,
470{
471    let mut rng = rand::rng();
472
473    let mut result = vec![0f64; n];
474
475    let w = WeightedUniform::from_max_pool_1d(f, (a, b), m, eps)?;
476
477    let mut initial_x = w.sample(n);
478    let mut left_num = n;
479
480    while left_num > 0 {
481        for &x in initial_x.iter() {
482            let weight = w.weight_at(x);
483            if weight <= 0f64 {
484                continue;
485            } else {
486                let y = rng.random_range(0f64..=weight);
487
488                if y <= f(x) {
489                    result[n - left_num] = x;
490                    left_num -= 1;
491                    if left_num == 0 {
492                        return Ok(result);
493                    }
494                }
495            }
496        }
497        initial_x = w.sample(left_num);
498    }
499    panic!("Error: failed to generate {} samples", n);
500}
501
502/// Piecewise Rejection Sampling with specific Rng
503///
504/// # Arguments
505/// * `f` - Function to sample (unnormalized function is allowed)
506/// * `n` - Number of samples
507/// * `(a, b)` - Range of sampling
508/// * `m` - Number of pieces
509/// * `eps` - Epsilon for max pooling
510/// * `rng` - Random number generator
511///
512/// # Examples
513/// ```
514/// use peroxide::fuga::*;
515///
516/// fn main() -> Result<(), Box<dyn Error>> {
517///     let mut rng = smallrng_from_seed(42);
518///     let f = |x: f64| {
519///         if (0f64..=2f64).contains(&x) {
520///             -(x - 1f64).powi(2) + 1f64
521///         } else {
522///             0f64
523///         }
524///     };
525///
526///     let samples = prs_with_rng(f, 1000, (-1f64, 3f64), 200, 1e-4, &mut rng)?;
527///     assert!((samples.mean() - 1f64).abs() < 1e-1);
528///
529///     Ok(())
530/// }
531pub fn prs_with_rng<F, R: Rng + Clone>(
532    f: F,
533    n: usize,
534    (a, b): (f64, f64),
535    m: usize,
536    eps: f64,
537    rng: &mut R,
538) -> anyhow::Result<Vec<f64>>
539where
540    F: Fn(f64) -> f64 + Copy,
541{
542    let mut result = vec![0f64; n];
543
544    let w = WeightedUniform::from_max_pool_1d(f, (a, b), m, eps)?;
545
546    let mut initial_x = w.sample_with_rng(rng, n);
547    let mut left_num = n;
548
549    while left_num > 0 {
550        for &x in initial_x.iter() {
551            let weight = w.weight_at(x);
552            if weight <= 0f64 {
553                continue;
554            } else {
555                let y = rng.random_range(0f64..=weight);
556
557                if y <= f(x) {
558                    result[n - left_num] = x;
559                    left_num -= 1;
560                    if left_num == 0 {
561                        return Ok(result);
562                    }
563                }
564            }
565        }
566        initial_x = w.sample_with_rng(rng, left_num);
567    }
568    panic!("Error: failed to generate {} samples", n);
569}