bevy_heavy/dim3/
eigen3.rs

1// The eigensolver is a Rust adaptation, with modifications, of the pseudocode and approach described in
2// "A Robust Eigensolver for 3 x 3 Symmetric Matrices" by David Eberly, Geometric Tools, Redmond WA 98052.
3// https://www.geometrictools.com/Documentation/RobustEigenSymmetric3x3.pdf
4
5use bevy_math::{ops, FloatPow, Mat3, Vec3, Vec3Swizzles};
6
7/// The [eigen decomposition] of a [symmetric] 3x3 matrix.
8///
9/// [eigen decomposition]: https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix
10/// [symmetric]: https://en.wikipedia.org/wiki/Symmetric_matrix
11#[derive(Clone, Copy, Debug, PartialEq)]
12#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
13#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
14pub struct SymmetricEigen3 {
15    /// The eigenvalues of the symmetric 3x3 matrix.
16    ///
17    /// These should be in ascending order `eigen1 <= eigen2 <= eigen3`.
18    pub eigenvalues: Vec3,
19    /// The three eigenvectors of the symmetric 3x3 matrix.
20    /// Each eigenvector should be unit length and orthogonal to the other eigenvectors.
21    ///
22    /// The eigenvectors are ordered to correspond to the eigenvalues. For example,
23    /// `eigenvectors.x_axis` corresponds to `eigenvalues.x`.
24    pub eigenvectors: Mat3,
25}
26
27impl SymmetricEigen3 {
28    /// Computes the eigen decomposition of the given symmetric 3x3 matrix.
29    ///
30    /// The eigenvalues are returned in ascending order `eigen1 <= eigen2 <= eigen3`.
31    /// This can be reversed with the [`reverse`](Self::reverse) method.
32    pub fn new(mat: Mat3) -> Self {
33        let (mut eigenvalues, is_diagonal) = Self::eigenvalues(mat);
34
35        if is_diagonal {
36            // The matrix is already diagonal. Sort the eigenvalues in ascending order,
37            // ordering the eigenvectors accordingly, and return early.
38            let mut eigenvectors = Mat3::IDENTITY;
39            if eigenvalues[0] > eigenvalues[1] {
40                core::mem::swap(&mut eigenvalues.x, &mut eigenvalues.y);
41                core::mem::swap(&mut eigenvectors.x_axis, &mut eigenvectors.y_axis);
42            }
43            if eigenvalues[1] > eigenvalues[2] {
44                core::mem::swap(&mut eigenvalues.y, &mut eigenvalues.z);
45                core::mem::swap(&mut eigenvectors.y_axis, &mut eigenvectors.z_axis);
46            }
47            return Self {
48                eigenvalues,
49                eigenvectors,
50            };
51        }
52
53        // Compute the eigenvectors corresponding to the eigenvalues.
54        let eigenvector1 = Self::eigenvector1(mat, eigenvalues.x);
55        let eigenvector2 = Self::eigenvector2(mat, eigenvector1, eigenvalues.y);
56        let eigenvector3 = Self::eigenvector3(eigenvector1, eigenvector2);
57
58        Self {
59            eigenvalues,
60            eigenvectors: Mat3::from_cols(eigenvector1, eigenvector2, eigenvector3),
61        }
62    }
63
64    /// Reverses the order of the eigenvalues and their corresponding eigenvectors.
65    pub fn reverse(&self) -> Self {
66        Self {
67            eigenvalues: self.eigenvalues.zyx(),
68            eigenvectors: Mat3::from_cols(
69                self.eigenvectors.z_axis,
70                self.eigenvectors.y_axis,
71                self.eigenvectors.x_axis,
72            ),
73        }
74    }
75
76    /// Computes the eigenvalues of a symmetric 3x3 matrix, also returning whether the input matrix is diagonal.
77    ///
78    /// If the matrix is already diagonal, the eigenvalues are returned as is without reordering.
79    /// Otherwise, the eigenvalues are computed and returned in ascending order
80    /// such that `eigen1 <= eigen2 <= eigen3`.
81    pub fn eigenvalues(mat: Mat3) -> (Vec3, bool) {
82        // Reference: https://en.wikipedia.org/wiki/Eigenvalue_algorithm#Symmetric_3%C3%973_matrices
83
84        let p1 = mat.y_axis.x.squared() + mat.z_axis.x.squared() + mat.z_axis.y.squared();
85
86        if p1 == 0.0 {
87            // The matrix is diagonal.
88            return (Vec3::new(mat.x_axis.x, mat.y_axis.y, mat.z_axis.z), true);
89        }
90
91        let q = (mat.x_axis.x + mat.y_axis.y + mat.z_axis.z) / 3.0;
92        let p2 = (mat.x_axis.x - q).squared()
93            + (mat.y_axis.y - q).squared()
94            + (mat.z_axis.z - q).squared()
95            + 2.0 * p1;
96        let p = ops::sqrt(p2 / 6.0);
97        let mat_b = 1.0 / p * (mat - q * Mat3::IDENTITY);
98        let r = mat_b.determinant() / 2.0;
99
100        // r should be in the [-1, 1] range for a symmetric matrix,
101        // but computation error can leave it slightly outside this range.
102        let phi = if r <= -1.0 {
103            core::f32::consts::FRAC_PI_3
104        } else if r >= 1.0 {
105            0.0
106        } else {
107            ops::acos(r) / 3.0
108        };
109
110        // The eigenvalues satisfy eigen3 <= eigen2 <= eigen1
111        let eigen1 = q + 2.0 * p * ops::cos(phi);
112        let eigen3 = q + 2.0 * p * ops::cos(phi + 2.0 * core::f32::consts::FRAC_PI_3);
113        let eigen2 = 3.0 * q - eigen1 - eigen3; // trace(mat) = eigen1 + eigen2 + eigen3
114        (Vec3::new(eigen3, eigen2, eigen1), false)
115    }
116
117    // TODO: Fall back to QL when the eigenvalue precision is poor.
118    /// Computes the unit-length eigenvector corresponding to the `eigenvalue1` of `mat` that was
119    /// computed from the root of a cubic polynomial with a multiplicity of 1.
120    ///
121    /// If the other two eigenvalues are well separated, this method can be used for computing
122    /// all three eigenvectors. However, to avoid numerical issues when eigenvalues are close to
123    /// each other, it's recommended to use the `eigenvector2` method for the second eigenvector.
124    ///
125    /// The third eigenvector can be computed as the cross product of the first two.
126    pub fn eigenvector1(mat: Mat3, eigenvalue1: f32) -> Vec3 {
127        let cols = mat - Mat3::from_diagonal(Vec3::splat(eigenvalue1));
128        let c0xc1 = cols.x_axis.cross(cols.y_axis);
129        let c0xc2 = cols.x_axis.cross(cols.z_axis);
130        let c1xc2 = cols.y_axis.cross(cols.z_axis);
131        let d0 = c0xc1.length_squared();
132        let d1 = c0xc2.length_squared();
133        let d2 = c1xc2.length_squared();
134
135        let mut d_max = d0;
136        let mut i_max = 0;
137
138        if d1 > d_max {
139            d_max = d1;
140            i_max = 1;
141        }
142        if d2 > d_max {
143            i_max = 2;
144        }
145        if i_max == 0 {
146            c0xc1 / ops::sqrt(d0)
147        } else if i_max == 1 {
148            c0xc2 / ops::sqrt(d1)
149        } else {
150            c1xc2 / ops::sqrt(d2)
151        }
152    }
153
154    /// Computes the unit-length eigenvector corresponding to the `eigenvalue2` of `mat` that was
155    /// computed from the root of a cubic polynomial with a potential multiplicity of 2.
156    ///
157    /// The third eigenvector can be computed as the cross product of the first two.
158    pub fn eigenvector2(mat: Mat3, eigenvector1: Vec3, eigenvalue2: f32) -> Vec3 {
159        // Compute right-handed orthonormal set { U, V, W }, where W is eigenvector1.
160        let (u, v) = eigenvector1.any_orthonormal_pair();
161
162        // The unit-length eigenvector is E = x0 * U + x1 * V. We need to compute x0 and x1.
163        //
164        // Define the symmetrix 2x2 matrix M = J^T * (mat - eigenvalue2 * I), where J = [U V]
165        // and I is a 3x3 identity matrix. This means that E = J * X, where X is a column vector
166        // with rows x0 and x1. The 3x3 linear system (mat - eigenvalue2 * I) * E = 0 reduces to
167        // the 2x2 linear system M * X = 0.
168        //
169        // When eigenvalue2 != eigenvalue3, M has rank 1 and is not the zero matrix.
170        // Otherwise, it has rank 0, and it is the zero matrix.
171
172        let au = mat * u;
173        let av = mat * v;
174
175        let mut m00 = u.dot(au) - eigenvalue2;
176        let mut m01 = u.dot(av);
177        let mut m11 = v.dot(av) - eigenvalue2;
178        let (abs_m00, abs_m01, abs_m11) = (ops::abs(m00), ops::abs(m01), ops::abs(m11));
179
180        if abs_m00 >= abs_m11 {
181            let max_abs_component = abs_m00.max(abs_m01);
182            if max_abs_component > 0.0 {
183                if abs_m00 >= abs_m01 {
184                    // m00 is the largest component of the row.
185                    // Factor it out for normalization and discard to avoid underflow or overflow.
186                    m01 /= m00;
187                    m00 = 1.0 / ops::sqrt(1.0 + m01 * m01);
188                    m01 *= m00;
189                } else {
190                    // m01 is the largest component of the row.
191                    // Factor it out for normalization and discard to avoid underflow or overflow.
192                    m00 /= m01;
193                    m01 = 1.0 / ops::sqrt(1.0 + m00 * m00);
194                    m00 *= m01;
195                }
196                return m01 * u - m00 * v;
197            }
198        } else {
199            let max_abs_component = abs_m11.max(abs_m01);
200            if max_abs_component > 0.0 {
201                if abs_m11 >= abs_m01 {
202                    // m11 is the largest component of the row.
203                    // Factor it out for normalization and discard to avoid underflow or overflow.
204                    m01 /= m11;
205                    m11 = 1.0 / ops::sqrt(1.0 + m01 * m01);
206                    m01 *= m11;
207                } else {
208                    // m01 is the largest component of the row.
209                    // Factor it out for normalization and discard to avoid underflow or overflow.
210                    m11 /= m01;
211                    m01 = 1.0 / ops::sqrt(1.0 + m11 * m11);
212                    m11 *= m01;
213                }
214                return m11 * u - m01 * v;
215            }
216        }
217
218        // M is the zero matrix, any unit-length solution suffices.
219        u
220    }
221
222    /// Computes the third eigenvector as the cross product of the first two.
223    /// If the given eigenvectors are valid, the returned vector should be unit length.
224    pub fn eigenvector3(eigenvector1: Vec3, eigenvector2: Vec3) -> Vec3 {
225        eigenvector1.cross(eigenvector2)
226    }
227}
228
229#[cfg(test)]
230mod test {
231    use super::SymmetricEigen3;
232    use approx::assert_relative_eq;
233    use bevy_math::{Mat3, Vec3};
234    use rand::{Rng, SeedableRng};
235
236    #[test]
237    fn eigen_3x3() {
238        let mat = Mat3::from_cols_array_2d(&[[2.0, 7.0, 8.0], [7.0, 6.0, 3.0], [8.0, 3.0, 0.0]]);
239        let eigen = SymmetricEigen3::new(mat);
240
241        assert_relative_eq!(
242            eigen.eigenvalues,
243            Vec3::new(-7.605, 0.577, 15.028),
244            epsilon = 0.001
245        );
246        assert_relative_eq!(
247            Mat3::from_cols(
248                eigen.eigenvectors.x_axis.abs(),
249                eigen.eigenvectors.y_axis.abs(),
250                eigen.eigenvectors.z_axis.abs()
251            ),
252            Mat3::from_cols(
253                Vec3::new(-1.075, 0.333, 1.0).normalize().abs(),
254                Vec3::new(0.542, -1.253, 1.0).normalize().abs(),
255                Vec3::new(1.359, 1.386, 1.0).normalize().abs()
256            ),
257            epsilon = 0.001
258        );
259    }
260
261    #[test]
262    fn eigen_3x3_diagonal() {
263        let mat = Mat3::from_cols_array_2d(&[[2.0, 0.0, 0.0], [0.0, 5.0, 0.0], [0.0, 0.0, 3.0]]);
264        let eigen = SymmetricEigen3::new(mat);
265
266        assert_eq!(eigen.eigenvalues, Vec3::new(2.0, 3.0, 5.0));
267        assert_eq!(
268            Mat3::from_cols(
269                eigen.eigenvectors.x_axis.normalize().abs(),
270                eigen.eigenvectors.y_axis.normalize().abs(),
271                eigen.eigenvectors.z_axis.normalize().abs()
272            ),
273            Mat3::from_cols_array_2d(&[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0]])
274        );
275    }
276
277    #[test]
278    fn eigen_3x3_reconstruction() {
279        let mut rng = rand_chacha::ChaCha8Rng::from_seed(Default::default());
280
281        // Generate random symmetric matrices and verify that the eigen decomposition is correct.
282        for _ in 0..10_000 {
283            let eigenvalues = Vec3::new(
284                rng.gen_range(0.1..100.0),
285                rng.gen_range(0.1..100.0),
286                rng.gen_range(0.1..100.0),
287            );
288            let eigenvectors = Mat3::from_cols(
289                Vec3::new(
290                    rng.gen_range(-1.0..1.0),
291                    rng.gen_range(-1.0..1.0),
292                    rng.gen_range(-1.0..1.0),
293                )
294                .normalize(),
295                Vec3::new(
296                    rng.gen_range(-1.0..1.0),
297                    rng.gen_range(-1.0..1.0),
298                    rng.gen_range(-1.0..1.0),
299                )
300                .normalize(),
301                Vec3::new(
302                    rng.gen_range(-1.0..1.0),
303                    rng.gen_range(-1.0..1.0),
304                    rng.gen_range(-1.0..1.0),
305                )
306                .normalize(),
307            );
308
309            // Construct the symmetric matrix from the eigenvalues and eigenvectors.
310            let mat1 = eigenvectors * Mat3::from_diagonal(eigenvalues) * eigenvectors.transpose();
311
312            // Compute the eigen decomposition of the constructed matrix.
313            let eigen = SymmetricEigen3::new(mat1);
314
315            // Reconstruct the matrix from the computed eigenvalues and eigenvectors.
316            let mat2 = eigen.eigenvectors
317                * Mat3::from_diagonal(eigen.eigenvalues)
318                * eigen.eigenvectors.transpose();
319
320            // The reconstructed matrix should be close to the original matrix.
321            // Note: The precision depends on how large the eigenvalues are.
322            //       Larger eigenvalues can lead to larger absolute error.
323            assert_relative_eq!(mat1, mat2, epsilon = 1e-2);
324        }
325    }
326}