obvhs/bvh2/
mod.rs

1//! A binary BVH
2
3pub mod builder;
4pub mod insertion_removal;
5pub mod leaf_collapser;
6
7pub mod node;
8pub mod reinsertion;
9
10use bytemuck::zeroed_vec;
11use glam::Vec3A;
12use node::Bvh2Node;
13
14use std::{
15    collections::{HashMap, HashSet},
16    fmt,
17};
18
19use reinsertion::find_reinsertion;
20
21use crate::{
22    Boundable, INVALID,
23    aabb::Aabb,
24    fast_stack,
25    faststack::FastStack,
26    ray::{Ray, RayHit},
27};
28
29/// A binary BVH
30#[derive(Clone)]
31pub struct Bvh2 {
32    /// List of nodes contained in this bvh. first_index in Bvh2Node for inner nodes indexes into this list. This list
33    /// fully represents the BVH tree. The other fields in this struct provide additional information that allow the BVH
34    /// to be manipulated more efficiently, but are not actually part of the BVH itself. The only other critical field is
35    /// `primitive_indices`, assuming the BVH is not using a direct mapping.
36    pub nodes: Vec<Bvh2Node>,
37
38    /// Mapping from bvh primitive indices to original input indices
39    /// The reason for this mapping is that if multiple primitives are contained in a node, they need to have their
40    /// indices laid out contiguously. To avoid this indirection we have two options:
41    /// 1. Layout the primitives in the order of the primitive_indices mapping so that this can index directly into the
42    ///    primitive list.
43    /// 2. Only allow one primitive per node and write back the original mapping to the bvh node list.
44    pub primitive_indices: Vec<u32>,
45
46    /// A freelist for use when removing primitives from the bvh. These represent slots in Bvh2::primitive_indices
47    /// that are available if a primitive is added to the bvh. Only currently used by Bvh2::remove_primitive() and
48    /// Bvh2::insert_primitive() which are not part of the typical initial bvh generation.
49    pub primitive_indices_freelist: Vec<u32>,
50
51    /// An optional mapping from primitives back to nodes.
52    /// Ex. `let node_id = primitives_to_nodes[primitive_id];`
53    /// Where primitive_id is the original index of the primitive used when making the BVH and node_id is the index
54    /// into Bvh2::nodes for the node of that primitive. Always use with the direct primitive id, not the one in the
55    /// bvh node.
56    /// See: Bvh2::init_primitives_to_nodes().
57    /// If `primitives_to_nodes` is empty it's expected that it has not been initialized yet or has been invalidated.
58    /// If `primitives_to_nodes` is not empty, it is expected that functions that modify the BVH will keep the mapping
59    /// valid.
60    pub primitives_to_nodes: Vec<u32>,
61
62    /// An optional mapping from a given node index to that node's parent for each node in the bvh.
63    /// See: Bvh2::init_parents_if_uninit().
64    /// If `parents` is empty it's expected that it has not been initialized yet or has been invalidated.
65    /// If `parents` is not empty it's expected that functions that modify the BVH will keep the mapping valid.
66    pub parents: Vec<u32>,
67
68    /// This is set by operations that ensure that parents have higher indices than children and unset by operations
69    /// that might disturb that order. Some operations require this ordering and will reorder if this is not true.
70    pub children_are_ordered_after_parents: bool,
71
72    /// Stack defaults to 96 or the max depth found during initial ploc building, whichever is larger. This may be
73    /// larger than needed depending on what post processing steps (like collapse, reinsertion, etc...), but the cost of
74    /// recalculating it may not be worth it so it is not done automatically.
75    pub max_depth: usize,
76
77    /// Indicates that this BVH is using spatial splits. Large triangles are split into multiple smaller Aabbs, so
78    /// primitives will extend outside the leaf in some cases.
79    /// If the bvh uses splits, a primitive can show up in multiple leaf nodes so there wont be a 1 to 1 correlation
80    /// between the total number of primitives in leaf nodes and in Bvh2::primitive_indices, vs the input triangles.
81    /// If spatial splits are used, some validation steps have to be skipped and some features are unavailable:
82    /// `Bvh2::add_leaf()`, `Bvh2::remove_leaf()`, `Bvh2::add_primitive()`, `Bvh2::remove_primitive()` as these would
83    /// require a mapping from one primitive to multiple nodes in `Bvh2::primitives_to_nodes`
84    pub uses_spatial_splits: bool,
85}
86
87pub const DEFAULT_MAX_STACK_DEPTH: usize = 96;
88
89impl Default for Bvh2 {
90    fn default() -> Self {
91        Self {
92            nodes: Default::default(),
93            primitive_indices: Default::default(),
94            primitive_indices_freelist: Default::default(),
95            primitives_to_nodes: Default::default(),
96            parents: Default::default(),
97            children_are_ordered_after_parents: Default::default(),
98            max_depth: DEFAULT_MAX_STACK_DEPTH,
99            uses_spatial_splits: Default::default(),
100        }
101    }
102}
103
104impl Bvh2 {
105    /// Reset BVH while keeping allocations for rebuild. Note: results in an invalid bvh until rebuilt.
106    pub fn reset_for_reuse(&mut self, prim_count: usize, indices: Option<Vec<u32>>) {
107        let nodes_count = (2 * prim_count as i64 - 1).max(0) as usize;
108        self.nodes.resize(nodes_count, Default::default());
109        if let Some(indices) = indices {
110            self.primitive_indices = indices;
111        } else {
112            self.primitive_indices
113                .resize(prim_count, Default::default());
114        }
115        self.primitive_indices_freelist.clear();
116        self.primitives_to_nodes.clear();
117        self.parents.clear();
118        self.children_are_ordered_after_parents = Default::default();
119        self.max_depth = DEFAULT_MAX_STACK_DEPTH;
120        self.uses_spatial_splits = Default::default();
121    }
122
123    pub fn zeroed(prim_count: usize) -> Self {
124        let nodes_count = (2 * prim_count as i64 - 1).max(0) as usize;
125        Self {
126            nodes: zeroed_vec(nodes_count),
127            primitive_indices: zeroed_vec(prim_count),
128            primitive_indices_freelist: Default::default(),
129            primitives_to_nodes: Default::default(),
130            parents: Default::default(),
131            children_are_ordered_after_parents: Default::default(),
132            max_depth: DEFAULT_MAX_STACK_DEPTH,
133            uses_spatial_splits: Default::default(),
134        }
135    }
136
137    /// Traverse the bvh for a given `Ray`. Returns the closest intersected primitive.
138    ///
139    /// # Arguments
140    /// * `ray` - The ray to be tested for intersection.
141    /// * `hit` - As traverse_dynamic intersects primitives, it will update `hit` with the closest.
142    /// * `intersection_fn` - should take the given ray and primitive index and return the distance to the intersection, if any.
143    ///
144    /// Note the primitive index should index first into Bvh2::primitive_indices then that will be index of original primitive.
145    /// Various parts of the BVH building process might reorder the primitives. To avoid this indirection, reorder your
146    /// original primitives per primitive_indices.
147    #[inline(always)]
148    pub fn ray_traverse<F: FnMut(&Ray, usize) -> f32>(
149        &self,
150        ray: Ray,
151        hit: &mut RayHit,
152        mut intersection_fn: F,
153    ) -> bool {
154        let mut intersect_prims = |node: &Bvh2Node, ray: &mut Ray, hit: &mut RayHit| {
155            (node.first_index..node.first_index + node.prim_count).for_each(|primitive_id| {
156                let t = intersection_fn(ray, primitive_id as usize);
157                if t < ray.tmax {
158                    hit.primitive_id = primitive_id;
159                    hit.t = t;
160                    ray.tmax = t;
161                }
162            });
163            true
164        };
165
166        fast_stack!(u32, (96, 192), self.max_depth, stack, {
167            Bvh2::ray_traverse_dynamic(self, &mut stack, ray, hit, &mut intersect_prims)
168        });
169
170        hit.t < ray.tmax // Note this is valid since traverse_with_stack does not mutate the ray
171    }
172
173    /// Traverse the bvh for a given `Ray`. Returns true if the ray missed all primitives.
174    ///
175    /// # Arguments
176    /// * `ray` - The ray to be tested for intersection.
177    /// * `hit` - As traverse_dynamic intersects primitives, it will update `hit` with the closest.
178    /// * `intersection_fn` - should take the given ray and primitive index and return the distance to the intersection, if any.
179    ///
180    /// Note the primitive index should index first into Bvh2::primitive_indices then that will be index of original primitive.
181    /// Various parts of the BVH building process might reorder the primitives. To avoid this indirection, reorder your
182    /// original primitives per primitive_indices.
183    #[inline(always)]
184    pub fn ray_traverse_miss<F: FnMut(&Ray, usize) -> f32>(
185        &self,
186        ray: Ray,
187        mut intersection_fn: F,
188    ) -> bool {
189        let mut miss = true;
190        let mut intersect_prims = |node: &Bvh2Node, ray: &mut Ray, _hit: &mut RayHit| {
191            for primitive_id in node.first_index..node.first_index + node.prim_count {
192                let t = intersection_fn(ray, primitive_id as usize);
193                if t < ray.tmax {
194                    miss = false;
195                    return false;
196                }
197            }
198            true
199        };
200
201        fast_stack!(u32, (96, 192), self.max_depth, stack, {
202            Bvh2::ray_traverse_dynamic(
203                self,
204                &mut stack,
205                ray,
206                &mut RayHit::none(),
207                &mut intersect_prims,
208            )
209        });
210
211        miss
212    }
213
214    /// Traverse the bvh for a given `Ray`. Intersects all primitives along ray (for things like evaluating transparency)
215    ///   intersection_fn is called for all intersections. Ray is not updated to allow for evaluating at every hit.
216    ///
217    /// # Arguments
218    /// * `ray` - The ray to be tested for intersection.
219    /// * `intersection_fn` - takes the given ray and primitive index.
220    ///
221    /// Note the primitive index should index first into Bvh2::primitive_indices then that will be index of original primitive.
222    /// Various parts of the BVH building process might reorder the primitives. To avoid this indirection, reorder your
223    /// original primitives per primitive_indices.
224    #[inline(always)]
225    pub fn ray_traverse_anyhit<F: FnMut(&Ray, usize)>(&self, ray: Ray, mut intersection_fn: F) {
226        let mut intersect_prims = |node: &Bvh2Node, ray: &mut Ray, _hit: &mut RayHit| {
227            for primitive_id in node.first_index..node.first_index + node.prim_count {
228                intersection_fn(ray, primitive_id as usize);
229            }
230            true
231        };
232
233        let mut hit = RayHit::none();
234        fast_stack!(u32, (96, 192), self.max_depth, stack, {
235            self.ray_traverse_dynamic(&mut stack, ray, &mut hit, &mut intersect_prims)
236        });
237    }
238
239    /// Traverse the BVH
240    /// Returns false when no hit is found. Consider using or referencing: Bvh2::ray_traverse(),
241    /// Bvh2::ray_traverse_miss(), or Bvh2::ray_traverse_anyhit().
242    ///
243    /// # Arguments
244    /// * `state` - Holds the current traversal state. Allows traverse_dynamic to yield.
245    /// * `hit` - As traverse_dynamic intersects primitives, it will update `hit` with the closest.
246    /// * `intersection_fn` - should test the primitives in the given node, update the ray.tmax, and hit info. Return
247    ///   false to halt traversal.
248    ///   For basic miss test return false on first hit to halt traversal.
249    ///   For closest hit run until it returns false and check hit.t < ray.tmax to see if it hit something
250    ///   For transparency, you want to hit every primitive in the ray's path, keeping track of the closest opaque hit.
251    ///   and then manually setting ray.tmax to that closest opaque hit at each iteration.
252    ///
253    /// Note the primitive index should index first into Bvh2::primitive_indices then that will be index of original primitive.
254    /// Various parts of the BVH building process might reorder the primitives. To avoid this indirection, reorder your
255    /// original primitives per primitive_indices.
256    #[inline(always)]
257    pub fn ray_traverse_dynamic<
258        F: FnMut(&Bvh2Node, &mut Ray, &mut RayHit) -> bool,
259        Stack: FastStack<u32>,
260    >(
261        &self,
262        stack: &mut Stack,
263        mut ray: Ray,
264        hit: &mut RayHit,
265        mut intersection_fn: F,
266    ) {
267        if self.nodes.is_empty() {
268            return;
269        }
270
271        let root_node = &self.nodes[0];
272        let hit_root = root_node.aabb().intersect_ray(&ray) < ray.tmax;
273        if !hit_root {
274            return;
275        } else if root_node.is_leaf() {
276            intersection_fn(root_node, &mut ray, hit);
277            return;
278        };
279
280        let mut current_node_index = root_node.first_index;
281        loop {
282            let right_index = current_node_index as usize + 1;
283            assert!(right_index < self.nodes.len());
284            let mut left_node = unsafe { self.nodes.get_unchecked(current_node_index as usize) };
285            let mut right_node = unsafe { self.nodes.get_unchecked(right_index) };
286
287            // TODO perf: could it be faster to intersect these at the same time with avx?
288            let mut left_t = left_node.aabb().intersect_ray(&ray);
289            let mut right_t = right_node.aabb().intersect_ray(&ray);
290
291            if left_t > right_t {
292                core::mem::swap(&mut left_t, &mut right_t);
293                core::mem::swap(&mut left_node, &mut right_node);
294            }
295
296            let hit_left = left_t < ray.tmax;
297
298            let go_left = if hit_left && left_node.is_leaf() {
299                if !intersection_fn(left_node, &mut ray, hit) {
300                    return;
301                }
302                false
303            } else {
304                hit_left
305            };
306
307            let hit_right = right_t < ray.tmax;
308
309            let go_right = if hit_right && right_node.is_leaf() {
310                if !intersection_fn(right_node, &mut ray, hit) {
311                    return;
312                }
313                false
314            } else {
315                hit_right
316            };
317
318            match (go_left, go_right) {
319                (true, true) => {
320                    current_node_index = left_node.first_index;
321                    stack.push(right_node.first_index);
322                }
323                (true, false) => current_node_index = left_node.first_index,
324                (false, true) => current_node_index = right_node.first_index,
325                (false, false) => {
326                    let Some(next) = stack.pop() else {
327                        hit.t = ray.tmax;
328                        return;
329                    };
330                    current_node_index = next;
331                }
332            }
333        }
334    }
335
336    /// Recursively traverse the bvh for a given `Ray`.
337    /// On completion, `leaf_indices` will contain a list of the intersected leaf node indices.
338    /// This method is slower than stack traversal and only exists as a reference.
339    /// This method does not check if the primitive was intersected, only the leaf node.
340    pub fn ray_traverse_recursive(
341        &self,
342        ray: &Ray,
343        node_index: usize,
344        leaf_indices: &mut Vec<usize>,
345    ) {
346        if self.nodes.is_empty() {
347            return;
348        }
349        let node = &self.nodes[node_index];
350        if node.aabb().intersect_ray(ray) < f32::INFINITY {
351            if node.is_leaf() {
352                leaf_indices.push(node_index);
353            } else {
354                self.ray_traverse_recursive(ray, node.first_index as usize, leaf_indices);
355                self.ray_traverse_recursive(ray, node.first_index as usize + 1, leaf_indices);
356            }
357        }
358    }
359
360    /// Traverse the BVH with an Aabb. fn `eval` is called for nodes that intersect `aabb`
361    /// The bvh (self) and the current node index is passed into fn `eval`
362    /// Note each node may have multiple primitives. `node.first_index` is the index of the first primitive.
363    /// `node.prim_count` is the quantity of primitives contained in the given node.
364    /// Return false from eval to halt traversal
365    pub fn aabb_traverse<F: FnMut(&Self, u32) -> bool>(&self, aabb: Aabb, mut eval: F) {
366        if self.nodes.is_empty() {
367            return;
368        }
369
370        let root_node = &self.nodes[0];
371        if root_node.is_leaf() {
372            if root_node.aabb().intersect_aabb(&aabb) {
373                eval(self, 0);
374            }
375            return;
376        }
377
378        fast_stack!(u32, (96, 192), self.max_depth, stack, {
379            stack.push(root_node.first_index);
380            while let Some(node_index) = stack.pop() {
381                // Left
382                let node = &self.nodes[node_index as usize];
383                if node.aabb().intersect_aabb(&aabb) {
384                    if node.is_leaf() {
385                        if !eval(self, node_index) {
386                            return;
387                        }
388                    } else {
389                        stack.push(node.first_index);
390                    }
391                }
392
393                // Right
394                let node_index = node_index + 1;
395                let node = &self.nodes[node_index as usize];
396                if node.aabb().intersect_aabb(&aabb) {
397                    if node.is_leaf() {
398                        if !eval(self, node_index) {
399                            return;
400                        }
401                    } else {
402                        stack.push(node.first_index);
403                    }
404                }
405            }
406        });
407    }
408
409    /// Traverse the BVH with a point. fn `eval` is called for nodes that intersect `point`
410    /// The bvh (self) and the current node index is passed into fn `eval`
411    /// Note each node may have multiple primitives. `node.first_index` is the index of the first primitive.
412    /// `node.prim_count` is the quantity of primitives contained in the given node.
413    /// Return false from eval to halt traversal
414    pub fn point_traverse<F: FnMut(&Self, u32) -> bool>(&self, point: Vec3A, mut eval: F) {
415        if self.nodes.is_empty() {
416            return;
417        }
418
419        let root_node = &self.nodes[0];
420        if root_node.is_leaf() {
421            if root_node.aabb().contains_point(point) {
422                eval(self, 0);
423            }
424            return;
425        }
426
427        fast_stack!(u32, (96, 192), self.max_depth, stack, {
428            stack.push(root_node.first_index);
429            while let Some(node_index) = stack.pop() {
430                // Left
431                let node = &self.nodes[node_index as usize];
432                if node.aabb().contains_point(point) {
433                    if node.is_leaf() {
434                        if !eval(self, node_index) {
435                            return;
436                        }
437                    } else {
438                        stack.push(node.first_index);
439                    }
440                }
441
442                // Right
443                let node_index = node_index + 1;
444                let node = &self.nodes[node_index as usize];
445                if node.aabb().contains_point(point) {
446                    if node.is_leaf() {
447                        if !eval(self, node_index) {
448                            return;
449                        }
450                    } else {
451                        stack.push(node.first_index);
452                    }
453                }
454            }
455        });
456    }
457
458    /// Order node array in stack traversal order. Ensures parents are always at lower indices than children. Fairly
459    /// slow, can take around 1/3 of the time of building the same BVH from scratch from with the fastest_build preset.
460    /// Doesn't seem to speed up traversal much for a new BVH created from PLOC, but if it has had many
461    /// removals/insertions it can help.
462    pub fn reorder_in_stack_traversal_order(&mut self) {
463        if self.nodes.len() < 2 {
464            return;
465        }
466        let mut new_nodes: Vec<Bvh2Node> = Vec::with_capacity(self.nodes.len());
467        let mut mapping = vec![0; self.nodes.len()]; // Map from where n node used to be to where it is now
468        let mut stack = Vec::new();
469        stack.push(self.nodes[0].first_index);
470        new_nodes.push(self.nodes[0]);
471        mapping[0] = 0;
472        while let Some(current_node_index) = stack.pop() {
473            let node_a = &self.nodes[current_node_index as usize];
474            let node_b = &self.nodes[current_node_index as usize + 1];
475            if !node_a.is_leaf() {
476                stack.push(node_a.first_index);
477            }
478            if !node_b.is_leaf() {
479                stack.push(node_b.first_index);
480            }
481            let new_node_idx = new_nodes.len() as u32;
482            mapping[current_node_index as usize] = new_node_idx;
483            mapping[current_node_index as usize + 1] = new_node_idx + 1;
484            new_nodes.push(*node_a);
485            new_nodes.push(*node_b);
486        }
487        for n in &mut new_nodes {
488            if !n.is_leaf() {
489                n.first_index = mapping[n.first_index as usize];
490            }
491        }
492        self.nodes = new_nodes;
493        if !self.parents.is_empty() {
494            self.update_parents();
495        }
496        if !self.primitives_to_nodes.is_empty() {
497            self.update_primitives_to_nodes();
498        }
499        self.children_are_ordered_after_parents = true;
500    }
501
502    /// Refits the whole BVH from the leaves up. If the leaves have moved very much the BVH can quickly become
503    /// degenerate causing significantly higher traversal times. Consider rebuilding the BVH from scratch or running a
504    /// bit of reinsertion after refit.
505    /// Usage:
506    /// ```
507    ///    use glam::*;
508    ///    use obvhs::*;
509    ///    use obvhs::{ploc::*, test_util::geometry::demoscene, bvh2::builder::build_bvh2_from_tris};
510    ///    use std::time::Duration;
511    ///
512    ///    let mut tris = demoscene(32, 0);
513    ///    let mut bvh = build_bvh2_from_tris(&tris, BvhBuildParams::fastest_build(), &mut Duration::default());
514    ///
515    ///    bvh.init_primitives_to_nodes_if_uninit(); // Generate mapping from primitives to nodes
516    ///    tris.transform(&Mat4::from_scale_rotation_translation(
517    ///        Vec3::splat(1.3),
518    ///        Quat::from_rotation_y(0.1),
519    ///        vec3(0.33, 0.3, 0.37),
520    ///    ));
521    ///    for (prim_id, tri) in tris.iter().enumerate() {
522    ///        bvh.nodes[bvh.primitives_to_nodes[prim_id] as usize].set_aabb(tri.aabb()); // Update aabbs
523    ///    }
524    ///    bvh.refit_all(); // Refit aabbs
525    ///    bvh.validate(&tris, false, true); // Validate that aabbs are now fitting tightly
526    /// ```
527    pub fn refit_all(&mut self) {
528        if self.nodes.is_empty() {
529            return;
530        }
531        if self.children_are_ordered_after_parents {
532            // If children are already ordered after parents we can update in a single sweep.
533            // Around 3x faster than the fallback below.
534            for node_id in (0..self.nodes.len()).rev() {
535                let node = &self.nodes[node_id];
536                if !node.is_leaf() {
537                    let first_child_bbox = *self.nodes[node.first_index as usize].aabb();
538                    let second_child_bbox = *self.nodes[node.first_index as usize + 1].aabb();
539                    self.nodes[node_id].set_aabb(first_child_bbox.union(&second_child_bbox));
540                }
541            }
542        } else {
543            // If not, we need to create a safe order in which we can make updates.
544            // This is much faster than reordering the whole bvh with Bvh2::reorder_in_stack_traversal_order()
545            fast_stack!(u32, (96, 192), self.max_depth, stack, {
546                let mut reverse_stack = Vec::with_capacity(self.nodes.len());
547                stack.push(0);
548                reverse_stack.push(0);
549                while let Some(current_node_index) = stack.pop() {
550                    let node = &self.nodes[current_node_index as usize];
551                    if !node.is_leaf() {
552                        reverse_stack.push(node.first_index);
553                        reverse_stack.push(node.first_index + 1);
554                        stack.push(node.first_index);
555                        stack.push(node.first_index + 1);
556                    }
557                }
558                for node_id in reverse_stack.iter().rev() {
559                    let node = &self.nodes[*node_id as usize];
560                    if !node.is_leaf() {
561                        let first_child_bbox = *self.nodes[node.first_index as usize].aabb();
562                        let second_child_bbox = *self.nodes[node.first_index as usize + 1].aabb();
563                        self.nodes[*node_id as usize]
564                            .set_aabb(first_child_bbox.union(&second_child_bbox));
565                    }
566                }
567            });
568        }
569    }
570
571    /// Compute parents and update cache only if they have not already been computed
572    pub fn init_parents_if_uninit(&mut self) {
573        if self.parents.is_empty() {
574            self.update_parents();
575        }
576    }
577
578    /// Compute the mapping from a given node index to that node's parent for each node in the bvh and update local
579    /// cache.
580    pub fn update_parents(&mut self) {
581        Bvh2::compute_parents(&self.nodes, &mut self.parents);
582    }
583
584    /// Compute the mapping from a given node index to that node's parent for each node in the bvh, takes a Vec to allow
585    /// reusing the allocation.
586    pub fn compute_parents(nodes: &[Bvh2Node], parents: &mut Vec<u32>) {
587        parents.resize(nodes.len(), 0);
588
589        if nodes.is_empty() {
590            return;
591        }
592
593        parents[0] = 0;
594
595        #[cfg(not(feature = "parallel"))]
596        {
597            nodes.iter().enumerate().for_each(|(i, node)| {
598                if !node.is_leaf() {
599                    parents[node.first_index as usize] = i as u32;
600                    parents[node.first_index as usize + 1] = i as u32;
601                }
602            });
603        }
604        // Seems around 80% faster than compute_parents.
605        // TODO is there a better way to parallelize?
606        #[cfg(feature = "parallel")]
607        {
608            use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
609            use std::sync::atomic::Ordering;
610
611            let parents = crate::as_slice_of_atomic_u32(parents);
612            nodes.par_iter().enumerate().for_each(|(i, node)| {
613                if !node.is_leaf() {
614                    parents[node.first_index as usize].store(i as u32, Ordering::Relaxed);
615                    parents[node.first_index as usize + 1].store(i as u32, Ordering::Relaxed);
616                }
617            });
618        }
619    }
620
621    /// Compute compute_primitives_to_nodes and update cache only if they have not already been computed. Not supported
622    /// if using spatial splits as it would require a mapping from one primitive to multiple nodes.
623    pub fn init_primitives_to_nodes_if_uninit(&mut self) {
624        if self.primitives_to_nodes.is_empty() {
625            self.update_primitives_to_nodes();
626        }
627    }
628
629    /// Compute the mapping from primitive index to node index and update local cache. Not supported if using spatial
630    /// splits as it would require a mapping from one primitive to multiple nodes.
631    pub fn update_primitives_to_nodes(&mut self) {
632        if self.uses_spatial_splits {
633            log::warn!(
634                "Calculating primitives_to_nodes while using spatial splits is currently unsupported as it would \
635                require a mapping from one primitive to multiple nodes in `Bvh2::primitives_to_nodes`."
636            );
637        }
638
639        Bvh2::compute_primitives_to_nodes(
640            &self.nodes,
641            &self.primitive_indices,
642            &mut self.primitives_to_nodes,
643        );
644    }
645
646    /// Compute the mapping from primitive index to node index. Takes a Vec to allow reusing the allocation.
647    pub fn compute_primitives_to_nodes(
648        nodes: &[Bvh2Node],
649        primitive_indices: &[u32],
650        primitives_to_nodes: &mut Vec<u32>,
651    ) {
652        primitives_to_nodes.clear();
653        primitives_to_nodes.resize(primitive_indices.len(), INVALID);
654        for (node_id, node) in nodes.iter().enumerate() {
655            if node.is_leaf() {
656                let start = node.first_index;
657                let end = node.first_index + node.prim_count;
658                for node_prim_id in start..end {
659                    // TODO perf avoid this indirection by making self.primitive_indices optional?
660                    let prim_id = primitive_indices[node_prim_id as usize];
661                    primitives_to_nodes[prim_id as usize] = node_id as u32;
662                }
663            }
664        }
665    }
666
667    pub fn validate_parents(&self) {
668        self.nodes.iter().enumerate().for_each(|(i, node)| {
669            if !node.is_leaf() {
670                assert_eq!(self.parents[node.first_index as usize], i as u32);
671                assert_eq!(self.parents[node.first_index as usize + 1], i as u32);
672            }
673        });
674    }
675
676    pub fn validate_primitives_to_nodes(&self) {
677        self.primitives_to_nodes
678            .iter()
679            .enumerate()
680            .for_each(|(prim_id, node_id)| {
681                if *node_id != INVALID {
682                    let prim_id = prim_id as u32;
683                    let node = &self.nodes[*node_id as usize];
684                    assert!(node.is_leaf());
685                    let start = node.first_index;
686                    let end = node.first_index + node.prim_count;
687                    let mut found = false;
688                    for node_prim_id in start..end {
689                        if prim_id == self.primitive_indices[node_prim_id as usize] {
690                            found = true;
691                            break;
692                        }
693                    }
694                    assert!(found, "prim_id {prim_id} not found")
695                }
696            });
697    }
698
699    /// Refit the BVH working up the tree from this node, ignoring leaves. (TODO add a version that checks leaves)
700    /// This recomputes the Aabbs for all the parents of the given node index.
701    /// This can only be used to refit when a single node has changed or moved.
702    pub fn refit_from(&mut self, mut index: usize) {
703        self.init_parents_if_uninit();
704        loop {
705            let node = &self.nodes[index];
706            if !node.is_leaf() {
707                let first_child_bbox = *self.nodes[node.first_index as usize].aabb();
708                let second_child_bbox = *self.nodes[node.first_index as usize + 1].aabb();
709                self.nodes[index].set_aabb(first_child_bbox.union(&second_child_bbox));
710            }
711            if index == 0 {
712                break;
713            }
714            index = self.parents[index] as usize;
715        }
716    }
717
718    /// Refit the BVH working up the tree from this node, ignoring leaves.
719    /// This recomputes the Aabbs for the parents of the given node index.
720    /// Halts if the parents are the same size. Panics in debug if some parents still needed to be resized.
721    /// This can only be used to refit when a single node has changed or moved.
722    pub fn refit_from_fast(&mut self, mut index: usize) {
723        self.init_parents_if_uninit();
724        let mut same_count = 0;
725        loop {
726            let node = &self.nodes[index];
727            if !node.is_leaf() {
728                let first_child_bbox = self.nodes[node.first_index as usize].aabb();
729                let second_child_bbox = self.nodes[node.first_index as usize + 1].aabb();
730                let new_aabb = first_child_bbox.union(second_child_bbox);
731                let node = &mut self.nodes[index];
732                if node.aabb() == &new_aabb {
733                    same_count += 1;
734                    #[cfg(not(debug_assertions))]
735                    if same_count == 2 {
736                        return;
737                    }
738                } else {
739                    debug_assert!(
740                        same_count < 2,
741                        "Some parents still needed refitting. Unideal fitting is occurring somewhere."
742                    );
743                }
744                node.set_aabb(new_aabb);
745            }
746            if index == 0 {
747                break;
748            }
749            index = self.parents[index] as usize;
750        }
751    }
752
753    /// Update node aabb and refit the BVH working up the tree from this node.
754    #[inline]
755    pub fn resize_node(&mut self, node_id: usize, aabb: Aabb) {
756        self.nodes[node_id].set_aabb(aabb);
757        self.refit_from_fast(node_id);
758    }
759
760    /// Find if there might be a better spot in the BVH for this node and move it there. The id of the reinserted node
761    /// does not changed.
762    #[inline]
763    pub fn reinsert_node(&mut self, node_id: usize) {
764        if node_id == 0 {
765            return;
766        }
767        let reinsertion = find_reinsertion(self, node_id);
768        if reinsertion.area_diff > 0.0 {
769            reinsertion::reinsert_node(self, reinsertion.from as usize, reinsertion.to as usize);
770            self.children_are_ordered_after_parents = false;
771        }
772    }
773
774    /// Get the count of active primitive indices.
775    /// when primitives are removed they are added to the `primitive_indices_freelist` so the
776    /// self.primitive_indices.len() may not represent the actual number of valid, active primitive_indices.
777    #[inline(always)]
778    pub fn active_primitive_indices_count(&self) -> usize {
779        self.primitive_indices.len() - self.primitive_indices_freelist.len()
780    }
781
782    /// direct_layout: The primitives are already laid out in bvh.primitive_indices order.
783    /// tight_fit: Requires that children nodes and primitives fit tightly in parents. This is ignored for primitives
784    ///     if the bvh uses spatial splits (tight_fit can still be set to `true`). This was added for validating
785    ///     refit_all().
786    pub fn validate<T: Boundable>(
787        &self,
788        primitives: &[T],
789        direct_layout: bool,
790        tight_fit: bool,
791    ) -> Bvh2ValidationResult {
792        let mut result = Bvh2ValidationResult {
793            direct_layout,
794            require_tight_fit: tight_fit,
795            ..Default::default()
796        };
797
798        if self.nodes.is_empty() {
799            assert!(self.parents.is_empty());
800            assert!(self.primitives_to_nodes.is_empty());
801            return result;
802        }
803
804        if !self.primitives_to_nodes.is_empty() {
805            self.validate_primitives_to_nodes();
806        }
807
808        if !self.parents.is_empty() {
809            self.validate_parents();
810        }
811
812        if !self.nodes.is_empty() {
813            self.validate_impl::<T>(primitives, &mut result, 0, 0, 0);
814        }
815        assert_eq!(result.discovered_nodes.len(), self.nodes.len());
816        assert_eq!(result.node_count, self.nodes.len());
817
818        // Ignore primitive_indices if this is a direct layout
819        if !direct_layout {
820            if result.discovered_primitives.is_empty() {
821                assert!(self.active_primitive_indices_count() == 0)
822            } else {
823                if !self.uses_spatial_splits {
824                    // If the bvh uses splits, a primitive can show up in multiple leaf nodes so there wont be a 1 to 1
825                    // correlation between the number of discovered primitives and the quantity in bvh.primitive_indices.
826                    let active_indices_count = self.active_primitive_indices_count();
827                    assert_eq!(result.discovered_primitives.len(), active_indices_count);
828                    assert_eq!(result.prim_count, active_indices_count);
829                }
830                // Check that the set of discovered_primitives is the same as the set in primitive_indices while
831                // ignoring empty slots in primitive_indices.
832                let primitive_indices_freeset: HashSet<&u32> =
833                    HashSet::from_iter(&self.primitive_indices_freelist);
834                for (slot, index) in self.primitive_indices.iter().enumerate() {
835                    let slot = slot as u32;
836                    if !primitive_indices_freeset.contains(&slot) {
837                        assert!(result.discovered_primitives.contains(index));
838                    }
839                }
840                let primitive_indices_set: HashSet<&u32> =
841                    HashSet::from_iter(self.primitive_indices.iter().filter(|i| **i != INVALID));
842                for discovered_prim_id in &result.discovered_primitives {
843                    assert!(primitive_indices_set.contains(discovered_prim_id))
844                }
845            }
846        }
847        assert!(
848            result.max_depth < self.max_depth as u32,
849            "result.max_depth ({}) must be less than self.max_depth ({})",
850            result.max_depth,
851            self.max_depth as u32
852        );
853        if result.max_depth > DEFAULT_MAX_STACK_DEPTH as u32 {
854            log::warn!(
855                "bvh depth is: {}, a depth beyond {} may be indicative of something pathological in the scene (like thousands of instances perfectly overlapping geometry) that will result in a BVH that is very slow to traverse.",
856                result.max_depth,
857                DEFAULT_MAX_STACK_DEPTH
858            );
859        }
860
861        if self.children_are_ordered_after_parents {
862            // Assert that children are always ordered after parents in self.nodes
863            let mut temp_parents = vec![];
864            let parents = if self.parents.is_empty() {
865                Bvh2::compute_parents(&self.nodes, &mut temp_parents);
866                &temp_parents
867            } else {
868                &self.parents
869            };
870
871            for node_id in (1..self.nodes.len()).rev() {
872                assert!(parents[node_id] < node_id as u32);
873            }
874        }
875
876        result
877    }
878
879    pub fn validate_impl<T: Boundable>(
880        &self,
881        primitives: &[T],
882        result: &mut Bvh2ValidationResult,
883        node_index: u32,
884        parent_index: u32,
885        current_depth: u32,
886    ) {
887        result.max_depth = result.max_depth.max(current_depth);
888        let parent_aabb = self.nodes[parent_index as usize].aabb();
889        result.discovered_nodes.insert(node_index);
890        let node = &self.nodes[node_index as usize];
891        result.node_count += 1;
892
893        if let Some(count) = result.nodes_at_depth.get(&current_depth) {
894            result.nodes_at_depth.insert(current_depth, count + 1);
895        } else {
896            result.nodes_at_depth.insert(current_depth, 1);
897        }
898
899        assert!(
900            node.aabb().min.cmpge(parent_aabb.min).all()
901                && node.aabb().max.cmple(parent_aabb.max).all(),
902            "Child {} does not fit in parent {}:\nchild:  {:?}\nparent: {:?}",
903            node_index,
904            parent_index,
905            node.aabb(),
906            parent_aabb
907        );
908
909        if node.is_leaf() {
910            result.leaf_count += 1;
911            if let Some(count) = result.leaves_at_depth.get(&current_depth) {
912                result.leaves_at_depth.insert(current_depth, count + 1);
913            } else {
914                result.leaves_at_depth.insert(current_depth, 1);
915            }
916            let mut temp_aabb = Aabb::empty();
917            for i in 0..node.prim_count {
918                result.prim_count += 1;
919                let mut prim_index = (node.first_index + i) as usize;
920                if result.direct_layout {
921                    result.discovered_primitives.insert(prim_index as u32);
922                } else {
923                    result
924                        .discovered_primitives
925                        .insert(self.primitive_indices[prim_index]);
926                }
927                // If using splits, primitives will extend outside the leaf in some cases.
928                if !self.uses_spatial_splits {
929                    if !result.direct_layout {
930                        prim_index = self.primitive_indices[prim_index] as usize
931                    }
932                    let prim_aabb = primitives[prim_index].aabb();
933                    temp_aabb = temp_aabb.union(&prim_aabb);
934                    assert!(
935                        prim_aabb.min.cmpge(node.aabb().min).all()
936                            && prim_aabb.max.cmple(node.aabb().max).all(),
937                        "Primitive {} does not fit in parent {}:\nprimitive: {:?}\nparent:    {:?}",
938                        prim_index,
939                        parent_index,
940                        prim_aabb,
941                        node.aabb()
942                    );
943                }
944            }
945            if result.require_tight_fit && !self.uses_spatial_splits {
946                assert_eq!(
947                    temp_aabb,
948                    *node.aabb(),
949                    "Primitive do not fit in tightly in parent {node_index}",
950                );
951            }
952        } else {
953            if result.require_tight_fit {
954                let left_id = node.first_index as usize;
955                let right_id = node.first_index as usize + 1;
956                let left_child_aabb = &self.nodes[left_id];
957                let right_child_aabb = &self.nodes[right_id];
958
959                assert_eq!(
960                    left_child_aabb.aabb().union(right_child_aabb.aabb()),
961                    *node.aabb(),
962                    "Children {left_id} & {right_id} do not fit in tightly in parent {node_index}",
963                );
964            }
965
966            self.validate_impl::<T>(
967                primitives,
968                result,
969                node.first_index,
970                parent_index,
971                current_depth + 1,
972            );
973            self.validate_impl::<T>(
974                primitives,
975                result,
976                node.first_index + 1,
977                parent_index,
978                current_depth + 1,
979            );
980        }
981    }
982
983    /// Basic debug print illustrating the bvh layout
984    pub fn print_bvh(&self, node_index: usize, depth: usize) {
985        let node = &self.nodes[node_index];
986        if node.is_leaf() {
987            println!(
988                "{}{} leaf > {}",
989                " ".repeat(depth),
990                node_index,
991                node.first_index
992            )
993        } else {
994            println!(
995                "{}{} inner > {}, {}",
996                " ".repeat(depth),
997                node_index,
998                node.first_index,
999                node.first_index + 1
1000            );
1001            self.print_bvh(node.first_index as usize, depth + 1);
1002            self.print_bvh(node.first_index as usize + 1, depth + 1);
1003        }
1004    }
1005
1006    /// Get the maximum depth of the BVH from the given node
1007    pub fn depth(&self, node_index: usize) -> usize {
1008        let node = &self.nodes[node_index];
1009        if node.is_leaf() {
1010            1
1011        } else {
1012            1 + self
1013                .depth(node.first_index as usize)
1014                .max(self.depth((node.first_index + 1) as usize))
1015        }
1016    }
1017}
1018
1019/// Update the `primitives_to_nodes` mappings for primitives contained in `node_id`. Does nothing if primitives_to_nodes
1020/// is not already init.
1021// Not a member of Bvh2 because of borrow issues when a reference to other things like parents is also taken.
1022// Maybe could be cleaner as a macro?
1023#[inline]
1024fn update_primitives_to_nodes_for_node(
1025    node: &Bvh2Node,
1026    node_id: usize,
1027    primitive_indices: &[u32],
1028    primitives_to_nodes: &mut [u32],
1029) {
1030    if !primitives_to_nodes.is_empty() {
1031        let start = node.first_index;
1032        let end = start + node.prim_count;
1033        for node_prim_id in start..end {
1034            let direct_prim_id = primitive_indices[node_prim_id as usize];
1035            primitives_to_nodes[direct_prim_id as usize] = node_id as u32;
1036        }
1037    }
1038}
1039
1040/// Result of Bvh2 validation. Contains various bvh stats.
1041#[derive(Default)]
1042pub struct Bvh2ValidationResult {
1043    /// The primitives are already laid out in bvh.primitive_indices order.
1044    pub direct_layout: bool,
1045    /// Require validation to ensure aabbs tightly fit children and primitives.
1046    pub require_tight_fit: bool,
1047    /// Set of primitives discovered though validation traversal.
1048    pub discovered_primitives: HashSet<u32>,
1049    /// Set of nodes discovered though validation traversal.
1050    pub discovered_nodes: HashSet<u32>,
1051    /// Total number of nodes discovered though validation traversal.
1052    pub node_count: usize,
1053    /// Total number of leaves discovered though validation traversal.
1054    pub leaf_count: usize,
1055    /// Total number of primitives discovered though validation traversal.
1056    pub prim_count: usize,
1057    /// Maximum hierarchical BVH depth discovered though validation traversal.
1058    pub max_depth: u32,
1059    /// Quantity of nodes found at each depth though validation traversal.
1060    pub nodes_at_depth: HashMap<u32, u32>,
1061    /// Quantity of leaves found at each depth though validation traversal.
1062    pub leaves_at_depth: HashMap<u32, u32>,
1063}
1064
1065impl fmt::Display for Bvh2ValidationResult {
1066    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1067        writeln!(
1068            f,
1069            "GPU BVH Avg primitives/leaf: {:.3}",
1070            self.prim_count as f64 / self.leaf_count as f64
1071        )?;
1072
1073        writeln!(
1074            f,
1075            "\
1076node_count: {}
1077prim_count: {}
1078leaf_count: {}",
1079            self.node_count, self.prim_count, self.leaf_count
1080        )?;
1081
1082        writeln!(f, "Node & Leaf counts for each depth")?;
1083        for i in 0..=self.max_depth {
1084            writeln!(
1085                f,
1086                "{:<3} {:<10} {:<10}",
1087                i,
1088                self.nodes_at_depth.get(&i).unwrap_or(&0),
1089                self.leaves_at_depth.get(&i).unwrap_or(&0)
1090            )?;
1091        }
1092
1093        Ok(())
1094    }
1095}
1096
1097#[cfg(test)]
1098mod tests {
1099
1100    use glam::*;
1101
1102    use crate::{
1103        BvhBuildParams, Transformable,
1104        ploc::{PlocBuilder, PlocSearchDistance, SortPrecision},
1105        test_util::geometry::demoscene,
1106    };
1107
1108    use super::builder::build_bvh2_from_tris;
1109
1110    #[test]
1111    fn test_refit_all() {
1112        let mut tris = demoscene(32, 0);
1113        let mut aabbs = Vec::with_capacity(tris.len());
1114        let mut indices = Vec::with_capacity(tris.len());
1115        for (i, primitive) in tris.iter().enumerate() {
1116            indices.push(i as u32);
1117            aabbs.push(primitive.aabb());
1118        }
1119
1120        // Test without init_primitives_to_nodes & init_parents
1121        let mut bvh = PlocBuilder::new().build(
1122            PlocSearchDistance::VeryLow,
1123            &aabbs,
1124            indices.clone(),
1125            SortPrecision::U64,
1126            1,
1127        );
1128
1129        bvh.init_primitives_to_nodes_if_uninit();
1130        tris.transform(&Mat4::from_scale_rotation_translation(
1131            Vec3::splat(1.3),
1132            Quat::from_rotation_y(0.1),
1133            vec3(0.33, 0.3, 0.37),
1134        ));
1135        for (prim_id, tri) in tris.iter().enumerate() {
1136            bvh.nodes[bvh.primitives_to_nodes[prim_id] as usize].set_aabb(tri.aabb());
1137        }
1138
1139        bvh.refit_all();
1140
1141        bvh.validate(&tris, false, true);
1142    }
1143
1144    #[test]
1145    fn test_reinsert_node() {
1146        let tris = demoscene(32, 0);
1147
1148        let mut bvh = build_bvh2_from_tris(
1149            &tris,
1150            BvhBuildParams::fastest_build(),
1151            &mut Default::default(),
1152        );
1153
1154        bvh.init_primitives_to_nodes_if_uninit();
1155        bvh.init_parents_if_uninit();
1156
1157        for node_id in 1..bvh.nodes.len() {
1158            bvh.reinsert_node(node_id);
1159        }
1160
1161        bvh.validate(&tris, false, false);
1162    }
1163}