obvhs/cwbvh/
node.rs

1use std::fmt::{self, Formatter};
2
3use crate::{aabb::Aabb, ray::Ray};
4use bytemuck::{Pod, Zeroable};
5use glam::{Vec3, Vec3A, vec3a};
6use std::fmt::Debug;
7
8use super::NQ_SCALE;
9
10/// A Compressed Wide BVH8 Node. repr(C), Pod, 80 bytes.
11// https://research.nvidia.com/sites/default/files/publications/ylitie2017hpg-paper.pdf
12#[derive(Clone, Copy, Default, PartialEq, Pod, Zeroable)]
13#[repr(C)]
14pub struct CwBvhNode {
15    /// Min point of node AABB
16    pub p: Vec3,
17
18    /// Exponent of child bounding box compression
19    /// Max point of node AABB could be calculated ex: `p.x + bitcast<f32>(e[0] << 23) * 255.0`
20    pub e: [u8; 3],
21
22    /// Bitmask indicating which children are internal nodes. 1 for internal, 0 for leaf
23    pub imask: u8,
24
25    /// Index of first child into `Vec<CwBvhNode>`
26    pub child_base_idx: u32,
27
28    /// Index of first primitive into primitive_indices `Vec<u32>`
29    pub primitive_base_idx: u32,
30
31    /// Meta data for each child
32    /// Empty child slot: The field is set to 00000000
33    ///
34    /// For leaves nodes: the low 5 bits store the primitive offset [0..24) from primitive_base_idx. And the high
35    /// 3 bits store the number of primitives in that leaf in a unary encoding.
36    /// A child leaf with 2 primitives with the first primitive starting at primitive_base_idx would be 0b01100000
37    /// A child leaf with 3 primitives with the first primitive starting at primitive_base_idx + 2 would be 0b11100010
38    /// A child leaf with 1 primitive with the first primitive starting at primitive_base_idx + 1 would be 0b00100001
39    ///
40    /// For internal nodes: The high 3 bits are set to 001 while the low 5 bits store the child slot index plus 24
41    /// i.e., the values range [24..32)
42    pub child_meta: [u8; 8],
43
44    // Note: deviation from the paper: the min&max are interleaved here.
45    /// Axis planes for each child.
46    /// The plane position could be calculated, for example, with `p.x + bitcast<f32>(e[0] << 23) * child_min_x[0]`
47    /// But in the actual intersection implementation the ray is transformed instead.
48    pub child_min_x: [u8; 8],
49    pub child_max_x: [u8; 8],
50    pub child_min_y: [u8; 8],
51    pub child_max_y: [u8; 8],
52    pub child_min_z: [u8; 8],
53    pub child_max_z: [u8; 8],
54}
55
56impl Debug for CwBvhNode {
57    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
58        f.debug_struct("CwBvhNode")
59            .field("p", &self.p)
60            .field("e", &self.e)
61            .field("imask", &format!("{:#010b}", &self.imask))
62            .field("child_base_idx", &self.child_base_idx)
63            .field("primitive_base_idx", &self.primitive_base_idx)
64            .field(
65                "child_meta",
66                &self
67                    .child_meta
68                    .iter()
69                    .map(|c| format!("{c:#010b}"))
70                    .collect::<Vec<_>>(),
71            )
72            .field("child_min_x", &self.child_min_x)
73            .field("child_max_x", &self.child_max_x)
74            .field("child_min_y", &self.child_min_y)
75            .field("child_max_y", &self.child_max_y)
76            .field("child_min_z", &self.child_min_z)
77            .field("child_max_z", &self.child_max_z)
78            .finish()
79    }
80}
81
82pub(crate) const EPSILON: f32 = 0.0001;
83
84impl CwBvhNode {
85    #[inline(always)]
86    pub fn intersect_ray(&self, ray: &Ray, oct_inv4: u32) -> u32 {
87        #[cfg(all(
88            any(target_arch = "x86", target_arch = "x86_64"),
89            target_feature = "sse2"
90        ))]
91        {
92            self.intersect_ray_simd(ray, oct_inv4)
93        }
94
95        #[cfg(not(all(
96            any(target_arch = "x86", target_arch = "x86_64"),
97            target_feature = "sse2"
98        )))]
99        {
100            self.intersect_ray_basic(ray, oct_inv4)
101        }
102    }
103
104    /// Intersects only one child at a time with the given ray. Limited simd usage on platforms that support it. Exists for reference & compatibility.
105    /// Traversal times with CwBvhNode::intersect_ray_simd take less than half the time vs intersect_ray_basic.
106    #[inline(always)]
107    pub fn intersect_ray_basic(&self, ray: &Ray, oct_inv4: u32) -> u32 {
108        let adjusted_ray_dir_inv = self.compute_extent() * ray.inv_direction;
109        let adjusted_ray_origin = (Vec3A::from(self.p) - ray.origin) * ray.inv_direction;
110
111        let mut hit_mask = 0;
112
113        let rdx = ray.direction.x < 0.0;
114        let rdy = ray.direction.y < 0.0;
115        let rdz = ray.direction.z < 0.0;
116
117        let (child_bits8, bit_index8) = self.get_child_and_index_bits(oct_inv4);
118
119        for child in 0..8 {
120            let q_lo_x = self.child_min_x[child];
121            let q_lo_y = self.child_min_y[child];
122            let q_lo_z = self.child_min_z[child];
123
124            let q_hi_x = self.child_max_x[child];
125            let q_hi_y = self.child_max_y[child];
126            let q_hi_z = self.child_max_z[child];
127
128            let x_min = if rdx { q_hi_x } else { q_lo_x };
129            let x_max = if rdx { q_lo_x } else { q_hi_x };
130            let y_min = if rdy { q_hi_y } else { q_lo_y };
131            let y_max = if rdy { q_lo_y } else { q_hi_y };
132            let z_min = if rdz { q_hi_z } else { q_lo_z };
133            let z_max = if rdz { q_lo_z } else { q_hi_z };
134
135            let mut tmin3 = vec3a(x_min as f32, y_min as f32, z_min as f32);
136            let mut tmax3 = vec3a(x_max as f32, y_max as f32, z_max as f32);
137
138            // Account for grid origin and scale
139            tmin3 = tmin3 * adjusted_ray_dir_inv + adjusted_ray_origin;
140            tmax3 = tmax3 * adjusted_ray_dir_inv + adjusted_ray_origin;
141
142            let tmin = tmin3.max_element().max(EPSILON); //ray.tmin?
143            let tmax = tmax3.min_element().min(ray.tmax);
144
145            let intersected = tmin <= tmax;
146            if intersected {
147                let child_bits = extract_byte64(child_bits8, child);
148                let bit_index = extract_byte64(bit_index8, child);
149                hit_mask |= child_bits << bit_index;
150            }
151        }
152
153        hit_mask
154    }
155
156    #[inline(always)]
157    pub fn intersect_aabb(&self, aabb: &Aabb, oct_inv4: u32) -> u32 {
158        let extent_rcp = 1.0 / self.compute_extent();
159        let p = Vec3A::from(self.p);
160
161        // Transform the query aabb into the node's local space
162        let adjusted_aabb = Aabb::new((aabb.min - p) * extent_rcp, (aabb.max - p) * extent_rcp);
163
164        let mut hit_mask = 0;
165
166        let (child_bits8, bit_index8) = self.get_child_and_index_bits(oct_inv4);
167
168        for child in 0..8 {
169            if self.local_child_aabb(child).intersect_aabb(&adjusted_aabb) {
170                let child_bits = extract_byte64(child_bits8, child);
171                let bit_index = extract_byte64(bit_index8, child);
172                hit_mask |= child_bits << bit_index;
173            }
174        }
175
176        hit_mask
177    }
178
179    #[inline(always)]
180    pub fn contains_point(&self, point: &Vec3A, oct_inv4: u32) -> u32 {
181        let extent_rcp = 1.0 / self.compute_extent();
182        let p = Vec3A::from(self.p);
183
184        // Transform the query point into the node's local space
185        let adjusted_point = (*point - p) * extent_rcp;
186
187        let mut hit_mask = 0;
188
189        let (child_bits8, bit_index8) = self.get_child_and_index_bits(oct_inv4);
190
191        for child in 0..8 {
192            if self.local_child_aabb(child).contains_point(adjusted_point) {
193                let child_bits = extract_byte64(child_bits8, child);
194                let bit_index = extract_byte64(bit_index8, child);
195                hit_mask |= child_bits << bit_index;
196            }
197        }
198
199        hit_mask
200    }
201
202    // TODO intersect frustum
203    // https://github.com/zeux/niagara/blob/bf90aa8c78e352d3b753b35553a3bcc8c65ef7a0/src/shaders/drawcull.comp.glsl#L71
204    // https://iquilezles.org/articles/frustumcorrect/
205
206    #[inline(always)]
207    pub fn get_child_and_index_bits(&self, oct_inv4: u32) -> (u64, u64) {
208        let mut oct_inv8 = oct_inv4 as u64;
209        oct_inv8 |= oct_inv8 << 32;
210        let meta8 = u64::from_le_bytes(self.child_meta);
211
212        // (meta8 & (meta8 << 1)) takes advantage of the offset indexing for inner nodes [24..32)
213        // [0b00011000..=0b00011111). For leaf nodes [0..24) these two bits (0b00011000) are never both set.
214        let inner_mask = 0b0001000000010000000100000001000000010000000100000001000000010000;
215        let is_inner8 = (meta8 & (meta8 << 1)) & inner_mask;
216
217        // 00010000 >> 4: 00000001, then 00000001 * 0xff: 11111111
218        let inner_mask8 = (is_inner8 >> 4) * 0xffu64;
219
220        // Each byte of bit_index8 contains the traversal priority, biased by 24, for internal nodes, and
221        // the triangle offset for leaf nodes. The bit index will later be used to shift the child bits.
222        let index_mask = 0b0001111100011111000111110001111100011111000111110001111100011111;
223        let bit_index8 = (meta8 ^ (oct_inv8 & inner_mask8)) & index_mask;
224
225        // For internal nodes child_bits8 will just be 1 in each byte, so that bit will then be shifted into the high
226        // byte of the node hit_mask (see CwBvhNode::intersect_ray). For leaf nodes it will have the unary encoded
227        // leaf primitive count and that will be shifted into the lower 24 bits of node hit_mask.
228        let child_mask = 0b0000011100000111000001110000011100000111000001110000011100000111;
229        let child_bits8 = (meta8 >> 5) & child_mask;
230        (child_bits8, bit_index8)
231    }
232
233    /// Get local child aabb position relative to the parent
234    #[inline(always)]
235    pub fn local_child_aabb(&self, child: usize) -> Aabb {
236        Aabb::new(
237            vec3a(
238                self.child_min_x[child] as f32,
239                self.child_min_y[child] as f32,
240                self.child_min_z[child] as f32,
241            ),
242            vec3a(
243                self.child_max_x[child] as f32,
244                self.child_max_y[child] as f32,
245                self.child_max_z[child] as f32,
246            ),
247        )
248    }
249
250    #[inline(always)]
251    pub fn child_aabb(&self, child: usize) -> Aabb {
252        let e = self.compute_extent();
253        let p: Vec3A = self.p.into();
254        let mut local_aabb = self.local_child_aabb(child);
255        local_aabb.min = local_aabb.min * e + p;
256        local_aabb.max = local_aabb.max * e + p;
257        local_aabb
258    }
259
260    #[inline(always)]
261    pub fn aabb(&self) -> Aabb {
262        let e = self.compute_extent();
263        let p: Vec3A = self.p.into();
264        Aabb::new(p, p + e * NQ_SCALE)
265    }
266
267    /// Convert stored extent exponent into float vector
268    #[inline(always)]
269    pub fn compute_extent(&self) -> Vec3A {
270        vec3a(
271            f32::from_bits((self.e[0] as u32) << 23),
272            f32::from_bits((self.e[1] as u32) << 23),
273            f32::from_bits((self.e[2] as u32) << 23),
274        )
275    }
276
277    // If the child is empty this will also return true. If needed also use CwBvh::is_child_empty().
278    #[inline(always)]
279    pub fn is_leaf(&self, child: usize) -> bool {
280        (self.imask & (1 << child)) == 0
281    }
282
283    #[inline(always)]
284    pub fn is_child_empty(&self, child: usize) -> bool {
285        self.child_meta[child] == 0
286    }
287
288    /// Returns the primitive starting index and primitive count for the given child.
289    #[inline(always)]
290    pub fn child_primitives(&self, child: usize) -> (u32, u32) {
291        let child_meta = self.child_meta[child];
292        let starting_index = self.primitive_base_idx + (self.child_meta[child] & 0b11111) as u32;
293        let primitive_count = (child_meta & 0b11100000).count_ones();
294        (starting_index, primitive_count)
295    }
296
297    /// Returns the node index of the given child.
298    #[inline(always)]
299    pub fn child_node_index(&self, child: usize) -> u32 {
300        let child_meta = self.child_meta[child];
301        let slot_index = (child_meta & 0b11111) as usize - 24;
302        let relative_index = (self.imask as u32 & !(0xffffffffu32 << slot_index)).count_ones();
303        self.child_base_idx + relative_index
304    }
305}
306
307#[inline(always)]
308pub fn extract_byte(x: u32, b: u32) -> u32 {
309    (x >> (b * 8)) & 0xFFu32
310}
311
312#[inline(always)]
313pub fn extract_byte64(x: u64, b: usize) -> u32 {
314    ((x >> (b * 8)) as u32) & 0xFFu32
315}