1use std::ops::BitAnd;
4
5use bytemuck::{Pod, Zeroable};
6use glam::Vec3A;
7
8use crate::{Boundable, ray::Ray};
9
10#[derive(Default, Clone, Copy, Debug, PartialEq, Zeroable)]
12#[repr(C)]
13pub struct Aabb {
14 pub min: Vec3A,
15 pub max: Vec3A,
16}
17
18unsafe impl Pod for Aabb {}
19
20impl Aabb {
21 pub const INVALID: Self = Self {
24 min: Vec3A::splat(f32::MAX),
25 max: Vec3A::splat(f32::MIN),
26 };
27
28 pub const LARGEST: Self = Self {
31 min: Vec3A::splat(-f32::MAX),
32 max: Vec3A::splat(f32::MAX),
33 };
34
35 pub const INFINITY: Self = Self {
38 min: Vec3A::splat(-f32::INFINITY),
39 max: Vec3A::splat(f32::INFINITY),
40 };
41
42 #[inline(always)]
44 pub fn new(min: Vec3A, max: Vec3A) -> Self {
45 Self { min, max }
46 }
47
48 #[inline(always)]
50 pub fn from_point(point: Vec3A) -> Self {
51 Self {
52 min: point,
53 max: point,
54 }
55 }
56
57 #[inline(always)]
59 pub fn from_points(points: &[Vec3A]) -> Self {
60 let mut points = points.iter();
61 let mut aabb = Aabb::from_point(*points.next().unwrap());
62 for point in points {
63 aabb.extend(*point);
64 }
65 aabb
66 }
67
68 #[inline(always)]
70 pub fn contains_point(&self, point: Vec3A) -> bool {
71 (point.cmpge(self.min).bitand(point.cmple(self.max))).all()
72 }
73
74 #[inline(always)]
76 pub fn extend(&mut self, point: Vec3A) -> &mut Self {
77 *self = self.union(&Self::from_point(point));
78 self
79 }
80
81 #[inline(always)]
83 #[must_use]
84 pub fn union(&self, other: &Self) -> Self {
85 Aabb {
86 min: self.min.min(other.min),
87 max: self.max.max(other.max),
88 }
89 }
90
91 #[inline(always)]
98 pub fn intersection(&self, other: &Self) -> Self {
99 Aabb {
100 min: self.min.max(other.min),
101 max: self.max.min(other.max),
102 }
103 }
104
105 #[inline(always)]
107 pub fn diagonal(&self) -> Vec3A {
108 self.max - self.min
109 }
110
111 #[inline(always)]
113 pub fn center(&self) -> Vec3A {
114 (self.max + self.min) * 0.5
115 }
116
117 #[inline(always)]
119 pub fn center_axis(&self, axis: usize) -> f32 {
120 (self.max[axis] + self.min[axis]) * 0.5
121 }
122
123 #[inline]
125 pub fn largest_axis(&self) -> usize {
126 let d = self.diagonal();
127 if d.x < d.y {
128 if d.y < d.z { 2 } else { 1 }
129 } else if d.x < d.z {
130 2
131 } else {
132 0
133 }
134 }
135
136 #[inline]
138 pub fn smallest_axis(&self) -> usize {
139 let d = self.diagonal();
140 if d.x > d.y {
141 if d.y > d.z { 2 } else { 1 }
142 } else if d.x > d.z {
143 2
144 } else {
145 0
146 }
147 }
148
149 #[inline(always)]
151 pub fn half_area(&self) -> f32 {
152 let d = self.diagonal();
153 (d.x + d.y) * d.z + d.x * d.y
154 }
155
156 #[inline(always)]
158 pub fn surface_area(&self) -> f32 {
159 let d = self.diagonal();
160 2.0 * d.dot(d)
161 }
162
163 #[inline(always)]
165 pub fn empty() -> Self {
166 Self {
167 min: Vec3A::new(f32::MAX, f32::MAX, f32::MAX),
168 max: Vec3A::new(f32::MIN, f32::MIN, f32::MIN),
169 }
170 }
171
172 pub fn valid(&self) -> bool {
174 self.min.cmple(self.max).all()
175 }
176
177 #[inline(always)]
179 pub fn intersect_aabb(&self, other: &Aabb) -> bool {
180 (self.min.cmpgt(other.max) | self.max.cmplt(other.min)).bitmask() == 0
181 }
182
183 #[inline(always)]
186 pub fn intersect_ray(&self, ray: &Ray) -> f32 {
187 let t1 = (self.min - ray.origin) * ray.inv_direction;
193 let t2 = (self.max - ray.origin) * ray.inv_direction;
194
195 let tmin = t1.min(t2);
196 let tmax = t1.max(t2);
197
198 let tmin_n = tmin.max_element();
199 let tmax_n = tmax.min_element();
200
201 if tmax_n >= tmin_n && tmax_n >= 0.0 {
202 tmin_n
203 } else {
204 f32::INFINITY
205 }
206 }
207}
208
209impl Boundable for Aabb {
210 #[inline(always)]
211 fn aabb(&self) -> Aabb {
212 *self
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use glam::Vec3A;
220
221 #[test]
222 fn test_from_point() {
223 let point = Vec3A::ONE;
224 let aabb = Aabb::from_point(point);
225 assert_eq!(aabb.min, point);
226 assert_eq!(aabb.max, point);
227 }
228
229 #[test]
230 fn test_from_points() {
231 let points = vec![Vec3A::ZERO, Vec3A::ONE, Vec3A::splat(2.0)];
232 let aabb = Aabb::from_points(&points);
233 assert_eq!(aabb.min, Vec3A::ZERO);
234 assert_eq!(aabb.max, Vec3A::splat(2.0));
235 }
236
237 #[test]
238 fn test_contains_point() {
239 let aabb = Aabb::new(Vec3A::ZERO, Vec3A::ONE);
240 assert!(aabb.contains_point(Vec3A::splat(0.5)));
241 assert!(!aabb.contains_point(Vec3A::splat(1.5)));
242 }
243
244 #[test]
245 fn test_extend() {
246 let mut aabb = Aabb::from_point(Vec3A::ZERO);
247 aabb.extend(Vec3A::ONE);
248 assert_eq!(aabb.min, Vec3A::ZERO);
249 assert_eq!(aabb.max, Vec3A::ONE);
250 }
251
252 #[test]
253 fn test_union() {
254 let aabb1 = Aabb::new(Vec3A::ZERO, Vec3A::ONE);
255 let aabb2 = Aabb::new(Vec3A::splat(0.5), Vec3A::splat(1.5));
256 let union = aabb1.union(&aabb2);
257 assert_eq!(union.min, Vec3A::ZERO);
258 assert_eq!(union.max, Vec3A::splat(1.5));
259 }
260
261 #[test]
262 fn test_intersection() {
263 let aabb1 = Aabb::new(Vec3A::ZERO, Vec3A::ONE);
264 let aabb2 = Aabb::new(Vec3A::splat(0.5), Vec3A::splat(1.5));
265 let intersection = aabb1.intersection(&aabb2);
266 assert_eq!(intersection.min, Vec3A::splat(0.5));
267 assert_eq!(intersection.max, Vec3A::ONE);
268 assert!(intersection.valid());
269 }
270
271 #[test]
272 fn test_intersection_no_overlap() {
273 let aabb1 = Aabb::new(Vec3A::ZERO, Vec3A::ONE);
274 let aabb2 = Aabb::new(Vec3A::splat(2.0), Vec3A::splat(3.0));
275 let intersection = aabb1.intersection(&aabb2);
276 assert_eq!(intersection.min, Vec3A::splat(2.0));
277 assert_eq!(intersection.max, Vec3A::ONE);
278 assert!(!intersection.valid());
279 }
280
281 #[test]
282 fn test_diagonal() {
283 let aabb = Aabb::new(Vec3A::ZERO, Vec3A::ONE);
284 assert_eq!(aabb.diagonal(), Vec3A::ONE);
285 }
286
287 #[test]
288 fn test_center() {
289 let aabb = Aabb::new(Vec3A::ZERO, Vec3A::ONE);
290 assert_eq!(aabb.center(), Vec3A::splat(0.5));
291 }
292
293 #[test]
294 fn test_center_axis() {
295 let aabb = Aabb::new(Vec3A::ZERO, Vec3A::ONE);
296 assert_eq!(aabb.center_axis(0), 0.5);
297 assert_eq!(aabb.center_axis(1), 0.5);
298 assert_eq!(aabb.center_axis(2), 0.5);
299 }
300
301 #[test]
302 fn test_largest_axis() {
303 let aabb = Aabb::new(Vec3A::ZERO, Vec3A::new(1.0, 2.0, 3.0));
304 assert_eq!(aabb.largest_axis(), 2);
305 }
306
307 #[test]
308 fn test_smallest_axis() {
309 let aabb = Aabb::new(Vec3A::ZERO, Vec3A::new(1.0, 2.0, 3.0));
310 assert_eq!(aabb.smallest_axis(), 0);
311 }
312
313 #[test]
314 fn test_half_area() {
315 let aabb = Aabb::new(Vec3A::ZERO, Vec3A::ONE);
316 assert_eq!(aabb.half_area(), 3.0);
317 }
318
319 #[test]
320 fn test_surface_area() {
321 let aabb = Aabb::new(Vec3A::ZERO, Vec3A::ONE);
322 assert_eq!(aabb.surface_area(), 6.0);
323 }
324
325 #[test]
326 fn test_empty() {
327 let aabb = Aabb::empty();
328 assert_eq!(aabb.min, Vec3A::new(f32::MAX, f32::MAX, f32::MAX));
329 assert_eq!(aabb.max, Vec3A::new(f32::MIN, f32::MIN, f32::MIN));
330 }
331
332 #[test]
333 fn test_valid() {
334 let valid_aabb = Aabb::new(Vec3A::ZERO, Vec3A::ONE);
335 assert!(valid_aabb.valid());
336
337 let invalid_aabb = Aabb::new(Vec3A::splat(2.0), Vec3A::splat(1.0));
338 assert!(!invalid_aabb.valid());
339 }
340
341 #[test]
342 fn test_intersect_aabb() {
343 let aabb1 = Aabb::new(Vec3A::ZERO, Vec3A::ONE);
344 let aabb2 = Aabb::new(Vec3A::splat(0.5), Vec3A::splat(1.5));
345 assert!(aabb1.intersect_aabb(&aabb2));
346 let aabb3 = Aabb::new(Vec3A::splat(1.5), Vec3A::splat(2.5));
347 assert!(!aabb1.intersect_aabb(&aabb3));
348 }
349
350 #[test]
351 fn test_intersect_ray() {
352 let aabb = Aabb::new(Vec3A::ZERO, Vec3A::ONE);
353 let ray = Ray::new(Vec3A::splat(-1.0), Vec3A::ONE, 0.0, f32::MAX);
354 assert_eq!(aabb.intersect_ray(&ray), 1.0);
355 let ray_no_intersect = Ray::new(Vec3A::splat(2.0), Vec3A::ONE, 0.0, f32::MAX);
356 assert_eq!(aabb.intersect_ray(&ray_no_intersect), f32::INFINITY);
357 }
358}