nalgebra/linalg/
schur.rs

1#![allow(clippy::suspicious_operation_groupings)]
2#[cfg(feature = "serde-serialize-no-std")]
3use serde::{Deserialize, Serialize};
4
5use approx::AbsDiffEq;
6use num_complex::Complex as NumComplex;
7use num_traits::identities::Zero;
8use simba::scalar::{ComplexField, RealField};
9use std::cmp;
10
11use crate::allocator::Allocator;
12use crate::base::dimension::{Const, Dim, DimDiff, DimSub, Dyn, U1, U2};
13use crate::base::storage::Storage;
14use crate::base::{DefaultAllocator, OMatrix, OVector, SquareMatrix, Unit, Vector2, Vector3};
15
16use crate::geometry::Reflection;
17use crate::linalg::Hessenberg;
18use crate::linalg::givens::GivensRotation;
19use crate::linalg::householder;
20use crate::{Matrix, UninitVector};
21use std::mem::MaybeUninit;
22
23/// Schur decomposition of a square matrix.
24///
25/// If this is a real matrix, this will be a `RealField` Schur decomposition.
26#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
27#[cfg_attr(
28    feature = "serde-serialize-no-std",
29    serde(bound(serialize = "DefaultAllocator: Allocator<D, D>,
30         OMatrix<T, D, D>: Serialize"))
31)]
32#[cfg_attr(
33    feature = "serde-serialize-no-std",
34    serde(bound(deserialize = "DefaultAllocator: Allocator<D, D>,
35         OMatrix<T, D, D>: Deserialize<'de>"))
36)]
37#[cfg_attr(feature = "defmt", derive(defmt::Format))]
38#[derive(Clone, Debug)]
39pub struct Schur<T: ComplexField, D: Dim>
40where
41    DefaultAllocator: Allocator<D, D>,
42{
43    q: OMatrix<T, D, D>,
44    t: OMatrix<T, D, D>,
45}
46
47impl<T: ComplexField, D: Dim> Copy for Schur<T, D>
48where
49    DefaultAllocator: Allocator<D, D>,
50    OMatrix<T, D, D>: Copy,
51{
52}
53
54impl<T: ComplexField, D: Dim> Schur<T, D>
55where
56    D: DimSub<U1>, // For Hessenberg.
57    DefaultAllocator:
58        Allocator<D, DimDiff<D, U1>> + Allocator<DimDiff<D, U1>> + Allocator<D, D> + Allocator<D>,
59{
60    /// Computes the Schur decomposition of a square matrix.
61    pub fn new(m: OMatrix<T, D, D>) -> Self {
62        Self::try_new(m, T::RealField::default_epsilon(), 0).unwrap()
63    }
64
65    /// Attempts to compute the Schur decomposition of a square matrix.
66    ///
67    /// If only eigenvalues are needed, it is more efficient to call the matrix method
68    /// `.eigenvalues()` instead.
69    ///
70    /// # Arguments
71    ///
72    /// * `eps`       − tolerance used to determine when a value converged to 0.
73    /// * `max_niter` − maximum total number of iterations performed by the algorithm. If this
74    ///   number of iteration is exceeded, `None` is returned. If `niter == 0`, then the algorithm
75    ///   continues indefinitely until convergence.
76    pub fn try_new(m: OMatrix<T, D, D>, eps: T::RealField, max_niter: usize) -> Option<Self> {
77        let mut work = Matrix::zeros_generic(m.shape_generic().0, Const::<1>);
78
79        Self::do_decompose(m, &mut work, eps, max_niter, true)
80            .map(|(q, t)| Schur { q: q.unwrap(), t })
81    }
82
83    fn do_decompose(
84        mut m: OMatrix<T, D, D>,
85        work: &mut OVector<T, D>,
86        eps: T::RealField,
87        max_niter: usize,
88        compute_q: bool,
89    ) -> Option<(Option<OMatrix<T, D, D>>, OMatrix<T, D, D>)> {
90        assert!(
91            m.is_square(),
92            "Unable to compute the eigenvectors and eigenvalues of a non-square matrix."
93        );
94
95        let dim = m.shape_generic().0;
96
97        // Specialization would make this easier.
98        if dim.value() == 0 {
99            let vecs = Some(OMatrix::from_element_generic(dim, dim, T::zero()));
100            let vals = OMatrix::from_element_generic(dim, dim, T::zero());
101            return Some((vecs, vals));
102        } else if dim.value() == 1 {
103            if compute_q {
104                let q = OMatrix::from_element_generic(dim, dim, T::one());
105                return Some((Some(q), m));
106            } else {
107                return Some((None, m));
108            }
109        } else if dim.value() == 2 {
110            return decompose_2x2(m, compute_q);
111        }
112
113        let amax_m = m.camax();
114        // if amax_m == 0 (i.e. the matrix is the zero matrix),
115        // then the unscale_mut call will turn the entire matrix into NaNs
116        // see https://github.com/dimforge/nalgebra/issues/1291
117        if !amax_m.is_zero() {
118            m.unscale_mut(amax_m.clone());
119        }
120
121        let hess = Hessenberg::new_with_workspace(m, work);
122        let mut q;
123        let mut t;
124
125        if compute_q {
126            // TODO: could we work without unpacking? Using only the internal representation of
127            // hessenberg decomposition.
128            let (vecs, vals) = hess.unpack();
129            q = Some(vecs);
130            t = vals;
131        } else {
132            q = None;
133            t = hess.unpack_h()
134        }
135
136        // Implicit double-shift QR method.
137        let mut niter = 0;
138        let (mut start, mut end) = Self::delimit_subproblem(&mut t, eps.clone(), dim.value() - 1);
139
140        while end != start {
141            let subdim = end - start + 1;
142
143            if subdim > 2 {
144                let m = end - 1;
145                let n = end;
146
147                let h11 = t[(start, start)].clone();
148                let h12 = t[(start, start + 1)].clone();
149                let h21 = t[(start + 1, start)].clone();
150                let h22 = t[(start + 1, start + 1)].clone();
151                let h32 = t[(start + 2, start + 1)].clone();
152
153                let hnn = t[(n, n)].clone();
154                let hmm = t[(m, m)].clone();
155                let hnm = t[(n, m)].clone();
156                let hmn = t[(m, n)].clone();
157
158                let tra = hnn.clone() + hmm.clone();
159                let det = hnn * hmm - hnm * hmn;
160
161                let mut axis = Vector3::new(
162                    h11.clone() * h11.clone() + h12 * h21.clone() - tra.clone() * h11.clone() + det,
163                    h21.clone() * (h11 + h22 - tra),
164                    h21 * h32,
165                );
166
167                for k in start..n - 1 {
168                    let (norm, not_zero) = householder::reflection_axis_mut(&mut axis);
169
170                    if not_zero {
171                        if k > start {
172                            t[(k, k - 1)] = norm;
173                            t[(k + 1, k - 1)] = T::zero();
174                            t[(k + 2, k - 1)] = T::zero();
175                        }
176
177                        let refl = Reflection::new(Unit::new_unchecked(axis.clone()), T::zero());
178
179                        {
180                            let krows = cmp::min(k + 4, end + 1);
181                            let mut work = work.rows_mut(0, krows);
182                            refl.reflect(
183                                &mut t.generic_view_mut((k, k), (Const::<3>, Dyn(dim.value() - k))),
184                            );
185                            refl.reflect_rows(
186                                &mut t.generic_view_mut((0, k), (Dyn(krows), Const::<3>)),
187                                &mut work,
188                            );
189                        }
190
191                        if let Some(ref mut q) = q {
192                            refl.reflect_rows(
193                                &mut q.generic_view_mut((0, k), (dim, Const::<3>)),
194                                work,
195                            );
196                        }
197                    }
198
199                    axis.x = t[(k + 1, k)].clone();
200                    axis.y = t[(k + 2, k)].clone();
201
202                    if k < n - 2 {
203                        axis.z = t[(k + 3, k)].clone();
204                    }
205                }
206
207                let mut axis = Vector2::new(axis.x.clone(), axis.y.clone());
208                let (norm, not_zero) = householder::reflection_axis_mut(&mut axis);
209
210                if not_zero {
211                    let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
212
213                    t[(m, m - 1)] = norm;
214                    t[(n, m - 1)] = T::zero();
215
216                    {
217                        let mut work = work.rows_mut(0, end + 1);
218                        refl.reflect(
219                            &mut t.generic_view_mut((m, m), (Const::<2>, Dyn(dim.value() - m))),
220                        );
221                        refl.reflect_rows(
222                            &mut t.generic_view_mut((0, m), (Dyn(end + 1), Const::<2>)),
223                            &mut work,
224                        );
225                    }
226
227                    if let Some(ref mut q) = q {
228                        refl.reflect_rows(&mut q.generic_view_mut((0, m), (dim, Const::<2>)), work);
229                    }
230                }
231            } else {
232                // Decouple the 2x2 block if it has real eigenvalues.
233                if let Some(rot) = compute_2x2_basis(&t.fixed_view::<2, 2>(start, start)) {
234                    let inv_rot = rot.inverse();
235                    inv_rot.rotate(
236                        &mut t.generic_view_mut(
237                            (start, start),
238                            (Const::<2>, Dyn(dim.value() - start)),
239                        ),
240                    );
241                    rot.rotate_rows(
242                        &mut t.generic_view_mut((0, start), (Dyn(end + 1), Const::<2>)),
243                    );
244                    t[(end, start)] = T::zero();
245
246                    if let Some(ref mut q) = q {
247                        rot.rotate_rows(&mut q.generic_view_mut((0, start), (dim, Const::<2>)));
248                    }
249                }
250
251                // Check if we reached the beginning of the matrix.
252                if end > 2 {
253                    end -= 2;
254                } else {
255                    break;
256                }
257            }
258
259            let sub = Self::delimit_subproblem(&mut t, eps.clone(), end);
260
261            start = sub.0;
262            end = sub.1;
263
264            niter += 1;
265            if niter == max_niter {
266                return None;
267            }
268        }
269
270        t.scale_mut(amax_m);
271
272        Some((q, t))
273    }
274
275    /// Computes the eigenvalues of the decomposed matrix.
276    fn do_eigenvalues(t: &OMatrix<T, D, D>, out: &mut OVector<T, D>) -> bool {
277        let dim = t.nrows();
278        let mut m = 0;
279
280        while m < dim - 1 {
281            let n = m + 1;
282
283            if t[(n, m)].is_zero() {
284                out[m] = t[(m, m)].clone();
285                m += 1;
286            } else {
287                // Complex eigenvalue.
288                return false;
289            }
290        }
291
292        if m == dim - 1 {
293            out[m] = t[(m, m)].clone();
294        }
295
296        true
297    }
298
299    /// Computes the complex eigenvalues of the decomposed matrix.
300    fn do_complex_eigenvalues(t: &OMatrix<T, D, D>, out: &mut UninitVector<NumComplex<T>, D>)
301    where
302        T: RealField,
303        DefaultAllocator: Allocator<D>,
304    {
305        let dim = t.nrows();
306        let mut m = 0;
307
308        while m < dim - 1 {
309            let n = m + 1;
310
311            if t[(n, m)].is_zero() {
312                out[m] = MaybeUninit::new(NumComplex::new(t[(m, m)].clone(), T::zero()));
313                m += 1;
314            } else {
315                // Solve the 2x2 eigenvalue subproblem.
316                let hmm = t[(m, m)].clone();
317                let hnm = t[(n, m)].clone();
318                let hmn = t[(m, n)].clone();
319                let hnn = t[(n, n)].clone();
320
321                // NOTE: use the same algorithm as in compute_2x2_eigvals.
322                let val = (hmm.clone() - hnn.clone()) * crate::convert(0.5);
323                let discr = hnm * hmn + val.clone() * val;
324
325                // All 2x2 blocks have negative discriminant because we already decoupled those
326                // with positive eigenvalues.
327                let sqrt_discr = NumComplex::new(T::zero(), (-discr).sqrt());
328
329                let half_tra = (hnn + hmm) * crate::convert(0.5);
330                out[m] = MaybeUninit::new(
331                    NumComplex::new(half_tra.clone(), T::zero()) + sqrt_discr.clone(),
332                );
333                out[m + 1] =
334                    MaybeUninit::new(NumComplex::new(half_tra, T::zero()) - sqrt_discr.clone());
335
336                m += 2;
337            }
338        }
339
340        if m == dim - 1 {
341            out[m] = MaybeUninit::new(NumComplex::new(t[(m, m)].clone(), T::zero()));
342        }
343    }
344
345    fn delimit_subproblem(t: &mut OMatrix<T, D, D>, eps: T::RealField, end: usize) -> (usize, usize)
346    where
347        D: DimSub<U1>,
348        DefaultAllocator: Allocator<DimDiff<D, U1>>,
349    {
350        let mut n = end;
351        // Equivalent to SMLNUM in LAPACK DLAHQR. Since SMLNUM depends on the machine
352        // precision, we use eps^2 here as best approximation.
353        // This is justified because the relative threshold use eps * ( t[n,n] + t[m,m] )
354        let absolute_threshold = eps.clone() * eps.clone();
355
356        while n > 0 {
357            let m = n - 1;
358            let off_diag_norm = t[(n, m)].clone().norm1();
359
360            if off_diag_norm <= absolute_threshold
361                || off_diag_norm
362                    <= eps.clone() * (t[(n, n)].clone().norm1() + t[(m, m)].clone().norm1())
363            {
364                t[(n, m)] = T::zero();
365            } else {
366                break;
367            }
368
369            n -= 1;
370        }
371
372        if n == 0 {
373            return (0, 0);
374        }
375
376        let mut new_start = n - 1;
377        while new_start > 0 {
378            let m = new_start - 1;
379
380            let off_diag = t[(new_start, m)].clone();
381            let off_diag_norm = off_diag.clone().norm1();
382
383            if off_diag.is_zero()
384                || off_diag_norm <= absolute_threshold
385                || off_diag_norm
386                    <= eps.clone()
387                        * (t[(new_start, new_start)].clone().norm1() + t[(m, m)].clone().norm1())
388            {
389                t[(new_start, m)] = T::zero();
390                break;
391            }
392
393            new_start -= 1;
394        }
395
396        (new_start, n)
397    }
398
399    /// Retrieves the unitary matrix `Q` and the upper-quasitriangular matrix `T` such that the
400    /// decomposed matrix equals `Q * T * Q.transpose()`.
401    pub fn unpack(self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
402        (self.q, self.t)
403    }
404
405    /// Computes the real eigenvalues of the decomposed matrix.
406    ///
407    /// Return `None` if some eigenvalues are complex.
408    #[must_use]
409    pub fn eigenvalues(&self) -> Option<OVector<T, D>> {
410        let mut out = Matrix::zeros_generic(self.t.shape_generic().0, Const::<1>);
411        if Self::do_eigenvalues(&self.t, &mut out) {
412            Some(out)
413        } else {
414            None
415        }
416    }
417
418    /// Computes the complex eigenvalues of the decomposed matrix.
419    #[must_use]
420    pub fn complex_eigenvalues(&self) -> OVector<NumComplex<T>, D>
421    where
422        T: RealField,
423        DefaultAllocator: Allocator<D>,
424    {
425        let mut out = Matrix::uninit(self.t.shape_generic().0, Const::<1>);
426        Self::do_complex_eigenvalues(&self.t, &mut out);
427        // Safety: out has been fully initialized by do_complex_eigenvalues.
428        unsafe { out.assume_init() }
429    }
430}
431
432fn decompose_2x2<T: ComplexField, D: Dim>(
433    mut m: OMatrix<T, D, D>,
434    compute_q: bool,
435) -> Option<(Option<OMatrix<T, D, D>>, OMatrix<T, D, D>)>
436where
437    DefaultAllocator: Allocator<D, D>,
438{
439    let dim = m.shape_generic().0;
440    let mut q = None;
441    match compute_2x2_basis(&m.fixed_view::<2, 2>(0, 0)) {
442        Some(rot) => {
443            let mut m = m.fixed_view_mut::<2, 2>(0, 0);
444            let inv_rot = rot.inverse();
445            inv_rot.rotate(&mut m);
446            rot.rotate_rows(&mut m);
447            m[(1, 0)] = T::zero();
448
449            if compute_q {
450                // XXX: we have to build the matrix manually because
451                // rot.to_rotation_matrix().unwrap() causes an ICE.
452                let c = T::from_real(rot.c());
453                q = Some(OMatrix::from_column_slice_generic(
454                    dim,
455                    dim,
456                    &[c.clone(), rot.s(), -rot.s().conjugate(), c],
457                ));
458            }
459        }
460        None => {
461            if compute_q {
462                q = Some(OMatrix::identity_generic(dim, dim));
463            }
464        }
465    };
466
467    Some((q, m))
468}
469
470fn compute_2x2_eigvals<T: ComplexField, S: Storage<T, U2, U2>>(
471    m: &SquareMatrix<T, U2, S>,
472) -> Option<(T, T)> {
473    // Solve the 2x2 eigenvalue subproblem.
474    let h00 = m[(0, 0)].clone();
475    let h10 = m[(1, 0)].clone();
476    let h01 = m[(0, 1)].clone();
477    let h11 = m[(1, 1)].clone();
478
479    // NOTE: this discriminant computation is more stable than the
480    // one based on the trace and determinant: 0.25 * tra * tra - det
481    // because it ensures positiveness for symmetric matrices.
482    let val = (h00.clone() - h11.clone()) * crate::convert(0.5);
483    let discr = h10 * h01 + val.clone() * val;
484
485    discr.try_sqrt().map(|sqrt_discr| {
486        let half_tra = (h00 + h11) * crate::convert(0.5);
487        (half_tra.clone() + sqrt_discr.clone(), half_tra - sqrt_discr)
488    })
489}
490
491// Computes the 2x2 transformation that upper-triangulates a 2x2 matrix with real eigenvalues.
492/// Computes the singular vectors for a 2x2 matrix.
493///
494/// Returns `None` if the matrix has complex eigenvalues, or is upper-triangular. In both case,
495/// the basis is the identity.
496fn compute_2x2_basis<T: ComplexField, S: Storage<T, U2, U2>>(
497    m: &SquareMatrix<T, U2, S>,
498) -> Option<GivensRotation<T>> {
499    let h10 = m[(1, 0)].clone();
500
501    if h10.is_zero() {
502        return None;
503    }
504
505    let (eigval1, eigval2) = compute_2x2_eigvals(m)?;
506    let x1 = eigval1 - m[(1, 1)].clone();
507    let x2 = eigval2 - m[(1, 1)].clone();
508
509    // NOTE: Choose the one that yields a larger x component.
510    // This is necessary for numerical stability of the normalization of the complex
511    // number.
512    if x1.clone().norm1() > x2.clone().norm1() {
513        Some(GivensRotation::new(x1, h10).0)
514    } else {
515        Some(GivensRotation::new(x2, h10).0)
516    }
517}
518
519impl<T: ComplexField, D: Dim, S: Storage<T, D, D>> SquareMatrix<T, D, S>
520where
521    D: DimSub<U1>, // For Hessenberg.
522    DefaultAllocator:
523        Allocator<D, DimDiff<D, U1>> + Allocator<DimDiff<D, U1>> + Allocator<D, D> + Allocator<D>,
524{
525    /// Computes the eigenvalues of this matrix.
526    #[must_use]
527    pub fn eigenvalues(&self) -> Option<OVector<T, D>> {
528        assert!(
529            self.is_square(),
530            "Unable to compute eigenvalues of a non-square matrix."
531        );
532
533        let mut work = Matrix::zeros_generic(self.shape_generic().0, Const::<1>);
534
535        // Special case for 2x2 matrices.
536        if self.nrows() == 2 {
537            // TODO: can we avoid this slicing
538            // (which is needed here just to transform D to U2)?
539            let me = self.fixed_view::<2, 2>(0, 0);
540            return match compute_2x2_eigvals(&me) {
541                Some((a, b)) => {
542                    work[0] = a;
543                    work[1] = b;
544                    Some(work)
545                }
546                None => None,
547            };
548        }
549
550        // TODO: add balancing?
551        let schur = Schur::do_decompose(
552            self.clone_owned(),
553            &mut work,
554            T::RealField::default_epsilon(),
555            0,
556            false,
557        )
558        .unwrap();
559
560        if Schur::do_eigenvalues(&schur.1, &mut work) {
561            Some(work)
562        } else {
563            None
564        }
565    }
566
567    /// Computes the eigenvalues of this matrix.
568    #[must_use]
569    pub fn complex_eigenvalues(&self) -> OVector<NumComplex<T>, D>
570    // TODO: add balancing?
571    where
572        T: RealField,
573        DefaultAllocator: Allocator<D>,
574    {
575        let dim = self.shape_generic().0;
576        let mut work = Matrix::zeros_generic(dim, Const::<1>);
577
578        let schur = Schur::do_decompose(
579            self.clone_owned(),
580            &mut work,
581            T::default_epsilon(),
582            0,
583            false,
584        )
585        .unwrap();
586        let mut eig = Matrix::uninit(dim, Const::<1>);
587        Schur::do_complex_eigenvalues(&schur.1, &mut eig);
588        // Safety: eig has been fully initialized by do_complex_eigenvalues.
589        unsafe { eig.assume_init() }
590    }
591}