1use std::fmt::{self, Formatter};
2
3use crate::{aabb::Aabb, ray::Ray};
4use bytemuck::{Pod, Zeroable};
5use glam::{Vec3, Vec3A, vec3a};
6use std::fmt::Debug;
7
8use super::NQ_SCALE;
9
10#[derive(Clone, Copy, Default, PartialEq, Pod, Zeroable)]
13#[repr(C)]
14pub struct CwBvhNode {
15 pub p: Vec3,
17
18 pub e: [u8; 3],
21
22 pub imask: u8,
24
25 pub child_base_idx: u32,
27
28 pub primitive_base_idx: u32,
30
31 pub child_meta: [u8; 8],
43
44 pub child_min_x: [u8; 8],
49 pub child_max_x: [u8; 8],
50 pub child_min_y: [u8; 8],
51 pub child_max_y: [u8; 8],
52 pub child_min_z: [u8; 8],
53 pub child_max_z: [u8; 8],
54}
55
56impl Debug for CwBvhNode {
57 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
58 f.debug_struct("CwBvhNode")
59 .field("p", &self.p)
60 .field("e", &self.e)
61 .field("imask", &format!("{:#010b}", &self.imask))
62 .field("child_base_idx", &self.child_base_idx)
63 .field("primitive_base_idx", &self.primitive_base_idx)
64 .field(
65 "child_meta",
66 &self
67 .child_meta
68 .iter()
69 .map(|c| format!("{c:#010b}"))
70 .collect::<Vec<_>>(),
71 )
72 .field("child_min_x", &self.child_min_x)
73 .field("child_max_x", &self.child_max_x)
74 .field("child_min_y", &self.child_min_y)
75 .field("child_max_y", &self.child_max_y)
76 .field("child_min_z", &self.child_min_z)
77 .field("child_max_z", &self.child_max_z)
78 .finish()
79 }
80}
81
82pub(crate) const EPSILON: f32 = 0.0001;
83
84impl CwBvhNode {
85 #[inline(always)]
86 pub fn intersect_ray(&self, ray: &Ray, oct_inv4: u32) -> u32 {
87 #[cfg(all(
88 any(target_arch = "x86", target_arch = "x86_64"),
89 target_feature = "sse2"
90 ))]
91 {
92 self.intersect_ray_simd(ray, oct_inv4)
93 }
94
95 #[cfg(not(all(
96 any(target_arch = "x86", target_arch = "x86_64"),
97 target_feature = "sse2"
98 )))]
99 {
100 self.intersect_ray_basic(ray, oct_inv4)
101 }
102 }
103
104 #[inline(always)]
107 pub fn intersect_ray_basic(&self, ray: &Ray, oct_inv4: u32) -> u32 {
108 let adjusted_ray_dir_inv = self.compute_extent() * ray.inv_direction;
109 let adjusted_ray_origin = (Vec3A::from(self.p) - ray.origin) * ray.inv_direction;
110
111 let mut hit_mask = 0;
112
113 let rdx = ray.direction.x < 0.0;
114 let rdy = ray.direction.y < 0.0;
115 let rdz = ray.direction.z < 0.0;
116
117 let (child_bits8, bit_index8) = self.get_child_and_index_bits(oct_inv4);
118
119 for child in 0..8 {
120 let q_lo_x = self.child_min_x[child];
121 let q_lo_y = self.child_min_y[child];
122 let q_lo_z = self.child_min_z[child];
123
124 let q_hi_x = self.child_max_x[child];
125 let q_hi_y = self.child_max_y[child];
126 let q_hi_z = self.child_max_z[child];
127
128 let x_min = if rdx { q_hi_x } else { q_lo_x };
129 let x_max = if rdx { q_lo_x } else { q_hi_x };
130 let y_min = if rdy { q_hi_y } else { q_lo_y };
131 let y_max = if rdy { q_lo_y } else { q_hi_y };
132 let z_min = if rdz { q_hi_z } else { q_lo_z };
133 let z_max = if rdz { q_lo_z } else { q_hi_z };
134
135 let mut tmin3 = vec3a(x_min as f32, y_min as f32, z_min as f32);
136 let mut tmax3 = vec3a(x_max as f32, y_max as f32, z_max as f32);
137
138 tmin3 = tmin3 * adjusted_ray_dir_inv + adjusted_ray_origin;
140 tmax3 = tmax3 * adjusted_ray_dir_inv + adjusted_ray_origin;
141
142 let tmin = tmin3.max_element().max(EPSILON); let tmax = tmax3.min_element().min(ray.tmax);
144
145 let intersected = tmin <= tmax;
146 if intersected {
147 let child_bits = extract_byte64(child_bits8, child);
148 let bit_index = extract_byte64(bit_index8, child);
149 hit_mask |= child_bits << bit_index;
150 }
151 }
152
153 hit_mask
154 }
155
156 #[inline(always)]
157 pub fn intersect_aabb(&self, aabb: &Aabb, oct_inv4: u32) -> u32 {
158 let extent_rcp = 1.0 / self.compute_extent();
159 let p = Vec3A::from(self.p);
160
161 let adjusted_aabb = Aabb::new((aabb.min - p) * extent_rcp, (aabb.max - p) * extent_rcp);
163
164 let mut hit_mask = 0;
165
166 let (child_bits8, bit_index8) = self.get_child_and_index_bits(oct_inv4);
167
168 for child in 0..8 {
169 if self.local_child_aabb(child).intersect_aabb(&adjusted_aabb) {
170 let child_bits = extract_byte64(child_bits8, child);
171 let bit_index = extract_byte64(bit_index8, child);
172 hit_mask |= child_bits << bit_index;
173 }
174 }
175
176 hit_mask
177 }
178
179 #[inline(always)]
180 pub fn contains_point(&self, point: &Vec3A, oct_inv4: u32) -> u32 {
181 let extent_rcp = 1.0 / self.compute_extent();
182 let p = Vec3A::from(self.p);
183
184 let adjusted_point = (*point - p) * extent_rcp;
186
187 let mut hit_mask = 0;
188
189 let (child_bits8, bit_index8) = self.get_child_and_index_bits(oct_inv4);
190
191 for child in 0..8 {
192 if self.local_child_aabb(child).contains_point(adjusted_point) {
193 let child_bits = extract_byte64(child_bits8, child);
194 let bit_index = extract_byte64(bit_index8, child);
195 hit_mask |= child_bits << bit_index;
196 }
197 }
198
199 hit_mask
200 }
201
202 #[inline(always)]
207 pub fn get_child_and_index_bits(&self, oct_inv4: u32) -> (u64, u64) {
208 let mut oct_inv8 = oct_inv4 as u64;
209 oct_inv8 |= oct_inv8 << 32;
210 let meta8 = u64::from_le_bytes(self.child_meta);
211
212 let inner_mask = 0b0001000000010000000100000001000000010000000100000001000000010000;
215 let is_inner8 = (meta8 & (meta8 << 1)) & inner_mask;
216
217 let inner_mask8 = (is_inner8 >> 4) * 0xffu64;
219
220 let index_mask = 0b0001111100011111000111110001111100011111000111110001111100011111;
223 let bit_index8 = (meta8 ^ (oct_inv8 & inner_mask8)) & index_mask;
224
225 let child_mask = 0b0000011100000111000001110000011100000111000001110000011100000111;
229 let child_bits8 = (meta8 >> 5) & child_mask;
230 (child_bits8, bit_index8)
231 }
232
233 #[inline(always)]
235 pub fn local_child_aabb(&self, child: usize) -> Aabb {
236 Aabb::new(
237 vec3a(
238 self.child_min_x[child] as f32,
239 self.child_min_y[child] as f32,
240 self.child_min_z[child] as f32,
241 ),
242 vec3a(
243 self.child_max_x[child] as f32,
244 self.child_max_y[child] as f32,
245 self.child_max_z[child] as f32,
246 ),
247 )
248 }
249
250 #[inline(always)]
251 pub fn child_aabb(&self, child: usize) -> Aabb {
252 let e = self.compute_extent();
253 let p: Vec3A = self.p.into();
254 let mut local_aabb = self.local_child_aabb(child);
255 local_aabb.min = local_aabb.min * e + p;
256 local_aabb.max = local_aabb.max * e + p;
257 local_aabb
258 }
259
260 #[inline(always)]
261 pub fn aabb(&self) -> Aabb {
262 let e = self.compute_extent();
263 let p: Vec3A = self.p.into();
264 Aabb::new(p, p + e * NQ_SCALE)
265 }
266
267 #[inline(always)]
269 pub fn compute_extent(&self) -> Vec3A {
270 vec3a(
271 f32::from_bits((self.e[0] as u32) << 23),
272 f32::from_bits((self.e[1] as u32) << 23),
273 f32::from_bits((self.e[2] as u32) << 23),
274 )
275 }
276
277 #[inline(always)]
279 pub fn is_leaf(&self, child: usize) -> bool {
280 (self.imask & (1 << child)) == 0
281 }
282
283 #[inline(always)]
284 pub fn is_child_empty(&self, child: usize) -> bool {
285 self.child_meta[child] == 0
286 }
287
288 #[inline(always)]
290 pub fn child_primitives(&self, child: usize) -> (u32, u32) {
291 let child_meta = self.child_meta[child];
292 let starting_index = self.primitive_base_idx + (self.child_meta[child] & 0b11111) as u32;
293 let primitive_count = (child_meta & 0b11100000).count_ones();
294 (starting_index, primitive_count)
295 }
296
297 #[inline(always)]
299 pub fn child_node_index(&self, child: usize) -> u32 {
300 let child_meta = self.child_meta[child];
301 let slot_index = (child_meta & 0b11111) as usize - 24;
302 let relative_index = (self.imask as u32 & !(0xffffffffu32 << slot_index)).count_ones();
303 self.child_base_idx + relative_index
304 }
305}
306
307#[inline(always)]
308pub fn extract_byte(x: u32, b: u32) -> u32 {
309 (x >> (b * 8)) & 0xFFu32
310}
311
312#[inline(always)]
313pub fn extract_byte64(x: u64, b: usize) -> u32 {
314 ((x >> (b * 8)) as u32) & 0xFFu32
315}