avian3d/collider_tree/
obvhs_ext.rs

1use bevy_math::Vec3A;
2use obvhs::{
3    aabb::Aabb,
4    bvh2::{Bvh2, node::Bvh2Node},
5    fast_stack,
6    faststack::FastStack,
7    ray::{INVALID_ID, safe_inverse},
8};
9
10use crate::math::Ray;
11
12/// A struct representing a sweep test, where an AABB is swept
13/// along a velocity vector to test for intersection with other AABBs.
14#[derive(Clone, Copy, Debug)]
15#[repr(C)]
16pub struct Sweep {
17    /// The AABB of the collider being swept, in its starting position.
18    pub aabb: Aabb,
19    /// The velocity vector along which the AABB is swept.
20    pub velocity: Vec3A,
21    /// The inverse of the velocity vector components. Used to avoid division in sweep/aabb tests.
22    pub inv_velocity: Vec3A,
23    /// The minimum `t` (fraction) value for the sweep.
24    pub tmin: f32,
25    /// The maximum `t` (fraction) value for the sweep.
26    pub tmax: f32,
27}
28
29impl Sweep {
30    /// Creates a new `Sweep` with the given AABB, velocity, and `t` (fraction) range.
31    pub fn new(aabb: Aabb, velocity: Vec3A, min: f32, max: f32) -> Self {
32        let sweep = Sweep {
33            aabb,
34            velocity,
35            inv_velocity: Vec3A::new(
36                safe_inverse(velocity.x),
37                safe_inverse(velocity.y),
38                safe_inverse(velocity.z),
39            ),
40            tmin: min,
41            tmax: max,
42        };
43
44        debug_assert!(sweep.inv_velocity.is_finite());
45        debug_assert!(sweep.velocity.is_finite());
46        debug_assert!(sweep.aabb.min.is_finite());
47        debug_assert!(sweep.aabb.max.is_finite());
48
49        sweep
50    }
51}
52
53/// A hit record for a sweep test, containing the ID of the primitive that was hit
54/// and the fraction along the sweep at which the hit occurred.
55#[derive(Clone, Copy, Debug)]
56#[repr(C)]
57pub struct SweepHit {
58    /// The ID of the primitive that was hit.
59    pub primitive_id: u32,
60    /// The fraction along the sweep at which the hit occurred.
61    pub t: f32,
62}
63
64impl SweepHit {
65    /// Creates a new `SweepHit` instance representing no hit.
66    pub fn none() -> Self {
67        Self {
68            primitive_id: INVALID_ID,
69            t: f32::INFINITY,
70        }
71    }
72}
73
74/// Extension trait for [`obvhs::bvh2::Bvh2`] to add additional traversal methods.
75pub trait Bvh2Ext {
76    /// Traverse the BVH by sweeping an AABB along a velocity vector. Returns the closest intersected primitive.
77    ///
78    /// # Arguments
79    /// * `sweep` - The sweep to be tested for intersection.
80    /// * `hit` - As `sweep_traverse_dynamic` intersects primitives, it will update `hit` with the closest.
81    /// * `intersection_fn` - should take the given sweep and primitive index and return the distance to the intersection, if any.
82    ///
83    /// Note the primitive index should index first into `Bvh2::primitive_indices` then that will be index of original primitive.
84    /// Various parts of the BVH building process might reorder the primitives. To avoid this indirection, reorder your
85    /// original primitives per `primitive_indices`.
86    fn sweep_traverse<F: FnMut(&Sweep, usize) -> f32>(
87        &self,
88        sweep: Sweep,
89        hit: &mut SweepHit,
90        intersection_fn: F,
91    ) -> bool;
92
93    /// Traverse the BVH by sweeping an AABB along a velocity vector. Returns true if the sweep missed all primitives.
94    ///
95    /// # Arguments
96    /// * `sweep` - The sweep to be tested for intersection.
97    /// * `intersection_fn` - should take the given sweep and primitive index and return the distance to the intersection, if any.
98    ///
99    /// Note the primitive index should index first into `Bvh2::primitive_indices` then that will be index of original primitive.
100    /// Various parts of the BVH building process might reorder the primitives. To avoid this indirection, reorder your
101    /// original primitives per `primitive_indices`.
102    fn sweep_traverse_miss<F: FnMut(&Sweep, usize) -> f32>(
103        &self,
104        sweep: Sweep,
105        intersection_fn: F,
106    ) -> bool;
107
108    /// Traverse the BVH by sweeping an AABB along a velocity vector. Intersects all primitives along the sweep
109    /// and calls `intersection_fn` for each hit. The sweep is not updated, to allow for evaluating at every hit.
110    ///
111    /// # Arguments
112    /// * `sweep` - The sweep to be tested for intersection.
113    /// * `intersection_fn` - should take the given sweep and primitive index.
114    ///
115    /// Note the primitive index should index first into `Bvh2::primitive_indices` then that will be index of original primitive.
116    /// Various parts of the BVH building process might reorder the primitives. To avoid this indirection, reorder your
117    /// original primitives per `primitive_indices`.
118    fn sweep_traverse_anyhit<F: FnMut(&Sweep, usize)>(&self, sweep: Sweep, intersection_fn: F);
119
120    /// Traverse the BVH by sweeping an AABB along a velocity vector.
121    ///
122    /// Terminates when no hits are found or when `intersection_fn` returns false for a hit.
123    ///
124    /// # Arguments
125    /// * `stack` - Stack for traversal state.
126    /// * `sweep` - The sweep to be tested for intersection.
127    /// * `hit` - As `sweep_traverse_dynamic` intersects primitives, it will update `hit` with the closest.
128    /// * `intersection_fn` - should test the primitives in the given node, update the ray.tmax, and hit info. Return
129    ///   false to halt traversal.
130    ///
131    /// Note the primitive index should index first into `Bvh2::primitive_indices` then that will be index of original primitive.
132    /// Various parts of the BVH building process might reorder the primitives. To avoid this indirection, reorder your
133    /// original primitives per `primitive_indices`.
134    fn sweep_traverse_dynamic<
135        F: FnMut(&Bvh2Node, &mut Sweep, &mut SweepHit) -> bool,
136        Stack: FastStack<u32>,
137    >(
138        &self,
139        stack: &mut Stack,
140        sweep: Sweep,
141        hit: &mut SweepHit,
142        intersection_fn: F,
143    );
144
145    /// Traverse the BVH to find the closest leaf node to a point.
146    /// Returns the primitive index and squared distance of the closest leaf, or `None` if no leaf is within `max_dist_sq`.
147    ///
148    /// # Arguments
149    /// * `stack` - Stack for traversal state.
150    /// * `point` - The query point.
151    /// * `max_dist_sq` - Maximum squared distance to search (use `f32::INFINITY` for unlimited).
152    /// * `closest_leaf` - Will be updated with the closest leaf node and distance found.
153    /// * `visit_fn` - Called for each leaf node within range. Should take the given ray and primitive index and return the squared distance
154    ///   to the primitive, if any.
155    ///
156    /// Note the primitive index should index first into `Bvh2::primitive_indices` then that will be index of original primitive.
157    /// Various parts of the BVH building process might reorder the primitives. To avoid this indirection, reorder your
158    /// original primitives per `primitive_indices`.
159    fn squared_distance_traverse<F: FnMut(Vec3A, usize) -> f32>(
160        &self,
161        point: Vec3A,
162        max_dist_sq: f32,
163        visit_fn: F,
164    ) -> Option<(u32, f32)>;
165
166    /// Traverse the BVH with a point, calling `visit_fn` for each leaf node within `max_dist_sq` of the point.
167    ///
168    /// Terminates when all nodes within `max_dist_sq` have been visited or when `visit_fn` returns false for a node.
169    ///
170    /// # Arguments
171    /// * `stack` - Stack for traversal state.
172    /// * `point` - The query point.
173    /// * `max_dist_sq` - Maximum squared distance to search (use `f32::INFINITY` for unlimited).
174    /// * `closest_leaf` - Will be updated with the closest leaf node and distance found.
175    /// * `visit_fn` - Called for each leaf node within range. Should update `max_dist_sq` and `closest_leaf`.
176    ///   Return false to halt traversal early.
177    ///
178    /// Note the primitive index should index first into `Bvh2::primitive_indices` then that will be index of original primitive.
179    /// Various parts of the BVH building process might reorder the primitives. To avoid this indirection, reorder your
180    /// original primitives per `primitive_indices`.
181    fn squared_distance_traverse_dynamic<
182        F: FnMut(&Bvh2Node, &mut f32, &mut Option<(u32, f32)>) -> bool,
183        Stack: FastStack<u32>,
184    >(
185        &self,
186        stack: &mut Stack,
187        point: Vec3A,
188        max_dist_sq: f32,
189        closest_leaf: &mut Option<(u32, f32)>,
190        visit_fn: F,
191    );
192}
193
194impl Bvh2Ext for Bvh2 {
195    #[inline(always)]
196    fn sweep_traverse<F: FnMut(&Sweep, usize) -> f32>(
197        &self,
198        sweep: Sweep,
199        hit: &mut SweepHit,
200        mut intersection_fn: F,
201    ) -> bool {
202        let mut intersect_prims = |node: &Bvh2Node, sweep: &mut Sweep, hit: &mut SweepHit| {
203            (node.first_index..node.first_index + node.prim_count).for_each(|primitive_id| {
204                let t = intersection_fn(sweep, primitive_id as usize);
205                if t < sweep.tmax {
206                    hit.primitive_id = primitive_id;
207                    hit.t = t;
208                    sweep.tmax = t;
209                }
210            });
211            true
212        };
213
214        fast_stack!(u32, (96, 192), self.max_depth, stack, {
215            Bvh2::sweep_traverse_dynamic(self, &mut stack, sweep, hit, &mut intersect_prims)
216        });
217
218        hit.t < sweep.tmax // Note this is valid since traverse_with_stack does not mutate the sweep
219    }
220
221    #[inline(always)]
222    fn sweep_traverse_miss<F: FnMut(&Sweep, usize) -> f32>(
223        &self,
224        sweep: Sweep,
225        mut intersection_fn: F,
226    ) -> bool {
227        let mut miss = true;
228        let mut intersect_prims = |node: &Bvh2Node, sweep: &mut Sweep, _hit: &mut SweepHit| {
229            for primitive_id in node.first_index..node.first_index + node.prim_count {
230                let t = intersection_fn(sweep, primitive_id as usize);
231                if t < sweep.tmax {
232                    miss = false;
233                    return false;
234                }
235            }
236            true
237        };
238
239        fast_stack!(u32, (96, 192), self.max_depth, stack, {
240            Bvh2::sweep_traverse_dynamic(
241                self,
242                &mut stack,
243                sweep,
244                &mut SweepHit::none(),
245                &mut intersect_prims,
246            )
247        });
248
249        miss
250    }
251
252    #[inline(always)]
253    fn sweep_traverse_anyhit<F: FnMut(&Sweep, usize)>(&self, sweep: Sweep, mut intersection_fn: F) {
254        let mut intersect_prims = |node: &Bvh2Node, sweep: &mut Sweep, _hit: &mut SweepHit| {
255            for primitive_id in node.first_index..node.first_index + node.prim_count {
256                intersection_fn(sweep, primitive_id as usize);
257            }
258            true
259        };
260
261        let mut hit = SweepHit::none();
262        fast_stack!(u32, (96, 192), self.max_depth, stack, {
263            self.sweep_traverse_dynamic(&mut stack, sweep, &mut hit, &mut intersect_prims)
264        });
265    }
266
267    #[inline(always)]
268    fn sweep_traverse_dynamic<
269        F: FnMut(&Bvh2Node, &mut Sweep, &mut SweepHit) -> bool,
270        Stack: FastStack<u32>,
271    >(
272        &self,
273        stack: &mut Stack,
274        mut sweep: Sweep,
275        hit: &mut SweepHit,
276        mut intersection_fn: F,
277    ) {
278        if self.nodes.is_empty() {
279            return;
280        }
281
282        let root_node = &self.nodes[0];
283        let root_aabb = root_node.aabb();
284        let hit_root = root_aabb.intersect_sweep(&sweep) < sweep.tmax;
285        if !hit_root {
286            return;
287        } else if root_node.is_leaf() {
288            intersection_fn(root_node, &mut sweep, hit);
289            return;
290        };
291
292        let mut current_node_index = root_node.first_index;
293        loop {
294            let right_index = current_node_index as usize + 1;
295            assert!(right_index < self.nodes.len());
296            let mut left_node = unsafe { self.nodes.get_unchecked(current_node_index as usize) };
297            let mut right_node = unsafe { self.nodes.get_unchecked(right_index) };
298
299            // TODO perf: could it be faster to intersect these at the same time with avx?
300            let mut left_t = left_node.aabb().intersect_sweep(&sweep);
301            let mut right_t = right_node.aabb().intersect_sweep(&sweep);
302
303            if left_t > right_t {
304                core::mem::swap(&mut left_t, &mut right_t);
305                core::mem::swap(&mut left_node, &mut right_node);
306            }
307
308            let hit_left = left_t < sweep.tmax;
309
310            let go_left = if hit_left && left_node.is_leaf() {
311                if !intersection_fn(left_node, &mut sweep, hit) {
312                    return;
313                }
314                false
315            } else {
316                hit_left
317            };
318
319            let hit_right = right_t < sweep.tmax;
320
321            let go_right = if hit_right && right_node.is_leaf() {
322                if !intersection_fn(right_node, &mut sweep, hit) {
323                    return;
324                }
325                false
326            } else {
327                hit_right
328            };
329
330            match (go_left, go_right) {
331                (true, true) => {
332                    current_node_index = left_node.first_index;
333                    stack.push(right_node.first_index);
334                }
335                (true, false) => current_node_index = left_node.first_index,
336                (false, true) => current_node_index = right_node.first_index,
337                (false, false) => {
338                    let Some(next) = stack.pop() else {
339                        hit.t = sweep.tmax;
340                        return;
341                    };
342                    current_node_index = next;
343                }
344            }
345        }
346    }
347
348    #[inline(always)]
349    fn squared_distance_traverse<F: FnMut(Vec3A, usize) -> f32>(
350        &self,
351        point: Vec3A,
352        max_dist_sq: f32,
353        mut visit_fn: F,
354    ) -> Option<(u32, f32)> {
355        let mut closest_leaf = None;
356
357        let mut visit_prims =
358            |node: &Bvh2Node, max_dist_sq: &mut f32, closest_leaf: &mut Option<(u32, f32)>| {
359                (node.first_index..node.first_index + node.prim_count).for_each(|primitive_id| {
360                    let distance_sq = visit_fn(point, primitive_id as usize);
361                    if distance_sq < *max_dist_sq {
362                        *closest_leaf = Some((primitive_id, distance_sq));
363                        *max_dist_sq = distance_sq;
364                    }
365                });
366                true
367            };
368
369        fast_stack!(u32, (96, 192), self.max_depth, stack, {
370            Bvh2::squared_distance_traverse_dynamic(
371                self,
372                &mut stack,
373                point,
374                max_dist_sq,
375                &mut closest_leaf,
376                &mut visit_prims,
377            )
378        });
379
380        closest_leaf
381    }
382
383    #[inline(always)]
384    fn squared_distance_traverse_dynamic<
385        F: FnMut(&Bvh2Node, &mut f32, &mut Option<(u32, f32)>) -> bool,
386        Stack: FastStack<u32>,
387    >(
388        &self,
389        stack: &mut Stack,
390        point: Vec3A,
391        mut max_dist_sq: f32,
392        closest_leaf: &mut Option<(u32, f32)>,
393        mut visit_fn: F,
394    ) {
395        if self.nodes.is_empty() {
396            return;
397        }
398
399        let root_node = &self.nodes[0];
400        let root_dist_sq = root_node.aabb().distance_to_point_squared(point);
401
402        if root_dist_sq > max_dist_sq {
403            return;
404        } else if root_node.is_leaf() {
405            visit_fn(root_node, &mut max_dist_sq, closest_leaf);
406            return;
407        }
408
409        let mut current_node_index = root_node.first_index;
410
411        loop {
412            let right_index = current_node_index as usize + 1;
413            assert!(right_index < self.nodes.len());
414            let mut left_node = unsafe { self.nodes.get_unchecked(current_node_index as usize) };
415            let mut right_node = unsafe { self.nodes.get_unchecked(right_index) };
416
417            // TODO perf: could it be faster to compute these at the same time with avx?
418            let mut left_dist_sq = left_node.aabb().distance_to_point_squared(point);
419            let mut right_dist_sq = right_node.aabb().distance_to_point_squared(point);
420
421            // Sort by distance (closer first)
422            if left_dist_sq > right_dist_sq {
423                core::mem::swap(&mut left_dist_sq, &mut right_dist_sq);
424                core::mem::swap(&mut left_node, &mut right_node);
425            }
426
427            let within_left = left_dist_sq <= max_dist_sq;
428
429            let go_left = if within_left && left_node.is_leaf() {
430                if !visit_fn(left_node, &mut max_dist_sq, closest_leaf) {
431                    return;
432                }
433                false
434            } else {
435                within_left
436            };
437
438            let within_right = right_dist_sq <= max_dist_sq;
439
440            let go_right = if within_right && right_node.is_leaf() {
441                if !visit_fn(right_node, &mut max_dist_sq, closest_leaf) {
442                    return;
443                }
444                false
445            } else {
446                within_right
447            };
448
449            match (go_left, go_right) {
450                (true, true) => {
451                    current_node_index = left_node.first_index;
452                    stack.push(right_node.first_index);
453                }
454                (true, false) => current_node_index = left_node.first_index,
455                (false, true) => current_node_index = right_node.first_index,
456                (false, false) => {
457                    let Some(next) = stack.pop() else {
458                        return;
459                    };
460                    current_node_index = next;
461                }
462            }
463        }
464    }
465}
466
467pub trait ObvhsAabbExt {
468    /// Computes the squared distance from a point to this AABB.
469    fn distance_to_point_squared(&self, point: Vec3A) -> f32;
470
471    /// Checks if this AABB intersects with a sweep and returns the fraction
472    /// along the sweep at which the intersection occurs.
473    ///
474    /// Returns `f32::INFINITY` if there is no intersection.
475    fn intersect_sweep(&self, sweep: &Sweep) -> f32;
476}
477
478impl ObvhsAabbExt for Aabb {
479    #[inline(always)]
480    fn distance_to_point_squared(&self, point: Vec3A) -> f32 {
481        // OBVHS may be using a different version of Glam,
482        // so we convert to our Vec3A type.
483        let min: Vec3A = self.min.to_array().into();
484        let max: Vec3A = self.max.to_array().into();
485        let point_min = min - point;
486        let point_max = max - point;
487        let dist_min = point_min.max(Vec3A::ZERO);
488        let dist_max = point_max.min(Vec3A::ZERO);
489        dist_min.length_squared().min(dist_max.length_squared())
490    }
491
492    #[inline(always)]
493    fn intersect_sweep(&self, sweep: &Sweep) -> f32 {
494        let minkowski_sum_shift = -sweep.aabb.center();
495        let minkowski_sum_margin = sweep.aabb.diagonal() * 0.5 + sweep.tmin;
496
497        // OBVHS may be using a different version of Glam,
498        // so we convert to our Vec3A type.
499        let msum_min: Vec3A = (self.min + minkowski_sum_shift - minkowski_sum_margin)
500            .to_array()
501            .into();
502        let msum_max: Vec3A = (self.max + minkowski_sum_shift + minkowski_sum_margin)
503            .to_array()
504            .into();
505
506        // Now, we cast a ray from the origin along the velocity,
507        // and intersect it with the Minkowski sum.
508        let t1 = msum_min * sweep.inv_velocity;
509        let t2 = msum_max * sweep.inv_velocity;
510
511        let tmin = t1.min(t2);
512        let tmax = t1.max(t2);
513
514        let tmin_n = tmin.max_element();
515        let tmax_n = tmax.min_element();
516
517        if tmax_n >= tmin_n && tmax_n >= 0.0 {
518            tmin_n
519        } else {
520            f32::INFINITY
521        }
522    }
523}
524
525#[inline(always)]
526pub fn obvhs_ray(ray: &Ray, max_distance: f32) -> obvhs::ray::Ray {
527    #[cfg(feature = "2d")]
528    let origin = ray.origin.extend(0.0).to_array().into();
529    #[cfg(feature = "3d")]
530    let origin = ray.origin.to_array().into();
531    #[cfg(feature = "2d")]
532    let direction = ray.direction.extend(0.0).to_array().into();
533    #[cfg(feature = "3d")]
534    let direction = ray.direction.to_array().into();
535
536    obvhs::ray::Ray::new(origin, direction, 0.0, max_distance)
537}