radsort/
double_buffer.rs

1use core::{mem::MaybeUninit, slice};
2
3use alloc::{boxed::Box, vec::Vec};
4
5/// Double buffer. Wraps a mutable slice and allocates a scratch memory of the same size, so that
6/// elements can be freely scattered from buffer to buffer.
7///
8/// # Drop behavior
9///
10/// Drop ensures that the mutable slice this buffer was constructed with contains all the original
11/// elements.
12pub struct DoubleBuffer<'a, T> {
13    slice: &'a mut [MaybeUninit<T>],
14    scratch: Box<[MaybeUninit<T>]>,
15    slice_is_write: bool,
16}
17
18impl<'a, T> DoubleBuffer<'a, T> {
19    /// Creates a double buffer, allocating a scratch buffer of the same length as the input slice.
20    ///
21    /// The supplied slice becomes the read buffer, the scratch buffer becomes the write buffer.
22    pub fn new(slice: &'a mut [T]) -> Self {
23        // SAFETY: The Drop impl ensures that the slice is initialized.
24        let slice = unsafe { slice_as_uninit_mut(slice) };
25        let scratch = {
26            let mut v = Vec::with_capacity(slice.len());
27            // SAFETY: we just allocated this capacity and MaybeUninit can be garbage.
28            unsafe {
29                v.set_len(slice.len());
30            }
31            v.into_boxed_slice()
32        };
33        DoubleBuffer {
34            slice,
35            scratch,
36            slice_is_write: false,
37        }
38    }
39
40    /// Scatters the elements from the read buffer to the computed indices in
41    /// the write buffer. The read buffer is iterated from the beginning.
42    ///
43    /// Call `swap` after this function to commit the write buffer state.
44    ///
45    /// Returning an out-of-bounds index from the indexer causes this function
46    /// to immediately return, without iterating over the remaining elements.
47    pub fn scatter<F>(&mut self, mut indexer: F)
48    where
49        F: FnMut(&T) -> usize,
50    {
51        let (read, write) = self.as_read_write();
52
53        let len = write.len();
54
55        for t in read {
56            let index = indexer(t);
57            if index >= len {
58                return;
59            }
60            let write_ptr = write[index].as_mut_ptr();
61            unsafe {
62                // SAFETY: both pointers are valid for T, aligned, and nonoverlapping
63                write_ptr.copy_from_nonoverlapping(t as *const T, 1);
64            }
65        }
66    }
67
68    /// Returns the current read and write buffers.
69    fn as_read_write(&mut self) -> (&[T], &mut [MaybeUninit<T>]) {
70        let (read, write): (&[MaybeUninit<T>], &mut [MaybeUninit<T>]) = if self.slice_is_write {
71            (self.scratch.as_ref(), self.slice)
72        } else {
73            (self.slice, self.scratch.as_mut())
74        };
75
76        // SAFETY: The read buffer is always initialized.
77        let read = unsafe { slice_assume_init_ref(read) };
78
79        (read, write)
80    }
81
82    /// Swaps the read and write buffer, committing the write buffer state.
83    ///
84    /// # Safety
85    ///
86    /// The caller must ensure that every element of the write buffer was
87    /// written to before calling this function.
88    pub unsafe fn swap(&mut self) {
89        self.slice_is_write = !self.slice_is_write;
90    }
91}
92
93/// Ensures that the input slice contains all the original elements.
94impl<'a, T> Drop for DoubleBuffer<'a, T> {
95    fn drop(&mut self) {
96        if self.slice_is_write {
97            // The input slice is the write buffer, copy the consistent state from the read buffer
98            unsafe {
99                // SAFETY: `scratch` is the read buffer, it is initialized. The length is the same.
100                self.slice
101                    .as_mut_ptr()
102                    .copy_from_nonoverlapping(self.scratch.as_ptr(), self.slice.len());
103            }
104            self.slice_is_write = false;
105        }
106    }
107}
108
109/// Get a slice of the initialized items.
110///
111/// # Safety
112///
113/// The caller must ensure that all the items are initialized.
114#[inline(always)]
115pub unsafe fn slice_assume_init_ref<T>(slice: &[MaybeUninit<T>]) -> &[T] {
116    // SAFETY: `[MaybeUninit<T>]` and `[T]` have the same layout.
117    unsafe { slice::from_raw_parts(slice.as_ptr() as *const T, slice.len()) }
118}
119
120/// View the mutable slice of `T` as a slice of `MaybeUnint<T>`.
121///
122/// # Safety
123///
124/// The caller must ensure that all the items of the returned slice are
125/// initialized before dropping it.
126#[inline(always)]
127pub unsafe fn slice_as_uninit_mut<T>(slice: &mut [T]) -> &mut [MaybeUninit<T>] {
128    // SAFETY: `[MaybeUninit<T>]` and `[T]` have the same layout.
129    unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut MaybeUninit<T>, slice.len()) }
130}