obvhs/ploc/
mod.rs

1//! PLOC (Parallel, Locally Ordered Clustering) BVH 2 Builder.
2
3pub mod morton;
4pub mod rebuild;
5
6// https://madmann91.github.io/2021/05/05/ploc-revisited.html
7// https://github.com/meistdan/ploc/
8// https://meistdan.github.io/publications/ploc/paper.pdf
9// https://github.com/madmann91/bvh/blob/v1/include/bvh/locally_ordered_clustering_builder.hpp
10
11use std::{f32, mem};
12
13use bytemuck::{Pod, Zeroable, cast_slice_mut, zeroed_vec};
14use glam::DVec3;
15use rdst::RadixKey;
16
17#[cfg(feature = "parallel")]
18use rayon::{
19    iter::{
20        IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator,
21        ParallelIterator,
22    },
23    slice::ParallelSliceMut,
24};
25
26#[cfg(not(feature = "parallel"))]
27use rdst::RadixSort;
28
29use crate::bvh2::DEFAULT_MAX_STACK_DEPTH;
30use crate::ploc::morton::{morton_encode_u64_unorm, morton_encode_u128_unorm};
31use crate::{Boundable, bvh2::node::Bvh2Node};
32use crate::{aabb::Aabb, bvh2::Bvh2};
33
34#[derive(Clone)]
35pub struct PlocBuilder {
36    pub current_nodes: Vec<Bvh2Node>,
37    pub next_nodes: Vec<Bvh2Node>,
38
39    // Enough space/align for Morton64 or Morton128. If this is updated make sure to also update anything that uses it.
40    // As things depend on it being exactly Vec<[u128; 2]>
41    pub mortons: Vec<[u128; 2]>,
42
43    #[cfg(feature = "parallel")]
44    pub local_aabbs: Vec<Aabb>,
45}
46
47impl Default for PlocBuilder {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl PlocBuilder {
54    /// Initialize a ploc builder. After initial building, keep around this builder to reuse the associated allocations.
55    pub fn new() -> PlocBuilder {
56        crate::scope!("preallocate_builder");
57        PlocBuilder {
58            current_nodes: Vec::new(),
59            next_nodes: Vec::new(),
60            mortons: Vec::new(),
61
62            #[cfg(feature = "parallel")]
63            local_aabbs: Vec::new(),
64        }
65    }
66
67    /// Initialize a ploc builder with pre-allocated capacity for building a bvh with prim_count.
68    /// After initial building, keep around this builder to reuse the associated allocations.
69    pub fn with_capacity(prim_count: usize) -> PlocBuilder {
70        crate::scope!("preallocate_builder");
71        PlocBuilder {
72            current_nodes: zeroed_vec(prim_count),
73            next_nodes: zeroed_vec(prim_count),
74            mortons: zeroed_vec(prim_count),
75
76            #[cfg(feature = "parallel")]
77            local_aabbs: zeroed_vec(128),
78        }
79    }
80
81    /// # Arguments
82    /// * `search_distance` - Which search distance should be used when building the ploc.
83    /// * `aabbs` - A list of bounding boxes. Should correspond to the number and order of primitives.
84    /// * `indices` - The list indices used to index into the list of primitives. This allows for
85    ///   flexibility in which primitives are included in the bvh and in what order they are referenced.
86    ///   Often this would just be equivalent to: (0..aabbs.len() as u32).collect::<Vec<_>>()
87    /// * `sort_precision` - Bits used for ploc radix sort. More bits results in a more accurate but slower sort.
88    /// * `search_depth_threshold` - Below this depth a search distance of 1 will be used. Set to 0 to bypass and
89    ///   just use PlocSearchDistance. When trying to optimize build time it can be beneficial to limit the search
90    ///   distance for the first few passes as that is when the largest number of primitives are being considered.
91    ///   This pairs are initially found more quickly since it doesn't need to search as far, and they are also
92    ///   found more often, since the nodes need to both agree to become paired. This also seems to occasionally
93    ///   result in an overall better bvh structure.
94    #[inline]
95    pub fn build<T: Boundable>(
96        &mut self,
97        search_distance: PlocSearchDistance,
98        aabbs: &[T],
99        indices: Vec<u32>,
100        sort_precision: SortPrecision,
101        search_depth_threshold: usize,
102    ) -> Bvh2 {
103        let mut bvh = Bvh2::zeroed(aabbs.len());
104        self.build_with_bvh(
105            &mut bvh,
106            search_distance,
107            aabbs,
108            indices,
109            sort_precision,
110            search_depth_threshold,
111        );
112        bvh
113    }
114
115    /// # Arguments
116    /// * `bvh` - An existing bvh. The builder will clear this bvh and reuse its allocations.
117    /// * `search_distance` - Which search distance should be used when building the ploc.
118    /// * `aabbs` - A list of bounding boxes. Should correspond to the number and order of primitives.
119    /// * `indices` - The list indices used to index into the list of primitives. This allows for
120    ///   flexibility in which primitives are included in the bvh and in what order they are referenced.
121    ///   Often this would just be equivalent to: (0..aabbs.len() as u32).collect::<Vec<_>>()
122    /// * `sort_precision` - Bits used for ploc radix sort. More bits results in a more accurate but slower sort.
123    /// * `search_depth_threshold` - Below this depth a search distance of 1 will be used. Set to 0 to bypass and
124    ///   just use PlocSearchDistance. When trying to optimize build time it can be beneficial to limit the search
125    ///   distance for the first few passes as that is when the largest number of primitives are being considered.
126    ///   This pairs are initially found more quickly since it doesn't need to search as far, and they are also
127    ///   found more often, since the nodes need to both agree to become paired. This also seems to occasionally
128    ///   result in an overall better bvh structure.
129    pub fn build_with_bvh<T: Boundable>(
130        &mut self,
131        bvh: &mut Bvh2,
132        search_distance: PlocSearchDistance,
133        aabbs: &[T],
134        indices: Vec<u32>,
135        sort_precision: SortPrecision,
136        search_depth_threshold: usize,
137    ) {
138        let search_thresh = search_depth_threshold;
139        match search_distance {
140            PlocSearchDistance::Minimum => {
141                self.build_ploc::<1, T>(bvh, aabbs, indices, sort_precision, search_thresh)
142            }
143            PlocSearchDistance::VeryLow => {
144                self.build_ploc::<2, T>(bvh, aabbs, indices, sort_precision, search_thresh)
145            }
146            PlocSearchDistance::Low => {
147                self.build_ploc::<6, T>(bvh, aabbs, indices, sort_precision, search_thresh)
148            }
149            PlocSearchDistance::Medium => {
150                self.build_ploc::<14, T>(bvh, aabbs, indices, sort_precision, search_thresh)
151            }
152            PlocSearchDistance::High => {
153                self.build_ploc::<24, T>(bvh, aabbs, indices, sort_precision, search_thresh)
154            }
155            PlocSearchDistance::VeryHigh => {
156                self.build_ploc::<32, T>(bvh, aabbs, indices, sort_precision, search_thresh)
157            }
158        }
159    }
160
161    /// # Arguments
162    /// * `bvh` - An existing bvh. The builder will clear this bvh and reuse its allocations.
163    /// * `aabbs` - A list of bounding boxes. Should correspond to the number and order of primitives.
164    /// * `sort_precision` - Bits used for ploc radix sort. More bits results in a more accurate but slower sort.
165    /// * `search_depth_threshold` - Below this depth a search distance of 1 will be used. Set to 0 to bypass and
166    ///   just use SEARCH_DISTANCE.
167    ///
168    /// SEARCH_DISTANCE should be <= 32
169    pub fn build_ploc<const SEARCH_DISTANCE: usize, T: Boundable>(
170        &mut self,
171        bvh: &mut Bvh2,
172        aabbs: &[T],
173        indices: Vec<u32>,
174        sort_precision: SortPrecision,
175        search_depth_threshold: usize,
176    ) {
177        crate::scope!("build_ploc");
178
179        let prim_count = aabbs.len();
180
181        bvh.reset_for_reuse(prim_count, Some(indices));
182
183        if prim_count == 0 {
184            return;
185        }
186
187        #[inline]
188        fn init_node(prim_index: &u32, aabb: Aabb, local_aabb: &mut Aabb) -> Bvh2Node {
189            local_aabb.extend(aabb.min);
190            local_aabb.extend(aabb.max);
191            debug_assert!(!aabb.min.is_nan());
192            debug_assert!(!aabb.max.is_nan());
193            Bvh2Node::new(aabb, 1, *prim_index)
194        }
195
196        let mut total_aabb = None;
197        self.current_nodes.resize(prim_count, Default::default());
198
199        // TODO perf/forte Due to rayon overhead using par_iter can be slower than just iter for small quantities of nodes.
200        // 500k chosen from testing various tri counts with the demoscene example
201        #[cfg(feature = "parallel")]
202        let min_parallel = 500_000;
203
204        #[cfg(feature = "parallel")]
205        if prim_count >= min_parallel {
206            let chunk_size = aabbs.len().div_ceil(rayon::current_num_threads());
207
208            let chunks = self
209                .current_nodes
210                .par_iter_mut()
211                .zip(&bvh.primitive_indices)
212                .enumerate()
213                .chunks(chunk_size);
214
215            self.local_aabbs.resize(chunks.len(), Aabb::empty());
216
217            chunks
218                .zip(self.local_aabbs.par_iter_mut())
219                .for_each(|(data, local_aabb)| {
220                    for (i, (node, prim_index)) in data {
221                        *node = init_node(prim_index, aabbs[i].aabb(), local_aabb);
222                    }
223                });
224
225            let mut total = Aabb::empty();
226            for local_aabb in self.local_aabbs.iter_mut() {
227                total.extend(local_aabb.min);
228                total.extend(local_aabb.max);
229            }
230            total_aabb = Some(total);
231        }
232
233        if total_aabb.is_none() {
234            let mut total = Aabb::empty();
235            self.current_nodes
236                .iter_mut()
237                .zip(&bvh.primitive_indices)
238                .zip(aabbs)
239                .for_each(|((node, prim_index), aabb)| {
240                    *node = init_node(prim_index, aabb.aabb(), &mut total)
241                });
242            total_aabb = Some(total);
243        }
244
245        self.build_ploc_from_leaves::<SEARCH_DISTANCE, false>(
246            bvh,
247            total_aabb.unwrap(),
248            sort_precision,
249            search_depth_threshold,
250        );
251    }
252
253    /// Prefer using Bvh2::build(), Bvh2::build_with_bvh(), Bvh2::build_ploc(), Bvh2::partial_rebuild(),
254    /// or Bvh2::full_rebuild(). This is only public for non-typical usages.
255    /// REBUILD is for partial BVH rebuilds. In that case inner nodes should be freed by setting them to invalid
256    /// (with Bvh2Node::set_invalid()) and both respective inner and leaf nodes moved on to PlocBuilder::current_nodes.
257    /// They must always be removed in pairs with the starting on an odd number. See PlocBuilder::partial_rebuild()
258    pub fn build_ploc_from_leaves<const SEARCH_DISTANCE: usize, const REBUILD: bool>(
259        &mut self,
260        bvh: &mut Bvh2,
261        total_aabb: Aabb,
262        sort_precision: SortPrecision,
263        search_depth_threshold: usize,
264    ) {
265        crate::scope!("build_ploc_from_leaves");
266
267        let prim_count = self.current_nodes.len();
268
269        if prim_count == 0 {
270            return;
271        }
272
273        // Merge nodes until there is only one left
274        let nodes_count = (2 * prim_count as i64 - 1).max(0) as usize;
275
276        let mut insert_index = if REBUILD {
277            if bvh.nodes.is_empty() {
278                return;
279            }
280            assert!(bvh.nodes.len() >= nodes_count);
281            bvh.nodes.len() - 1
282        } else {
283            bvh.nodes.resize(nodes_count, Bvh2Node::default());
284            nodes_count
285        };
286
287        let scale = 1.0 / total_aabb.diagonal().as_dvec3();
288        let offset = -total_aabb.min.as_dvec3() * scale;
289
290        let mortons_size = match sort_precision {
291            SortPrecision::U128 => prim_count,
292            SortPrecision::U64 => prim_count.div_ceil(2),
293        };
294        self.mortons.resize(mortons_size, Default::default());
295        self.next_nodes.resize(prim_count, Default::default());
296
297        // Sort primitives according to their morton code
298        sort_precision.sort_nodes(
299            &mut self.current_nodes,
300            &mut self.next_nodes,
301            &mut self.mortons,
302            scale,
303            offset,
304        );
305        mem::swap(&mut self.current_nodes, &mut self.next_nodes);
306
307        assert!(i8::MAX as usize > SEARCH_DISTANCE);
308
309        let merge_buffer: &mut [i8] = &mut cast_slice_mut(&mut self.mortons)[..prim_count];
310
311        #[cfg(not(feature = "parallel"))]
312        let mut cache = SearchCache::<SEARCH_DISTANCE>::default();
313
314        #[cfg(feature = "parallel")]
315        let threads = rayon::current_num_threads();
316
317        #[cfg(feature = "parallel")]
318        let mut cache = if prim_count < 4000 {
319            vec![SearchCache::<SEARCH_DISTANCE>::default()]
320        } else {
321            vec![SearchCache::<SEARCH_DISTANCE>::default(); threads * 4]
322        };
323
324        let mut depth: usize = 0;
325        let mut next_nodes_idx = 0;
326        let mut count = prim_count;
327        while count > 1 {
328            let merge = &mut merge_buffer[..count];
329            if SEARCH_DISTANCE == 1 || depth < search_depth_threshold {
330                let mut last_cost = f32::INFINITY;
331                let calculate_costs = |(i, merge_n): (usize, &mut i8)| {
332                    let cost = self.current_nodes[i]
333                        .aabb()
334                        .union(self.current_nodes[i + 1].aabb())
335                        .half_area();
336                    *merge_n = if last_cost < cost { -1 } else { 1 };
337                    last_cost = cost;
338                };
339
340                let count_m1 = count - 1;
341                let merge_m1 = &mut merge[..count_m1];
342
343                #[cfg(feature = "parallel")]
344                {
345                    let chunk_size = merge_m1.len().div_ceil(threads);
346                    let calculate_costs_parallel = |(chunk_id, chunk): (usize, &mut [i8])| {
347                        let start = chunk_id * chunk_size;
348                        let mut last_cost = if start == 0 {
349                            f32::INFINITY
350                        } else {
351                            self.current_nodes[start - 1]
352                                .aabb()
353                                .union(self.current_nodes[start].aabb())
354                                .half_area()
355                        };
356                        for (local_n, merge_n) in chunk.iter_mut().enumerate() {
357                            let i = local_n + start;
358                            let cost = self.current_nodes[i]
359                                .aabb()
360                                .union(self.current_nodes[i + 1].aabb())
361                                .half_area();
362                            *merge_n = if last_cost < cost { -1 } else { 1 };
363                            last_cost = cost;
364                        }
365                    };
366
367                    // TODO perf/forte Due to rayon overhead using par_iter can be slower than just iter for small quantities.
368                    // 300k chosen from testing various scenes in tray racing
369                    if count < 300_000 {
370                        merge_m1.iter_mut().enumerate().for_each(calculate_costs);
371                    } else {
372                        merge_m1
373                            .par_chunks_mut(chunk_size.max(1))
374                            .enumerate()
375                            .for_each(calculate_costs_parallel)
376                    }
377                }
378                #[cfg(not(feature = "parallel"))]
379                {
380                    merge_m1.iter_mut().enumerate().for_each(calculate_costs);
381                }
382                merge[count_m1] = -1;
383            } else {
384                #[cfg(not(feature = "parallel"))]
385                merge.iter_mut().enumerate().for_each(|(index, best)| {
386                    *best = cache.find_best_node(index, &self.current_nodes[..count]);
387                });
388
389                #[cfg(feature = "parallel")]
390                {
391                    // TODO perf/forte Due to rayon overhead using par_iter can be slower than just iter for small quantities.
392                    // 4k chosen from testing with demoscene
393                    if count < 4000 {
394                        let cache = &mut cache[0];
395                        merge.iter_mut().enumerate().for_each(|(index, best)| {
396                            *best = cache.find_best_node(index, &self.current_nodes[..count]);
397                        });
398                    } else {
399                        // Split search into chunks in parallel
400                        let chunk_size = merge.len().div_ceil(cache.len());
401                        let chunks = merge.par_chunks_mut(merge.len().div_ceil(cache.len()));
402                        if chunks.len() > cache.len() {
403                            cache.resize(chunks.len(), SearchCache::<SEARCH_DISTANCE>::default());
404                        }
405                        chunks.zip(cache.par_iter_mut()).enumerate().for_each(
406                            |(chunk, (bests, cache))| {
407                                for (i, best) in bests.iter_mut().enumerate() {
408                                    let index = chunk * chunk_size + i;
409                                    *best = cache.find_best_node_parallel(
410                                        index,
411                                        i,
412                                        &self.current_nodes[..count],
413                                    );
414                                }
415                            },
416                        );
417                    }
418                }
419            };
420            let mut index = 0;
421            // Tried making this parallel but it was similar perf as the sequential version below. Could be memory bound?
422            // https://github.com/DGriffin91/pool_racing/commit/a35b92496a1c28043b11565ee48dff0137ada68f
423            while index < count {
424                let index_offset = merge[index] as i64;
425                let best_index = (index as i64 + index_offset) as usize;
426                // The two nodes should be merged if they agree on their respective merge indices.
427                if best_index as i64 + merge[best_index] as i64 != index as i64 {
428                    // If not, the current node should be kept for the next iteration
429                    self.next_nodes[next_nodes_idx] = self.current_nodes[index];
430                    next_nodes_idx += 1;
431                    index += 1;
432                    continue;
433                }
434
435                // Since we only need to merge once, we only merge if the first index is less than the second.
436                if best_index > index {
437                    index += 1;
438                    continue;
439                }
440
441                debug_assert_ne!(best_index, index);
442
443                let left = self.current_nodes[index];
444                let right = self.current_nodes[best_index];
445
446                let first_child;
447
448                // Reserve space in the target array for the two children
449                if REBUILD {
450                    loop {
451                        // Out of bounds here error here could indicate NaN present in input aabb. Try running in debug mode.
452                        let left_slot = &mut bvh.nodes[insert_index - 1];
453                        if !left_slot.valid() {
454                            *left_slot = left;
455                            debug_assert!(!bvh.nodes[insert_index].valid());
456                            bvh.nodes[insert_index] = right;
457                            first_child = insert_index - 1;
458                            insert_index -= 2;
459                            break;
460                        }
461                        insert_index -= 2;
462                    }
463                } else {
464                    debug_assert!(insert_index >= 2);
465                    insert_index -= 2;
466                    // Out of bounds here error here could indicate NaN present in input aabb. Try running in debug mode.
467                    bvh.nodes[insert_index] = left;
468                    bvh.nodes[insert_index + 1] = right;
469                    first_child = insert_index;
470                }
471
472                // Create the parent node and place it in the array for the next iteration
473                self.next_nodes[next_nodes_idx] =
474                    Bvh2Node::new(left.aabb().union(right.aabb()), 0, first_child as u32);
475                next_nodes_idx += 1;
476
477                if SEARCH_DISTANCE == 1 && index_offset == 1 {
478                    // If the search distance is only 1, and the next index was merged with this one,
479                    // we can skip the next index.
480                    // The code for this with the while loop seemed to also be slightly faster than:
481                    //     for (index, best_index) in merge.iter().enumerate() {
482                    // even in the other cases. For some reason...
483                    index += 2;
484                } else {
485                    index += 1;
486                }
487            }
488
489            mem::swap(&mut self.next_nodes, &mut self.current_nodes);
490            count = next_nodes_idx;
491            next_nodes_idx = 0;
492            depth += 1;
493        }
494
495        if !REBUILD {
496            debug_assert_eq!(insert_index, 1);
497        }
498
499        bvh.nodes[0] = self.current_nodes[0];
500
501        bvh.max_depth = DEFAULT_MAX_STACK_DEPTH.max(depth + 1);
502        bvh.children_are_ordered_after_parents = !REBUILD;
503    }
504}
505
506// For reference/testing
507#[allow(dead_code)]
508fn find_best_node_basic(index: usize, nodes: &[Bvh2Node], search_distance: usize) -> i8 {
509    let mut best_node = index;
510    let mut best_cost = f32::INFINITY;
511
512    let begin = index - search_distance.min(index);
513    let end = (index + search_distance + 1).min(nodes.len());
514
515    let our_aabb = nodes[index].aabb();
516    for other in begin..end {
517        if other == index {
518            continue;
519        }
520        let cost = our_aabb.union(nodes[other].aabb()).half_area();
521        if cost <= best_cost {
522            best_node = other;
523            best_cost = cost;
524        }
525    }
526
527    (best_node as i64 - index as i64) as i8
528}
529
530/// In PLOC, the number of nodes before and after the current one that are evaluated for pairing.
531/// Minimum (1) has a fast path in building and still results in decent quality BVHs especially
532/// when paired with a bit of reinsertion.
533#[derive(Default, Clone, Copy, Debug)]
534pub enum PlocSearchDistance {
535    /// 1
536    Minimum,
537    /// 2
538    VeryLow,
539    /// 6
540    Low,
541    #[default]
542    /// 14
543    Medium,
544    /// 24
545    High,
546    /// 32
547    VeryHigh,
548}
549
550impl From<u32> for PlocSearchDistance {
551    fn from(value: u32) -> Self {
552        match value {
553            1 => PlocSearchDistance::Minimum,
554            2 => PlocSearchDistance::VeryLow,
555            6 => PlocSearchDistance::Low,
556            14 => PlocSearchDistance::Medium,
557            24 => PlocSearchDistance::High,
558            32 => PlocSearchDistance::VeryHigh,
559            _ => panic!("Invalid value for PlocSearchDistance: {value}"),
560        }
561    }
562}
563
564// Tried using a Vec it was ~30% slower with a search distance of 14.
565// Tried making the Vec flat, used get_unchecked, etc... (without those it was ~80% slower)
566#[derive(Clone, Copy)]
567pub struct SearchCache<const SEARCH_DISTANCE: usize>([[f32; SEARCH_DISTANCE]; SEARCH_DISTANCE]);
568
569impl<const SEARCH_DISTANCE: usize> Default for SearchCache<SEARCH_DISTANCE> {
570    fn default() -> Self {
571        SearchCache([[0.0; SEARCH_DISTANCE]; SEARCH_DISTANCE])
572    }
573}
574
575impl<const SEARCH_DISTANCE: usize> SearchCache<SEARCH_DISTANCE> {
576    #[inline]
577    fn back(&self, index: usize, other: usize) -> f32 {
578        // Note: the compiler removes the bounds check due to the % SEARCH_DISTANCE
579        self.0[other % SEARCH_DISTANCE][index % SEARCH_DISTANCE]
580    }
581
582    #[inline]
583    fn front(&mut self, index: usize, other: usize) -> &mut f32 {
584        &mut self.0[index % SEARCH_DISTANCE][other % SEARCH_DISTANCE]
585    }
586
587    #[allow(dead_code)]
588    fn find_best_node_parallel(&mut self, index: usize, i: usize, nodes: &[Bvh2Node]) -> i8 {
589        let mut best_node = index;
590        let mut best_cost = f32::INFINITY;
591
592        let begin = index - SEARCH_DISTANCE.min(index);
593        let end = (index + SEARCH_DISTANCE + 1).min(nodes.len());
594
595        let our_aabb = nodes[index].aabb();
596        for other in begin..index {
597            // When using the cache in parallel, the search is broken into chunks. This means the first
598            // n = SEARCH_DISTANCE slots in the cache won't have been filled yet.
599            // (TODO this could be tighter, using more of the cache within the n = SEARCH_DISTANCE range as it's filled)
600            let area = if i <= SEARCH_DISTANCE {
601                our_aabb.union(nodes[other].aabb()).half_area()
602            } else {
603                self.back(index, other)
604            };
605
606            if area <= best_cost {
607                best_node = other;
608                best_cost = area;
609            }
610        }
611
612        ((index + 1)..end).for_each(|other| {
613            let cost = our_aabb.union(nodes[other].aabb()).half_area();
614            *self.front(index, other) = cost;
615            if cost <= best_cost {
616                best_node = other;
617                best_cost = cost;
618            }
619        });
620
621        (best_node as i64 - index as i64) as i8
622    }
623
624    fn find_best_node(&mut self, index: usize, nodes: &[Bvh2Node]) -> i8 {
625        let mut best_node = index;
626        let mut best_cost = f32::INFINITY;
627
628        let begin = index - SEARCH_DISTANCE.min(index);
629        let end = (index + SEARCH_DISTANCE + 1).min(nodes.len());
630
631        for other in begin..index {
632            let area = self.back(index, other);
633            if area <= best_cost {
634                best_node = other;
635                best_cost = area;
636            }
637        }
638
639        let our_aabb = nodes[index].aabb();
640        ((index + 1)..end).for_each(|other| {
641            let cost = our_aabb.union(nodes[other].aabb()).half_area();
642            *self.front(index, other) = cost;
643            if cost <= best_cost {
644                best_node = other;
645                best_cost = cost;
646            }
647        });
648
649        (best_node as i64 - index as i64) as i8
650    }
651}
652
653// ---------------------
654// --- SORTING NODES ---
655// ---------------------
656
657#[derive(Debug, Copy, Clone)]
658pub enum SortPrecision {
659    U128,
660    U64,
661}
662
663impl SortPrecision {
664    fn sort_nodes(
665        &self,
666        nodes: &mut [Bvh2Node],
667        sorted: &mut [Bvh2Node],
668        mortons_allocation: &mut [[u128; 2]],
669        scale: DVec3,
670        offset: DVec3,
671    ) {
672        match self {
673            SortPrecision::U128 => {
674                let mortons = cast_slice_mut(mortons_allocation);
675                sort_nodes_by_morton::<Morton128>(*self, nodes, sorted, mortons, scale, offset)
676            }
677            SortPrecision::U64 => {
678                let smaller: &mut [u128] = cast_slice_mut(mortons_allocation);
679                let mortons = cast_slice_mut(&mut smaller[..nodes.len()]);
680                sort_nodes_by_morton::<Morton64>(*self, nodes, sorted, mortons, scale, offset)
681            }
682        }
683    }
684}
685
686#[derive(Clone, Copy, Pod, Zeroable)]
687#[repr(C)]
688struct Morton128 {
689    code: u128,
690    index: u64,
691    padding: u64,
692}
693
694impl RadixKey for Morton128 {
695    const LEVELS: usize = 16;
696
697    #[inline(always)]
698    fn get_level(&self, level: usize) -> u8 {
699        self.code.get_level(level)
700    }
701}
702
703#[derive(Clone, Copy, Pod, Zeroable)]
704#[repr(C)]
705struct Morton64 {
706    code: u64,
707    index: u64,
708}
709
710impl RadixKey for Morton64 {
711    const LEVELS: usize = 8;
712
713    #[inline(always)]
714    fn get_level(&self, level: usize) -> u8 {
715        self.code.get_level(level)
716    }
717}
718
719trait MortonCode: RadixKey + Send + Sync + Copy {
720    fn new(index: usize, center: DVec3) -> Self;
721    fn index(&self) -> usize;
722    fn code64(&self) -> u64;
723    fn code128(&self) -> u128;
724}
725
726impl MortonCode for Morton128 {
727    #[inline(always)]
728    fn new(index: usize, center: DVec3) -> Self {
729        Morton128 {
730            index: index as u64,
731            code: morton_encode_u128_unorm(center),
732            padding: Default::default(),
733        }
734    }
735    #[inline(always)]
736    fn index(&self) -> usize {
737        self.index as usize
738    }
739    #[inline(always)]
740    fn code64(&self) -> u64 {
741        panic!("Don't sort Morton128 using code64");
742    }
743    #[inline(always)]
744    fn code128(&self) -> u128 {
745        self.code
746    }
747}
748
749impl MortonCode for Morton64 {
750    #[inline(always)]
751    fn new(index: usize, center: DVec3) -> Self {
752        Morton64 {
753            index: index as u64,
754            code: morton_encode_u64_unorm(center),
755        }
756    }
757    #[inline(always)]
758    fn index(&self) -> usize {
759        self.index as usize
760    }
761    #[inline(always)]
762    fn code64(&self) -> u64 {
763        self.code
764    }
765    #[inline(always)]
766    fn code128(&self) -> u128 {
767        panic!("Don't sort Morton64 using code128");
768    }
769}
770
771fn sort_nodes_by_morton<M: MortonCode>(
772    precision: SortPrecision,
773    nodes: &mut [Bvh2Node],
774    sorted_nodes: &mut [Bvh2Node],
775    mortons: &mut [M],
776    scale: DVec3,
777    offset: DVec3,
778) {
779    crate::scope!("sort_nodes");
780    let nodes_count = nodes.len();
781
782    let gen_mort = |(index, (morton, leaf)): (usize, (&mut M, &Bvh2Node))| {
783        let center = leaf.aabb().center().as_dvec3() * scale + offset;
784        *morton = M::new(index, center);
785    };
786
787    #[cfg(feature = "parallel")]
788    {
789        let min_parallel = 100_000;
790        if nodes_count > min_parallel {
791            mortons
792                .par_iter_mut()
793                .zip(nodes.par_iter())
794                .enumerate()
795                .for_each(gen_mort);
796        } else {
797            mortons
798                .iter_mut()
799                .zip(nodes.iter())
800                .enumerate()
801                .for_each(gen_mort);
802        }
803    }
804    #[cfg(not(feature = "parallel"))]
805    mortons
806        .iter_mut()
807        .zip(nodes.iter())
808        .enumerate()
809        .for_each(gen_mort);
810
811    #[cfg(feature = "parallel")]
812    {
813        match precision {
814            SortPrecision::U128 => mortons.par_sort_unstable_by_key(|m| m.code128()),
815            SortPrecision::U64 => mortons.par_sort_unstable_by_key(|m| m.code64()),
816        }
817    }
818    #[cfg(not(feature = "parallel"))]
819    {
820        match nodes_count {
821            0..=250_000 => match precision {
822                SortPrecision::U128 => mortons.sort_unstable_by_key(|m| m.code128()),
823                SortPrecision::U64 => mortons.sort_unstable_by_key(|m| m.code64()),
824            },
825            _ => mortons.radix_sort_unstable(),
826        };
827    }
828
829    let remap = |(n, m): (&mut Bvh2Node, &M)| *n = nodes[m.index()];
830
831    #[cfg(feature = "parallel")]
832    {
833        let min_parallel = 100_000;
834        if nodes_count > min_parallel {
835            sorted_nodes
836                .par_iter_mut()
837                .zip(mortons.par_iter())
838                .for_each(remap)
839        } else {
840            sorted_nodes.iter_mut().zip(mortons.iter()).for_each(remap)
841        }
842    }
843    #[cfg(not(feature = "parallel"))]
844    {
845        sorted_nodes.iter_mut().zip(mortons.iter()).for_each(remap);
846    }
847}