1pub mod morton;
4pub mod rebuild;
5
6use std::{f32, mem};
12
13use bytemuck::{Pod, Zeroable, cast_slice_mut, zeroed_vec};
14use glam::DVec3;
15use rdst::RadixKey;
16
17#[cfg(feature = "parallel")]
18use rayon::{
19 iter::{
20 IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator,
21 ParallelIterator,
22 },
23 slice::ParallelSliceMut,
24};
25
26#[cfg(not(feature = "parallel"))]
27use rdst::RadixSort;
28
29use crate::bvh2::DEFAULT_MAX_STACK_DEPTH;
30use crate::ploc::morton::{morton_encode_u64_unorm, morton_encode_u128_unorm};
31use crate::{Boundable, bvh2::node::Bvh2Node};
32use crate::{aabb::Aabb, bvh2::Bvh2};
33
34#[derive(Clone)]
35pub struct PlocBuilder {
36 pub current_nodes: Vec<Bvh2Node>,
37 pub next_nodes: Vec<Bvh2Node>,
38
39 pub mortons: Vec<[u128; 2]>,
42
43 #[cfg(feature = "parallel")]
44 pub local_aabbs: Vec<Aabb>,
45}
46
47impl Default for PlocBuilder {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl PlocBuilder {
54 pub fn new() -> PlocBuilder {
56 crate::scope!("preallocate_builder");
57 PlocBuilder {
58 current_nodes: Vec::new(),
59 next_nodes: Vec::new(),
60 mortons: Vec::new(),
61
62 #[cfg(feature = "parallel")]
63 local_aabbs: Vec::new(),
64 }
65 }
66
67 pub fn with_capacity(prim_count: usize) -> PlocBuilder {
70 crate::scope!("preallocate_builder");
71 PlocBuilder {
72 current_nodes: zeroed_vec(prim_count),
73 next_nodes: zeroed_vec(prim_count),
74 mortons: zeroed_vec(prim_count),
75
76 #[cfg(feature = "parallel")]
77 local_aabbs: zeroed_vec(128),
78 }
79 }
80
81 #[inline]
95 pub fn build<T: Boundable>(
96 &mut self,
97 search_distance: PlocSearchDistance,
98 aabbs: &[T],
99 indices: Vec<u32>,
100 sort_precision: SortPrecision,
101 search_depth_threshold: usize,
102 ) -> Bvh2 {
103 let mut bvh = Bvh2::zeroed(aabbs.len());
104 self.build_with_bvh(
105 &mut bvh,
106 search_distance,
107 aabbs,
108 indices,
109 sort_precision,
110 search_depth_threshold,
111 );
112 bvh
113 }
114
115 pub fn build_with_bvh<T: Boundable>(
130 &mut self,
131 bvh: &mut Bvh2,
132 search_distance: PlocSearchDistance,
133 aabbs: &[T],
134 indices: Vec<u32>,
135 sort_precision: SortPrecision,
136 search_depth_threshold: usize,
137 ) {
138 let search_thresh = search_depth_threshold;
139 match search_distance {
140 PlocSearchDistance::Minimum => {
141 self.build_ploc::<1, T>(bvh, aabbs, indices, sort_precision, search_thresh)
142 }
143 PlocSearchDistance::VeryLow => {
144 self.build_ploc::<2, T>(bvh, aabbs, indices, sort_precision, search_thresh)
145 }
146 PlocSearchDistance::Low => {
147 self.build_ploc::<6, T>(bvh, aabbs, indices, sort_precision, search_thresh)
148 }
149 PlocSearchDistance::Medium => {
150 self.build_ploc::<14, T>(bvh, aabbs, indices, sort_precision, search_thresh)
151 }
152 PlocSearchDistance::High => {
153 self.build_ploc::<24, T>(bvh, aabbs, indices, sort_precision, search_thresh)
154 }
155 PlocSearchDistance::VeryHigh => {
156 self.build_ploc::<32, T>(bvh, aabbs, indices, sort_precision, search_thresh)
157 }
158 }
159 }
160
161 pub fn build_ploc<const SEARCH_DISTANCE: usize, T: Boundable>(
170 &mut self,
171 bvh: &mut Bvh2,
172 aabbs: &[T],
173 indices: Vec<u32>,
174 sort_precision: SortPrecision,
175 search_depth_threshold: usize,
176 ) {
177 crate::scope!("build_ploc");
178
179 let prim_count = aabbs.len();
180
181 bvh.reset_for_reuse(prim_count, Some(indices));
182
183 if prim_count == 0 {
184 return;
185 }
186
187 #[inline]
188 fn init_node(prim_index: &u32, aabb: Aabb, local_aabb: &mut Aabb) -> Bvh2Node {
189 local_aabb.extend(aabb.min);
190 local_aabb.extend(aabb.max);
191 debug_assert!(!aabb.min.is_nan());
192 debug_assert!(!aabb.max.is_nan());
193 Bvh2Node::new(aabb, 1, *prim_index)
194 }
195
196 let mut total_aabb = None;
197 self.current_nodes.resize(prim_count, Default::default());
198
199 #[cfg(feature = "parallel")]
202 let min_parallel = 500_000;
203
204 #[cfg(feature = "parallel")]
205 if prim_count >= min_parallel {
206 let chunk_size = aabbs.len().div_ceil(rayon::current_num_threads());
207
208 let chunks = self
209 .current_nodes
210 .par_iter_mut()
211 .zip(&bvh.primitive_indices)
212 .enumerate()
213 .chunks(chunk_size);
214
215 self.local_aabbs.resize(chunks.len(), Aabb::empty());
216
217 chunks
218 .zip(self.local_aabbs.par_iter_mut())
219 .for_each(|(data, local_aabb)| {
220 for (i, (node, prim_index)) in data {
221 *node = init_node(prim_index, aabbs[i].aabb(), local_aabb);
222 }
223 });
224
225 let mut total = Aabb::empty();
226 for local_aabb in self.local_aabbs.iter_mut() {
227 total.extend(local_aabb.min);
228 total.extend(local_aabb.max);
229 }
230 total_aabb = Some(total);
231 }
232
233 if total_aabb.is_none() {
234 let mut total = Aabb::empty();
235 self.current_nodes
236 .iter_mut()
237 .zip(&bvh.primitive_indices)
238 .zip(aabbs)
239 .for_each(|((node, prim_index), aabb)| {
240 *node = init_node(prim_index, aabb.aabb(), &mut total)
241 });
242 total_aabb = Some(total);
243 }
244
245 self.build_ploc_from_leaves::<SEARCH_DISTANCE, false>(
246 bvh,
247 total_aabb.unwrap(),
248 sort_precision,
249 search_depth_threshold,
250 );
251 }
252
253 pub fn build_ploc_from_leaves<const SEARCH_DISTANCE: usize, const REBUILD: bool>(
259 &mut self,
260 bvh: &mut Bvh2,
261 total_aabb: Aabb,
262 sort_precision: SortPrecision,
263 search_depth_threshold: usize,
264 ) {
265 crate::scope!("build_ploc_from_leaves");
266
267 let prim_count = self.current_nodes.len();
268
269 if prim_count == 0 {
270 return;
271 }
272
273 let nodes_count = (2 * prim_count as i64 - 1).max(0) as usize;
275
276 let mut insert_index = if REBUILD {
277 if bvh.nodes.is_empty() {
278 return;
279 }
280 assert!(bvh.nodes.len() >= nodes_count);
281 bvh.nodes.len() - 1
282 } else {
283 bvh.nodes.resize(nodes_count, Bvh2Node::default());
284 nodes_count
285 };
286
287 let scale = 1.0 / total_aabb.diagonal().as_dvec3();
288 let offset = -total_aabb.min.as_dvec3() * scale;
289
290 let mortons_size = match sort_precision {
291 SortPrecision::U128 => prim_count,
292 SortPrecision::U64 => prim_count.div_ceil(2),
293 };
294 self.mortons.resize(mortons_size, Default::default());
295 self.next_nodes.resize(prim_count, Default::default());
296
297 sort_precision.sort_nodes(
299 &mut self.current_nodes,
300 &mut self.next_nodes,
301 &mut self.mortons,
302 scale,
303 offset,
304 );
305 mem::swap(&mut self.current_nodes, &mut self.next_nodes);
306
307 assert!(i8::MAX as usize > SEARCH_DISTANCE);
308
309 let merge_buffer: &mut [i8] = &mut cast_slice_mut(&mut self.mortons)[..prim_count];
310
311 #[cfg(not(feature = "parallel"))]
312 let mut cache = SearchCache::<SEARCH_DISTANCE>::default();
313
314 #[cfg(feature = "parallel")]
315 let threads = rayon::current_num_threads();
316
317 #[cfg(feature = "parallel")]
318 let mut cache = if prim_count < 4000 {
319 vec![SearchCache::<SEARCH_DISTANCE>::default()]
320 } else {
321 vec![SearchCache::<SEARCH_DISTANCE>::default(); threads * 4]
322 };
323
324 let mut depth: usize = 0;
325 let mut next_nodes_idx = 0;
326 let mut count = prim_count;
327 while count > 1 {
328 let merge = &mut merge_buffer[..count];
329 if SEARCH_DISTANCE == 1 || depth < search_depth_threshold {
330 let mut last_cost = f32::INFINITY;
331 let calculate_costs = |(i, merge_n): (usize, &mut i8)| {
332 let cost = self.current_nodes[i]
333 .aabb()
334 .union(self.current_nodes[i + 1].aabb())
335 .half_area();
336 *merge_n = if last_cost < cost { -1 } else { 1 };
337 last_cost = cost;
338 };
339
340 let count_m1 = count - 1;
341 let merge_m1 = &mut merge[..count_m1];
342
343 #[cfg(feature = "parallel")]
344 {
345 let chunk_size = merge_m1.len().div_ceil(threads);
346 let calculate_costs_parallel = |(chunk_id, chunk): (usize, &mut [i8])| {
347 let start = chunk_id * chunk_size;
348 let mut last_cost = if start == 0 {
349 f32::INFINITY
350 } else {
351 self.current_nodes[start - 1]
352 .aabb()
353 .union(self.current_nodes[start].aabb())
354 .half_area()
355 };
356 for (local_n, merge_n) in chunk.iter_mut().enumerate() {
357 let i = local_n + start;
358 let cost = self.current_nodes[i]
359 .aabb()
360 .union(self.current_nodes[i + 1].aabb())
361 .half_area();
362 *merge_n = if last_cost < cost { -1 } else { 1 };
363 last_cost = cost;
364 }
365 };
366
367 if count < 300_000 {
370 merge_m1.iter_mut().enumerate().for_each(calculate_costs);
371 } else {
372 merge_m1
373 .par_chunks_mut(chunk_size.max(1))
374 .enumerate()
375 .for_each(calculate_costs_parallel)
376 }
377 }
378 #[cfg(not(feature = "parallel"))]
379 {
380 merge_m1.iter_mut().enumerate().for_each(calculate_costs);
381 }
382 merge[count_m1] = -1;
383 } else {
384 #[cfg(not(feature = "parallel"))]
385 merge.iter_mut().enumerate().for_each(|(index, best)| {
386 *best = cache.find_best_node(index, &self.current_nodes[..count]);
387 });
388
389 #[cfg(feature = "parallel")]
390 {
391 if count < 4000 {
394 let cache = &mut cache[0];
395 merge.iter_mut().enumerate().for_each(|(index, best)| {
396 *best = cache.find_best_node(index, &self.current_nodes[..count]);
397 });
398 } else {
399 let chunk_size = merge.len().div_ceil(cache.len());
401 let chunks = merge.par_chunks_mut(merge.len().div_ceil(cache.len()));
402 if chunks.len() > cache.len() {
403 cache.resize(chunks.len(), SearchCache::<SEARCH_DISTANCE>::default());
404 }
405 chunks.zip(cache.par_iter_mut()).enumerate().for_each(
406 |(chunk, (bests, cache))| {
407 for (i, best) in bests.iter_mut().enumerate() {
408 let index = chunk * chunk_size + i;
409 *best = cache.find_best_node_parallel(
410 index,
411 i,
412 &self.current_nodes[..count],
413 );
414 }
415 },
416 );
417 }
418 }
419 };
420 let mut index = 0;
421 while index < count {
424 let index_offset = merge[index] as i64;
425 let best_index = (index as i64 + index_offset) as usize;
426 if best_index as i64 + merge[best_index] as i64 != index as i64 {
428 self.next_nodes[next_nodes_idx] = self.current_nodes[index];
430 next_nodes_idx += 1;
431 index += 1;
432 continue;
433 }
434
435 if best_index > index {
437 index += 1;
438 continue;
439 }
440
441 debug_assert_ne!(best_index, index);
442
443 let left = self.current_nodes[index];
444 let right = self.current_nodes[best_index];
445
446 let first_child;
447
448 if REBUILD {
450 loop {
451 let left_slot = &mut bvh.nodes[insert_index - 1];
453 if !left_slot.valid() {
454 *left_slot = left;
455 debug_assert!(!bvh.nodes[insert_index].valid());
456 bvh.nodes[insert_index] = right;
457 first_child = insert_index - 1;
458 insert_index -= 2;
459 break;
460 }
461 insert_index -= 2;
462 }
463 } else {
464 debug_assert!(insert_index >= 2);
465 insert_index -= 2;
466 bvh.nodes[insert_index] = left;
468 bvh.nodes[insert_index + 1] = right;
469 first_child = insert_index;
470 }
471
472 self.next_nodes[next_nodes_idx] =
474 Bvh2Node::new(left.aabb().union(right.aabb()), 0, first_child as u32);
475 next_nodes_idx += 1;
476
477 if SEARCH_DISTANCE == 1 && index_offset == 1 {
478 index += 2;
484 } else {
485 index += 1;
486 }
487 }
488
489 mem::swap(&mut self.next_nodes, &mut self.current_nodes);
490 count = next_nodes_idx;
491 next_nodes_idx = 0;
492 depth += 1;
493 }
494
495 if !REBUILD {
496 debug_assert_eq!(insert_index, 1);
497 }
498
499 bvh.nodes[0] = self.current_nodes[0];
500
501 bvh.max_depth = DEFAULT_MAX_STACK_DEPTH.max(depth + 1);
502 bvh.children_are_ordered_after_parents = !REBUILD;
503 }
504}
505
506#[allow(dead_code)]
508fn find_best_node_basic(index: usize, nodes: &[Bvh2Node], search_distance: usize) -> i8 {
509 let mut best_node = index;
510 let mut best_cost = f32::INFINITY;
511
512 let begin = index - search_distance.min(index);
513 let end = (index + search_distance + 1).min(nodes.len());
514
515 let our_aabb = nodes[index].aabb();
516 for other in begin..end {
517 if other == index {
518 continue;
519 }
520 let cost = our_aabb.union(nodes[other].aabb()).half_area();
521 if cost <= best_cost {
522 best_node = other;
523 best_cost = cost;
524 }
525 }
526
527 (best_node as i64 - index as i64) as i8
528}
529
530#[derive(Default, Clone, Copy, Debug)]
534pub enum PlocSearchDistance {
535 Minimum,
537 VeryLow,
539 Low,
541 #[default]
542 Medium,
544 High,
546 VeryHigh,
548}
549
550impl From<u32> for PlocSearchDistance {
551 fn from(value: u32) -> Self {
552 match value {
553 1 => PlocSearchDistance::Minimum,
554 2 => PlocSearchDistance::VeryLow,
555 6 => PlocSearchDistance::Low,
556 14 => PlocSearchDistance::Medium,
557 24 => PlocSearchDistance::High,
558 32 => PlocSearchDistance::VeryHigh,
559 _ => panic!("Invalid value for PlocSearchDistance: {value}"),
560 }
561 }
562}
563
564#[derive(Clone, Copy)]
567pub struct SearchCache<const SEARCH_DISTANCE: usize>([[f32; SEARCH_DISTANCE]; SEARCH_DISTANCE]);
568
569impl<const SEARCH_DISTANCE: usize> Default for SearchCache<SEARCH_DISTANCE> {
570 fn default() -> Self {
571 SearchCache([[0.0; SEARCH_DISTANCE]; SEARCH_DISTANCE])
572 }
573}
574
575impl<const SEARCH_DISTANCE: usize> SearchCache<SEARCH_DISTANCE> {
576 #[inline]
577 fn back(&self, index: usize, other: usize) -> f32 {
578 self.0[other % SEARCH_DISTANCE][index % SEARCH_DISTANCE]
580 }
581
582 #[inline]
583 fn front(&mut self, index: usize, other: usize) -> &mut f32 {
584 &mut self.0[index % SEARCH_DISTANCE][other % SEARCH_DISTANCE]
585 }
586
587 #[allow(dead_code)]
588 fn find_best_node_parallel(&mut self, index: usize, i: usize, nodes: &[Bvh2Node]) -> i8 {
589 let mut best_node = index;
590 let mut best_cost = f32::INFINITY;
591
592 let begin = index - SEARCH_DISTANCE.min(index);
593 let end = (index + SEARCH_DISTANCE + 1).min(nodes.len());
594
595 let our_aabb = nodes[index].aabb();
596 for other in begin..index {
597 let area = if i <= SEARCH_DISTANCE {
601 our_aabb.union(nodes[other].aabb()).half_area()
602 } else {
603 self.back(index, other)
604 };
605
606 if area <= best_cost {
607 best_node = other;
608 best_cost = area;
609 }
610 }
611
612 ((index + 1)..end).for_each(|other| {
613 let cost = our_aabb.union(nodes[other].aabb()).half_area();
614 *self.front(index, other) = cost;
615 if cost <= best_cost {
616 best_node = other;
617 best_cost = cost;
618 }
619 });
620
621 (best_node as i64 - index as i64) as i8
622 }
623
624 fn find_best_node(&mut self, index: usize, nodes: &[Bvh2Node]) -> i8 {
625 let mut best_node = index;
626 let mut best_cost = f32::INFINITY;
627
628 let begin = index - SEARCH_DISTANCE.min(index);
629 let end = (index + SEARCH_DISTANCE + 1).min(nodes.len());
630
631 for other in begin..index {
632 let area = self.back(index, other);
633 if area <= best_cost {
634 best_node = other;
635 best_cost = area;
636 }
637 }
638
639 let our_aabb = nodes[index].aabb();
640 ((index + 1)..end).for_each(|other| {
641 let cost = our_aabb.union(nodes[other].aabb()).half_area();
642 *self.front(index, other) = cost;
643 if cost <= best_cost {
644 best_node = other;
645 best_cost = cost;
646 }
647 });
648
649 (best_node as i64 - index as i64) as i8
650 }
651}
652
653#[derive(Debug, Copy, Clone)]
658pub enum SortPrecision {
659 U128,
660 U64,
661}
662
663impl SortPrecision {
664 fn sort_nodes(
665 &self,
666 nodes: &mut [Bvh2Node],
667 sorted: &mut [Bvh2Node],
668 mortons_allocation: &mut [[u128; 2]],
669 scale: DVec3,
670 offset: DVec3,
671 ) {
672 match self {
673 SortPrecision::U128 => {
674 let mortons = cast_slice_mut(mortons_allocation);
675 sort_nodes_by_morton::<Morton128>(*self, nodes, sorted, mortons, scale, offset)
676 }
677 SortPrecision::U64 => {
678 let smaller: &mut [u128] = cast_slice_mut(mortons_allocation);
679 let mortons = cast_slice_mut(&mut smaller[..nodes.len()]);
680 sort_nodes_by_morton::<Morton64>(*self, nodes, sorted, mortons, scale, offset)
681 }
682 }
683 }
684}
685
686#[derive(Clone, Copy, Pod, Zeroable)]
687#[repr(C)]
688struct Morton128 {
689 code: u128,
690 index: u64,
691 padding: u64,
692}
693
694impl RadixKey for Morton128 {
695 const LEVELS: usize = 16;
696
697 #[inline(always)]
698 fn get_level(&self, level: usize) -> u8 {
699 self.code.get_level(level)
700 }
701}
702
703#[derive(Clone, Copy, Pod, Zeroable)]
704#[repr(C)]
705struct Morton64 {
706 code: u64,
707 index: u64,
708}
709
710impl RadixKey for Morton64 {
711 const LEVELS: usize = 8;
712
713 #[inline(always)]
714 fn get_level(&self, level: usize) -> u8 {
715 self.code.get_level(level)
716 }
717}
718
719trait MortonCode: RadixKey + Send + Sync + Copy {
720 fn new(index: usize, center: DVec3) -> Self;
721 fn index(&self) -> usize;
722 fn code64(&self) -> u64;
723 fn code128(&self) -> u128;
724}
725
726impl MortonCode for Morton128 {
727 #[inline(always)]
728 fn new(index: usize, center: DVec3) -> Self {
729 Morton128 {
730 index: index as u64,
731 code: morton_encode_u128_unorm(center),
732 padding: Default::default(),
733 }
734 }
735 #[inline(always)]
736 fn index(&self) -> usize {
737 self.index as usize
738 }
739 #[inline(always)]
740 fn code64(&self) -> u64 {
741 panic!("Don't sort Morton128 using code64");
742 }
743 #[inline(always)]
744 fn code128(&self) -> u128 {
745 self.code
746 }
747}
748
749impl MortonCode for Morton64 {
750 #[inline(always)]
751 fn new(index: usize, center: DVec3) -> Self {
752 Morton64 {
753 index: index as u64,
754 code: morton_encode_u64_unorm(center),
755 }
756 }
757 #[inline(always)]
758 fn index(&self) -> usize {
759 self.index as usize
760 }
761 #[inline(always)]
762 fn code64(&self) -> u64 {
763 self.code
764 }
765 #[inline(always)]
766 fn code128(&self) -> u128 {
767 panic!("Don't sort Morton64 using code128");
768 }
769}
770
771fn sort_nodes_by_morton<M: MortonCode>(
772 precision: SortPrecision,
773 nodes: &mut [Bvh2Node],
774 sorted_nodes: &mut [Bvh2Node],
775 mortons: &mut [M],
776 scale: DVec3,
777 offset: DVec3,
778) {
779 crate::scope!("sort_nodes");
780 let nodes_count = nodes.len();
781
782 let gen_mort = |(index, (morton, leaf)): (usize, (&mut M, &Bvh2Node))| {
783 let center = leaf.aabb().center().as_dvec3() * scale + offset;
784 *morton = M::new(index, center);
785 };
786
787 #[cfg(feature = "parallel")]
788 {
789 let min_parallel = 100_000;
790 if nodes_count > min_parallel {
791 mortons
792 .par_iter_mut()
793 .zip(nodes.par_iter())
794 .enumerate()
795 .for_each(gen_mort);
796 } else {
797 mortons
798 .iter_mut()
799 .zip(nodes.iter())
800 .enumerate()
801 .for_each(gen_mort);
802 }
803 }
804 #[cfg(not(feature = "parallel"))]
805 mortons
806 .iter_mut()
807 .zip(nodes.iter())
808 .enumerate()
809 .for_each(gen_mort);
810
811 #[cfg(feature = "parallel")]
812 {
813 match precision {
814 SortPrecision::U128 => mortons.par_sort_unstable_by_key(|m| m.code128()),
815 SortPrecision::U64 => mortons.par_sort_unstable_by_key(|m| m.code64()),
816 }
817 }
818 #[cfg(not(feature = "parallel"))]
819 {
820 match nodes_count {
821 0..=250_000 => match precision {
822 SortPrecision::U128 => mortons.sort_unstable_by_key(|m| m.code128()),
823 SortPrecision::U64 => mortons.sort_unstable_by_key(|m| m.code64()),
824 },
825 _ => mortons.radix_sort_unstable(),
826 };
827 }
828
829 let remap = |(n, m): (&mut Bvh2Node, &M)| *n = nodes[m.index()];
830
831 #[cfg(feature = "parallel")]
832 {
833 let min_parallel = 100_000;
834 if nodes_count > min_parallel {
835 sorted_nodes
836 .par_iter_mut()
837 .zip(mortons.par_iter())
838 .for_each(remap)
839 } else {
840 sorted_nodes.iter_mut().zip(mortons.iter()).for_each(remap)
841 }
842 }
843 #[cfg(not(feature = "parallel"))]
844 {
845 sorted_nodes.iter_mut().zip(mortons.iter()).for_each(remap);
846 }
847}