obvhs/cwbvh/
mod.rs

1//! An eight-way compressed wide BVH8 builder.
2
3pub mod builder;
4pub mod bvh2_to_cwbvh;
5pub mod node;
6
7#[cfg(all(
8    any(target_arch = "x86", target_arch = "x86_64"),
9    target_feature = "sse2"
10))]
11pub mod simd;
12pub mod traverse_macro;
13
14use std::{
15    collections::{HashMap, HashSet},
16    fmt,
17};
18
19use glam::{UVec2, UVec3, Vec3A, uvec2};
20use node::CwBvhNode;
21
22use crate::{
23    Boundable, PerComponent,
24    aabb::Aabb,
25    faststack::{FastStack, StackStack},
26    ray::{Ray, RayHit},
27};
28
29pub const BRANCHING: usize = 8;
30
31// Corresponds directly to the number of bit patterns created for child ordering
32const DIRECTIONS: usize = 8;
33
34const INVALID: u32 = u32::MAX;
35
36const NQ: u32 = 8;
37const NQ_SCALE: f32 = ((1 << NQ) - 1) as f32; //255.0
38const DENOM: f32 = 1.0 / NQ_SCALE; // 1.0 / 255.0
39
40/// A Compressed Wide BVH8
41#[derive(Clone, Default, PartialEq, Debug)]
42#[repr(C)]
43pub struct CwBvh {
44    pub nodes: Vec<CwBvhNode>,
45    pub primitive_indices: Vec<u32>,
46    pub total_aabb: Aabb,
47    pub exact_node_aabbs: Option<Vec<Aabb>>,
48
49    /// Indicates that this BVH is using spatial splits. Large triangles are split into multiple smaller Aabbs, so
50    /// primitives will extend outside the leaf in some cases.
51    /// If the bvh uses splits, a primitive can show up in multiple leaf nodes so there wont be a 1 to 1 correlation
52    /// between the total number of primitives in leaf nodes and in Bvh2::primitive_indices, vs the input triangles.
53    /// If spatial splits are used, some validation steps have to be skipped.
54    pub uses_spatial_splits: bool,
55}
56
57// BVH8's tend to be shallow. A stack of 32 would be very deep even for a large scene with no TLAS.
58// A BVH that deep would perform very slowly and would likely indicate that the geometry is degenerate in some way.
59// CwBvh::validate() will assert the CwBvh depth is less than TRAVERSAL_STACK_SIZE
60const TRAVERSAL_STACK_SIZE: usize = 32;
61
62/// Holds Ray traversal state to allow for dynamic traversal (yield on hit)
63pub struct RayTraversal {
64    pub stack: StackStack<UVec2, TRAVERSAL_STACK_SIZE>,
65    pub current_group: UVec2,
66    pub primitive_group: UVec2,
67    pub oct_inv4: u32,
68    pub ray: Ray,
69}
70
71impl RayTraversal {
72    #[inline(always)]
73    /// Reinitialize traversal state with new ray.
74    pub fn reinit(&mut self, ray: Ray) {
75        self.stack.clear();
76        self.current_group = uvec2(0, 0x80000000);
77        self.primitive_group = UVec2::ZERO;
78        self.oct_inv4 = ray_get_octant_inv4(&ray.direction);
79        self.ray = ray;
80    }
81}
82
83/// Holds traversal state to allow for dynamic traversal (yield on hit)
84pub struct Traversal {
85    pub stack: StackStack<UVec2, TRAVERSAL_STACK_SIZE>,
86    pub current_group: UVec2,
87    pub primitive_group: UVec2,
88    pub oct_inv4: u32,
89    pub traversal_direction: Vec3A,
90    pub primitive_id: u32,
91    pub hitmask: u32,
92}
93
94impl Default for Traversal {
95    fn default() -> Self {
96        Self {
97            stack: Default::default(),
98            current_group: uvec2(0, 0x80000000),
99            primitive_group: Default::default(),
100            oct_inv4: Default::default(),
101            traversal_direction: Default::default(),
102            primitive_id: Default::default(),
103            hitmask: Default::default(),
104        }
105    }
106}
107
108impl Traversal {
109    #[inline(always)]
110    /// Reinitialize traversal state with new traversal direction.
111    pub fn reinit(&mut self, traversal_direction: Vec3A) {
112        self.stack.clear();
113        self.current_group = uvec2(0, 0x80000000);
114        self.primitive_group = UVec2::ZERO;
115        self.oct_inv4 = ray_get_octant_inv4(&traversal_direction);
116        self.traversal_direction = traversal_direction;
117        self.primitive_id = 0;
118        self.hitmask = 0;
119    }
120}
121
122impl CwBvh {
123    #[inline(always)]
124    pub fn new_ray_traversal(&self, ray: Ray) -> RayTraversal {
125        //  BVH8's tend to be shallow. A stack of 32 would be very deep even for a large scene with no tlas.
126        let stack = StackStack::default();
127        let current_group = if self.nodes.is_empty() {
128            UVec2::ZERO
129        } else {
130            uvec2(0, 0x80000000)
131        };
132        let primitive_group = UVec2::ZERO;
133        let oct_inv4 = ray_get_octant_inv4(&ray.direction);
134
135        RayTraversal {
136            stack,
137            current_group,
138            primitive_group,
139            oct_inv4,
140            ray,
141        }
142    }
143
144    #[inline(always)]
145    /// traversal_direction is used to determine the order of bvh node child traversal. This would typically be the ray direction.
146    pub fn new_traversal(&self, traversal_direction: Vec3A) -> Traversal {
147        //  BVH8's tend to be shallow. A stack of 32 would be very deep even for a large scene with no tlas.
148        let stack = StackStack::default();
149        let current_group = if self.nodes.is_empty() {
150            UVec2::ZERO
151        } else {
152            uvec2(0, 0x80000000)
153        };
154        let primitive_group = UVec2::ZERO;
155        let oct_inv4 = ray_get_octant_inv4(&traversal_direction);
156        Traversal {
157            stack,
158            current_group,
159            primitive_group,
160            oct_inv4,
161            traversal_direction,
162            primitive_id: 0,
163            hitmask: 0,
164        }
165    }
166
167    /// Traverse the BVH, finding the closest hit.
168    /// Returns true if any primitive was hit.
169    pub fn ray_traverse<F: FnMut(&Ray, usize) -> f32>(
170        &self,
171        ray: Ray,
172        hit: &mut RayHit,
173        mut intersection_fn: F,
174    ) -> bool {
175        let mut traverse_ray = ray;
176        let mut state = self.new_traversal(ray.direction);
177        let mut node;
178        crate::traverse!(
179            self,
180            node,
181            state,
182            node.intersect_ray(&traverse_ray, state.oct_inv4),
183            {
184                let t = intersection_fn(&traverse_ray, state.primitive_id as usize);
185                if t < traverse_ray.tmax {
186                    hit.primitive_id = state.primitive_id;
187                    hit.t = t;
188                    traverse_ray.tmax = t;
189                }
190            }
191        );
192
193        // Alternatively (performance seems slightly slower):
194        // let mut state = self.new_ray_traversal(ray);
195        // while self.ray_traverse_dynamic(&mut state, hit, &mut intersection_fn) {}
196
197        hit.t < ray.tmax // Note this is valid since this does not mutate the ray
198    }
199
200    /// Traverse the bvh for a given `Ray`. Returns true if the ray missed all primitives.
201    pub fn ray_traverse_miss<F: FnMut(&Ray, usize) -> f32>(
202        &self,
203        ray: Ray,
204        mut intersection_fn: F,
205    ) -> bool {
206        let mut state = self.new_traversal(ray.direction);
207        let mut node;
208        let mut miss = true;
209        'outer: {
210            crate::traverse!(
211                self,
212                node,
213                state,
214                node.intersect_ray(&ray, state.oct_inv4),
215                {
216                    let t = intersection_fn(&ray, state.primitive_id as usize);
217                    if t < ray.tmax {
218                        miss = false;
219                        break 'outer;
220                    }
221                }
222            );
223        }
224        miss
225    }
226
227    /// Traverse the bvh for a given `Ray`. Intersects all primitives along ray (for things like evaluating transparency)
228    ///   intersection_fn is called for all intersections. Ray is not updated to allow for evaluating at every hit.
229    ///
230    /// # Arguments
231    /// * `ray` - The ray to be tested for intersection.
232    /// * `intersection_fn` - takes the given ray and primitive index.
233    pub fn ray_traverse_anyhit<F: FnMut(&Ray, usize)>(&self, ray: Ray, mut intersection_fn: F) {
234        let mut state = self.new_traversal(ray.direction);
235        let mut node;
236        crate::traverse!(
237            self,
238            node,
239            state,
240            node.intersect_ray(&ray, state.oct_inv4),
241            {
242                intersection_fn(&ray, state.primitive_id as usize);
243            }
244        );
245    }
246
247    /// Traverse the BVH
248    /// Yields at every primitive hit, returning true.
249    /// Returns false when no hit is found.
250    /// For basic miss test, just run until the first time it yields true.
251    /// For closest hit run until it returns false and check hit.t < ray.tmax to see if it hit something
252    /// For transparency, you want to hit every primitive in the ray's path, keeping track of the closest opaque hit.
253    ///     and then manually setting ray.tmax to that closest opaque hit at each iteration.
254    /// For best performance & customizability use the traverse! macro instead.
255    #[inline]
256    pub fn ray_traverse_dynamic<F: FnMut(&Ray, usize) -> f32>(
257        &self,
258        state: &mut RayTraversal,
259        hit: &mut RayHit,
260        mut intersection_fn: F,
261    ) -> bool {
262        loop {
263            // While the primitive group is not empty
264            while state.primitive_group.y != 0 {
265                let local_primitive_index = firstbithigh(state.primitive_group.y);
266
267                // Remove primitive from current_group
268                state.primitive_group.y &= !(1u32 << local_primitive_index);
269
270                let global_primitive_index = state.primitive_group.x + local_primitive_index;
271                let t = intersection_fn(&state.ray, global_primitive_index as usize);
272                if t < state.ray.tmax {
273                    hit.primitive_id = global_primitive_index;
274                    hit.t = t;
275                    state.ray.tmax = hit.t;
276                    // Yield when we hit a primitive
277                    return true;
278                }
279            }
280            state.primitive_group = UVec2::ZERO;
281
282            // If there's remaining nodes in the current group to check
283            if state.current_group.y & 0xff000000 != 0 {
284                let hits_imask = state.current_group.y;
285
286                let child_index_offset = firstbithigh(hits_imask);
287                let child_index_base = state.current_group.x;
288
289                // Remove node from current_group
290                state.current_group.y &= !(1u32 << child_index_offset);
291
292                // If the node group is not yet empty, push it on the stack
293                if state.current_group.y & 0xff000000 != 0 {
294                    state.stack.push(state.current_group);
295                }
296
297                let slot_index = (child_index_offset - 24) ^ (state.oct_inv4 & 0xff);
298                let relative_index = (hits_imask & !(0xffffffffu32 << slot_index)).count_ones();
299
300                let child_node_index = child_index_base + relative_index;
301
302                let node = &self.nodes[child_node_index as usize];
303
304                let hitmask = node.intersect_ray(&state.ray, state.oct_inv4);
305
306                state.current_group.x = node.child_base_idx;
307                state.primitive_group.x = node.primitive_base_idx;
308
309                state.current_group.y = (hitmask & 0xff000000) | (node.imask as u32);
310                state.primitive_group.y = hitmask & 0x00ffffff;
311            } else
312            // There's no nodes left in the current group
313            {
314                // Below is only needed when using triangle postponing, which would only be helpful on the
315                // GPU (it helps reduce thread divergence). Also, this isn't compatible with traversal yielding.
316                // state.primitive_group = state.current_group;
317                state.current_group = UVec2::ZERO;
318            }
319
320            // If there's no remaining nodes in the current group to check, pop it off the stack.
321            if state.primitive_group.y == 0 && (state.current_group.y & 0xff000000) == 0 {
322                // If the stack is empty, end traversal.
323                if state.stack.is_empty() {
324                    state.current_group.y = 0;
325                    break;
326                }
327
328                state.current_group = state.stack.pop_fast();
329            }
330        }
331
332        // Returns false when there are no more primitives to test.
333        // This doesn't mean we never hit one along the way though. (and yielded then)
334        false
335    }
336
337    /// This is currently mostly here just for reference. It's setup somewhat similarly to the GPU version,
338    /// reusing the same stack for both BLAS and TLAS traversal. It might be better to traverse separately on
339    /// the CPU using two instances of `Traversal` with `CwBvh::traverse_dynamic()` or the `traverse!` macro.
340    /// I haven't benchmarked this comparison yet. This example also does not take into account transforming
341    /// the ray into the local space of the blas instance. (but has comments denoting where this would happen)
342    pub fn ray_traverse_tlas_blas<F: FnMut(&Ray, usize, usize) -> f32>(
343        &self,
344        blas: &[CwBvh],
345        mut ray: Ray,
346        hit: &mut RayHit,
347        mut intersection_fn: F,
348    ) -> bool {
349        let mut stack: StackStack<UVec2, TRAVERSAL_STACK_SIZE> = StackStack::default();
350        let mut current_group;
351        let mut tlas_stack_size = INVALID; // tlas_stack_size is used to indicate whether we are in the TLAS or not.
352        let mut current_mesh = INVALID;
353        let mut bvh = self;
354
355        let oct_inv4 = ray_get_octant_inv4(&ray.direction);
356
357        current_group = uvec2(0, 0x80000000);
358
359        loop {
360            let mut primitive_group = UVec2::ZERO;
361
362            // If there's remaining nodes in the current group to check
363            if current_group.y & 0xff000000 != 0 {
364                let hits_imask = current_group.y;
365
366                let child_index_offset = firstbithigh(hits_imask);
367                let child_index_base = current_group.x;
368
369                // Remove node from current_group
370                current_group.y &= !(1u32 << child_index_offset);
371
372                // If the node group is not yet empty, push it on the stack
373                if current_group.y & 0xff000000 != 0 {
374                    stack.push(current_group);
375                }
376
377                let slot_index = (child_index_offset - 24) ^ (oct_inv4 & 0xff);
378                let relative_index = (hits_imask & !(0xffffffffu32 << slot_index)).count_ones();
379
380                let child_node_index = child_index_base + relative_index;
381
382                let node = &bvh.nodes[child_node_index as usize];
383
384                let hitmask = node.intersect_ray(&ray, oct_inv4);
385
386                current_group.x = node.child_base_idx;
387                primitive_group.x = node.primitive_base_idx;
388
389                current_group.y = (hitmask & 0xff000000) | (node.imask as u32);
390                primitive_group.y = hitmask & 0x00ffffff;
391            } else
392            // There's no nodes left in the current group
393            {
394                // Below is only needed when using triangle postponing, which would only be helpful on the
395                // GPU (it helps reduce thread divergence). Also, this isn't compatible with traversal yielding.
396                // primitive_group = current_group;
397                current_group = UVec2::ZERO;
398            }
399
400            // While the primitive group is not empty
401            while primitive_group.y != 0 {
402                // https://github.com/jan-van-bergen/GPU-Raytracer/issues/24#issuecomment-1042746566
403                // If tlas_stack_size is INVALID we are in the TLAS. This means use the primitive index as a mesh index.
404                // (TODO: The ray is transform according to the mesh transform and) traversal is continued at the root of the Mesh's BLAS.
405                if tlas_stack_size == INVALID {
406                    let local_primitive_index = firstbithigh(primitive_group.y);
407
408                    // Remove primitive from current_group
409                    primitive_group.y &= !(1u32 << local_primitive_index);
410
411                    let global_primitive_index = primitive_group.x + local_primitive_index;
412
413                    if primitive_group.y != 0 {
414                        stack.push(primitive_group);
415                    }
416
417                    if current_group.y & 0xff000000 != 0 {
418                        stack.push(current_group);
419                    }
420
421                    // The value of tlas_stack_size is now set to the current size of the traversal stack.
422                    tlas_stack_size = stack.len() as u32;
423
424                    // TODO transform ray according to the mesh transform
425                    // https://github.com/jan-van-bergen/GPU-Raytracer/blob/6559ae2241c8fdea0ddaec959fe1a47ec9b3ab0d/Src/CUDA/Raytracing/BVH8.h#L222
426
427                    // For primitives, we remap them to match the cwbvh indices layout. But for tlas
428                    // it would not be typically reasonable to reorder the blas and mesh buffers. So we
429                    // need to look up the original index using bvh.primitive_indices[].
430                    let blas_index = bvh.primitive_indices[global_primitive_index as usize];
431                    bvh = &blas[blas_index as usize];
432                    current_mesh = blas_index;
433
434                    // since we assign bvh = &blas[global_primitive_index as usize] above the index is just the first node at 0
435                    current_group = uvec2(0, 0x80000000);
436
437                    break;
438                } else {
439                    // If tlas_stack_size is any other value we are in the BLAS. This performs the usual primitive intersection.
440
441                    let local_primitive_index = firstbithigh(primitive_group.y);
442
443                    // Remove primitive from current_group
444                    primitive_group.y &= !(1u32 << local_primitive_index);
445
446                    let global_primitive_index = primitive_group.x + local_primitive_index;
447                    let t = intersection_fn(
448                        &ray,
449                        current_mesh as usize,
450                        global_primitive_index as usize,
451                    );
452
453                    if t < ray.tmax {
454                        hit.primitive_id = global_primitive_index;
455                        hit.geometry_id = current_mesh;
456                        ray.tmax = t;
457                    }
458                }
459            }
460
461            // If there's no remaining nodes in the current group to check, pop it off the stack.
462            if (current_group.y & 0xff000000) == 0 {
463                // If the stack is empty, end traversal.
464                if stack.is_empty() {
465                    #[allow(unused)]
466                    {
467                        current_group.y = 0;
468                    }
469                    break;
470                }
471
472                // The value of tlas_stack_size is used to determine when traversal of a BLAS is finished, and we should revert back to TLAS traversal.
473                if stack.len() as u32 == tlas_stack_size {
474                    tlas_stack_size = INVALID;
475                    current_mesh = INVALID;
476                    bvh = self;
477                    // TODO Reset Ray to untransformed version
478                    // https://github.com/jan-van-bergen/GPU-Raytracer/blob/6559ae2241c8fdea0ddaec959fe1a47ec9b3ab0d/Src/CUDA/Raytracing/BVH8.h#L262
479                }
480
481                current_group = stack.pop_fast();
482            }
483        }
484
485        if hit.primitive_id != u32::MAX {
486            hit.t = ray.tmax;
487            return true;
488        }
489
490        false
491    }
492
493    /// Returns the list of parents where `parent_index = parents[node_index]`
494    pub fn compute_parents(&self) -> Vec<u32> {
495        let mut parents = vec![0; self.nodes.len()];
496        parents[0] = 0;
497        self.nodes.iter().enumerate().for_each(|(i, node)| {
498            for ch in 0..8 {
499                if node.is_child_empty(ch) {
500                    continue;
501                }
502                if !node.is_leaf(ch) {
503                    parents[node.child_node_index(ch) as usize] = i as u32;
504                }
505            }
506        });
507        parents
508    }
509
510    /// Reorder the children of every BVH node. Arranges child nodes in Morton order according to their centroids
511    /// so that the order in which the intersected children are traversed can be determined by the ray octant.
512    /// This results in a slightly different order since the normal reordering during
513    /// building is using the aabb's from the Bvh2 and this uses the children node.p and node.e to compute the aabb. Traversal
514    /// seems to be a bit slower on some scenes and a bit faster on others. Note this will rearrange self.nodes. Anything that
515    /// depends on the order of self.nodes will need to be updated.
516    ///
517    /// # Arguments
518    /// * `primitives` - List of BVH primitives, implementing Boundable.
519    /// * `direct_layout` - The primitives are already laid out in bvh.primitive_indices order.
520    pub fn order_children<T: Boundable>(&mut self, primitives: &[T], direct_layout: bool) {
521        for i in 0..self.nodes.len() {
522            self.order_node_children(primitives, i, direct_layout);
523        }
524    }
525
526    /// Reorder the children of the given node_idx. Arranges child nodes in Morton order according to their centroids
527    /// so that the order in which the intersected children are traversed can be determined by the ray octant.
528    /// This results in a slightly different order since the normal reordering during
529    /// building is using the aabb's from the Bvh2 and this uses the children node.p and node.e to compute the aabb. Traversal
530    /// seems to be a bit slower on some scenes and a bit faster on others. Note this will rearrange self.nodes. Anything that
531    /// depends on the order of self.nodes will need to be updated.
532    ///
533    /// # Arguments
534    /// * `primitives` - List of BVH primitives, implementing Boundable.
535    /// * `node_idx` - Index of node to be reordered.
536    /// * `direct_layout` - The primitives are already laid out in bvh.primitive_indices order.
537    pub fn order_node_children<T: Boundable>(
538        &mut self,
539        primitives: &[T],
540        node_index: usize,
541        direct_layout: bool,
542    ) {
543        // TODO could this use ints and work in local node grid space?
544        // TODO support using exact_node_aabbs
545
546        let old_node = self.nodes[node_index];
547
548        const INVALID32: u32 = u32::MAX;
549        const INVALID_USIZE: usize = INVALID32 as usize;
550        let center = old_node.aabb().center();
551
552        let mut cost = [[f32::MAX; DIRECTIONS]; BRANCHING];
553
554        let mut child_count = 0;
555        let mut child_inner_count = 0;
556        for ch in 0..BRANCHING {
557            if !old_node.is_child_empty(ch) {
558                child_count += 1;
559                if !old_node.is_leaf(ch) {
560                    child_inner_count += 1;
561                }
562            }
563        }
564
565        let mut old_child_centers = [Vec3A::default(); 8];
566        for ch in 0..BRANCHING {
567            if old_node.is_child_empty(ch) {
568                continue;
569            }
570            if old_node.is_leaf(ch) {
571                let (child_prim_start, count) = old_node.child_primitives(ch);
572                let mut aabb = Aabb::empty();
573                for i in 0..count {
574                    let mut prim_index = (child_prim_start + i) as usize;
575                    if !direct_layout {
576                        prim_index = self.primitive_indices[prim_index] as usize;
577                    }
578                    aabb = aabb.union(&primitives[prim_index].aabb());
579                }
580                old_child_centers[ch] = aabb.center();
581            } else {
582                old_child_centers[ch] = self.nodes[old_node.child_node_index(ch) as usize]
583                    .aabb()
584                    .center();
585                let child_node_index = old_node.child_node_index(ch) as usize;
586                old_child_centers[ch] = self.node_aabb(child_node_index).center();
587            }
588        }
589
590        assert!(child_count <= BRANCHING);
591        assert!(cost.len() >= child_count);
592        // Fill cost table
593        for s in 0..DIRECTIONS {
594            let d = Vec3A::new(
595                if (s & 0b100) != 0 { -1.0 } else { 1.0 },
596                if (s & 0b010) != 0 { -1.0 } else { 1.0 },
597                if (s & 0b001) != 0 { -1.0 } else { 1.0 },
598            );
599            // We have to use BRANCHING here instead of child_count because the first slots wont be children if it was already reordered.
600            for ch in 0..BRANCHING {
601                if old_node.is_child_empty(ch) {
602                    continue;
603                }
604                let v = old_child_centers[ch] - center; //old_node.child_aabb(c).center() - center;
605                let cost_slot = unsafe { cost.get_unchecked_mut(ch).get_unchecked_mut(s) };
606                *cost_slot = d.dot(v); // No benefit from normalizing
607            }
608        }
609
610        let mut assignment = [INVALID_USIZE; BRANCHING];
611        let mut slot_filled = [false; DIRECTIONS];
612
613        // The paper suggests the auction method, but greedy is almost as good.
614        loop {
615            let mut min_cost = f32::MAX;
616
617            let mut min_slot = INVALID_USIZE;
618            let mut min_index = INVALID_USIZE;
619
620            // Find cheapest unfilled slot of any unassigned child
621            // We have to use BRANCHING here instead of child_count because the first slots wont be children if it was already reordered.
622            for ch in 0..BRANCHING {
623                if old_node.is_child_empty(ch) {
624                    continue;
625                }
626                if assignment[ch] == INVALID_USIZE {
627                    for (s, &slot_filled) in slot_filled.iter().enumerate() {
628                        let cost = unsafe { *cost.get_unchecked(ch).get_unchecked(s) };
629                        if !slot_filled && cost < min_cost {
630                            min_cost = cost;
631
632                            min_slot = s;
633                            min_index = ch;
634                        }
635                    }
636                }
637            }
638
639            if min_slot == INVALID_USIZE {
640                break;
641            }
642
643            slot_filled[min_slot] = true;
644            assignment[min_index] = min_slot;
645        }
646
647        let mut new_node = old_node;
648        new_node.imask = 0;
649
650        for ch in 0..BRANCHING {
651            new_node.child_meta[ch] = 0;
652        }
653
654        for ch in 0..BRANCHING {
655            if old_node.is_child_empty(ch) {
656                continue;
657            }
658            let new_ch = assignment[ch];
659            assert!(new_ch < BRANCHING);
660            if old_node.is_leaf(ch) {
661                new_node.child_meta[new_ch] = old_node.child_meta[ch];
662            } else {
663                new_node.imask |= 1 << new_ch;
664                new_node.child_meta[new_ch] = (24 + new_ch as u8) | 0b0010_0000;
665            }
666            new_node.child_min_x[new_ch] = old_node.child_min_x[ch];
667            new_node.child_max_x[new_ch] = old_node.child_max_x[ch];
668            new_node.child_min_y[new_ch] = old_node.child_min_y[ch];
669            new_node.child_max_y[new_ch] = old_node.child_max_y[ch];
670            new_node.child_min_z[new_ch] = old_node.child_min_z[ch];
671            new_node.child_max_z[new_ch] = old_node.child_max_z[ch];
672        }
673
674        if child_inner_count == 0 {
675            self.nodes[node_index] = new_node;
676            return;
677        }
678
679        let mut old_child_nodes = [CwBvhNode::default(); 8];
680        for ch in 0..BRANCHING {
681            if old_node.is_child_empty(ch) {
682                continue;
683            }
684            if old_node.is_leaf(ch) {
685                continue;
686            }
687            old_child_nodes[ch] = self.nodes[old_node.child_node_index(ch) as usize]
688        }
689
690        let old_child_exact_aabbs = if let Some(exact_node_aabbs) = &self.exact_node_aabbs {
691            let mut old_child_exact_aabbs = [Aabb::empty(); 8];
692            for ch in 0..BRANCHING {
693                if old_node.is_child_empty(ch) {
694                    continue;
695                }
696                if old_node.is_leaf(ch) {
697                    continue;
698                }
699                old_child_exact_aabbs[ch] =
700                    exact_node_aabbs[old_node.child_node_index(ch) as usize];
701            }
702            Some(old_child_exact_aabbs)
703        } else {
704            None
705        };
706
707        // check if this is really needed or if we can specify the offset in the child_meta out of order
708        for ch in 0..BRANCHING {
709            if old_node.is_child_empty(ch) {
710                continue;
711            }
712            if assignment[ch] == INVALID_USIZE {
713                continue;
714            }
715            let new_ch = assignment[ch];
716            assert_eq!(
717                !new_node.is_leaf(new_ch),
718                (new_node.child_meta[new_ch] & 0b11111) >= 24
719            );
720            if old_node.is_leaf(ch) {
721                continue;
722            }
723            let new_idx = new_node.child_node_index(new_ch) as usize;
724            self.nodes[new_idx] = old_child_nodes[ch];
725            if let Some(old_child_exact_aabbs) = &old_child_exact_aabbs
726                && let Some(exact_node_aabbs) = &mut self.exact_node_aabbs
727            {
728                exact_node_aabbs[new_idx] = old_child_exact_aabbs[ch];
729            }
730            assert!(new_idx >= old_node.child_base_idx as usize);
731            assert!(new_idx < old_node.child_base_idx as usize + child_inner_count);
732        }
733        self.nodes[node_index] = new_node;
734    }
735
736    /// Tries to use the exact node aabb if it is available, otherwise computes it from the compressed node min P and extent exponent.
737    #[inline(always)]
738    fn node_aabb(&self, node_index: usize) -> Aabb {
739        if let Some(exact_node_aabbs) = &self.exact_node_aabbs {
740            exact_node_aabbs[node_index]
741        } else {
742            self.nodes[node_index].aabb()
743        }
744    }
745
746    /// Direct layout: The primitives are already laid out in bvh.primitive_indices order.
747    pub fn validate<T: Boundable>(
748        &self,
749        primitives: &[T],
750        direct_layout: bool,
751    ) -> CwBvhValidationResult {
752        if !self.uses_spatial_splits {
753            // Could still check this if duplicated were removed from self.primitive_indices first
754            assert_eq!(self.primitive_indices.len(), primitives.len());
755        }
756        let mut result = CwBvhValidationResult {
757            direct_layout,
758            ..Default::default()
759        };
760        if !self.nodes.is_empty() {
761            self.validate_impl(0, Aabb::LARGEST, &mut result, primitives);
762        }
763        //self.print_nodes();
764
765        result.max_depth = self.calculate_max_depth(0, &mut result, 0);
766
767        if let Some(exact_node_aabbs) = &self.exact_node_aabbs {
768            for node in &self.nodes {
769                for ch in 0..8 {
770                    if !node.is_leaf(ch) {
771                        let child_node_index = node.child_node_index(ch) as usize;
772                        let comp_aabb = node.child_aabb(ch);
773                        let self_aabb = self.nodes[child_node_index].aabb();
774                        let exact_aabb = exact_node_aabbs[child_node_index];
775
776                        // TODO Could these bounds be tighter?
777                        assert!(exact_aabb.min.cmpge(comp_aabb.min - 1.0e-5).all());
778                        assert!(exact_aabb.max.cmple(comp_aabb.max + 1.0e-5).all());
779                        assert!(exact_aabb.min.cmpge(self_aabb.min - 1.0e-5).all());
780                        assert!(exact_aabb.max.cmple(self_aabb.max + 1.0e-5).all());
781                    }
782                }
783            }
784        }
785
786        assert_eq!(result.discovered_nodes.len(), self.nodes.len());
787        assert_eq!(
788            result.discovered_primitives.len(),
789            self.primitive_indices.len()
790        );
791        assert!(result.max_depth < TRAVERSAL_STACK_SIZE as u32);
792
793        result
794    }
795
796    fn validate_impl<T: Boundable>(
797        &self,
798        node_idx: usize,
799        parent_bounds: Aabb,
800        result: &mut CwBvhValidationResult,
801        primitives: &[T],
802    ) {
803        result.discovered_nodes.insert(node_idx as u32);
804        result.node_count += 1;
805
806        let node = &self.nodes[node_idx];
807
808        assert!(node.p.is_finite());
809        assert!(parent_bounds.min.is_finite());
810        assert!(parent_bounds.max.is_finite());
811        // TODO Could these bounds be tighter?
812        assert!(node.p.cmpge((parent_bounds.min - 1.0e-5).into()).all());
813        assert!(node.p.cmple((parent_bounds.max + 1.0e-5).into()).all());
814
815        let e: UVec3 = [
816            (node.e[0] as u32) << 23,
817            (node.e[1] as u32) << 23,
818            (node.e[2] as u32) << 23,
819        ]
820        .into();
821        let e: Vec3A = e.per_comp(f32::from_bits);
822
823        for ch in 0..8 {
824            let child_meta = node.child_meta[ch];
825            if child_meta == 0 {
826                assert!(node.is_child_empty(ch));
827                // Empty
828                continue;
829            }
830            assert!(!node.is_child_empty(ch));
831
832            result.child_count += 1;
833
834            let quantized_min = UVec3::new(
835                node.child_min_x[ch] as u32,
836                node.child_min_y[ch] as u32,
837                node.child_min_z[ch] as u32,
838            );
839            let quantized_max = UVec3::new(
840                node.child_max_x[ch] as u32,
841                node.child_max_y[ch] as u32,
842                node.child_max_z[ch] as u32,
843            );
844
845            assert_eq!(
846                Aabb::new(quantized_min.as_vec3a(), quantized_max.as_vec3a()),
847                node.local_child_aabb(ch)
848            );
849
850            let p = Vec3A::from(node.p);
851            let quantized_min = quantized_min.as_vec3a() * e + p;
852            let quantized_max = quantized_max.as_vec3a() * e + p;
853
854            assert_eq!(Aabb::new(quantized_min, quantized_max), node.child_aabb(ch));
855
856            let is_child_inner = (node.imask & (1 << ch)) != 0;
857            assert_eq!(is_child_inner, (child_meta & 0b11111) >= 24);
858
859            if is_child_inner {
860                assert!(!node.is_leaf(ch));
861                let slot_index = (child_meta & 0b11111) as usize - 24;
862                let relative_index =
863                    (node.imask as u32 & !(0xffffffffu32 << slot_index)).count_ones();
864                let child_node_idx = node.child_base_idx as usize + relative_index as usize;
865                self.validate_impl(
866                    child_node_idx,
867                    Aabb {
868                        min: quantized_min,
869                        max: quantized_max,
870                    }
871                    .intersection(&parent_bounds),
872                    result,
873                    primitives,
874                );
875            } else {
876                assert!(node.is_leaf(ch));
877                result.leaf_count += 1;
878
879                let first_prim = node.primitive_base_idx + (child_meta & 0b11111) as u32;
880                assert_eq!(first_prim, node.child_primitives(ch).0);
881                let mut prim_count = 0;
882                for i in 0..3 {
883                    if (child_meta & (0b1_00000 << i)) != 0 {
884                        result.discovered_primitives.insert(first_prim + i);
885                        result.prim_count += 1;
886                        prim_count += 1;
887                        let mut prim_index = (first_prim + i) as usize;
888                        if !result.direct_layout {
889                            prim_index = self.primitive_indices[prim_index] as usize;
890                        }
891                        let prim_aabb = primitives[prim_index].aabb();
892
893                        if !self.uses_spatial_splits {
894                            // TODO: option that correctly takes into account error of compressed triangle.
895                            // Maybe Boundable can return an epsilon, and for compressed triangles it
896                            // can take into account the edge length
897                            assert!(
898                                prim_aabb.min.cmpge(parent_bounds.min - 1.0e-5).all()
899                                    && prim_aabb.max.cmple(parent_bounds.max + 1.0e-5).all(),
900                                "Primitive {prim_index} does not fit in parent {node_idx}:\nprimitive: {prim_aabb:?}\nparent:    {parent_bounds:?}"
901                            );
902                        }
903                    }
904                }
905                assert_eq!(prim_count, node.child_primitives(ch).1);
906            }
907        }
908    }
909
910    /// Calculate the maximum depth of the BVH from this node down.
911    fn calculate_max_depth(
912        &self,
913        node_idx: usize,
914        result: &mut CwBvhValidationResult,
915        current_depth: u32,
916    ) -> u32 {
917        if self.nodes.is_empty() {
918            return 0;
919        }
920        let node = &self.nodes[node_idx];
921        let mut max_depth = current_depth;
922
923        if let Some(count) = result.nodes_at_depth.get(&current_depth) {
924            result.nodes_at_depth.insert(current_depth, count + 1);
925        } else {
926            result.nodes_at_depth.insert(current_depth, 1);
927        }
928
929        for ch in 0..8 {
930            let child_meta = node.child_meta[ch];
931            if child_meta == 0 {
932                // Empty
933                continue;
934            }
935
936            let is_child_inner = (node.imask & (1 << ch)) != 0;
937            assert_eq!(is_child_inner, (child_meta & 0b11111) >= 24);
938
939            if is_child_inner {
940                let slot_index = (child_meta & 0b11111) as usize - 24;
941                let relative_index =
942                    (node.imask as u32 & !(0xffffffffu32 << slot_index)).count_ones();
943                let child_node_idx = node.child_base_idx as usize + relative_index as usize;
944
945                let child_depth =
946                    self.calculate_max_depth(child_node_idx, result, current_depth + 1);
947
948                max_depth = max_depth.max(child_depth);
949            } else {
950                // Leaf
951                // max_depth = max_depth.max(current_depth + 1);
952
953                if let Some(count) = result.leaves_at_depth.get(&current_depth) {
954                    result.leaves_at_depth.insert(current_depth, count + 1);
955                } else {
956                    result.leaves_at_depth.insert(current_depth, 1);
957                }
958            }
959        }
960
961        max_depth
962    }
963
964    #[allow(dead_code)]
965    fn print_nodes(&self) {
966        for (i, node) in self.nodes.iter().enumerate() {
967            println!("node: {i}");
968            for ch in 0..8 {
969                let child_meta = node.child_meta[ch];
970                if child_meta == 0 {
971                    // Empty
972                    continue;
973                }
974
975                let is_child_inner = (node.imask & (1 << ch)) != 0;
976                assert_eq!(is_child_inner, (child_meta & 0b11111) >= 24);
977
978                if is_child_inner {
979                    println!("inner");
980                } else {
981                    // Leaf
982                    let mut prims = 0;
983                    for i in 0..3 {
984                        if (child_meta & (0b1_00000 << i)) != 0 {
985                            prims += 1;
986                        }
987                    }
988                    println!("leaf, prims: {prims}");
989                }
990            }
991        }
992    }
993}
994
995#[inline(always)]
996pub fn firstbithigh(value: u32) -> u32 {
997    31 - value.leading_zeros()
998}
999
1000#[inline(always)]
1001fn ray_get_octant_inv4(dir: &Vec3A) -> u32 {
1002    // Ray octant, encoded in 3 bits
1003    // let oct = (if dir.x < 0.0 { 0b100 } else { 0 })
1004    //     | (if dir.y < 0.0 { 0b010 } else { 0 })
1005    //     | (if dir.z < 0.0 { 0b001 } else { 0 });
1006    // return (7 - oct) * 0x01010101;
1007    (if dir.x < 0.0 { 0 } else { 0x04040404 }
1008        | if dir.y < 0.0 { 0 } else { 0x02020202 }
1009        | if dir.z < 0.0 { 0 } else { 0x01010101 })
1010}
1011
1012/// Result of CwBvh validation. Contains various bvh stats.
1013#[derive(Default)]
1014pub struct CwBvhValidationResult {
1015    /// The primitives are already laid out in bvh.primitive_indices order.
1016    pub direct_layout: bool,
1017    /// Set of primitives discovered though validation traversal.
1018    pub discovered_primitives: HashSet<u32>,
1019    /// Set of nodes discovered though validation traversal.
1020    pub discovered_nodes: HashSet<u32>,
1021    /// Total number of nodes discovered though validation traversal.
1022    pub node_count: usize,
1023    /// Total number of node children discovered though validation traversal.
1024    pub child_count: usize,
1025    /// Total number of leaves discovered though validation traversal.
1026    pub leaf_count: usize,
1027    /// Total number of primitives discovered though validation traversal.
1028    pub prim_count: usize,
1029    /// Maximum hierarchical BVH depth discovered though validation traversal.
1030    pub max_depth: u32,
1031    /// Quantity of nodes found at each depth though validation traversal.
1032    pub nodes_at_depth: HashMap<u32, u32>,
1033    /// Quantity of leaves found at each depth though validation traversal.
1034    pub leaves_at_depth: HashMap<u32, u32>,
1035}
1036
1037impl fmt::Display for CwBvhValidationResult {
1038    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1039        writeln!(
1040            f,
1041            "GPU BVH Avg children/node: {:.3}, primitives/leaf: {:.3}",
1042            self.child_count as f64 / self.node_count as f64,
1043            self.prim_count as f64 / self.leaf_count as f64
1044        )?;
1045
1046        writeln!(
1047            f,
1048            "\
1049child_count: {}
1050 node_count: {}
1051 prim_count: {}
1052 leaf_count: {}",
1053            self.child_count, self.node_count, self.prim_count, self.leaf_count
1054        )?;
1055
1056        writeln!(f, "Node & Leaf counts for each depth")?;
1057        for i in 0..=self.max_depth {
1058            writeln!(
1059                f,
1060                "{:<3} {:<10} {:<10}",
1061                i,
1062                self.nodes_at_depth.get(&i).unwrap_or(&0),
1063                self.leaves_at_depth.get(&i).unwrap_or(&0)
1064            )?;
1065        }
1066
1067        Ok(())
1068    }
1069}