rand/distr/slice.rs
1// Copyright 2021 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Distributions over slices
10
11use core::num::NonZeroUsize;
12
13use crate::distr::uniform::{UniformSampler, UniformUsize};
14use crate::distr::Distribution;
15#[cfg(feature = "alloc")]
16use alloc::string::String;
17
18/// A distribution to uniformly sample elements of a slice
19///
20/// Like [`IndexedRandom::choose`], this uniformly samples elements of a slice
21/// without modification of the slice (so called "sampling with replacement").
22/// This distribution object may be a little faster for repeated sampling (but
23/// slower for small numbers of samples).
24///
25/// ## Examples
26///
27/// Since this is a distribution, [`Rng::sample_iter`] and
28/// [`Distribution::sample_iter`] may be used, for example:
29/// ```
30/// use rand::distr::{Distribution, slice::Choose};
31///
32/// let vowels = ['a', 'e', 'i', 'o', 'u'];
33/// let vowels_dist = Choose::new(&vowels).unwrap();
34///
35/// // build a string of 10 vowels
36/// let vowel_string: String = vowels_dist
37///     .sample_iter(&mut rand::rng())
38///     .take(10)
39///     .collect();
40///
41/// println!("{}", vowel_string);
42/// assert_eq!(vowel_string.len(), 10);
43/// assert!(vowel_string.chars().all(|c| vowels.contains(&c)));
44/// ```
45///
46/// For a single sample, [`IndexedRandom::choose`] may be preferred:
47/// ```
48/// use rand::seq::IndexedRandom;
49///
50/// let vowels = ['a', 'e', 'i', 'o', 'u'];
51/// let mut rng = rand::rng();
52///
53/// println!("{}", vowels.choose(&mut rng).unwrap());
54/// ```
55///
56/// [`IndexedRandom::choose`]: crate::seq::IndexedRandom::choose
57/// [`Rng::sample_iter`]: crate::Rng::sample_iter
58#[derive(Debug, Clone, Copy)]
59pub struct Choose<'a, T> {
60    slice: &'a [T],
61    range: UniformUsize,
62    num_choices: NonZeroUsize,
63}
64
65impl<'a, T> Choose<'a, T> {
66    /// Create a new `Choose` instance which samples uniformly from the slice.
67    ///
68    /// Returns error [`Empty`] if the slice is empty.
69    pub fn new(slice: &'a [T]) -> Result<Self, Empty> {
70        let num_choices = NonZeroUsize::new(slice.len()).ok_or(Empty)?;
71
72        Ok(Self {
73            slice,
74            range: UniformUsize::new(0, num_choices.get()).unwrap(),
75            num_choices,
76        })
77    }
78
79    /// Returns the count of choices in this distribution
80    pub fn num_choices(&self) -> NonZeroUsize {
81        self.num_choices
82    }
83}
84
85impl<'a, T> Distribution<&'a T> for Choose<'a, T> {
86    fn sample<R: crate::Rng + ?Sized>(&self, rng: &mut R) -> &'a T {
87        let idx = self.range.sample(rng);
88
89        debug_assert!(
90            idx < self.slice.len(),
91            "Uniform::new(0, {}) somehow returned {}",
92            self.slice.len(),
93            idx
94        );
95
96        // Safety: at construction time, it was ensured that the slice was
97        // non-empty, and that the `Uniform` range produces values in range
98        // for the slice
99        unsafe { self.slice.get_unchecked(idx) }
100    }
101}
102
103/// Error: empty slice
104///
105/// This error is returned when [`Choose::new`] is given an empty slice.
106#[derive(Debug, Clone, Copy)]
107pub struct Empty;
108
109impl core::fmt::Display for Empty {
110    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
111        write!(
112            f,
113            "Tried to create a `rand::distr::slice::Choose` with an empty slice"
114        )
115    }
116}
117
118#[cfg(feature = "std")]
119impl std::error::Error for Empty {}
120
121#[cfg(feature = "alloc")]
122impl super::SampleString for Choose<'_, char> {
123    fn append_string<R: crate::Rng + ?Sized>(&self, rng: &mut R, string: &mut String, len: usize) {
124        // Get the max char length to minimize extra space.
125        // Limit this check to avoid searching for long slice.
126        let max_char_len = if self.slice.len() < 200 {
127            self.slice
128                .iter()
129                .try_fold(1, |max_len, char| {
130                    // When the current max_len is 4, the result max_char_len will be 4.
131                    Some(max_len.max(char.len_utf8())).filter(|len| *len < 4)
132                })
133                .unwrap_or(4)
134        } else {
135            4
136        };
137
138        // Split the extension of string to reuse the unused capacities.
139        // Skip the split for small length or only ascii slice.
140        let mut extend_len = if max_char_len == 1 || len < 100 {
141            len
142        } else {
143            len / 4
144        };
145        let mut remain_len = len;
146        while extend_len > 0 {
147            string.reserve(max_char_len * extend_len);
148            string.extend(self.sample_iter(&mut *rng).take(extend_len));
149            remain_len -= extend_len;
150            extend_len = extend_len.min(remain_len);
151        }
152    }
153}
154
155#[cfg(test)]
156mod test {
157    use super::*;
158    use core::iter;
159
160    #[test]
161    fn value_stability() {
162        let rng = crate::test::rng(651);
163        let slice = Choose::new(b"escaped emus explore extensively").unwrap();
164        let expected = b"eaxee";
165        assert!(iter::zip(slice.sample_iter(rng), expected).all(|(a, b)| a == b));
166    }
167}