parry3d/utils/
sdp_matrix.rs

1use crate::math::{Matrix2, Matrix3, Real, Vector2, Vector3};
2use core::ops::{Add, Div, Mul, Neg, Sub};
3use num_traits::{One, Zero};
4
5/// A 2x2 symmetric-definite-positive matrix.
6#[derive(Copy, Clone, Debug, PartialEq, Default)]
7#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
8#[cfg_attr(
9    feature = "rkyv",
10    derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)
11)]
12pub struct SdpMatrix2<N> {
13    /// The component at the first row and first column of this matrix.
14    pub m11: N,
15    /// The component at the first row and second column of this matrix.
16    pub m12: N,
17    /// The component at the second row and second column of this matrix.
18    pub m22: N,
19}
20
21impl<
22        N: Copy
23            + Zero
24            + One
25            + Add<Output = N>
26            + Sub<Output = N>
27            + Mul<Output = N>
28            + Div<Output = N>
29            + Neg<Output = N>,
30    > SdpMatrix2<N>
31{
32    /// A new SDP 2x2 matrix with the given components.
33    ///
34    /// Because the matrix is symmetric, only the lower off-diagonal component is required.
35    pub fn new(m11: N, m12: N, m22: N) -> Self {
36        Self { m11, m12, m22 }
37    }
38
39    /// Create a new SDP matrix filled with zeros.
40    pub fn zero() -> Self {
41        Self {
42            m11: N::zero(),
43            m12: N::zero(),
44            m22: N::zero(),
45        }
46    }
47
48    /// Create a new SDP matrix with its diagonal filled with `val`, and its off-diagonal elements set to zero.
49    pub fn diagonal(val: N) -> Self {
50        Self {
51            m11: val,
52            m12: N::zero(),
53            m22: val,
54        }
55    }
56
57    /// Adds `val` to the diagonal components of `self`.
58    pub fn add_diagonal(&mut self, elt: N) -> Self {
59        Self {
60            m11: self.m11 + elt,
61            m12: self.m12,
62            m22: self.m22 + elt,
63        }
64    }
65
66    /// Compute the inverse of this SDP matrix without performing any inversibility check.
67    pub fn inverse_unchecked(&self) -> Self {
68        self.inverse_and_get_determinant_unchecked().0
69    }
70
71    /// Compute the inverse of this SDP matrix without performing any inversibility check.
72    pub fn inverse_and_get_determinant_unchecked(&self) -> (Self, N) {
73        let determinant = self.m11 * self.m22 - self.m12 * self.m12;
74        let m11 = self.m22 / determinant;
75        let m12 = -self.m12 / determinant;
76        let m22 = self.m11 / determinant;
77
78        (Self { m11, m12, m22 }, determinant)
79    }
80}
81
82impl SdpMatrix2<Real> {
83    /// Build an `SdpMatrix2` structure from a plain matrix, assuming it is SDP.
84    pub fn from_sdp_matrix(mat: Matrix2) -> Self {
85        let cols = mat.to_cols_array_2d();
86        Self {
87            m11: cols[0][0],
88            m12: cols[1][0],
89            m22: cols[1][1],
90        }
91    }
92
93    /// Convert this SDP matrix to a regular matrix representation.
94    pub fn into_matrix(self) -> Matrix2 {
95        Matrix2::from_cols(
96            Vector2::new(self.m11, self.m12),
97            Vector2::new(self.m12, self.m22),
98        )
99    }
100
101    /// Multiply this matrix by a vector.
102    pub fn mul_vec(&self, rhs: Vector2) -> Vector2 {
103        Vector2::new(
104            self.m11 * rhs.x + self.m12 * rhs.y,
105            self.m12 * rhs.x + self.m22 * rhs.y,
106        )
107    }
108}
109
110impl<N: Add<Output = N>> Add<SdpMatrix2<N>> for SdpMatrix2<N> {
111    type Output = Self;
112
113    fn add(self, rhs: SdpMatrix2<N>) -> Self {
114        Self {
115            m11: self.m11 + rhs.m11,
116            m12: self.m12 + rhs.m12,
117            m22: self.m22 + rhs.m22,
118        }
119    }
120}
121
122impl Mul<Vector2> for SdpMatrix2<Real> {
123    type Output = Vector2;
124
125    fn mul(self, rhs: Vector2) -> Self::Output {
126        self.mul_vec(rhs)
127    }
128}
129
130impl Mul<Real> for SdpMatrix2<Real> {
131    type Output = SdpMatrix2<Real>;
132
133    fn mul(self, rhs: Real) -> Self::Output {
134        SdpMatrix2 {
135            m11: self.m11 * rhs,
136            m12: self.m12 * rhs,
137            m22: self.m22 * rhs,
138        }
139    }
140}
141
142/// A 3x3 symmetric-definite-positive matrix.
143#[derive(Copy, Clone, Debug, PartialEq, Default)]
144#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
145#[cfg_attr(
146    feature = "rkyv",
147    derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)
148)]
149pub struct SdpMatrix3<N> {
150    /// The component at the first row and first column of this matrix.
151    pub m11: N,
152    /// The component at the first row and second column of this matrix.
153    pub m12: N,
154    /// The component at the first row and third column of this matrix.
155    pub m13: N,
156    /// The component at the second row and second column of this matrix.
157    pub m22: N,
158    /// The component at the second row and third column of this matrix.
159    pub m23: N,
160    /// The component at the third row and third column of this matrix.
161    pub m33: N,
162}
163
164impl<
165        N: Copy
166            + Zero
167            + One
168            + Add<Output = N>
169            + Sub<Output = N>
170            + Mul<Output = N>
171            + Div<Output = N>
172            + Neg<Output = N>
173            + PartialEq,
174    > SdpMatrix3<N>
175{
176    /// A new SDP 3x3 matrix with the given components.
177    ///
178    /// Because the matrix is symmetric, only the lower off-diagonal components is required.
179    pub fn new(m11: N, m12: N, m13: N, m22: N, m23: N, m33: N) -> Self {
180        Self {
181            m11,
182            m12,
183            m13,
184            m22,
185            m23,
186            m33,
187        }
188    }
189
190    /// Create a new SDP matrix filled with zeros.
191    pub fn zero() -> Self {
192        Self {
193            m11: N::zero(),
194            m12: N::zero(),
195            m13: N::zero(),
196            m22: N::zero(),
197            m23: N::zero(),
198            m33: N::zero(),
199        }
200    }
201
202    /// Create a new SDP matrix with its diagonal filled with `val`, and its off-diagonal elements set to zero.
203    pub fn diagonal(val: N) -> Self {
204        Self {
205            m11: val,
206            m12: N::zero(),
207            m13: N::zero(),
208            m22: val,
209            m23: N::zero(),
210            m33: val,
211        }
212    }
213
214    /// Are all components of this matrix equal to zero?
215    pub fn is_zero(&self) -> bool {
216        self.m11 == N::zero()
217            && self.m12 == N::zero()
218            && self.m13 == N::zero()
219            && self.m22 == N::zero()
220            && self.m23 == N::zero()
221            && self.m33 == N::zero()
222    }
223
224    /// Compute the inverse of this SDP matrix without performing any inversibility check.
225    pub fn inverse_unchecked(&self) -> Self {
226        let minor_m12_m23 = self.m22 * self.m33 - self.m23 * self.m23;
227        let minor_m11_m23 = self.m12 * self.m33 - self.m13 * self.m23;
228        let minor_m11_m22 = self.m12 * self.m23 - self.m13 * self.m22;
229
230        let determinant =
231            self.m11 * minor_m12_m23 - self.m12 * minor_m11_m23 + self.m13 * minor_m11_m22;
232        let inv_det = N::one() / determinant;
233
234        SdpMatrix3 {
235            m11: minor_m12_m23 * inv_det,
236            m12: -minor_m11_m23 * inv_det,
237            m13: minor_m11_m22 * inv_det,
238            m22: (self.m11 * self.m33 - self.m13 * self.m13) * inv_det,
239            m23: (self.m13 * self.m12 - self.m23 * self.m11) * inv_det,
240            m33: (self.m11 * self.m22 - self.m12 * self.m12) * inv_det,
241        }
242    }
243
244    /// Adds `elt` to the diagonal components of `self`.
245    pub fn add_diagonal(&self, elt: N) -> Self {
246        Self {
247            m11: self.m11 + elt,
248            m12: self.m12,
249            m13: self.m13,
250            m22: self.m22 + elt,
251            m23: self.m23,
252            m33: self.m33 + elt,
253        }
254    }
255}
256
257impl SdpMatrix3<Real> {
258    /// Build an `SdpMatrix3` structure from a plain matrix, assuming it is SDP.
259    pub fn from_sdp_matrix(mat: Matrix3) -> Self {
260        let cols = mat.to_cols_array_2d();
261        Self {
262            m11: cols[0][0],
263            m12: cols[1][0],
264            m13: cols[2][0],
265            m22: cols[1][1],
266            m23: cols[2][1],
267            m33: cols[2][2],
268        }
269    }
270
271    /// Multiply this matrix by a vector.
272    pub fn mul_vec(&self, rhs: Vector3) -> Vector3 {
273        let x = self.m11 * rhs.x + self.m12 * rhs.y + self.m13 * rhs.z;
274        let y = self.m12 * rhs.x + self.m22 * rhs.y + self.m23 * rhs.z;
275        let z = self.m13 * rhs.x + self.m23 * rhs.y + self.m33 * rhs.z;
276        Vector3::new(x, y, z)
277    }
278
279    /// Multiply this matrix by a 3x3 matrix.
280    pub fn mul_mat(&self, rhs: Matrix3) -> Matrix3 {
281        let cols = rhs.to_cols_array_2d();
282        let col0 = self.mul_vec(Vector3::new(cols[0][0], cols[0][1], cols[0][2]));
283        let col1 = self.mul_vec(Vector3::new(cols[1][0], cols[1][1], cols[1][2]));
284        let col2 = self.mul_vec(Vector3::new(cols[2][0], cols[2][1], cols[2][2]));
285        Matrix3::from_cols(col0, col1, col2)
286    }
287
288    /// Compute the quadratic form `m.transpose() * self * m`.
289    pub fn quadform(&self, m: &Matrix3) -> Self {
290        let sm = self.mul_mat(*m);
291        let result = m.transpose() * sm;
292        Self::from_sdp_matrix(result)
293    }
294
295    /// Compute the quadratic form `m.transpose() * self * m` for a 3x2 matrix.
296    pub fn quadform3x2(
297        &self,
298        m11: Real,
299        m12: Real,
300        m21: Real,
301        m22: Real,
302        m31: Real,
303        m32: Real,
304    ) -> SdpMatrix2<Real> {
305        let x0 = self.m11 * m11 + self.m12 * m21 + self.m13 * m31;
306        let y0 = self.m12 * m11 + self.m22 * m21 + self.m23 * m31;
307        let z0 = self.m13 * m11 + self.m23 * m21 + self.m33 * m31;
308
309        let x1 = self.m11 * m12 + self.m12 * m22 + self.m13 * m32;
310        let y1 = self.m12 * m12 + self.m22 * m22 + self.m23 * m32;
311        let z1 = self.m13 * m12 + self.m23 * m22 + self.m33 * m32;
312
313        let r11 = m11 * x0 + m21 * y0 + m31 * z0;
314        let r12 = m11 * x1 + m21 * y1 + m31 * z1;
315        let r22 = m12 * x1 + m22 * y1 + m32 * z1;
316
317        SdpMatrix2 {
318            m11: r11,
319            m12: r12,
320            m22: r22,
321        }
322    }
323}
324
325impl<N: Add<Output = N>> Add<SdpMatrix3<N>> for SdpMatrix3<N> {
326    type Output = SdpMatrix3<N>;
327
328    fn add(self, rhs: SdpMatrix3<N>) -> Self::Output {
329        SdpMatrix3 {
330            m11: self.m11 + rhs.m11,
331            m12: self.m12 + rhs.m12,
332            m13: self.m13 + rhs.m13,
333            m22: self.m22 + rhs.m22,
334            m23: self.m23 + rhs.m23,
335            m33: self.m33 + rhs.m33,
336        }
337    }
338}
339
340impl Mul<Real> for SdpMatrix3<Real> {
341    type Output = SdpMatrix3<Real>;
342
343    fn mul(self, rhs: Real) -> Self::Output {
344        SdpMatrix3 {
345            m11: self.m11 * rhs,
346            m12: self.m12 * rhs,
347            m13: self.m13 * rhs,
348            m22: self.m22 * rhs,
349            m23: self.m23 * rhs,
350            m33: self.m33 * rhs,
351        }
352    }
353}
354
355impl Mul<Vector3> for SdpMatrix3<Real> {
356    type Output = Vector3;
357
358    fn mul(self, rhs: Vector3) -> Self::Output {
359        self.mul_vec(rhs)
360    }
361}
362
363impl Mul<Matrix3> for SdpMatrix3<Real> {
364    type Output = Matrix3;
365
366    fn mul(self, rhs: Matrix3) -> Self::Output {
367        self.mul_mat(rhs)
368    }
369}
370
371impl<T> From<[SdpMatrix3<Real>; 4]> for SdpMatrix3<T>
372where
373    T: From<[Real; 4]>,
374{
375    fn from(data: [SdpMatrix3<Real>; 4]) -> Self {
376        SdpMatrix3 {
377            m11: T::from([data[0].m11, data[1].m11, data[2].m11, data[3].m11]),
378            m12: T::from([data[0].m12, data[1].m12, data[2].m12, data[3].m12]),
379            m13: T::from([data[0].m13, data[1].m13, data[2].m13, data[3].m13]),
380            m22: T::from([data[0].m22, data[1].m22, data[2].m22, data[3].m22]),
381            m23: T::from([data[0].m23, data[1].m23, data[2].m23, data[3].m23]),
382            m33: T::from([data[0].m33, data[1].m33, data[2].m33, data[3].m33]),
383        }
384    }
385}