obvhs/bvh2/
leaf_collapser.rs

1use bytemuck::zeroed_vec;
2
3// Based on https://github.com/madmann91/bvh/blob/2fd0db62022993963a7343669275647cb073e19a/include/bvh/leaf_collapser.hpp
4#[cfg(feature = "parallel")]
5use rayon::{
6    iter::{
7        IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
8    },
9    slice::ParallelSliceMut,
10};
11
12#[cfg(feature = "parallel")]
13use std::sync::atomic::{AtomicU32, Ordering};
14
15use crate::bvh2::{Bvh2, Bvh2Node};
16
17/// Collapses leaves of the BVH according to the SAH. This optimization
18/// is only helpful for bottom-up builders, as top-down builders already
19/// have a termination criterion that prevents leaf creation when the SAH
20/// cost does not improve.
21pub fn collapse(bvh: &mut Bvh2, max_prims: u32, traversal_cost: f32) {
22    crate::scope!("collapse");
23    let nodes_qty = bvh.nodes.len();
24
25    if max_prims <= 1 || nodes_qty as u32 <= max_prims * 2 + 1 {
26        return;
27    }
28
29    if !bvh.primitive_indices.is_empty() && bvh.primitive_indices.len() as u32 <= max_prims {
30        return;
31    }
32
33    if bvh.nodes.is_empty() || bvh.nodes[0].is_leaf() {
34        return;
35    }
36
37    let previously_had_parents = !bvh.parents.is_empty();
38
39    bvh.init_parents_if_uninit();
40
41    let mut node_counts = vec![1u32; nodes_qty];
42    let mut prim_counts = vec![0u32; nodes_qty];
43    let node_count;
44
45    {
46        let node_counts = as_slice_of_sometimes_atomic_u32(&mut node_counts);
47        let prim_counts = as_slice_of_sometimes_atomic_u32(&mut prim_counts);
48
49        // Bottom-up traversal to collapse leaves
50        bottom_up_traverse(bvh, |leaf, i| {
51            if leaf {
52                prim_counts[i].set(bvh.nodes[i].prim_count);
53            } else {
54                let node = &bvh.nodes[i];
55                debug_assert!(!node.is_leaf());
56                let first_child = node.first_index as usize;
57
58                let left_count = prim_counts[first_child].get();
59                let right_count = prim_counts[first_child + 1].get();
60                let total_count = left_count + right_count;
61
62                // Compute the cost of collapsing this node when both children are leaves
63                if left_count > 0 && right_count > 0 && total_count <= max_prims {
64                    let left = bvh.nodes[first_child];
65                    let right = bvh.nodes[first_child + 1];
66                    let collapse_cost =
67                        node.aabb().half_area() * (total_count as f32 - traversal_cost);
68                    let base_cost = left.aabb().half_area() * left_count as f32
69                        + right.aabb().half_area() * right_count as f32;
70                    let both_have_same_prim =
71                        (left.first_index == right.first_index) && total_count == 2;
72
73                    // Collapse them if cost of the collapsed node is lower, or both children contain the same primitive (as a result of splits)
74                    if collapse_cost <= base_cost || both_have_same_prim {
75                        //if both_have_same_prim { 1 } else { total_count }; // TODO, Reduce total count (was showing artifacts)
76                        prim_counts[i].set(total_count);
77                        prim_counts[first_child].set(0);
78                        prim_counts[first_child + 1].set(0);
79                        node_counts[first_child].set(0);
80                        node_counts[first_child + 1].set(0);
81                    }
82                }
83            }
84        });
85    }
86
87    #[cfg(feature = "parallel")]
88    {
89        parallel_prefix_sum(&mut node_counts);
90        parallel_prefix_sum(&mut prim_counts);
91    }
92    #[cfg(not(feature = "parallel"))]
93    {
94        prefix_sum(&mut node_counts);
95        prefix_sum(&mut prim_counts);
96    }
97
98    let mut indices_copy = Vec::new();
99    let mut nodes_copy = Vec::new();
100    {
101        node_count = node_counts[bvh.nodes.len() - 1];
102        if prim_counts[0] > 0 {
103            // This means the root node has become a leaf.
104            // We avoid copying the data and just swap the old prim array with the new one.
105            bvh.nodes[0].first_index = 0;
106            bvh.nodes[0].prim_count = prim_counts[0];
107            std::mem::swap(&mut bvh.primitive_indices, &mut indices_copy);
108            std::mem::swap(&mut bvh.nodes, &mut nodes_copy);
109        } else {
110            nodes_copy = zeroed_vec(node_count as usize);
111            indices_copy = zeroed_vec(prim_counts[bvh.nodes.len() - 1] as usize);
112            nodes_copy[0] = bvh.nodes[0];
113            nodes_copy[0].first_index = node_counts[nodes_copy[0].first_index as usize - 1];
114        }
115    }
116
117    {
118        let indices_copy = as_slice_of_sometimes_atomic_u32(&mut indices_copy);
119
120        #[cfg(feature = "parallel")]
121        let mut needs_traversal = Vec::with_capacity(bvh.nodes.len().div_ceil(4));
122
123        #[allow(unused_mut)]
124        let mut top_down_traverse = |i| {
125            // Top-down traversal to store the prims contained in this subtree.
126            #[allow(clippy::unnecessary_cast)]
127            let i = i as usize;
128            let mut first_prim = prim_counts[i - 1];
129            let mut j = i;
130            loop {
131                let node = bvh.nodes[j];
132                if node.is_leaf() {
133                    for n in 0..node.prim_count {
134                        indices_copy[(first_prim + n) as usize]
135                            .set(bvh.primitive_indices[(node.first_index + n) as usize]);
136                    }
137
138                    first_prim += node.prim_count;
139                    while !Bvh2Node::is_left_sibling(j) && j != i {
140                        j = bvh.parents[j] as usize;
141                    }
142                    if j == i {
143                        break;
144                    }
145                    j = Bvh2Node::get_sibling_id(j);
146                } else {
147                    j = node.first_index as usize;
148                }
149            }
150        };
151
152        (1..bvh.nodes.len()).for_each(|i| {
153            let node_id = node_counts[i - 1] as usize;
154            if node_id == node_counts[i] as usize {
155                return;
156            }
157            nodes_copy[node_id] = bvh.nodes[i];
158            let first_prim = prim_counts[i - 1];
159            if first_prim == prim_counts[i] {
160                let first_child = &mut nodes_copy[node_id].first_index;
161                *first_child = node_counts[*first_child as usize - 1];
162            } else {
163                nodes_copy[node_id].prim_count = prim_counts[i] - first_prim;
164                nodes_copy[node_id].first_index = first_prim;
165                #[cfg(feature = "parallel")]
166                needs_traversal.push(i as u32);
167                #[cfg(not(feature = "parallel"))]
168                top_down_traverse(i);
169            }
170        });
171
172        #[cfg(feature = "parallel")]
173        needs_traversal.into_par_iter().for_each(top_down_traverse);
174    }
175
176    std::mem::swap(&mut bvh.nodes, &mut nodes_copy);
177    std::mem::swap(&mut bvh.primitive_indices, &mut indices_copy);
178
179    if previously_had_parents {
180        // If we had parents already computed before collapse we need to recompute them now
181        // TODO perf there might be a way to update this during collapse
182        bvh.update_parents();
183    } else {
184        // If not, skip the extra computation
185        bvh.parents.clear();
186    }
187    if !bvh.primitives_to_nodes.is_empty() {
188        // If primitives_to_nodes already existed we need to make sure it remains valid.
189        // TODO perf there might be a way to update this during collapse
190        bvh.update_primitives_to_nodes();
191    }
192}
193
194// Based on https://github.com/madmann91/bvh/blob/2fd0db62022993963a7343669275647cb073e19a/include/bvh/bottom_up_algorithm.hpp
195#[cfg(not(feature = "parallel"))]
196/// Caller must make sure Bvh2::parents is initialized
197fn bottom_up_traverse<F>(
198    bvh: &Bvh2,
199    mut process_node: F, // True is for leaf
200) where
201    F: FnMut(bool, usize),
202{
203    // Special case if the BVH is just a leaf
204    if bvh.nodes.len() == 1 {
205        process_node(true, 0);
206        return;
207    }
208
209    let mut flags: Vec<u8> = zeroed_vec(bvh.nodes.len());
210
211    // Iterate through all nodes starting from 1, since node 0 is assumed to be the root
212    (1..bvh.nodes.len()).for_each(|i| {
213        // Always start at leaf
214        if bvh.nodes[i].is_leaf() {
215            process_node(true, i);
216
217            // Process inner nodes on the path from that leaf up to the root
218            let mut j = i;
219            while j != 0 {
220                j = bvh.parents[j] as usize;
221
222                let flag = &mut flags[j];
223
224                // Make sure that the children of this inner node have been processed
225                let previous_flag = *flag;
226                *flag = previous_flag.saturating_add(1);
227                if previous_flag != 1 {
228                    break;
229                }
230                *flag = 0;
231
232                process_node(false, j);
233            }
234        }
235    });
236}
237
238// Based on https://github.com/madmann91/bvh/blob/2fd0db62022993963a7343669275647cb073e19a/include/bvh/bottom_up_algorithm.hpp
239// https://research.nvidia.com/sites/default/files/pubs/2012-06_Maximizing-Parallelism-in/karras2012hpg_paper.pdf
240// Paths from leaf nodes to the root are processed in parallel. Each thread starts from one leaf node and walks up the
241// tree using parent pointers. We track how many threads have visited each internal node using atomic counters—the first
242// thread terminates immediately while the second one gets to process the node. This way, each node is processed by
243// exactly one thread, which leads to O(n) time complexity.
244#[cfg(feature = "parallel")]
245/// Caller must make sure Bvh2::parents is initialized
246fn bottom_up_traverse<F>(
247    bvh: &Bvh2,
248    process_node: F, // True is for leaf
249) where
250    F: Fn(bool, usize) + Sync + Send,
251{
252    // Special case if the BVH is just a leaf
253
254    if bvh.nodes.len() == 1 {
255        process_node(true, 0);
256        return;
257    }
258
259    // Compiles down to just alloc_zeroed https://users.rust-lang.org/t/create-vector-of-atomicusize-etc/121695/5
260    let flags = vec![0u32; bvh.nodes.len()]
261        .into_iter()
262        .map(AtomicU32::new)
263        .collect::<Vec<_>>();
264
265    // Iterate through all nodes starting from 1, since node 0 is assumed to be the root
266    (1..bvh.nodes.len()).into_par_iter().for_each(|i| {
267        // Always start at leaf
268        if bvh.nodes[i].is_leaf() {
269            process_node(true, i);
270
271            // Process inner nodes on the path from that leaf up to the root
272            let mut j = i;
273            while j != 0 {
274                j = bvh.parents[j] as usize;
275
276                let flag = &flags[j];
277
278                // Make sure that the children of this inner node have been processed
279                if flag.fetch_add(1, Ordering::SeqCst) != 1 {
280                    break;
281                }
282                flag.store(0, Ordering::SeqCst);
283
284                process_node(false, j);
285            }
286        }
287    });
288}
289
290#[cfg(feature = "parallel")]
291fn parallel_prefix_sum<T>(data: &mut [T])
292where
293    T: std::ops::Add + std::ops::AddAssign + Send + Default + Clone + Copy,
294{
295    // Split into chunks
296    let chunk_size = 1.max(data.len().div_ceil(rayon::current_num_threads()));
297    let chunks = data.par_chunks_mut(chunk_size);
298    let mut partial_sums: Vec<T> = vec![Default::default(); chunks.len()];
299
300    // Compute local prefix sum in parallel
301    chunks
302        .zip(partial_sums.par_iter_mut())
303        .for_each(|(chunk, partial_sum)| *partial_sum = prefix_sum(chunk));
304
305    // Compute partial sums
306    prefix_sum(&mut partial_sums);
307
308    // Apply partial sums
309    data.par_chunks_mut(chunk_size)
310        .skip(1)
311        .zip(partial_sums)
312        .for_each(|(chunk, partial_sum)| chunk.iter_mut().for_each(move |n| *n += partial_sum));
313}
314
315#[inline]
316fn prefix_sum<T>(data: &mut [T]) -> T
317where
318    T: std::ops::Add + std::ops::AddAssign + Send + Default + Clone + Copy,
319{
320    let mut sum: T = Default::default();
321    data.iter_mut().for_each(|count| {
322        sum += *count;
323        *count = sum;
324    });
325    sum
326}
327
328pub struct SometimesAtomicU32 {
329    #[cfg(feature = "parallel")]
330    pub value: AtomicU32,
331    #[cfg(not(feature = "parallel"))]
332    pub value: u32,
333}
334
335impl SometimesAtomicU32 {
336    #[inline]
337    pub fn new(value: u32) -> SometimesAtomicU32 {
338        #[cfg(feature = "parallel")]
339        {
340            SometimesAtomicU32 {
341                value: AtomicU32::new(value),
342            }
343        }
344        #[cfg(not(feature = "parallel"))]
345        {
346            SometimesAtomicU32 { value }
347        }
348    }
349
350    #[inline]
351    #[cfg(feature = "parallel")]
352    pub fn set(&self, value: u32) {
353        self.value.store(value, Ordering::SeqCst);
354        #[cfg(not(feature = "parallel"))]
355        {
356            self.value = value;
357        }
358    }
359
360    #[inline]
361    #[cfg(not(feature = "parallel"))]
362    pub fn set(&mut self, value: u32) {
363        self.value = value;
364    }
365
366    #[inline]
367    pub fn get(&self) -> u32 {
368        #[cfg(feature = "parallel")]
369        {
370            self.value.load(Ordering::SeqCst)
371        }
372        #[cfg(not(feature = "parallel"))]
373        {
374            self.value
375        }
376    }
377}
378
379#[inline]
380fn as_slice_of_sometimes_atomic_u32(slice: &mut [u32]) -> &mut [SometimesAtomicU32] {
381    assert_eq!(size_of::<SometimesAtomicU32>(), size_of::<u32>());
382    assert_eq!(align_of::<SometimesAtomicU32>(), align_of::<u32>());
383    let atomic_slice: &mut [SometimesAtomicU32] = unsafe {
384        core::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut SometimesAtomicU32, slice.len())
385    };
386    // Alternatively:
387    //let slice: &mut [SometimesAtomicU32] = unsafe { &mut *((slice.as_mut_slice() as *mut [u32]) as *mut [SometimesAtomicU32]) };
388    atomic_slice
389}
390
391#[cfg(test)]
392mod tests {
393
394    use crate::{
395        ploc::{PlocBuilder, PlocSearchDistance, SortPrecision},
396        test_util::geometry::demoscene,
397    };
398
399    use super::*;
400
401    #[test]
402    fn test_collapse() {
403        let tris = demoscene(32, 0);
404        let mut aabbs = Vec::with_capacity(tris.len());
405        let mut indices = Vec::with_capacity(tris.len());
406        for (i, primitive) in tris.iter().enumerate() {
407            indices.push(i as u32);
408            aabbs.push(primitive.aabb());
409        }
410        {
411            // Test without init_primitives_to_nodes & init_parents
412            let mut bvh = PlocBuilder::new().build(
413                PlocSearchDistance::VeryLow,
414                &aabbs,
415                indices.clone(),
416                SortPrecision::U64,
417                1,
418            );
419            bvh.validate(&tris, false, false);
420            collapse(&mut bvh, 8, 1.0);
421            bvh.validate(&tris, false, false);
422        }
423        {
424            // Test with init_primitives_to_nodes & init_parents
425            let mut bvh = PlocBuilder::new().build(
426                PlocSearchDistance::VeryLow,
427                &aabbs,
428                indices,
429                SortPrecision::U64,
430                1,
431            );
432            bvh.validate(&tris, false, false);
433            bvh.init_primitives_to_nodes_if_uninit();
434            bvh.init_parents_if_uninit();
435            bvh.validate(&tris, false, false);
436            collapse(&mut bvh, 8, 1.0);
437            bvh.validate(&tris, false, false);
438        }
439    }
440}