338 lines
11 KiB
Rust
338 lines
11 KiB
Rust
// Copyright 2017 The Rust Project Developers. See the COPYRIGHT
|
|
// file at the top-level directory of this distribution and at
|
|
// http://rust-lang.org/COPYRIGHT.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
|
|
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
|
|
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
|
|
// option. This file may not be copied, modified, or distributed
|
|
// except according to those terms.
|
|
|
|
//! Functions for randomly accessing and sampling sequences.
|
|
|
|
use super::Rng;
|
|
|
|
// This crate is only enabled when either std or alloc is available.
|
|
// BTreeMap is not as fast in tests, but better than nothing.
|
|
#[cfg(feature="std")] use std::collections::HashMap;
|
|
#[cfg(not(feature="std"))] use alloc::btree_map::BTreeMap;
|
|
|
|
#[cfg(not(feature="std"))] use alloc::Vec;
|
|
|
|
/// Randomly sample `amount` elements from a finite iterator.
|
|
///
|
|
/// The following can be returned:
|
|
/// - `Ok`: `Vec` of `amount` non-repeating randomly sampled elements. The order is not random.
|
|
/// - `Err`: `Vec` of all the elements from `iterable` in sequential order. This happens when the
|
|
/// length of `iterable` was less than `amount`. This is considered an error since exactly
|
|
/// `amount` elements is typically expected.
|
|
///
|
|
/// This implementation uses `O(len(iterable))` time and `O(amount)` memory.
|
|
///
|
|
/// # Example
|
|
///
|
|
/// ```rust
|
|
/// use rand::{thread_rng, seq};
|
|
///
|
|
/// let mut rng = thread_rng();
|
|
/// let sample = seq::sample_iter(&mut rng, 1..100, 5).unwrap();
|
|
/// println!("{:?}", sample);
|
|
/// ```
|
|
pub fn sample_iter<T, I, R>(rng: &mut R, iterable: I, amount: usize) -> Result<Vec<T>, Vec<T>>
|
|
where I: IntoIterator<Item=T>,
|
|
R: Rng,
|
|
{
|
|
let mut iter = iterable.into_iter();
|
|
let mut reservoir = Vec::with_capacity(amount);
|
|
reservoir.extend(iter.by_ref().take(amount));
|
|
|
|
// Continue unless the iterator was exhausted
|
|
//
|
|
// note: this prevents iterators that "restart" from causing problems.
|
|
// If the iterator stops once, then so do we.
|
|
if reservoir.len() == amount {
|
|
for (i, elem) in iter.enumerate() {
|
|
let k = rng.gen_range(0, i + 1 + amount);
|
|
if let Some(spot) = reservoir.get_mut(k) {
|
|
*spot = elem;
|
|
}
|
|
}
|
|
Ok(reservoir)
|
|
} else {
|
|
// Don't hang onto extra memory. There is a corner case where
|
|
// `amount` was much less than `len(iterable)`.
|
|
reservoir.shrink_to_fit();
|
|
Err(reservoir)
|
|
}
|
|
}
|
|
|
|
/// Randomly sample exactly `amount` values from `slice`.
|
|
///
|
|
/// The values are non-repeating and in random order.
|
|
///
|
|
/// This implementation uses `O(amount)` time and memory.
|
|
///
|
|
/// Panics if `amount > slice.len()`
|
|
///
|
|
/// # Example
|
|
///
|
|
/// ```rust
|
|
/// use rand::{thread_rng, seq};
|
|
///
|
|
/// let mut rng = thread_rng();
|
|
/// let values = vec![5, 6, 1, 3, 4, 6, 7];
|
|
/// println!("{:?}", seq::sample_slice(&mut rng, &values, 3));
|
|
/// ```
|
|
pub fn sample_slice<R, T>(rng: &mut R, slice: &[T], amount: usize) -> Vec<T>
|
|
where R: Rng,
|
|
T: Clone
|
|
{
|
|
let indices = sample_indices(rng, slice.len(), amount);
|
|
|
|
let mut out = Vec::with_capacity(amount);
|
|
out.extend(indices.iter().map(|i| slice[*i].clone()));
|
|
out
|
|
}
|
|
|
|
/// Randomly sample exactly `amount` references from `slice`.
|
|
///
|
|
/// The references are non-repeating and in random order.
|
|
///
|
|
/// This implementation uses `O(amount)` time and memory.
|
|
///
|
|
/// Panics if `amount > slice.len()`
|
|
///
|
|
/// # Example
|
|
///
|
|
/// ```rust
|
|
/// use rand::{thread_rng, seq};
|
|
///
|
|
/// let mut rng = thread_rng();
|
|
/// let values = vec![5, 6, 1, 3, 4, 6, 7];
|
|
/// println!("{:?}", seq::sample_slice_ref(&mut rng, &values, 3));
|
|
/// ```
|
|
pub fn sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T>
|
|
where R: Rng
|
|
{
|
|
let indices = sample_indices(rng, slice.len(), amount);
|
|
|
|
let mut out = Vec::with_capacity(amount);
|
|
out.extend(indices.iter().map(|i| &slice[*i]));
|
|
out
|
|
}
|
|
|
|
/// Randomly sample exactly `amount` indices from `0..length`.
|
|
///
|
|
/// The values are non-repeating and in random order.
|
|
///
|
|
/// This implementation uses `O(amount)` time and memory.
|
|
///
|
|
/// This method is used internally by the slice sampling methods, but it can sometimes be useful to
|
|
/// have the indices themselves so this is provided as an alternative.
|
|
///
|
|
/// Panics if `amount > length`
|
|
pub fn sample_indices<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize>
|
|
where R: Rng,
|
|
{
|
|
if amount > length {
|
|
panic!("`amount` must be less than or equal to `slice.len()`");
|
|
}
|
|
|
|
// We are going to have to allocate at least `amount` for the output no matter what. However,
|
|
// if we use the `cached` version we will have to allocate `amount` as a HashMap as well since
|
|
// it inserts an element for every loop.
|
|
//
|
|
// Therefore, if `amount >= length / 2` then inplace will be both faster and use less memory.
|
|
// In fact, benchmarks show the inplace version is faster for length up to about 20 times
|
|
// faster than amount.
|
|
//
|
|
// TODO: there is probably even more fine-tuning that can be done here since
|
|
// `HashMap::with_capacity(amount)` probably allocates more than `amount` in practice,
|
|
// and a trade off could probably be made between memory/cpu, since hashmap operations
|
|
// are slower than array index swapping.
|
|
if amount >= length / 20 {
|
|
sample_indices_inplace(rng, length, amount)
|
|
} else {
|
|
sample_indices_cache(rng, length, amount)
|
|
}
|
|
}
|
|
|
|
/// Sample an amount of indices using an inplace partial fisher yates method.
|
|
///
|
|
/// This allocates the entire `length` of indices and randomizes only the first `amount`.
|
|
/// It then truncates to `amount` and returns.
|
|
///
|
|
/// This is better than using a HashMap "cache" when `amount >= length / 2` since it does not
|
|
/// require allocating an extra cache and is much faster.
|
|
fn sample_indices_inplace<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize>
|
|
where R: Rng,
|
|
{
|
|
debug_assert!(amount <= length);
|
|
let mut indices: Vec<usize> = Vec::with_capacity(length);
|
|
indices.extend(0..length);
|
|
for i in 0..amount {
|
|
let j: usize = rng.gen_range(i, length);
|
|
let tmp = indices[i];
|
|
indices[i] = indices[j];
|
|
indices[j] = tmp;
|
|
}
|
|
indices.truncate(amount);
|
|
debug_assert_eq!(indices.len(), amount);
|
|
indices
|
|
}
|
|
|
|
|
|
/// This method performs a partial fisher-yates on a range of indices using a HashMap
|
|
/// as a cache to record potential collisions.
|
|
///
|
|
/// The cache avoids allocating the entire `length` of values. This is especially useful when
|
|
/// `amount <<< length`, i.e. select 3 non-repeating from 1_000_000
|
|
fn sample_indices_cache<R>(
|
|
rng: &mut R,
|
|
length: usize,
|
|
amount: usize,
|
|
) -> Vec<usize>
|
|
where R: Rng,
|
|
{
|
|
debug_assert!(amount <= length);
|
|
#[cfg(feature="std")] let mut cache = HashMap::with_capacity(amount);
|
|
#[cfg(not(feature="std"))] let mut cache = BTreeMap::new();
|
|
let mut out = Vec::with_capacity(amount);
|
|
for i in 0..amount {
|
|
let j: usize = rng.gen_range(i, length);
|
|
|
|
// equiv: let tmp = slice[i];
|
|
let tmp = match cache.get(&i) {
|
|
Some(e) => *e,
|
|
None => i,
|
|
};
|
|
|
|
// equiv: slice[i] = slice[j];
|
|
let x = match cache.get(&j) {
|
|
Some(x) => *x,
|
|
None => j,
|
|
};
|
|
|
|
// equiv: slice[j] = tmp;
|
|
cache.insert(j, tmp);
|
|
|
|
// note that in the inplace version, slice[i] is automatically "returned" value
|
|
out.push(x);
|
|
}
|
|
debug_assert_eq!(out.len(), amount);
|
|
out
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use super::*;
|
|
use {thread_rng, XorShiftRng, SeedableRng};
|
|
|
|
#[test]
|
|
fn test_sample_iter() {
|
|
let min_val = 1;
|
|
let max_val = 100;
|
|
|
|
let mut r = thread_rng();
|
|
let vals = (min_val..max_val).collect::<Vec<i32>>();
|
|
let small_sample = sample_iter(&mut r, vals.iter(), 5).unwrap();
|
|
let large_sample = sample_iter(&mut r, vals.iter(), vals.len() + 5).unwrap_err();
|
|
|
|
assert_eq!(small_sample.len(), 5);
|
|
assert_eq!(large_sample.len(), vals.len());
|
|
// no randomization happens when amount >= len
|
|
assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
|
|
|
|
assert!(small_sample.iter().all(|e| {
|
|
**e >= min_val && **e <= max_val
|
|
}));
|
|
}
|
|
#[test]
|
|
fn test_sample_slice_boundaries() {
|
|
let empty: &[u8] = &[];
|
|
|
|
let mut r = thread_rng();
|
|
|
|
// sample 0 items
|
|
assert_eq!(sample_slice(&mut r, empty, 0), vec![]);
|
|
assert_eq!(sample_slice(&mut r, &[42, 2, 42], 0), vec![]);
|
|
|
|
// sample 1 item
|
|
assert_eq!(sample_slice(&mut r, &[42], 1), vec![42]);
|
|
let v = sample_slice(&mut r, &[1, 42], 1)[0];
|
|
assert!(v == 1 || v == 42);
|
|
|
|
// sample "all" the items
|
|
let v = sample_slice(&mut r, &[42, 133], 2);
|
|
assert!(v == vec![42, 133] || v == vec![133, 42]);
|
|
|
|
assert_eq!(sample_indices_inplace(&mut r, 0, 0), vec![]);
|
|
assert_eq!(sample_indices_inplace(&mut r, 1, 0), vec![]);
|
|
assert_eq!(sample_indices_inplace(&mut r, 1, 1), vec![0]);
|
|
|
|
assert_eq!(sample_indices_cache(&mut r, 0, 0), vec![]);
|
|
assert_eq!(sample_indices_cache(&mut r, 1, 0), vec![]);
|
|
assert_eq!(sample_indices_cache(&mut r, 1, 1), vec![0]);
|
|
|
|
// Make sure lucky 777's aren't lucky
|
|
let slice = &[42, 777];
|
|
let mut num_42 = 0;
|
|
let total = 1000;
|
|
for _ in 0..total {
|
|
let v = sample_slice(&mut r, slice, 1);
|
|
assert_eq!(v.len(), 1);
|
|
let v = v[0];
|
|
assert!(v == 42 || v == 777);
|
|
if v == 42 {
|
|
num_42 += 1;
|
|
}
|
|
}
|
|
let ratio_42 = num_42 as f64 / 1000 as f64;
|
|
assert!(0.4 <= ratio_42 || ratio_42 <= 0.6, "{}", ratio_42);
|
|
}
|
|
|
|
#[test]
|
|
fn test_sample_slice() {
|
|
let xor_rng = XorShiftRng::from_seed;
|
|
|
|
let max_range = 100;
|
|
let mut r = thread_rng();
|
|
|
|
for length in 1usize..max_range {
|
|
let amount = r.gen_range(0, length);
|
|
let seed: [u32; 4] = [
|
|
r.next_u32(), r.next_u32(), r.next_u32(), r.next_u32()
|
|
];
|
|
|
|
println!("Selecting indices: len={}, amount={}, seed={:?}", length, amount, seed);
|
|
|
|
// assert that the two index methods give exactly the same result
|
|
let inplace = sample_indices_inplace(
|
|
&mut xor_rng(seed), length, amount);
|
|
let cache = sample_indices_cache(
|
|
&mut xor_rng(seed), length, amount);
|
|
assert_eq!(inplace, cache);
|
|
|
|
// assert the basics work
|
|
let regular = sample_indices(
|
|
&mut xor_rng(seed), length, amount);
|
|
assert_eq!(regular.len(), amount);
|
|
assert!(regular.iter().all(|e| *e < length));
|
|
assert_eq!(regular, inplace);
|
|
|
|
// also test that sampling the slice works
|
|
let vec: Vec<usize> = (0..length).collect();
|
|
{
|
|
let result = sample_slice(&mut xor_rng(seed), &vec, amount);
|
|
assert_eq!(result, regular);
|
|
}
|
|
|
|
{
|
|
let result = sample_slice_ref(&mut xor_rng(seed), &vec, amount);
|
|
let expected = regular.iter().map(|v| v).collect::<Vec<_>>();
|
|
assert_eq!(result, expected);
|
|
}
|
|
}
|
|
}
|
|
}
|