1use crate::math::{Matrix2, Matrix3, Real, Vector2, Vector3};
2use core::ops::{Add, Div, Mul, Neg, Sub};
3use num_traits::{One, Zero};
4
5#[derive(Copy, Clone, Debug, PartialEq, Default)]
7#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
8#[cfg_attr(
9 feature = "rkyv",
10 derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)
11)]
12pub struct SdpMatrix2<N> {
13 pub m11: N,
15 pub m12: N,
17 pub m22: N,
19}
20
21impl<
22 N: Copy
23 + Zero
24 + One
25 + Add<Output = N>
26 + Sub<Output = N>
27 + Mul<Output = N>
28 + Div<Output = N>
29 + Neg<Output = N>,
30 > SdpMatrix2<N>
31{
32 pub fn new(m11: N, m12: N, m22: N) -> Self {
36 Self { m11, m12, m22 }
37 }
38
39 pub fn zero() -> Self {
41 Self {
42 m11: N::zero(),
43 m12: N::zero(),
44 m22: N::zero(),
45 }
46 }
47
48 pub fn diagonal(val: N) -> Self {
50 Self {
51 m11: val,
52 m12: N::zero(),
53 m22: val,
54 }
55 }
56
57 pub fn add_diagonal(&mut self, elt: N) -> Self {
59 Self {
60 m11: self.m11 + elt,
61 m12: self.m12,
62 m22: self.m22 + elt,
63 }
64 }
65
66 pub fn inverse_unchecked(&self) -> Self {
68 self.inverse_and_get_determinant_unchecked().0
69 }
70
71 pub fn inverse_and_get_determinant_unchecked(&self) -> (Self, N) {
73 let determinant = self.m11 * self.m22 - self.m12 * self.m12;
74 let m11 = self.m22 / determinant;
75 let m12 = -self.m12 / determinant;
76 let m22 = self.m11 / determinant;
77
78 (Self { m11, m12, m22 }, determinant)
79 }
80}
81
82impl SdpMatrix2<Real> {
83 pub fn from_sdp_matrix(mat: Matrix2) -> Self {
85 let cols = mat.to_cols_array_2d();
86 Self {
87 m11: cols[0][0],
88 m12: cols[1][0],
89 m22: cols[1][1],
90 }
91 }
92
93 pub fn into_matrix(self) -> Matrix2 {
95 Matrix2::from_cols(
96 Vector2::new(self.m11, self.m12),
97 Vector2::new(self.m12, self.m22),
98 )
99 }
100
101 pub fn mul_vec(&self, rhs: Vector2) -> Vector2 {
103 Vector2::new(
104 self.m11 * rhs.x + self.m12 * rhs.y,
105 self.m12 * rhs.x + self.m22 * rhs.y,
106 )
107 }
108}
109
110impl<N: Add<Output = N>> Add<SdpMatrix2<N>> for SdpMatrix2<N> {
111 type Output = Self;
112
113 fn add(self, rhs: SdpMatrix2<N>) -> Self {
114 Self {
115 m11: self.m11 + rhs.m11,
116 m12: self.m12 + rhs.m12,
117 m22: self.m22 + rhs.m22,
118 }
119 }
120}
121
122impl Mul<Vector2> for SdpMatrix2<Real> {
123 type Output = Vector2;
124
125 fn mul(self, rhs: Vector2) -> Self::Output {
126 self.mul_vec(rhs)
127 }
128}
129
130impl Mul<Real> for SdpMatrix2<Real> {
131 type Output = SdpMatrix2<Real>;
132
133 fn mul(self, rhs: Real) -> Self::Output {
134 SdpMatrix2 {
135 m11: self.m11 * rhs,
136 m12: self.m12 * rhs,
137 m22: self.m22 * rhs,
138 }
139 }
140}
141
142#[derive(Copy, Clone, Debug, PartialEq, Default)]
144#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
145#[cfg_attr(
146 feature = "rkyv",
147 derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)
148)]
149pub struct SdpMatrix3<N> {
150 pub m11: N,
152 pub m12: N,
154 pub m13: N,
156 pub m22: N,
158 pub m23: N,
160 pub m33: N,
162}
163
164impl<
165 N: Copy
166 + Zero
167 + One
168 + Add<Output = N>
169 + Sub<Output = N>
170 + Mul<Output = N>
171 + Div<Output = N>
172 + Neg<Output = N>
173 + PartialEq,
174 > SdpMatrix3<N>
175{
176 pub fn new(m11: N, m12: N, m13: N, m22: N, m23: N, m33: N) -> Self {
180 Self {
181 m11,
182 m12,
183 m13,
184 m22,
185 m23,
186 m33,
187 }
188 }
189
190 pub fn zero() -> Self {
192 Self {
193 m11: N::zero(),
194 m12: N::zero(),
195 m13: N::zero(),
196 m22: N::zero(),
197 m23: N::zero(),
198 m33: N::zero(),
199 }
200 }
201
202 pub fn diagonal(val: N) -> Self {
204 Self {
205 m11: val,
206 m12: N::zero(),
207 m13: N::zero(),
208 m22: val,
209 m23: N::zero(),
210 m33: val,
211 }
212 }
213
214 pub fn is_zero(&self) -> bool {
216 self.m11 == N::zero()
217 && self.m12 == N::zero()
218 && self.m13 == N::zero()
219 && self.m22 == N::zero()
220 && self.m23 == N::zero()
221 && self.m33 == N::zero()
222 }
223
224 pub fn inverse_unchecked(&self) -> Self {
226 let minor_m12_m23 = self.m22 * self.m33 - self.m23 * self.m23;
227 let minor_m11_m23 = self.m12 * self.m33 - self.m13 * self.m23;
228 let minor_m11_m22 = self.m12 * self.m23 - self.m13 * self.m22;
229
230 let determinant =
231 self.m11 * minor_m12_m23 - self.m12 * minor_m11_m23 + self.m13 * minor_m11_m22;
232 let inv_det = N::one() / determinant;
233
234 SdpMatrix3 {
235 m11: minor_m12_m23 * inv_det,
236 m12: -minor_m11_m23 * inv_det,
237 m13: minor_m11_m22 * inv_det,
238 m22: (self.m11 * self.m33 - self.m13 * self.m13) * inv_det,
239 m23: (self.m13 * self.m12 - self.m23 * self.m11) * inv_det,
240 m33: (self.m11 * self.m22 - self.m12 * self.m12) * inv_det,
241 }
242 }
243
244 pub fn add_diagonal(&self, elt: N) -> Self {
246 Self {
247 m11: self.m11 + elt,
248 m12: self.m12,
249 m13: self.m13,
250 m22: self.m22 + elt,
251 m23: self.m23,
252 m33: self.m33 + elt,
253 }
254 }
255}
256
257impl SdpMatrix3<Real> {
258 pub fn from_sdp_matrix(mat: Matrix3) -> Self {
260 let cols = mat.to_cols_array_2d();
261 Self {
262 m11: cols[0][0],
263 m12: cols[1][0],
264 m13: cols[2][0],
265 m22: cols[1][1],
266 m23: cols[2][1],
267 m33: cols[2][2],
268 }
269 }
270
271 pub fn mul_vec(&self, rhs: Vector3) -> Vector3 {
273 let x = self.m11 * rhs.x + self.m12 * rhs.y + self.m13 * rhs.z;
274 let y = self.m12 * rhs.x + self.m22 * rhs.y + self.m23 * rhs.z;
275 let z = self.m13 * rhs.x + self.m23 * rhs.y + self.m33 * rhs.z;
276 Vector3::new(x, y, z)
277 }
278
279 pub fn mul_mat(&self, rhs: Matrix3) -> Matrix3 {
281 let cols = rhs.to_cols_array_2d();
282 let col0 = self.mul_vec(Vector3::new(cols[0][0], cols[0][1], cols[0][2]));
283 let col1 = self.mul_vec(Vector3::new(cols[1][0], cols[1][1], cols[1][2]));
284 let col2 = self.mul_vec(Vector3::new(cols[2][0], cols[2][1], cols[2][2]));
285 Matrix3::from_cols(col0, col1, col2)
286 }
287
288 pub fn quadform(&self, m: &Matrix3) -> Self {
290 let sm = self.mul_mat(*m);
291 let result = m.transpose() * sm;
292 Self::from_sdp_matrix(result)
293 }
294
295 pub fn quadform3x2(
297 &self,
298 m11: Real,
299 m12: Real,
300 m21: Real,
301 m22: Real,
302 m31: Real,
303 m32: Real,
304 ) -> SdpMatrix2<Real> {
305 let x0 = self.m11 * m11 + self.m12 * m21 + self.m13 * m31;
306 let y0 = self.m12 * m11 + self.m22 * m21 + self.m23 * m31;
307 let z0 = self.m13 * m11 + self.m23 * m21 + self.m33 * m31;
308
309 let x1 = self.m11 * m12 + self.m12 * m22 + self.m13 * m32;
310 let y1 = self.m12 * m12 + self.m22 * m22 + self.m23 * m32;
311 let z1 = self.m13 * m12 + self.m23 * m22 + self.m33 * m32;
312
313 let r11 = m11 * x0 + m21 * y0 + m31 * z0;
314 let r12 = m11 * x1 + m21 * y1 + m31 * z1;
315 let r22 = m12 * x1 + m22 * y1 + m32 * z1;
316
317 SdpMatrix2 {
318 m11: r11,
319 m12: r12,
320 m22: r22,
321 }
322 }
323}
324
325impl<N: Add<Output = N>> Add<SdpMatrix3<N>> for SdpMatrix3<N> {
326 type Output = SdpMatrix3<N>;
327
328 fn add(self, rhs: SdpMatrix3<N>) -> Self::Output {
329 SdpMatrix3 {
330 m11: self.m11 + rhs.m11,
331 m12: self.m12 + rhs.m12,
332 m13: self.m13 + rhs.m13,
333 m22: self.m22 + rhs.m22,
334 m23: self.m23 + rhs.m23,
335 m33: self.m33 + rhs.m33,
336 }
337 }
338}
339
340impl Mul<Real> for SdpMatrix3<Real> {
341 type Output = SdpMatrix3<Real>;
342
343 fn mul(self, rhs: Real) -> Self::Output {
344 SdpMatrix3 {
345 m11: self.m11 * rhs,
346 m12: self.m12 * rhs,
347 m13: self.m13 * rhs,
348 m22: self.m22 * rhs,
349 m23: self.m23 * rhs,
350 m33: self.m33 * rhs,
351 }
352 }
353}
354
355impl Mul<Vector3> for SdpMatrix3<Real> {
356 type Output = Vector3;
357
358 fn mul(self, rhs: Vector3) -> Self::Output {
359 self.mul_vec(rhs)
360 }
361}
362
363impl Mul<Matrix3> for SdpMatrix3<Real> {
364 type Output = Matrix3;
365
366 fn mul(self, rhs: Matrix3) -> Self::Output {
367 self.mul_mat(rhs)
368 }
369}
370
371impl<T> From<[SdpMatrix3<Real>; 4]> for SdpMatrix3<T>
372where
373 T: From<[Real; 4]>,
374{
375 fn from(data: [SdpMatrix3<Real>; 4]) -> Self {
376 SdpMatrix3 {
377 m11: T::from([data[0].m11, data[1].m11, data[2].m11, data[3].m11]),
378 m12: T::from([data[0].m12, data[1].m12, data[2].m12, data[3].m12]),
379 m13: T::from([data[0].m13, data[1].m13, data[2].m13, data[3].m13]),
380 m22: T::from([data[0].m22, data[1].m22, data[2].m22, data[3].m22]),
381 m23: T::from([data[0].m23, data[1].m23, data[2].m23, data[3].m23]),
382 m33: T::from([data[0].m33, data[1].m33, data[2].m33, data[3].m33]),
383 }
384 }
385}