1use bevy_math::Vec3A;
2use obvhs::{
3 aabb::Aabb,
4 bvh2::{Bvh2, node::Bvh2Node},
5 fast_stack,
6 faststack::FastStack,
7 ray::{INVALID_ID, safe_inverse},
8};
9
10use crate::math::Ray;
11
12#[derive(Clone, Copy, Debug)]
15#[repr(C)]
16pub struct Sweep {
17 pub aabb: Aabb,
19 pub velocity: Vec3A,
21 pub inv_velocity: Vec3A,
23 pub tmin: f32,
25 pub tmax: f32,
27}
28
29impl Sweep {
30 pub fn new(aabb: Aabb, velocity: Vec3A, min: f32, max: f32) -> Self {
32 let sweep = Sweep {
33 aabb,
34 velocity,
35 inv_velocity: Vec3A::new(
36 safe_inverse(velocity.x),
37 safe_inverse(velocity.y),
38 safe_inverse(velocity.z),
39 ),
40 tmin: min,
41 tmax: max,
42 };
43
44 debug_assert!(sweep.inv_velocity.is_finite());
45 debug_assert!(sweep.velocity.is_finite());
46 debug_assert!(sweep.aabb.min.is_finite());
47 debug_assert!(sweep.aabb.max.is_finite());
48
49 sweep
50 }
51}
52
53#[derive(Clone, Copy, Debug)]
56#[repr(C)]
57pub struct SweepHit {
58 pub primitive_id: u32,
60 pub t: f32,
62}
63
64impl SweepHit {
65 pub fn none() -> Self {
67 Self {
68 primitive_id: INVALID_ID,
69 t: f32::INFINITY,
70 }
71 }
72}
73
74pub trait Bvh2Ext {
76 fn sweep_traverse<F: FnMut(&Sweep, usize) -> f32>(
87 &self,
88 sweep: Sweep,
89 hit: &mut SweepHit,
90 intersection_fn: F,
91 ) -> bool;
92
93 fn sweep_traverse_miss<F: FnMut(&Sweep, usize) -> f32>(
103 &self,
104 sweep: Sweep,
105 intersection_fn: F,
106 ) -> bool;
107
108 fn sweep_traverse_anyhit<F: FnMut(&Sweep, usize)>(&self, sweep: Sweep, intersection_fn: F);
119
120 fn sweep_traverse_dynamic<
135 F: FnMut(&Bvh2Node, &mut Sweep, &mut SweepHit) -> bool,
136 Stack: FastStack<u32>,
137 >(
138 &self,
139 stack: &mut Stack,
140 sweep: Sweep,
141 hit: &mut SweepHit,
142 intersection_fn: F,
143 );
144
145 fn squared_distance_traverse<F: FnMut(Vec3A, usize) -> f32>(
160 &self,
161 point: Vec3A,
162 max_dist_sq: f32,
163 visit_fn: F,
164 ) -> Option<(u32, f32)>;
165
166 fn squared_distance_traverse_dynamic<
182 F: FnMut(&Bvh2Node, &mut f32, &mut Option<(u32, f32)>) -> bool,
183 Stack: FastStack<u32>,
184 >(
185 &self,
186 stack: &mut Stack,
187 point: Vec3A,
188 max_dist_sq: f32,
189 closest_leaf: &mut Option<(u32, f32)>,
190 visit_fn: F,
191 );
192}
193
194impl Bvh2Ext for Bvh2 {
195 #[inline(always)]
196 fn sweep_traverse<F: FnMut(&Sweep, usize) -> f32>(
197 &self,
198 sweep: Sweep,
199 hit: &mut SweepHit,
200 mut intersection_fn: F,
201 ) -> bool {
202 let mut intersect_prims = |node: &Bvh2Node, sweep: &mut Sweep, hit: &mut SweepHit| {
203 (node.first_index..node.first_index + node.prim_count).for_each(|primitive_id| {
204 let t = intersection_fn(sweep, primitive_id as usize);
205 if t < sweep.tmax {
206 hit.primitive_id = primitive_id;
207 hit.t = t;
208 sweep.tmax = t;
209 }
210 });
211 true
212 };
213
214 fast_stack!(u32, (96, 192), self.max_depth, stack, {
215 Bvh2::sweep_traverse_dynamic(self, &mut stack, sweep, hit, &mut intersect_prims)
216 });
217
218 hit.t < sweep.tmax }
220
221 #[inline(always)]
222 fn sweep_traverse_miss<F: FnMut(&Sweep, usize) -> f32>(
223 &self,
224 sweep: Sweep,
225 mut intersection_fn: F,
226 ) -> bool {
227 let mut miss = true;
228 let mut intersect_prims = |node: &Bvh2Node, sweep: &mut Sweep, _hit: &mut SweepHit| {
229 for primitive_id in node.first_index..node.first_index + node.prim_count {
230 let t = intersection_fn(sweep, primitive_id as usize);
231 if t < sweep.tmax {
232 miss = false;
233 return false;
234 }
235 }
236 true
237 };
238
239 fast_stack!(u32, (96, 192), self.max_depth, stack, {
240 Bvh2::sweep_traverse_dynamic(
241 self,
242 &mut stack,
243 sweep,
244 &mut SweepHit::none(),
245 &mut intersect_prims,
246 )
247 });
248
249 miss
250 }
251
252 #[inline(always)]
253 fn sweep_traverse_anyhit<F: FnMut(&Sweep, usize)>(&self, sweep: Sweep, mut intersection_fn: F) {
254 let mut intersect_prims = |node: &Bvh2Node, sweep: &mut Sweep, _hit: &mut SweepHit| {
255 for primitive_id in node.first_index..node.first_index + node.prim_count {
256 intersection_fn(sweep, primitive_id as usize);
257 }
258 true
259 };
260
261 let mut hit = SweepHit::none();
262 fast_stack!(u32, (96, 192), self.max_depth, stack, {
263 self.sweep_traverse_dynamic(&mut stack, sweep, &mut hit, &mut intersect_prims)
264 });
265 }
266
267 #[inline(always)]
268 fn sweep_traverse_dynamic<
269 F: FnMut(&Bvh2Node, &mut Sweep, &mut SweepHit) -> bool,
270 Stack: FastStack<u32>,
271 >(
272 &self,
273 stack: &mut Stack,
274 mut sweep: Sweep,
275 hit: &mut SweepHit,
276 mut intersection_fn: F,
277 ) {
278 if self.nodes.is_empty() {
279 return;
280 }
281
282 let root_node = &self.nodes[0];
283 let root_aabb = root_node.aabb();
284 let hit_root = root_aabb.intersect_sweep(&sweep) < sweep.tmax;
285 if !hit_root {
286 return;
287 } else if root_node.is_leaf() {
288 intersection_fn(root_node, &mut sweep, hit);
289 return;
290 };
291
292 let mut current_node_index = root_node.first_index;
293 loop {
294 let right_index = current_node_index as usize + 1;
295 assert!(right_index < self.nodes.len());
296 let mut left_node = unsafe { self.nodes.get_unchecked(current_node_index as usize) };
297 let mut right_node = unsafe { self.nodes.get_unchecked(right_index) };
298
299 let mut left_t = left_node.aabb().intersect_sweep(&sweep);
301 let mut right_t = right_node.aabb().intersect_sweep(&sweep);
302
303 if left_t > right_t {
304 core::mem::swap(&mut left_t, &mut right_t);
305 core::mem::swap(&mut left_node, &mut right_node);
306 }
307
308 let hit_left = left_t < sweep.tmax;
309
310 let go_left = if hit_left && left_node.is_leaf() {
311 if !intersection_fn(left_node, &mut sweep, hit) {
312 return;
313 }
314 false
315 } else {
316 hit_left
317 };
318
319 let hit_right = right_t < sweep.tmax;
320
321 let go_right = if hit_right && right_node.is_leaf() {
322 if !intersection_fn(right_node, &mut sweep, hit) {
323 return;
324 }
325 false
326 } else {
327 hit_right
328 };
329
330 match (go_left, go_right) {
331 (true, true) => {
332 current_node_index = left_node.first_index;
333 stack.push(right_node.first_index);
334 }
335 (true, false) => current_node_index = left_node.first_index,
336 (false, true) => current_node_index = right_node.first_index,
337 (false, false) => {
338 let Some(next) = stack.pop() else {
339 hit.t = sweep.tmax;
340 return;
341 };
342 current_node_index = next;
343 }
344 }
345 }
346 }
347
348 #[inline(always)]
349 fn squared_distance_traverse<F: FnMut(Vec3A, usize) -> f32>(
350 &self,
351 point: Vec3A,
352 max_dist_sq: f32,
353 mut visit_fn: F,
354 ) -> Option<(u32, f32)> {
355 let mut closest_leaf = None;
356
357 let mut visit_prims =
358 |node: &Bvh2Node, max_dist_sq: &mut f32, closest_leaf: &mut Option<(u32, f32)>| {
359 (node.first_index..node.first_index + node.prim_count).for_each(|primitive_id| {
360 let distance_sq = visit_fn(point, primitive_id as usize);
361 if distance_sq < *max_dist_sq {
362 *closest_leaf = Some((primitive_id, distance_sq));
363 *max_dist_sq = distance_sq;
364 }
365 });
366 true
367 };
368
369 fast_stack!(u32, (96, 192), self.max_depth, stack, {
370 Bvh2::squared_distance_traverse_dynamic(
371 self,
372 &mut stack,
373 point,
374 max_dist_sq,
375 &mut closest_leaf,
376 &mut visit_prims,
377 )
378 });
379
380 closest_leaf
381 }
382
383 #[inline(always)]
384 fn squared_distance_traverse_dynamic<
385 F: FnMut(&Bvh2Node, &mut f32, &mut Option<(u32, f32)>) -> bool,
386 Stack: FastStack<u32>,
387 >(
388 &self,
389 stack: &mut Stack,
390 point: Vec3A,
391 mut max_dist_sq: f32,
392 closest_leaf: &mut Option<(u32, f32)>,
393 mut visit_fn: F,
394 ) {
395 if self.nodes.is_empty() {
396 return;
397 }
398
399 let root_node = &self.nodes[0];
400 let root_dist_sq = root_node.aabb().distance_to_point_squared(point);
401
402 if root_dist_sq > max_dist_sq {
403 return;
404 } else if root_node.is_leaf() {
405 visit_fn(root_node, &mut max_dist_sq, closest_leaf);
406 return;
407 }
408
409 let mut current_node_index = root_node.first_index;
410
411 loop {
412 let right_index = current_node_index as usize + 1;
413 assert!(right_index < self.nodes.len());
414 let mut left_node = unsafe { self.nodes.get_unchecked(current_node_index as usize) };
415 let mut right_node = unsafe { self.nodes.get_unchecked(right_index) };
416
417 let mut left_dist_sq = left_node.aabb().distance_to_point_squared(point);
419 let mut right_dist_sq = right_node.aabb().distance_to_point_squared(point);
420
421 if left_dist_sq > right_dist_sq {
423 core::mem::swap(&mut left_dist_sq, &mut right_dist_sq);
424 core::mem::swap(&mut left_node, &mut right_node);
425 }
426
427 let within_left = left_dist_sq <= max_dist_sq;
428
429 let go_left = if within_left && left_node.is_leaf() {
430 if !visit_fn(left_node, &mut max_dist_sq, closest_leaf) {
431 return;
432 }
433 false
434 } else {
435 within_left
436 };
437
438 let within_right = right_dist_sq <= max_dist_sq;
439
440 let go_right = if within_right && right_node.is_leaf() {
441 if !visit_fn(right_node, &mut max_dist_sq, closest_leaf) {
442 return;
443 }
444 false
445 } else {
446 within_right
447 };
448
449 match (go_left, go_right) {
450 (true, true) => {
451 current_node_index = left_node.first_index;
452 stack.push(right_node.first_index);
453 }
454 (true, false) => current_node_index = left_node.first_index,
455 (false, true) => current_node_index = right_node.first_index,
456 (false, false) => {
457 let Some(next) = stack.pop() else {
458 return;
459 };
460 current_node_index = next;
461 }
462 }
463 }
464 }
465}
466
467pub trait ObvhsAabbExt {
468 fn distance_to_point_squared(&self, point: Vec3A) -> f32;
470
471 fn intersect_sweep(&self, sweep: &Sweep) -> f32;
476}
477
478impl ObvhsAabbExt for Aabb {
479 #[inline(always)]
480 fn distance_to_point_squared(&self, point: Vec3A) -> f32 {
481 let min: Vec3A = self.min.to_array().into();
484 let max: Vec3A = self.max.to_array().into();
485 let point_min = min - point;
486 let point_max = max - point;
487 let dist_min = point_min.max(Vec3A::ZERO);
488 let dist_max = point_max.min(Vec3A::ZERO);
489 dist_min.length_squared().min(dist_max.length_squared())
490 }
491
492 #[inline(always)]
493 fn intersect_sweep(&self, sweep: &Sweep) -> f32 {
494 let minkowski_sum_shift = -sweep.aabb.center();
495 let minkowski_sum_margin = sweep.aabb.diagonal() * 0.5 + sweep.tmin;
496
497 let msum_min: Vec3A = (self.min + minkowski_sum_shift - minkowski_sum_margin)
500 .to_array()
501 .into();
502 let msum_max: Vec3A = (self.max + minkowski_sum_shift + minkowski_sum_margin)
503 .to_array()
504 .into();
505
506 let t1 = msum_min * sweep.inv_velocity;
509 let t2 = msum_max * sweep.inv_velocity;
510
511 let tmin = t1.min(t2);
512 let tmax = t1.max(t2);
513
514 let tmin_n = tmin.max_element();
515 let tmax_n = tmax.min_element();
516
517 if tmax_n >= tmin_n && tmax_n >= 0.0 {
518 tmin_n
519 } else {
520 f32::INFINITY
521 }
522 }
523}
524
525#[inline(always)]
526pub fn obvhs_ray(ray: &Ray, max_distance: f32) -> obvhs::ray::Ray {
527 #[cfg(feature = "2d")]
528 let origin = ray.origin.extend(0.0).to_array().into();
529 #[cfg(feature = "3d")]
530 let origin = ray.origin.to_array().into();
531 #[cfg(feature = "2d")]
532 let direction = ray.direction.extend(0.0).to_array().into();
533 #[cfg(feature = "3d")]
534 let direction = ray.direction.to_array().into();
535
536 obvhs::ray::Ray::new(origin, direction, 0.0, max_distance)
537}