1use crate::{
4 ComplexField, OMatrix, RealField,
5 base::{
6 DefaultAllocator,
7 allocator::Allocator,
8 dimension::{Const, Dim, DimMin, DimMinimum},
9 },
10 convert, try_convert,
11};
12
13use crate::num::Zero;
14
15const FACTORIAL: [u128; 35] = [
19 1,
20 1,
21 2,
22 6,
23 24,
24 120,
25 720,
26 5040,
27 40320,
28 362880,
29 3628800,
30 39916800,
31 479001600,
32 6227020800,
33 87178291200,
34 1307674368000,
35 20922789888000,
36 355687428096000,
37 6402373705728000,
38 121645100408832000,
39 2432902008176640000,
40 51090942171709440000,
41 1124000727777607680000,
42 25852016738884976640000,
43 620448401733239439360000,
44 15511210043330985984000000,
45 403291461126605635584000000,
46 10888869450418352160768000000,
47 304888344611713860501504000000,
48 8841761993739701954543616000000,
49 265252859812191058636308480000000,
50 8222838654177922817725562880000000,
51 263130836933693530167218012160000000,
52 8683317618811886495518194401280000000,
53 295232799039604140847618609643520000000,
54];
55
56struct ExpmPadeHelper<T, D>
58where
59 T: ComplexField,
60 D: DimMin<D>,
61 DefaultAllocator: Allocator<D, D> + Allocator<DimMinimum<D, D>>,
62{
63 use_exact_norm: bool,
64 ident: OMatrix<T, D, D>,
65
66 a: OMatrix<T, D, D>,
67 a2: Option<OMatrix<T, D, D>>,
68 a4: Option<OMatrix<T, D, D>>,
69 a6: Option<OMatrix<T, D, D>>,
70 a8: Option<OMatrix<T, D, D>>,
71 a10: Option<OMatrix<T, D, D>>,
72
73 d4_exact: Option<T::RealField>,
74 d6_exact: Option<T::RealField>,
75 d8_exact: Option<T::RealField>,
76 d10_exact: Option<T::RealField>,
77
78 d4_approx: Option<T::RealField>,
79 d6_approx: Option<T::RealField>,
80 d8_approx: Option<T::RealField>,
81 d10_approx: Option<T::RealField>,
82}
83
84impl<T, D> ExpmPadeHelper<T, D>
85where
86 T: ComplexField,
87 D: DimMin<D>,
88 DefaultAllocator: Allocator<D, D> + Allocator<DimMinimum<D, D>>,
89{
90 fn new(a: OMatrix<T, D, D>, use_exact_norm: bool) -> Self {
91 let (nrows, ncols) = a.shape_generic();
92 ExpmPadeHelper {
93 use_exact_norm,
94 ident: OMatrix::<T, D, D>::identity_generic(nrows, ncols),
95 a,
96 a2: None,
97 a4: None,
98 a6: None,
99 a8: None,
100 a10: None,
101 d4_exact: None,
102 d6_exact: None,
103 d8_exact: None,
104 d10_exact: None,
105 d4_approx: None,
106 d6_approx: None,
107 d8_approx: None,
108 d10_approx: None,
109 }
110 }
111
112 fn calc_a2(&mut self) {
113 if self.a2.is_none() {
114 self.a2 = Some(&self.a * &self.a);
115 }
116 }
117
118 fn calc_a4(&mut self) {
119 if self.a4.is_none() {
120 self.calc_a2();
121 let a2 = self.a2.as_ref().unwrap();
122 self.a4 = Some(a2 * a2);
123 }
124 }
125
126 fn calc_a6(&mut self) {
127 if self.a6.is_none() {
128 self.calc_a2();
129 self.calc_a4();
130 let a2 = self.a2.as_ref().unwrap();
131 let a4 = self.a4.as_ref().unwrap();
132 self.a6 = Some(a4 * a2);
133 }
134 }
135
136 fn calc_a8(&mut self) {
137 if self.a8.is_none() {
138 self.calc_a2();
139 self.calc_a6();
140 let a2 = self.a2.as_ref().unwrap();
141 let a6 = self.a6.as_ref().unwrap();
142 self.a8 = Some(a6 * a2);
143 }
144 }
145
146 fn calc_a10(&mut self) {
147 if self.a10.is_none() {
148 self.calc_a4();
149 self.calc_a6();
150 let a4 = self.a4.as_ref().unwrap();
151 let a6 = self.a6.as_ref().unwrap();
152 self.a10 = Some(a6 * a4);
153 }
154 }
155
156 fn d4_tight(&mut self) -> T::RealField {
157 if self.d4_exact.is_none() {
158 self.calc_a4();
159 self.d4_exact = Some(one_norm(self.a4.as_ref().unwrap()).powf(convert(0.25)));
160 }
161 self.d4_exact.clone().unwrap()
162 }
163
164 fn d6_tight(&mut self) -> T::RealField {
165 if self.d6_exact.is_none() {
166 self.calc_a6();
167 self.d6_exact = Some(one_norm(self.a6.as_ref().unwrap()).powf(convert(1.0 / 6.0)));
168 }
169 self.d6_exact.clone().unwrap()
170 }
171
172 fn d8_tight(&mut self) -> T::RealField {
173 if self.d8_exact.is_none() {
174 self.calc_a8();
175 self.d8_exact = Some(one_norm(self.a8.as_ref().unwrap()).powf(convert(1.0 / 8.0)));
176 }
177 self.d8_exact.clone().unwrap()
178 }
179
180 fn d10_tight(&mut self) -> T::RealField {
181 if self.d10_exact.is_none() {
182 self.calc_a10();
183 self.d10_exact = Some(one_norm(self.a10.as_ref().unwrap()).powf(convert(1.0 / 10.0)));
184 }
185 self.d10_exact.clone().unwrap()
186 }
187
188 fn d4_loose(&mut self) -> T::RealField {
189 if self.use_exact_norm {
190 return self.d4_tight();
191 }
192
193 if self.d4_exact.is_some() {
194 return self.d4_exact.clone().unwrap();
195 }
196
197 if self.d4_approx.is_none() {
198 self.calc_a4();
199 self.d4_approx = Some(one_norm(self.a4.as_ref().unwrap()).powf(convert(0.25)));
200 }
201
202 self.d4_approx.clone().unwrap()
203 }
204
205 fn d6_loose(&mut self) -> T::RealField {
206 if self.use_exact_norm {
207 return self.d6_tight();
208 }
209
210 if self.d6_exact.is_some() {
211 return self.d6_exact.clone().unwrap();
212 }
213
214 if self.d6_approx.is_none() {
215 self.calc_a6();
216 self.d6_approx = Some(one_norm(self.a6.as_ref().unwrap()).powf(convert(1.0 / 6.0)));
217 }
218
219 self.d6_approx.clone().unwrap()
220 }
221
222 fn d8_loose(&mut self) -> T::RealField {
223 if self.use_exact_norm {
224 return self.d8_tight();
225 }
226
227 if self.d8_exact.is_some() {
228 return self.d8_exact.clone().unwrap();
229 }
230
231 if self.d8_approx.is_none() {
232 self.calc_a8();
233 self.d8_approx = Some(one_norm(self.a8.as_ref().unwrap()).powf(convert(1.0 / 8.0)));
234 }
235
236 self.d8_approx.clone().unwrap()
237 }
238
239 fn d10_loose(&mut self) -> T::RealField {
240 if self.use_exact_norm {
241 return self.d10_tight();
242 }
243
244 if self.d10_exact.is_some() {
245 return self.d10_exact.clone().unwrap();
246 }
247
248 if self.d10_approx.is_none() {
249 self.calc_a10();
250 self.d10_approx = Some(one_norm(self.a10.as_ref().unwrap()).powf(convert(1.0 / 10.0)));
251 }
252
253 self.d10_approx.clone().unwrap()
254 }
255
256 fn pade3(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
257 let b: [T; 4] = [convert(120.0), convert(60.0), convert(12.0), convert(1.0)];
258 self.calc_a2();
259 let a2 = self.a2.as_ref().unwrap();
260 let u = &self.a * (a2 * b[3].clone() + &self.ident * b[1].clone());
261 let v = a2 * b[2].clone() + &self.ident * b[0].clone();
262 (u, v)
263 }
264
265 fn pade5(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
266 let b: [T; 6] = [
267 convert(30240.0),
268 convert(15120.0),
269 convert(3360.0),
270 convert(420.0),
271 convert(30.0),
272 convert(1.0),
273 ];
274 self.calc_a2();
275 self.calc_a6();
276 let u = &self.a
277 * (self.a4.as_ref().unwrap() * b[5].clone()
278 + self.a2.as_ref().unwrap() * b[3].clone()
279 + &self.ident * b[1].clone());
280 let v = self.a4.as_ref().unwrap() * b[4].clone()
281 + self.a2.as_ref().unwrap() * b[2].clone()
282 + &self.ident * b[0].clone();
283 (u, v)
284 }
285
286 fn pade7(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
287 let b: [T; 8] = [
288 convert(17_297_280.0),
289 convert(8_648_640.0),
290 convert(1_995_840.0),
291 convert(277_200.0),
292 convert(25_200.0),
293 convert(1_512.0),
294 convert(56.0),
295 convert(1.0),
296 ];
297 self.calc_a2();
298 self.calc_a4();
299 self.calc_a6();
300 let u = &self.a
301 * (self.a6.as_ref().unwrap() * b[7].clone()
302 + self.a4.as_ref().unwrap() * b[5].clone()
303 + self.a2.as_ref().unwrap() * b[3].clone()
304 + &self.ident * b[1].clone());
305 let v = self.a6.as_ref().unwrap() * b[6].clone()
306 + self.a4.as_ref().unwrap() * b[4].clone()
307 + self.a2.as_ref().unwrap() * b[2].clone()
308 + &self.ident * b[0].clone();
309 (u, v)
310 }
311
312 fn pade9(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
313 let b: [T; 10] = [
314 convert(17_643_225_600.0),
315 convert(8_821_612_800.0),
316 convert(2_075_673_600.0),
317 convert(302_702_400.0),
318 convert(30_270_240.0),
319 convert(2_162_160.0),
320 convert(110_880.0),
321 convert(3_960.0),
322 convert(90.0),
323 convert(1.0),
324 ];
325 self.calc_a2();
326 self.calc_a4();
327 self.calc_a6();
328 self.calc_a8();
329 let u = &self.a
330 * (self.a8.as_ref().unwrap() * b[9].clone()
331 + self.a6.as_ref().unwrap() * b[7].clone()
332 + self.a4.as_ref().unwrap() * b[5].clone()
333 + self.a2.as_ref().unwrap() * b[3].clone()
334 + &self.ident * b[1].clone());
335 let v = self.a8.as_ref().unwrap() * b[8].clone()
336 + self.a6.as_ref().unwrap() * b[6].clone()
337 + self.a4.as_ref().unwrap() * b[4].clone()
338 + self.a2.as_ref().unwrap() * b[2].clone()
339 + &self.ident * b[0].clone();
340 (u, v)
341 }
342
343 fn pade13_scaled(&mut self, s: u64) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
344 let b: [T; 14] = [
345 convert(64_764_752_532_480_000.0),
346 convert(32_382_376_266_240_000.0),
347 convert(7_771_770_303_897_600.0),
348 convert(1_187_353_796_428_800.0),
349 convert(129_060_195_264_000.0),
350 convert(10_559_470_521_600.0),
351 convert(670_442_572_800.0),
352 convert(33_522_128_640.0),
353 convert(1_323_241_920.0),
354 convert(40_840_800.0),
355 convert(960_960.0),
356 convert(16_380.0),
357 convert(182.0),
358 convert(1.0),
359 ];
360 let s = s as f64;
361
362 let mb = &self.a * convert::<f64, T>(2.0_f64.powf(-s));
363 self.calc_a2();
364 self.calc_a4();
365 self.calc_a6();
366 let mb2 = self.a2.as_ref().unwrap() * convert::<f64, T>(2.0_f64.powf(-2.0 * s));
367 let mb4 = self.a4.as_ref().unwrap() * convert::<f64, T>(2.0.powf(-4.0 * s));
368 let mb6 = self.a6.as_ref().unwrap() * convert::<f64, T>(2.0.powf(-6.0 * s));
369
370 let u2 = &mb6 * (&mb6 * b[13].clone() + &mb4 * b[11].clone() + &mb2 * b[9].clone());
371 let u = &mb
372 * (&u2
373 + &mb6 * b[7].clone()
374 + &mb4 * b[5].clone()
375 + &mb2 * b[3].clone()
376 + &self.ident * b[1].clone());
377 let v2 = &mb6 * (&mb6 * b[12].clone() + &mb4 * b[10].clone() + &mb2 * b[8].clone());
378 let v = v2
379 + &mb6 * b[6].clone()
380 + &mb4 * b[4].clone()
381 + &mb2 * b[2].clone()
382 + &self.ident * b[0].clone();
383 (u, v)
384 }
385}
386
387#[inline(always)]
389fn factorial(n: usize) -> u128 {
390 match FACTORIAL.get(n) {
391 Some(f) => *f,
392 None => panic!("{n}! is greater than u128::MAX"),
393 }
394}
395
396fn onenorm_matrix_power_nonm<T, D>(a: &OMatrix<T, D, D>, p: usize) -> T
398where
399 T: RealField,
400 D: Dim,
401 DefaultAllocator: Allocator<D, D> + Allocator<D>,
402{
403 let nrows = a.shape_generic().0;
404 let mut v = crate::OVector::<T, D>::repeat_generic(nrows, Const::<1>, convert(1.0));
405 let m = a.transpose();
406
407 for _ in 0..p {
408 v = &m * v;
409 }
410
411 v.max()
412}
413
414fn ell<T, D>(a: &OMatrix<T, D, D>, m: usize) -> u64
415where
416 T: ComplexField,
417 D: Dim,
418 DefaultAllocator: Allocator<D, D> + Allocator<D> + Allocator<D> + Allocator<D, D>,
419{
420 let a_abs = a.map(|x| x.abs());
421
422 let a_abs_onenorm = onenorm_matrix_power_nonm(&a_abs, 2 * m + 1);
423
424 if a_abs_onenorm == <T as ComplexField>::RealField::zero() {
425 return 0;
426 }
427
428 let m_factorial = factorial(m);
430 let choose_2m_m = factorial(2 * m) / (m_factorial * m_factorial);
431
432 let abs_c_recip = choose_2m_m * factorial(2 * m + 1);
433 let alpha = a_abs_onenorm / one_norm(a);
434 let alpha: f64 = try_convert::<_, f64>(alpha).unwrap() / abs_c_recip as f64;
435
436 let u = 2_f64.powf(-53.0);
437 let log2_alpha_div_u = (alpha / u).log2();
438 let value = (log2_alpha_div_u / (2.0 * m as f64)).ceil();
439 if value > 0.0 { value as u64 } else { 0 }
440}
441
442fn solve_p_q<T, D>(u: OMatrix<T, D, D>, v: OMatrix<T, D, D>) -> OMatrix<T, D, D>
443where
444 T: ComplexField,
445 D: DimMin<D, Output = D>,
446 DefaultAllocator: Allocator<D, D> + Allocator<DimMinimum<D, D>>,
447{
448 let p = &u + &v;
449 let q = &v - &u;
450
451 q.lu().solve(&p).unwrap()
452}
453
454fn one_norm<T, D>(m: &OMatrix<T, D, D>) -> T::RealField
455where
456 T: ComplexField,
457 D: Dim,
458 DefaultAllocator: Allocator<D, D>,
459{
460 let mut max = <T as ComplexField>::RealField::zero();
461
462 for i in 0..m.ncols() {
463 let col = m.column(i);
464 max = max.max(
465 col.iter()
466 .fold(<T as ComplexField>::RealField::zero(), |a, b| {
467 a + b.clone().abs()
468 }),
469 );
470 }
471
472 max
473}
474
475impl<T: ComplexField, D> OMatrix<T, D, D>
476where
477 D: DimMin<D, Output = D>,
478 DefaultAllocator: Allocator<D, D>
479 + Allocator<DimMinimum<D, D>>
480 + Allocator<D>
481 + Allocator<D>
482 + Allocator<D, D>,
483{
484 #[must_use]
486 pub fn exp(&self) -> Self {
487 if self.nrows() == 1 {
489 return self.map(|v| v.exp());
490 }
491
492 let mut helper = ExpmPadeHelper::new(self.clone(), true);
493
494 let eta_1 = T::RealField::max(helper.d4_loose(), helper.d6_loose());
495 if eta_1 < convert(1.495_585_217_958_292e-2) && ell(&helper.a, 3) == 0 {
496 let (u, v) = helper.pade3();
497 return solve_p_q(u, v);
498 }
499
500 let eta_2 = T::RealField::max(helper.d4_tight(), helper.d6_loose());
501 if eta_2 < convert(2.539_398_330_063_23e-1) && ell(&helper.a, 5) == 0 {
502 let (u, v) = helper.pade5();
503 return solve_p_q(u, v);
504 }
505
506 let eta_3 = T::RealField::max(helper.d6_tight(), helper.d8_loose());
507 if eta_3 < convert(9.504_178_996_162_932e-1) && ell(&helper.a, 7) == 0 {
508 let (u, v) = helper.pade7();
509 return solve_p_q(u, v);
510 }
511 if eta_3 < convert(2.097_847_961_257_068e0) && ell(&helper.a, 9) == 0 {
512 let (u, v) = helper.pade9();
513 return solve_p_q(u, v);
514 }
515
516 let eta_4 = T::RealField::max(helper.d8_loose(), helper.d10_loose());
517 let eta_5 = T::RealField::min(eta_3, eta_4);
518 let theta_13 = convert(4.25);
519
520 let mut s = if eta_5 == T::RealField::zero() {
521 0
522 } else {
523 let l2 = try_convert::<_, f64>((eta_5 / theta_13).log2().ceil()).unwrap();
524
525 if l2 < 0.0 { 0 } else { l2 as u64 }
526 };
527
528 s += ell(
529 &(&helper.a * convert::<f64, T>(2.0_f64.powf(-(s as f64)))),
530 13,
531 );
532
533 let (u, v) = helper.pade13_scaled(s);
534 let mut x = solve_p_q(u, v);
535
536 for _ in 0..s {
537 x = &x * &x;
538 }
539 x
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 #[test]
546 #[allow(clippy::float_cmp)]
547 fn one_norm() {
548 use crate::Matrix3;
549 let m = Matrix3::new(-3.0, 5.0, 7.0, 2.0, 6.0, 4.0, 0.0, 2.0, 8.0);
550
551 assert_eq!(super::one_norm(&m), 19.0);
552 }
553}