1#[cfg(feature = "serde-serialize-no-std")]
2use serde::{Deserialize, Serialize};
3
4use approx::AbsDiffEq;
5use num::Zero;
6
7use crate::allocator::Allocator;
8use crate::base::{DefaultAllocator, Matrix2, OMatrix, OVector, SquareMatrix, Vector2};
9use crate::dimension::{Dim, DimDiff, DimSub, U1};
10use crate::storage::Storage;
11use simba::scalar::ComplexField;
12
13use crate::linalg::SymmetricTridiagonal;
14use crate::linalg::givens::GivensRotation;
15
16#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
18#[cfg_attr(
19 feature = "serde-serialize-no-std",
20 serde(bound(serialize = "DefaultAllocator: Allocator<D, D> +
21 Allocator<D>,
22 OVector<T::RealField, D>: Serialize,
23 OMatrix<T, D, D>: Serialize"))
24)]
25#[cfg_attr(
26 feature = "serde-serialize-no-std",
27 serde(bound(deserialize = "DefaultAllocator: Allocator<D, D> +
28 Allocator<D>,
29 OVector<T::RealField, D>: Deserialize<'de>,
30 OMatrix<T, D, D>: Deserialize<'de>"))
31)]
32#[cfg_attr(feature = "defmt", derive(defmt::Format))]
33#[derive(Clone, Debug)]
34pub struct SymmetricEigen<T: ComplexField, D: Dim>
35where
36 DefaultAllocator: Allocator<D, D> + Allocator<D>,
37{
38 pub eigenvectors: OMatrix<T, D, D>,
40
41 pub eigenvalues: OVector<T::RealField, D>,
43}
44
45impl<T: ComplexField, D: Dim> Copy for SymmetricEigen<T, D>
46where
47 DefaultAllocator: Allocator<D, D> + Allocator<D>,
48 OMatrix<T, D, D>: Copy,
49 OVector<T::RealField, D>: Copy,
50{
51}
52
53impl<T: ComplexField, D: Dim> SymmetricEigen<T, D>
54where
55 DefaultAllocator: Allocator<D, D> + Allocator<D>,
56{
57 pub fn new(m: OMatrix<T, D, D>) -> Self
61 where
62 D: DimSub<U1>,
63 DefaultAllocator: Allocator<DimDiff<D, U1>> + Allocator<DimDiff<D, U1>>,
64 {
65 Self::try_new(m, T::RealField::default_epsilon(), 0).unwrap()
66 }
67
68 pub fn try_new(m: OMatrix<T, D, D>, eps: T::RealField, max_niter: usize) -> Option<Self>
80 where
81 D: DimSub<U1>,
82 DefaultAllocator: Allocator<DimDiff<D, U1>> + Allocator<DimDiff<D, U1>>,
83 {
84 Self::do_decompose(m, true, eps, max_niter).map(|(vals, vecs)| SymmetricEigen {
85 eigenvectors: vecs.unwrap(),
86 eigenvalues: vals,
87 })
88 }
89
90 fn do_decompose(
91 mut matrix: OMatrix<T, D, D>,
92 eigenvectors: bool,
93 eps: T::RealField,
94 max_niter: usize,
95 ) -> Option<(OVector<T::RealField, D>, Option<OMatrix<T, D, D>>)>
96 where
97 D: DimSub<U1>,
98 DefaultAllocator: Allocator<DimDiff<D, U1>> + Allocator<DimDiff<D, U1>>,
99 {
100 assert!(
101 matrix.is_square(),
102 "Unable to compute the eigendecomposition of a non-square matrix."
103 );
104 let dim = matrix.nrows();
105 let m_amax = matrix.camax();
106
107 if !m_amax.is_zero() {
108 matrix.unscale_mut(m_amax.clone());
109 }
110
111 let (mut q_mat, mut diag, mut off_diag);
112
113 if eigenvectors {
114 let res = SymmetricTridiagonal::new(matrix).unpack();
115 q_mat = Some(res.0);
116 diag = res.1;
117 off_diag = res.2;
118 } else {
119 let res = SymmetricTridiagonal::new(matrix).unpack_tridiagonal();
120 q_mat = None;
121 diag = res.0;
122 off_diag = res.1;
123 }
124
125 if dim == 1 {
126 diag.scale_mut(m_amax);
127 return Some((diag, q_mat));
128 }
129
130 let mut niter = 0;
131 let (mut start, mut end) =
132 Self::delimit_subproblem(&diag, &mut off_diag, dim - 1, eps.clone());
133
134 while end != start {
135 let subdim = end - start + 1;
136
137 #[allow(clippy::comparison_chain)]
138 if subdim > 2 {
139 let m = end - 1;
140 let n = end;
141
142 let mut vec = Vector2::new(
143 diag[start].clone()
144 - wilkinson_shift(
145 diag[m].clone().clone(),
146 diag[n].clone(),
147 off_diag[m].clone().clone(),
148 ),
149 off_diag[start].clone(),
150 );
151
152 for i in start..n {
153 let j = i + 1;
154
155 match GivensRotation::cancel_y(&vec) {
156 Some((rot, norm)) => {
157 if i > start {
158 off_diag[i - 1] = norm;
160 }
161
162 let mii = diag[i].clone();
163 let mjj = diag[j].clone();
164 let mij = off_diag[i].clone();
165
166 let cc = rot.c() * rot.c();
167 let ss = rot.s() * rot.s();
168 let cs = rot.c() * rot.s();
169
170 let b = cs.clone() * crate::convert(2.0) * mij.clone();
171
172 diag[i] =
173 (cc.clone() * mii.clone() + ss.clone() * mjj.clone()) - b.clone();
174 diag[j] = (ss.clone() * mii.clone() + cc.clone() * mjj.clone()) + b;
175 off_diag[i] = cs * (mii - mjj) + mij * (cc - ss);
176
177 if i != n - 1 {
178 vec.x = off_diag[i].clone();
179 vec.y = -rot.s() * off_diag[i + 1].clone();
180 off_diag[i + 1] *= rot.c();
181 }
182
183 if let Some(ref mut q) = q_mat {
184 let rot =
185 GivensRotation::new_unchecked(rot.c(), T::from_real(rot.s()));
186 rot.inverse().rotate_rows(&mut q.fixed_columns_mut::<2>(i));
187 }
188 }
189 None => {
190 break;
191 }
192 }
193 }
194
195 if off_diag[m].clone().norm1()
196 <= eps.clone() * (diag[m].clone().norm1() + diag[n].clone().norm1())
197 {
198 end -= 1;
199 }
200 } else if subdim == 2 {
201 let m = Matrix2::new(
202 diag[start].clone(),
203 off_diag[start].clone().conjugate(),
204 off_diag[start].clone(),
205 diag[start + 1].clone(),
206 );
207 let eigvals = m.eigenvalues().unwrap();
208
209 let basis = if (eigvals.x.clone() - diag[start + 1].clone()).abs()
211 > (eigvals.x.clone() - diag[start].clone()).abs()
212 {
213 Vector2::new(
214 eigvals.x.clone() - diag[start + 1].clone(),
215 off_diag[start].clone(),
216 )
217 } else {
218 Vector2::new(
219 off_diag[start].clone(),
220 eigvals.x.clone() - diag[start].clone(),
221 )
222 };
223
224 diag[start] = eigvals[0].clone();
225 diag[start + 1] = eigvals[1].clone();
226
227 if let Some(ref mut q) = q_mat {
228 if let Some((rot, _)) =
229 GivensRotation::try_new(basis.x.clone(), basis.y.clone(), eps.clone())
230 {
231 let rot = GivensRotation::new_unchecked(rot.c(), T::from_real(rot.s()));
232 rot.rotate_rows(&mut q.fixed_columns_mut::<2>(start));
233 }
234 }
235
236 end -= 1;
237 }
238
239 let sub = Self::delimit_subproblem(&diag, &mut off_diag, end, eps.clone());
241
242 start = sub.0;
243 end = sub.1;
244
245 niter += 1;
246 if niter == max_niter {
247 return None;
248 }
249 }
250
251 diag.scale_mut(m_amax);
252
253 Some((diag, q_mat))
254 }
255
256 fn delimit_subproblem(
257 diag: &OVector<T::RealField, D>,
258 off_diag: &mut OVector<T::RealField, DimDiff<D, U1>>,
259 end: usize,
260 eps: T::RealField,
261 ) -> (usize, usize)
262 where
263 D: DimSub<U1>,
264 DefaultAllocator: Allocator<DimDiff<D, U1>>,
265 {
266 let mut n = end;
267
268 while n > 0 {
269 let m = n - 1;
270
271 if off_diag[m].clone().norm1()
272 > eps.clone() * (diag[n].clone().norm1() + diag[m].clone().norm1())
273 {
274 break;
275 }
276
277 n -= 1;
278 }
279
280 if n == 0 {
281 return (0, 0);
282 }
283
284 let mut new_start = n - 1;
285 while new_start > 0 {
286 let m = new_start - 1;
287
288 if off_diag[m].clone().is_zero()
289 || off_diag[m].clone().norm1()
290 <= eps.clone() * (diag[new_start].clone().norm1() + diag[m].clone().norm1())
291 {
292 off_diag[m] = T::RealField::zero();
293 break;
294 }
295
296 new_start -= 1;
297 }
298
299 (new_start, n)
300 }
301
302 #[must_use]
306 pub fn recompose(&self) -> OMatrix<T, D, D> {
307 let mut u_t = self.eigenvectors.clone();
308 for i in 0..self.eigenvalues.len() {
309 let val = self.eigenvalues[i].clone();
310 u_t.column_mut(i).scale_mut(val);
311 }
312 u_t.adjoint_mut();
313 &self.eigenvectors * u_t
314 }
315}
316
317pub fn wilkinson_shift<T: ComplexField>(tmm: T, tnn: T, tmn: T) -> T {
324 let sq_tmn = tmn.clone() * tmn;
325 if !sq_tmn.is_zero() {
326 let d = (tmm - tnn.clone()) * crate::convert(0.5);
328 tnn - sq_tmn.clone() / (d.clone() + d.clone().signum() * (d.clone() * d + sq_tmn).sqrt())
329 } else {
330 tnn
331 }
332}
333
334impl<T: ComplexField, D: DimSub<U1>, S: Storage<T, D, D>> SquareMatrix<T, D, S>
340where
341 DefaultAllocator:
342 Allocator<D, D> + Allocator<DimDiff<D, U1>> + Allocator<D> + Allocator<DimDiff<D, U1>>,
343{
344 #[must_use]
348 pub fn symmetric_eigenvalues(&self) -> OVector<T::RealField, D> {
349 SymmetricEigen::do_decompose(
350 self.clone_owned(),
351 false,
352 T::RealField::default_epsilon(),
353 0,
354 )
355 .unwrap()
356 .0
357 }
358}
359
360#[cfg(test)]
361mod test {
362 use crate::base::{Matrix2, Matrix4};
363
364 #[test]
366 fn symmetric_eigen_regression_issue_1109() {
367 let m = Matrix4::new(
368 -19884.07f64,
369 -10.07188,
370 11.277279,
371 -188560.63,
372 -10.07188,
373 12.518197,
374 1.3770627,
375 -102.97504,
376 11.277279,
377 1.3770627,
378 14.587362,
379 113.26099,
380 -188560.63,
381 -102.97504,
382 113.26099,
383 -1788112.3,
384 );
385 let eig = m.symmetric_eigen();
386 assert!(relative_eq!(
387 m.lower_triangle(),
388 eig.recompose().lower_triangle(),
389 epsilon = 1.0e-5
390 ));
391 }
392
393 fn expected_shift(m: Matrix2<f64>) -> f64 {
394 let vals = m.eigenvalues().unwrap();
395
396 if (vals.x - m.m22).abs() < (vals.y - m.m22).abs() {
397 vals.x
398 } else {
399 vals.y
400 }
401 }
402
403 #[cfg(feature = "rand")]
404 #[test]
405 fn wilkinson_shift_random() {
406 for _ in 0..1000 {
407 let m = Matrix2::<f64>::new_random();
408 let m = m * m.transpose();
409
410 let expected = expected_shift(m);
411 let computed = super::wilkinson_shift(m.m11, m.m22, m.m12);
412 assert!(relative_eq!(expected, computed, epsilon = 1.0e-7));
413 }
414 }
415
416 #[test]
417 fn wilkinson_shift_zero() {
418 let m = Matrix2::new(0.0, 0.0, 0.0, 0.0);
419 assert!(relative_eq!(
420 expected_shift(m),
421 super::wilkinson_shift(m.m11, m.m22, m.m12)
422 ));
423 }
424
425 #[test]
426 fn wilkinson_shift_zero_diagonal() {
427 let m = Matrix2::new(0.0, 42.0, 42.0, 0.0);
428 assert!(relative_eq!(
429 expected_shift(m),
430 super::wilkinson_shift(m.m11, m.m22, m.m12)
431 ));
432 }
433
434 #[test]
435 fn wilkinson_shift_zero_off_diagonal() {
436 let m = Matrix2::new(42.0, 0.0, 0.0, 64.0);
437 assert!(relative_eq!(
438 expected_shift(m),
439 super::wilkinson_shift(m.m11, m.m22, m.m12)
440 ));
441 }
442
443 #[test]
444 fn wilkinson_shift_zero_trace() {
445 let m = Matrix2::new(42.0, 20.0, 20.0, -42.0);
446 assert!(relative_eq!(
447 expected_shift(m),
448 super::wilkinson_shift(m.m11, m.m22, m.m12)
449 ));
450 }
451
452 #[test]
453 fn wilkinson_shift_zero_diag_diff_and_zero_off_diagonal() {
454 let m = Matrix2::new(42.0, 0.0, 0.0, 42.0);
455 assert!(relative_eq!(
456 expected_shift(m),
457 super::wilkinson_shift(m.m11, m.m22, m.m12)
458 ));
459 }
460
461 #[test]
462 fn wilkinson_shift_zero_det() {
463 let m = Matrix2::new(2.0, 4.0, 4.0, 8.0);
464 assert!(relative_eq!(
465 expected_shift(m),
466 super::wilkinson_shift(m.m11, m.m22, m.m12)
467 ));
468 }
469}