obvhs/bvh2/
leaf_collapser.rs1use bytemuck::zeroed_vec;
2
3#[cfg(feature = "parallel")]
5use rayon::{
6 iter::{
7 IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
8 },
9 slice::ParallelSliceMut,
10};
11
12#[cfg(feature = "parallel")]
13use std::sync::atomic::{AtomicU32, Ordering};
14
15use crate::bvh2::{Bvh2, Bvh2Node};
16
17pub fn collapse(bvh: &mut Bvh2, max_prims: u32, traversal_cost: f32) {
22 crate::scope!("collapse");
23 let nodes_qty = bvh.nodes.len();
24
25 if max_prims <= 1 || nodes_qty as u32 <= max_prims * 2 + 1 {
26 return;
27 }
28
29 if !bvh.primitive_indices.is_empty() && bvh.primitive_indices.len() as u32 <= max_prims {
30 return;
31 }
32
33 if bvh.nodes.is_empty() || bvh.nodes[0].is_leaf() {
34 return;
35 }
36
37 let previously_had_parents = !bvh.parents.is_empty();
38
39 bvh.init_parents_if_uninit();
40
41 let mut node_counts = vec![1u32; nodes_qty];
42 let mut prim_counts = vec![0u32; nodes_qty];
43 let node_count;
44
45 {
46 let node_counts = as_slice_of_sometimes_atomic_u32(&mut node_counts);
47 let prim_counts = as_slice_of_sometimes_atomic_u32(&mut prim_counts);
48
49 bottom_up_traverse(bvh, |leaf, i| {
51 if leaf {
52 prim_counts[i].set(bvh.nodes[i].prim_count);
53 } else {
54 let node = &bvh.nodes[i];
55 debug_assert!(!node.is_leaf());
56 let first_child = node.first_index as usize;
57
58 let left_count = prim_counts[first_child].get();
59 let right_count = prim_counts[first_child + 1].get();
60 let total_count = left_count + right_count;
61
62 if left_count > 0 && right_count > 0 && total_count <= max_prims {
64 let left = bvh.nodes[first_child];
65 let right = bvh.nodes[first_child + 1];
66 let collapse_cost =
67 node.aabb().half_area() * (total_count as f32 - traversal_cost);
68 let base_cost = left.aabb().half_area() * left_count as f32
69 + right.aabb().half_area() * right_count as f32;
70 let both_have_same_prim =
71 (left.first_index == right.first_index) && total_count == 2;
72
73 if collapse_cost <= base_cost || both_have_same_prim {
75 prim_counts[i].set(total_count);
77 prim_counts[first_child].set(0);
78 prim_counts[first_child + 1].set(0);
79 node_counts[first_child].set(0);
80 node_counts[first_child + 1].set(0);
81 }
82 }
83 }
84 });
85 }
86
87 #[cfg(feature = "parallel")]
88 {
89 parallel_prefix_sum(&mut node_counts);
90 parallel_prefix_sum(&mut prim_counts);
91 }
92 #[cfg(not(feature = "parallel"))]
93 {
94 prefix_sum(&mut node_counts);
95 prefix_sum(&mut prim_counts);
96 }
97
98 let mut indices_copy = Vec::new();
99 let mut nodes_copy = Vec::new();
100 {
101 node_count = node_counts[bvh.nodes.len() - 1];
102 if prim_counts[0] > 0 {
103 bvh.nodes[0].first_index = 0;
106 bvh.nodes[0].prim_count = prim_counts[0];
107 std::mem::swap(&mut bvh.primitive_indices, &mut indices_copy);
108 std::mem::swap(&mut bvh.nodes, &mut nodes_copy);
109 } else {
110 nodes_copy = zeroed_vec(node_count as usize);
111 indices_copy = zeroed_vec(prim_counts[bvh.nodes.len() - 1] as usize);
112 nodes_copy[0] = bvh.nodes[0];
113 nodes_copy[0].first_index = node_counts[nodes_copy[0].first_index as usize - 1];
114 }
115 }
116
117 {
118 let indices_copy = as_slice_of_sometimes_atomic_u32(&mut indices_copy);
119
120 #[cfg(feature = "parallel")]
121 let mut needs_traversal = Vec::with_capacity(bvh.nodes.len().div_ceil(4));
122
123 #[allow(unused_mut)]
124 let mut top_down_traverse = |i| {
125 #[allow(clippy::unnecessary_cast)]
127 let i = i as usize;
128 let mut first_prim = prim_counts[i - 1];
129 let mut j = i;
130 loop {
131 let node = bvh.nodes[j];
132 if node.is_leaf() {
133 for n in 0..node.prim_count {
134 indices_copy[(first_prim + n) as usize]
135 .set(bvh.primitive_indices[(node.first_index + n) as usize]);
136 }
137
138 first_prim += node.prim_count;
139 while !Bvh2Node::is_left_sibling(j) && j != i {
140 j = bvh.parents[j] as usize;
141 }
142 if j == i {
143 break;
144 }
145 j = Bvh2Node::get_sibling_id(j);
146 } else {
147 j = node.first_index as usize;
148 }
149 }
150 };
151
152 (1..bvh.nodes.len()).for_each(|i| {
153 let node_id = node_counts[i - 1] as usize;
154 if node_id == node_counts[i] as usize {
155 return;
156 }
157 nodes_copy[node_id] = bvh.nodes[i];
158 let first_prim = prim_counts[i - 1];
159 if first_prim == prim_counts[i] {
160 let first_child = &mut nodes_copy[node_id].first_index;
161 *first_child = node_counts[*first_child as usize - 1];
162 } else {
163 nodes_copy[node_id].prim_count = prim_counts[i] - first_prim;
164 nodes_copy[node_id].first_index = first_prim;
165 #[cfg(feature = "parallel")]
166 needs_traversal.push(i as u32);
167 #[cfg(not(feature = "parallel"))]
168 top_down_traverse(i);
169 }
170 });
171
172 #[cfg(feature = "parallel")]
173 needs_traversal.into_par_iter().for_each(top_down_traverse);
174 }
175
176 std::mem::swap(&mut bvh.nodes, &mut nodes_copy);
177 std::mem::swap(&mut bvh.primitive_indices, &mut indices_copy);
178
179 if previously_had_parents {
180 bvh.update_parents();
183 } else {
184 bvh.parents.clear();
186 }
187 if !bvh.primitives_to_nodes.is_empty() {
188 bvh.update_primitives_to_nodes();
191 }
192}
193
194#[cfg(not(feature = "parallel"))]
196fn bottom_up_traverse<F>(
198 bvh: &Bvh2,
199 mut process_node: F, ) where
201 F: FnMut(bool, usize),
202{
203 if bvh.nodes.len() == 1 {
205 process_node(true, 0);
206 return;
207 }
208
209 let mut flags: Vec<u8> = zeroed_vec(bvh.nodes.len());
210
211 (1..bvh.nodes.len()).for_each(|i| {
213 if bvh.nodes[i].is_leaf() {
215 process_node(true, i);
216
217 let mut j = i;
219 while j != 0 {
220 j = bvh.parents[j] as usize;
221
222 let flag = &mut flags[j];
223
224 let previous_flag = *flag;
226 *flag = previous_flag.saturating_add(1);
227 if previous_flag != 1 {
228 break;
229 }
230 *flag = 0;
231
232 process_node(false, j);
233 }
234 }
235 });
236}
237
238#[cfg(feature = "parallel")]
245fn bottom_up_traverse<F>(
247 bvh: &Bvh2,
248 process_node: F, ) where
250 F: Fn(bool, usize) + Sync + Send,
251{
252 if bvh.nodes.len() == 1 {
255 process_node(true, 0);
256 return;
257 }
258
259 let flags = vec![0u32; bvh.nodes.len()]
261 .into_iter()
262 .map(AtomicU32::new)
263 .collect::<Vec<_>>();
264
265 (1..bvh.nodes.len()).into_par_iter().for_each(|i| {
267 if bvh.nodes[i].is_leaf() {
269 process_node(true, i);
270
271 let mut j = i;
273 while j != 0 {
274 j = bvh.parents[j] as usize;
275
276 let flag = &flags[j];
277
278 if flag.fetch_add(1, Ordering::SeqCst) != 1 {
280 break;
281 }
282 flag.store(0, Ordering::SeqCst);
283
284 process_node(false, j);
285 }
286 }
287 });
288}
289
290#[cfg(feature = "parallel")]
291fn parallel_prefix_sum<T>(data: &mut [T])
292where
293 T: std::ops::Add + std::ops::AddAssign + Send + Default + Clone + Copy,
294{
295 let chunk_size = 1.max(data.len().div_ceil(rayon::current_num_threads()));
297 let chunks = data.par_chunks_mut(chunk_size);
298 let mut partial_sums: Vec<T> = vec![Default::default(); chunks.len()];
299
300 chunks
302 .zip(partial_sums.par_iter_mut())
303 .for_each(|(chunk, partial_sum)| *partial_sum = prefix_sum(chunk));
304
305 prefix_sum(&mut partial_sums);
307
308 data.par_chunks_mut(chunk_size)
310 .skip(1)
311 .zip(partial_sums)
312 .for_each(|(chunk, partial_sum)| chunk.iter_mut().for_each(move |n| *n += partial_sum));
313}
314
315#[inline]
316fn prefix_sum<T>(data: &mut [T]) -> T
317where
318 T: std::ops::Add + std::ops::AddAssign + Send + Default + Clone + Copy,
319{
320 let mut sum: T = Default::default();
321 data.iter_mut().for_each(|count| {
322 sum += *count;
323 *count = sum;
324 });
325 sum
326}
327
328pub struct SometimesAtomicU32 {
329 #[cfg(feature = "parallel")]
330 pub value: AtomicU32,
331 #[cfg(not(feature = "parallel"))]
332 pub value: u32,
333}
334
335impl SometimesAtomicU32 {
336 #[inline]
337 pub fn new(value: u32) -> SometimesAtomicU32 {
338 #[cfg(feature = "parallel")]
339 {
340 SometimesAtomicU32 {
341 value: AtomicU32::new(value),
342 }
343 }
344 #[cfg(not(feature = "parallel"))]
345 {
346 SometimesAtomicU32 { value }
347 }
348 }
349
350 #[inline]
351 #[cfg(feature = "parallel")]
352 pub fn set(&self, value: u32) {
353 self.value.store(value, Ordering::SeqCst);
354 #[cfg(not(feature = "parallel"))]
355 {
356 self.value = value;
357 }
358 }
359
360 #[inline]
361 #[cfg(not(feature = "parallel"))]
362 pub fn set(&mut self, value: u32) {
363 self.value = value;
364 }
365
366 #[inline]
367 pub fn get(&self) -> u32 {
368 #[cfg(feature = "parallel")]
369 {
370 self.value.load(Ordering::SeqCst)
371 }
372 #[cfg(not(feature = "parallel"))]
373 {
374 self.value
375 }
376 }
377}
378
379#[inline]
380fn as_slice_of_sometimes_atomic_u32(slice: &mut [u32]) -> &mut [SometimesAtomicU32] {
381 assert_eq!(size_of::<SometimesAtomicU32>(), size_of::<u32>());
382 assert_eq!(align_of::<SometimesAtomicU32>(), align_of::<u32>());
383 let atomic_slice: &mut [SometimesAtomicU32] = unsafe {
384 core::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut SometimesAtomicU32, slice.len())
385 };
386 atomic_slice
389}
390
391#[cfg(test)]
392mod tests {
393
394 use crate::{
395 ploc::{PlocBuilder, PlocSearchDistance, SortPrecision},
396 test_util::geometry::demoscene,
397 };
398
399 use super::*;
400
401 #[test]
402 fn test_collapse() {
403 let tris = demoscene(32, 0);
404 let mut aabbs = Vec::with_capacity(tris.len());
405 let mut indices = Vec::with_capacity(tris.len());
406 for (i, primitive) in tris.iter().enumerate() {
407 indices.push(i as u32);
408 aabbs.push(primitive.aabb());
409 }
410 {
411 let mut bvh = PlocBuilder::new().build(
413 PlocSearchDistance::VeryLow,
414 &aabbs,
415 indices.clone(),
416 SortPrecision::U64,
417 1,
418 );
419 bvh.validate(&tris, false, false);
420 collapse(&mut bvh, 8, 1.0);
421 bvh.validate(&tris, false, false);
422 }
423 {
424 let mut bvh = PlocBuilder::new().build(
426 PlocSearchDistance::VeryLow,
427 &aabbs,
428 indices,
429 SortPrecision::U64,
430 1,
431 );
432 bvh.validate(&tris, false, false);
433 bvh.init_primitives_to_nodes_if_uninit();
434 bvh.init_parents_if_uninit();
435 bvh.validate(&tris, false, false);
436 collapse(&mut bvh, 8, 1.0);
437 bvh.validate(&tris, false, false);
438 }
439 }
440}