radsort/double_buffer.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
use core::{mem::MaybeUninit, slice};
use alloc::{boxed::Box, vec::Vec};
/// Double buffer. Wraps a mutable slice and allocates a scratch memory of the same size, so that
/// elements can be freely scattered from buffer to buffer.
///
/// # Drop behavior
///
/// Drop ensures that the mutable slice this buffer was constructed with contains all the original
/// elements.
pub struct DoubleBuffer<'a, T> {
slice: &'a mut [MaybeUninit<T>],
scratch: Box<[MaybeUninit<T>]>,
slice_is_write: bool,
}
impl<'a, T> DoubleBuffer<'a, T> {
/// Creates a double buffer, allocating a scratch buffer of the same length as the input slice.
///
/// The supplied slice becomes the read buffer, the scratch buffer becomes the write buffer.
pub fn new(slice: &'a mut [T]) -> Self {
// SAFETY: The Drop impl ensures that the slice is initialized.
let slice = unsafe { slice_as_uninit_mut(slice) };
let scratch = {
let mut v = Vec::with_capacity(slice.len());
// SAFETY: we just allocated this capacity and MaybeUninit can be garbage.
unsafe {
v.set_len(slice.len());
}
v.into_boxed_slice()
};
DoubleBuffer {
slice,
scratch,
slice_is_write: false,
}
}
/// Scatters the elements from the read buffer to the computed indices in
/// the write buffer. The read buffer is iterated from the beginning.
///
/// Call `swap` after this function to commit the write buffer state.
///
/// Returning an out-of-bounds index from the indexer causes this function
/// to immediately return, without iterating over the remaining elements.
pub fn scatter<F>(&mut self, mut indexer: F)
where
F: FnMut(&T) -> usize,
{
let (read, write) = self.as_read_write();
let len = write.len();
for t in read {
let index = indexer(t);
if index >= len {
return;
}
let write_ptr = write[index].as_mut_ptr();
unsafe {
// SAFETY: both pointers are valid for T, aligned, and nonoverlapping
write_ptr.copy_from_nonoverlapping(t as *const T, 1);
}
}
}
/// Returns the current read and write buffers.
fn as_read_write(&mut self) -> (&[T], &mut [MaybeUninit<T>]) {
let (read, write): (&[MaybeUninit<T>], &mut [MaybeUninit<T>]) = if self.slice_is_write {
(self.scratch.as_ref(), self.slice)
} else {
(self.slice, self.scratch.as_mut())
};
// SAFETY: The read buffer is always initialized.
let read = unsafe { slice_assume_init_ref(read) };
(read, write)
}
/// Swaps the read and write buffer, committing the write buffer state.
///
/// # Safety
///
/// The caller must ensure that every element of the write buffer was
/// written to before calling this function.
pub unsafe fn swap(&mut self) {
self.slice_is_write = !self.slice_is_write;
}
}
/// Ensures that the input slice contains all the original elements.
impl<'a, T> Drop for DoubleBuffer<'a, T> {
fn drop(&mut self) {
if self.slice_is_write {
// The input slice is the write buffer, copy the consistent state from the read buffer
unsafe {
// SAFETY: `scratch` is the read buffer, it is initialized. The length is the same.
self.slice
.as_mut_ptr()
.copy_from_nonoverlapping(self.scratch.as_ptr(), self.slice.len());
}
self.slice_is_write = false;
}
}
}
/// Get a slice of the initialized items.
///
/// # Safety
///
/// The caller must ensure that all the items are initialized.
#[inline(always)]
pub unsafe fn slice_assume_init_ref<T>(slice: &[MaybeUninit<T>]) -> &[T] {
// SAFETY: `[MaybeUninit<T>]` and `[T]` have the same layout.
unsafe { slice::from_raw_parts(slice.as_ptr() as *const T, slice.len()) }
}
/// View the mutable slice of `T` as a slice of `MaybeUnint<T>`.
///
/// # Safety
///
/// The caller must ensure that all the items of the returned slice are
/// initialized before dropping it.
#[inline(always)]
pub unsafe fn slice_as_uninit_mut<T>(slice: &mut [T]) -> &mut [MaybeUninit<T>] {
// SAFETY: `[MaybeUninit<T>]` and `[T]` have the same layout.
unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut MaybeUninit<T>, slice.len()) }
}