nalgebra/linalg/
exp.rs

1//! This module provides the matrix exponent (exp) function to square matrices.
2//!
3use crate::{
4    ComplexField, OMatrix, RealField,
5    base::{
6        DefaultAllocator,
7        allocator::Allocator,
8        dimension::{Const, Dim, DimMin, DimMinimum},
9    },
10    convert, try_convert,
11};
12
13use crate::num::Zero;
14
15/// Precomputed factorials for integers in range `0..=34`.
16/// Note: `35!` does not fit into 128 bits.
17// TODO: find a better place for this array?
18const FACTORIAL: [u128; 35] = [
19    1,
20    1,
21    2,
22    6,
23    24,
24    120,
25    720,
26    5040,
27    40320,
28    362880,
29    3628800,
30    39916800,
31    479001600,
32    6227020800,
33    87178291200,
34    1307674368000,
35    20922789888000,
36    355687428096000,
37    6402373705728000,
38    121645100408832000,
39    2432902008176640000,
40    51090942171709440000,
41    1124000727777607680000,
42    25852016738884976640000,
43    620448401733239439360000,
44    15511210043330985984000000,
45    403291461126605635584000000,
46    10888869450418352160768000000,
47    304888344611713860501504000000,
48    8841761993739701954543616000000,
49    265252859812191058636308480000000,
50    8222838654177922817725562880000000,
51    263130836933693530167218012160000000,
52    8683317618811886495518194401280000000,
53    295232799039604140847618609643520000000,
54];
55
56// https://github.com/scipy/scipy/blob/c1372d8aa90a73d8a52f135529293ff4edb98fc8/scipy/sparse/linalg/matfuncs.py
57struct ExpmPadeHelper<T, D>
58where
59    T: ComplexField,
60    D: DimMin<D>,
61    DefaultAllocator: Allocator<D, D> + Allocator<DimMinimum<D, D>>,
62{
63    use_exact_norm: bool,
64    ident: OMatrix<T, D, D>,
65
66    a: OMatrix<T, D, D>,
67    a2: Option<OMatrix<T, D, D>>,
68    a4: Option<OMatrix<T, D, D>>,
69    a6: Option<OMatrix<T, D, D>>,
70    a8: Option<OMatrix<T, D, D>>,
71    a10: Option<OMatrix<T, D, D>>,
72
73    d4_exact: Option<T::RealField>,
74    d6_exact: Option<T::RealField>,
75    d8_exact: Option<T::RealField>,
76    d10_exact: Option<T::RealField>,
77
78    d4_approx: Option<T::RealField>,
79    d6_approx: Option<T::RealField>,
80    d8_approx: Option<T::RealField>,
81    d10_approx: Option<T::RealField>,
82}
83
84impl<T, D> ExpmPadeHelper<T, D>
85where
86    T: ComplexField,
87    D: DimMin<D>,
88    DefaultAllocator: Allocator<D, D> + Allocator<DimMinimum<D, D>>,
89{
90    fn new(a: OMatrix<T, D, D>, use_exact_norm: bool) -> Self {
91        let (nrows, ncols) = a.shape_generic();
92        ExpmPadeHelper {
93            use_exact_norm,
94            ident: OMatrix::<T, D, D>::identity_generic(nrows, ncols),
95            a,
96            a2: None,
97            a4: None,
98            a6: None,
99            a8: None,
100            a10: None,
101            d4_exact: None,
102            d6_exact: None,
103            d8_exact: None,
104            d10_exact: None,
105            d4_approx: None,
106            d6_approx: None,
107            d8_approx: None,
108            d10_approx: None,
109        }
110    }
111
112    fn calc_a2(&mut self) {
113        if self.a2.is_none() {
114            self.a2 = Some(&self.a * &self.a);
115        }
116    }
117
118    fn calc_a4(&mut self) {
119        if self.a4.is_none() {
120            self.calc_a2();
121            let a2 = self.a2.as_ref().unwrap();
122            self.a4 = Some(a2 * a2);
123        }
124    }
125
126    fn calc_a6(&mut self) {
127        if self.a6.is_none() {
128            self.calc_a2();
129            self.calc_a4();
130            let a2 = self.a2.as_ref().unwrap();
131            let a4 = self.a4.as_ref().unwrap();
132            self.a6 = Some(a4 * a2);
133        }
134    }
135
136    fn calc_a8(&mut self) {
137        if self.a8.is_none() {
138            self.calc_a2();
139            self.calc_a6();
140            let a2 = self.a2.as_ref().unwrap();
141            let a6 = self.a6.as_ref().unwrap();
142            self.a8 = Some(a6 * a2);
143        }
144    }
145
146    fn calc_a10(&mut self) {
147        if self.a10.is_none() {
148            self.calc_a4();
149            self.calc_a6();
150            let a4 = self.a4.as_ref().unwrap();
151            let a6 = self.a6.as_ref().unwrap();
152            self.a10 = Some(a6 * a4);
153        }
154    }
155
156    fn d4_tight(&mut self) -> T::RealField {
157        if self.d4_exact.is_none() {
158            self.calc_a4();
159            self.d4_exact = Some(one_norm(self.a4.as_ref().unwrap()).powf(convert(0.25)));
160        }
161        self.d4_exact.clone().unwrap()
162    }
163
164    fn d6_tight(&mut self) -> T::RealField {
165        if self.d6_exact.is_none() {
166            self.calc_a6();
167            self.d6_exact = Some(one_norm(self.a6.as_ref().unwrap()).powf(convert(1.0 / 6.0)));
168        }
169        self.d6_exact.clone().unwrap()
170    }
171
172    fn d8_tight(&mut self) -> T::RealField {
173        if self.d8_exact.is_none() {
174            self.calc_a8();
175            self.d8_exact = Some(one_norm(self.a8.as_ref().unwrap()).powf(convert(1.0 / 8.0)));
176        }
177        self.d8_exact.clone().unwrap()
178    }
179
180    fn d10_tight(&mut self) -> T::RealField {
181        if self.d10_exact.is_none() {
182            self.calc_a10();
183            self.d10_exact = Some(one_norm(self.a10.as_ref().unwrap()).powf(convert(1.0 / 10.0)));
184        }
185        self.d10_exact.clone().unwrap()
186    }
187
188    fn d4_loose(&mut self) -> T::RealField {
189        if self.use_exact_norm {
190            return self.d4_tight();
191        }
192
193        if self.d4_exact.is_some() {
194            return self.d4_exact.clone().unwrap();
195        }
196
197        if self.d4_approx.is_none() {
198            self.calc_a4();
199            self.d4_approx = Some(one_norm(self.a4.as_ref().unwrap()).powf(convert(0.25)));
200        }
201
202        self.d4_approx.clone().unwrap()
203    }
204
205    fn d6_loose(&mut self) -> T::RealField {
206        if self.use_exact_norm {
207            return self.d6_tight();
208        }
209
210        if self.d6_exact.is_some() {
211            return self.d6_exact.clone().unwrap();
212        }
213
214        if self.d6_approx.is_none() {
215            self.calc_a6();
216            self.d6_approx = Some(one_norm(self.a6.as_ref().unwrap()).powf(convert(1.0 / 6.0)));
217        }
218
219        self.d6_approx.clone().unwrap()
220    }
221
222    fn d8_loose(&mut self) -> T::RealField {
223        if self.use_exact_norm {
224            return self.d8_tight();
225        }
226
227        if self.d8_exact.is_some() {
228            return self.d8_exact.clone().unwrap();
229        }
230
231        if self.d8_approx.is_none() {
232            self.calc_a8();
233            self.d8_approx = Some(one_norm(self.a8.as_ref().unwrap()).powf(convert(1.0 / 8.0)));
234        }
235
236        self.d8_approx.clone().unwrap()
237    }
238
239    fn d10_loose(&mut self) -> T::RealField {
240        if self.use_exact_norm {
241            return self.d10_tight();
242        }
243
244        if self.d10_exact.is_some() {
245            return self.d10_exact.clone().unwrap();
246        }
247
248        if self.d10_approx.is_none() {
249            self.calc_a10();
250            self.d10_approx = Some(one_norm(self.a10.as_ref().unwrap()).powf(convert(1.0 / 10.0)));
251        }
252
253        self.d10_approx.clone().unwrap()
254    }
255
256    fn pade3(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
257        let b: [T; 4] = [convert(120.0), convert(60.0), convert(12.0), convert(1.0)];
258        self.calc_a2();
259        let a2 = self.a2.as_ref().unwrap();
260        let u = &self.a * (a2 * b[3].clone() + &self.ident * b[1].clone());
261        let v = a2 * b[2].clone() + &self.ident * b[0].clone();
262        (u, v)
263    }
264
265    fn pade5(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
266        let b: [T; 6] = [
267            convert(30240.0),
268            convert(15120.0),
269            convert(3360.0),
270            convert(420.0),
271            convert(30.0),
272            convert(1.0),
273        ];
274        self.calc_a2();
275        self.calc_a6();
276        let u = &self.a
277            * (self.a4.as_ref().unwrap() * b[5].clone()
278                + self.a2.as_ref().unwrap() * b[3].clone()
279                + &self.ident * b[1].clone());
280        let v = self.a4.as_ref().unwrap() * b[4].clone()
281            + self.a2.as_ref().unwrap() * b[2].clone()
282            + &self.ident * b[0].clone();
283        (u, v)
284    }
285
286    fn pade7(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
287        let b: [T; 8] = [
288            convert(17_297_280.0),
289            convert(8_648_640.0),
290            convert(1_995_840.0),
291            convert(277_200.0),
292            convert(25_200.0),
293            convert(1_512.0),
294            convert(56.0),
295            convert(1.0),
296        ];
297        self.calc_a2();
298        self.calc_a4();
299        self.calc_a6();
300        let u = &self.a
301            * (self.a6.as_ref().unwrap() * b[7].clone()
302                + self.a4.as_ref().unwrap() * b[5].clone()
303                + self.a2.as_ref().unwrap() * b[3].clone()
304                + &self.ident * b[1].clone());
305        let v = self.a6.as_ref().unwrap() * b[6].clone()
306            + self.a4.as_ref().unwrap() * b[4].clone()
307            + self.a2.as_ref().unwrap() * b[2].clone()
308            + &self.ident * b[0].clone();
309        (u, v)
310    }
311
312    fn pade9(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
313        let b: [T; 10] = [
314            convert(17_643_225_600.0),
315            convert(8_821_612_800.0),
316            convert(2_075_673_600.0),
317            convert(302_702_400.0),
318            convert(30_270_240.0),
319            convert(2_162_160.0),
320            convert(110_880.0),
321            convert(3_960.0),
322            convert(90.0),
323            convert(1.0),
324        ];
325        self.calc_a2();
326        self.calc_a4();
327        self.calc_a6();
328        self.calc_a8();
329        let u = &self.a
330            * (self.a8.as_ref().unwrap() * b[9].clone()
331                + self.a6.as_ref().unwrap() * b[7].clone()
332                + self.a4.as_ref().unwrap() * b[5].clone()
333                + self.a2.as_ref().unwrap() * b[3].clone()
334                + &self.ident * b[1].clone());
335        let v = self.a8.as_ref().unwrap() * b[8].clone()
336            + self.a6.as_ref().unwrap() * b[6].clone()
337            + self.a4.as_ref().unwrap() * b[4].clone()
338            + self.a2.as_ref().unwrap() * b[2].clone()
339            + &self.ident * b[0].clone();
340        (u, v)
341    }
342
343    fn pade13_scaled(&mut self, s: u64) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
344        let b: [T; 14] = [
345            convert(64_764_752_532_480_000.0),
346            convert(32_382_376_266_240_000.0),
347            convert(7_771_770_303_897_600.0),
348            convert(1_187_353_796_428_800.0),
349            convert(129_060_195_264_000.0),
350            convert(10_559_470_521_600.0),
351            convert(670_442_572_800.0),
352            convert(33_522_128_640.0),
353            convert(1_323_241_920.0),
354            convert(40_840_800.0),
355            convert(960_960.0),
356            convert(16_380.0),
357            convert(182.0),
358            convert(1.0),
359        ];
360        let s = s as f64;
361
362        let mb = &self.a * convert::<f64, T>(2.0_f64.powf(-s));
363        self.calc_a2();
364        self.calc_a4();
365        self.calc_a6();
366        let mb2 = self.a2.as_ref().unwrap() * convert::<f64, T>(2.0_f64.powf(-2.0 * s));
367        let mb4 = self.a4.as_ref().unwrap() * convert::<f64, T>(2.0.powf(-4.0 * s));
368        let mb6 = self.a6.as_ref().unwrap() * convert::<f64, T>(2.0.powf(-6.0 * s));
369
370        let u2 = &mb6 * (&mb6 * b[13].clone() + &mb4 * b[11].clone() + &mb2 * b[9].clone());
371        let u = &mb
372            * (&u2
373                + &mb6 * b[7].clone()
374                + &mb4 * b[5].clone()
375                + &mb2 * b[3].clone()
376                + &self.ident * b[1].clone());
377        let v2 = &mb6 * (&mb6 * b[12].clone() + &mb4 * b[10].clone() + &mb2 * b[8].clone());
378        let v = v2
379            + &mb6 * b[6].clone()
380            + &mb4 * b[4].clone()
381            + &mb2 * b[2].clone()
382            + &self.ident * b[0].clone();
383        (u, v)
384    }
385}
386
387/// Compute `n!`
388#[inline(always)]
389fn factorial(n: usize) -> u128 {
390    match FACTORIAL.get(n) {
391        Some(f) => *f,
392        None => panic!("{n}! is greater than u128::MAX"),
393    }
394}
395
396/// Compute the 1-norm of a non-negative integer power of a non-negative matrix.
397fn onenorm_matrix_power_nonm<T, D>(a: &OMatrix<T, D, D>, p: usize) -> T
398where
399    T: RealField,
400    D: Dim,
401    DefaultAllocator: Allocator<D, D> + Allocator<D>,
402{
403    let nrows = a.shape_generic().0;
404    let mut v = crate::OVector::<T, D>::repeat_generic(nrows, Const::<1>, convert(1.0));
405    let m = a.transpose();
406
407    for _ in 0..p {
408        v = &m * v;
409    }
410
411    v.max()
412}
413
414fn ell<T, D>(a: &OMatrix<T, D, D>, m: usize) -> u64
415where
416    T: ComplexField,
417    D: Dim,
418    DefaultAllocator: Allocator<D, D> + Allocator<D> + Allocator<D> + Allocator<D, D>,
419{
420    let a_abs = a.map(|x| x.abs());
421
422    let a_abs_onenorm = onenorm_matrix_power_nonm(&a_abs, 2 * m + 1);
423
424    if a_abs_onenorm == <T as ComplexField>::RealField::zero() {
425        return 0;
426    }
427
428    // 2m choose m = (2m)!/(m! * (2m-m)!) = (2m)!/((m!)^2)
429    let m_factorial = factorial(m);
430    let choose_2m_m = factorial(2 * m) / (m_factorial * m_factorial);
431
432    let abs_c_recip = choose_2m_m * factorial(2 * m + 1);
433    let alpha = a_abs_onenorm / one_norm(a);
434    let alpha: f64 = try_convert::<_, f64>(alpha).unwrap() / abs_c_recip as f64;
435
436    let u = 2_f64.powf(-53.0);
437    let log2_alpha_div_u = (alpha / u).log2();
438    let value = (log2_alpha_div_u / (2.0 * m as f64)).ceil();
439    if value > 0.0 { value as u64 } else { 0 }
440}
441
442fn solve_p_q<T, D>(u: OMatrix<T, D, D>, v: OMatrix<T, D, D>) -> OMatrix<T, D, D>
443where
444    T: ComplexField,
445    D: DimMin<D, Output = D>,
446    DefaultAllocator: Allocator<D, D> + Allocator<DimMinimum<D, D>>,
447{
448    let p = &u + &v;
449    let q = &v - &u;
450
451    q.lu().solve(&p).unwrap()
452}
453
454fn one_norm<T, D>(m: &OMatrix<T, D, D>) -> T::RealField
455where
456    T: ComplexField,
457    D: Dim,
458    DefaultAllocator: Allocator<D, D>,
459{
460    let mut max = <T as ComplexField>::RealField::zero();
461
462    for i in 0..m.ncols() {
463        let col = m.column(i);
464        max = max.max(
465            col.iter()
466                .fold(<T as ComplexField>::RealField::zero(), |a, b| {
467                    a + b.clone().abs()
468                }),
469        );
470    }
471
472    max
473}
474
475impl<T: ComplexField, D> OMatrix<T, D, D>
476where
477    D: DimMin<D, Output = D>,
478    DefaultAllocator: Allocator<D, D>
479        + Allocator<DimMinimum<D, D>>
480        + Allocator<D>
481        + Allocator<D>
482        + Allocator<D, D>,
483{
484    /// Computes exponential of this matrix
485    #[must_use]
486    pub fn exp(&self) -> Self {
487        // Simple case
488        if self.nrows() == 1 {
489            return self.map(|v| v.exp());
490        }
491
492        let mut helper = ExpmPadeHelper::new(self.clone(), true);
493
494        let eta_1 = T::RealField::max(helper.d4_loose(), helper.d6_loose());
495        if eta_1 < convert(1.495_585_217_958_292e-2) && ell(&helper.a, 3) == 0 {
496            let (u, v) = helper.pade3();
497            return solve_p_q(u, v);
498        }
499
500        let eta_2 = T::RealField::max(helper.d4_tight(), helper.d6_loose());
501        if eta_2 < convert(2.539_398_330_063_23e-1) && ell(&helper.a, 5) == 0 {
502            let (u, v) = helper.pade5();
503            return solve_p_q(u, v);
504        }
505
506        let eta_3 = T::RealField::max(helper.d6_tight(), helper.d8_loose());
507        if eta_3 < convert(9.504_178_996_162_932e-1) && ell(&helper.a, 7) == 0 {
508            let (u, v) = helper.pade7();
509            return solve_p_q(u, v);
510        }
511        if eta_3 < convert(2.097_847_961_257_068e0) && ell(&helper.a, 9) == 0 {
512            let (u, v) = helper.pade9();
513            return solve_p_q(u, v);
514        }
515
516        let eta_4 = T::RealField::max(helper.d8_loose(), helper.d10_loose());
517        let eta_5 = T::RealField::min(eta_3, eta_4);
518        let theta_13 = convert(4.25);
519
520        let mut s = if eta_5 == T::RealField::zero() {
521            0
522        } else {
523            let l2 = try_convert::<_, f64>((eta_5 / theta_13).log2().ceil()).unwrap();
524
525            if l2 < 0.0 { 0 } else { l2 as u64 }
526        };
527
528        s += ell(
529            &(&helper.a * convert::<f64, T>(2.0_f64.powf(-(s as f64)))),
530            13,
531        );
532
533        let (u, v) = helper.pade13_scaled(s);
534        let mut x = solve_p_q(u, v);
535
536        for _ in 0..s {
537            x = &x * &x;
538        }
539        x
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    #[test]
546    #[allow(clippy::float_cmp)]
547    fn one_norm() {
548        use crate::Matrix3;
549        let m = Matrix3::new(-3.0, 5.0, 7.0, 2.0, 6.0, 4.0, 0.0, 2.0, 8.0);
550
551        assert_eq!(super::one_norm(&m), 19.0);
552    }
553}