radsort/
double_buffer.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
use core::{mem::MaybeUninit, slice};

use alloc::{boxed::Box, vec::Vec};

/// Double buffer. Wraps a mutable slice and allocates a scratch memory of the same size, so that
/// elements can be freely scattered from buffer to buffer.
///
/// # Drop behavior
///
/// Drop ensures that the mutable slice this buffer was constructed with contains all the original
/// elements.
pub struct DoubleBuffer<'a, T> {
    slice: &'a mut [MaybeUninit<T>],
    scratch: Box<[MaybeUninit<T>]>,
    slice_is_write: bool,
}

impl<'a, T> DoubleBuffer<'a, T> {
    /// Creates a double buffer, allocating a scratch buffer of the same length as the input slice.
    ///
    /// The supplied slice becomes the read buffer, the scratch buffer becomes the write buffer.
    pub fn new(slice: &'a mut [T]) -> Self {
        // SAFETY: The Drop impl ensures that the slice is initialized.
        let slice = unsafe { slice_as_uninit_mut(slice) };
        let scratch = {
            let mut v = Vec::with_capacity(slice.len());
            // SAFETY: we just allocated this capacity and MaybeUninit can be garbage.
            unsafe {
                v.set_len(slice.len());
            }
            v.into_boxed_slice()
        };
        DoubleBuffer {
            slice,
            scratch,
            slice_is_write: false,
        }
    }

    /// Scatters the elements from the read buffer to the computed indices in
    /// the write buffer. The read buffer is iterated from the beginning.
    ///
    /// Call `swap` after this function to commit the write buffer state.
    ///
    /// Returning an out-of-bounds index from the indexer causes this function
    /// to immediately return, without iterating over the remaining elements.
    pub fn scatter<F>(&mut self, mut indexer: F)
    where
        F: FnMut(&T) -> usize,
    {
        let (read, write) = self.as_read_write();

        let len = write.len();

        for t in read {
            let index = indexer(t);
            if index >= len {
                return;
            }
            let write_ptr = write[index].as_mut_ptr();
            unsafe {
                // SAFETY: both pointers are valid for T, aligned, and nonoverlapping
                write_ptr.copy_from_nonoverlapping(t as *const T, 1);
            }
        }
    }

    /// Returns the current read and write buffers.
    fn as_read_write(&mut self) -> (&[T], &mut [MaybeUninit<T>]) {
        let (read, write): (&[MaybeUninit<T>], &mut [MaybeUninit<T>]) = if self.slice_is_write {
            (self.scratch.as_ref(), self.slice)
        } else {
            (self.slice, self.scratch.as_mut())
        };

        // SAFETY: The read buffer is always initialized.
        let read = unsafe { slice_assume_init_ref(read) };

        (read, write)
    }

    /// Swaps the read and write buffer, committing the write buffer state.
    ///
    /// # Safety
    ///
    /// The caller must ensure that every element of the write buffer was
    /// written to before calling this function.
    pub unsafe fn swap(&mut self) {
        self.slice_is_write = !self.slice_is_write;
    }
}

/// Ensures that the input slice contains all the original elements.
impl<'a, T> Drop for DoubleBuffer<'a, T> {
    fn drop(&mut self) {
        if self.slice_is_write {
            // The input slice is the write buffer, copy the consistent state from the read buffer
            unsafe {
                // SAFETY: `scratch` is the read buffer, it is initialized. The length is the same.
                self.slice
                    .as_mut_ptr()
                    .copy_from_nonoverlapping(self.scratch.as_ptr(), self.slice.len());
            }
            self.slice_is_write = false;
        }
    }
}

/// Get a slice of the initialized items.
///
/// # Safety
///
/// The caller must ensure that all the items are initialized.
#[inline(always)]
pub unsafe fn slice_assume_init_ref<T>(slice: &[MaybeUninit<T>]) -> &[T] {
    // SAFETY: `[MaybeUninit<T>]` and `[T]` have the same layout.
    unsafe { slice::from_raw_parts(slice.as_ptr() as *const T, slice.len()) }
}

/// View the mutable slice of `T` as a slice of `MaybeUnint<T>`.
///
/// # Safety
///
/// The caller must ensure that all the items of the returned slice are
/// initialized before dropping it.
#[inline(always)]
pub unsafe fn slice_as_uninit_mut<T>(slice: &mut [T]) -> &mut [MaybeUninit<T>] {
    // SAFETY: `[MaybeUninit<T>]` and `[T]` have the same layout.
    unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut MaybeUninit<T>, slice.len()) }
}