radsort/double_buffer.rs
1use core::{mem::MaybeUninit, slice};
2
3use alloc::{boxed::Box, vec::Vec};
4
5/// Double buffer. Wraps a mutable slice and allocates a scratch memory of the same size, so that
6/// elements can be freely scattered from buffer to buffer.
7///
8/// # Drop behavior
9///
10/// Drop ensures that the mutable slice this buffer was constructed with contains all the original
11/// elements.
12pub struct DoubleBuffer<'a, T> {
13 slice: &'a mut [MaybeUninit<T>],
14 scratch: Box<[MaybeUninit<T>]>,
15 slice_is_write: bool,
16}
17
18impl<'a, T> DoubleBuffer<'a, T> {
19 /// Creates a double buffer, allocating a scratch buffer of the same length as the input slice.
20 ///
21 /// The supplied slice becomes the read buffer, the scratch buffer becomes the write buffer.
22 pub fn new(slice: &'a mut [T]) -> Self {
23 // SAFETY: The Drop impl ensures that the slice is initialized.
24 let slice = unsafe { slice_as_uninit_mut(slice) };
25 let scratch = {
26 let mut v = Vec::with_capacity(slice.len());
27 // SAFETY: we just allocated this capacity and MaybeUninit can be garbage.
28 unsafe {
29 v.set_len(slice.len());
30 }
31 v.into_boxed_slice()
32 };
33 DoubleBuffer {
34 slice,
35 scratch,
36 slice_is_write: false,
37 }
38 }
39
40 /// Scatters the elements from the read buffer to the computed indices in
41 /// the write buffer. The read buffer is iterated from the beginning.
42 ///
43 /// Call `swap` after this function to commit the write buffer state.
44 ///
45 /// Returning an out-of-bounds index from the indexer causes this function
46 /// to immediately return, without iterating over the remaining elements.
47 pub fn scatter<F>(&mut self, mut indexer: F)
48 where
49 F: FnMut(&T) -> usize,
50 {
51 let (read, write) = self.as_read_write();
52
53 let len = write.len();
54
55 for t in read {
56 let index = indexer(t);
57 if index >= len {
58 return;
59 }
60 let write_ptr = write[index].as_mut_ptr();
61 unsafe {
62 // SAFETY: both pointers are valid for T, aligned, and nonoverlapping
63 write_ptr.copy_from_nonoverlapping(t as *const T, 1);
64 }
65 }
66 }
67
68 /// Returns the current read and write buffers.
69 fn as_read_write(&mut self) -> (&[T], &mut [MaybeUninit<T>]) {
70 let (read, write): (&[MaybeUninit<T>], &mut [MaybeUninit<T>]) = if self.slice_is_write {
71 (self.scratch.as_ref(), self.slice)
72 } else {
73 (self.slice, self.scratch.as_mut())
74 };
75
76 // SAFETY: The read buffer is always initialized.
77 let read = unsafe { slice_assume_init_ref(read) };
78
79 (read, write)
80 }
81
82 /// Swaps the read and write buffer, committing the write buffer state.
83 ///
84 /// # Safety
85 ///
86 /// The caller must ensure that every element of the write buffer was
87 /// written to before calling this function.
88 pub unsafe fn swap(&mut self) {
89 self.slice_is_write = !self.slice_is_write;
90 }
91}
92
93/// Ensures that the input slice contains all the original elements.
94impl<'a, T> Drop for DoubleBuffer<'a, T> {
95 fn drop(&mut self) {
96 if self.slice_is_write {
97 // The input slice is the write buffer, copy the consistent state from the read buffer
98 unsafe {
99 // SAFETY: `scratch` is the read buffer, it is initialized. The length is the same.
100 self.slice
101 .as_mut_ptr()
102 .copy_from_nonoverlapping(self.scratch.as_ptr(), self.slice.len());
103 }
104 self.slice_is_write = false;
105 }
106 }
107}
108
109/// Get a slice of the initialized items.
110///
111/// # Safety
112///
113/// The caller must ensure that all the items are initialized.
114#[inline(always)]
115pub unsafe fn slice_assume_init_ref<T>(slice: &[MaybeUninit<T>]) -> &[T] {
116 // SAFETY: `[MaybeUninit<T>]` and `[T]` have the same layout.
117 unsafe { slice::from_raw_parts(slice.as_ptr() as *const T, slice.len()) }
118}
119
120/// View the mutable slice of `T` as a slice of `MaybeUnint<T>`.
121///
122/// # Safety
123///
124/// The caller must ensure that all the items of the returned slice are
125/// initialized before dropping it.
126#[inline(always)]
127pub unsafe fn slice_as_uninit_mut<T>(slice: &mut [T]) -> &mut [MaybeUninit<T>] {
128 // SAFETY: `[MaybeUninit<T>]` and `[T]` have the same layout.
129 unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut MaybeUninit<T>, slice.len()) }
130}