| // Copyright 2017 The Rust Project Developers. See the COPYRIGHT |
| // file at the top-level directory of this distribution and at |
| // http://rust-lang.org/COPYRIGHT. |
| // |
| // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or |
| // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license |
| // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your |
| // option. This file may not be copied, modified, or distributed |
| // except according to those terms. |
| |
| //! Functions for randomly accessing and sampling sequences. |
| |
| use super::Rng; |
| |
| // This crate is only enabled when either std or alloc is available. |
| // BTreeMap is not as fast in tests, but better than nothing. |
| #[cfg(feature="std")] use std::collections::HashMap; |
| #[cfg(not(feature="std"))] use alloc::btree_map::BTreeMap; |
| |
| #[cfg(not(feature="std"))] use alloc::Vec; |
| |
| /// Randomly sample `amount` elements from a finite iterator. |
| /// |
| /// The following can be returned: |
| /// - `Ok`: `Vec` of `amount` non-repeating randomly sampled elements. The order is not random. |
| /// - `Err`: `Vec` of all the elements from `iterable` in sequential order. This happens when the |
| /// length of `iterable` was less than `amount`. This is considered an error since exactly |
| /// `amount` elements is typically expected. |
| /// |
| /// This implementation uses `O(len(iterable))` time and `O(amount)` memory. |
| /// |
| /// # Example |
| /// |
| /// ```rust |
| /// use rand::{thread_rng, seq}; |
| /// |
| /// let mut rng = thread_rng(); |
| /// let sample = seq::sample_iter(&mut rng, 1..100, 5).unwrap(); |
| /// println!("{:?}", sample); |
| /// ``` |
| pub fn sample_iter<T, I, R>(rng: &mut R, iterable: I, amount: usize) -> Result<Vec<T>, Vec<T>> |
| where I: IntoIterator<Item=T>, |
| R: Rng, |
| { |
| let mut iter = iterable.into_iter(); |
| let mut reservoir = Vec::with_capacity(amount); |
| reservoir.extend(iter.by_ref().take(amount)); |
| |
| // Continue unless the iterator was exhausted |
| // |
| // note: this prevents iterators that "restart" from causing problems. |
| // If the iterator stops once, then so do we. |
| if reservoir.len() == amount { |
| for (i, elem) in iter.enumerate() { |
| let k = rng.gen_range(0, i + 1 + amount); |
| if let Some(spot) = reservoir.get_mut(k) { |
| *spot = elem; |
| } |
| } |
| Ok(reservoir) |
| } else { |
| // Don't hang onto extra memory. There is a corner case where |
| // `amount` was much less than `len(iterable)`. |
| reservoir.shrink_to_fit(); |
| Err(reservoir) |
| } |
| } |
| |
| /// Randomly sample exactly `amount` values from `slice`. |
| /// |
| /// The values are non-repeating and in random order. |
| /// |
| /// This implementation uses `O(amount)` time and memory. |
| /// |
| /// Panics if `amount > slice.len()` |
| /// |
| /// # Example |
| /// |
| /// ```rust |
| /// use rand::{thread_rng, seq}; |
| /// |
| /// let mut rng = thread_rng(); |
| /// let values = vec![5, 6, 1, 3, 4, 6, 7]; |
| /// println!("{:?}", seq::sample_slice(&mut rng, &values, 3)); |
| /// ``` |
| pub fn sample_slice<R, T>(rng: &mut R, slice: &[T], amount: usize) -> Vec<T> |
| where R: Rng, |
| T: Clone |
| { |
| let indices = sample_indices(rng, slice.len(), amount); |
| |
| let mut out = Vec::with_capacity(amount); |
| out.extend(indices.iter().map(|i| slice[*i].clone())); |
| out |
| } |
| |
| /// Randomly sample exactly `amount` references from `slice`. |
| /// |
| /// The references are non-repeating and in random order. |
| /// |
| /// This implementation uses `O(amount)` time and memory. |
| /// |
| /// Panics if `amount > slice.len()` |
| /// |
| /// # Example |
| /// |
| /// ```rust |
| /// use rand::{thread_rng, seq}; |
| /// |
| /// let mut rng = thread_rng(); |
| /// let values = vec![5, 6, 1, 3, 4, 6, 7]; |
| /// println!("{:?}", seq::sample_slice_ref(&mut rng, &values, 3)); |
| /// ``` |
| pub fn sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T> |
| where R: Rng |
| { |
| let indices = sample_indices(rng, slice.len(), amount); |
| |
| let mut out = Vec::with_capacity(amount); |
| out.extend(indices.iter().map(|i| &slice[*i])); |
| out |
| } |
| |
| /// Randomly sample exactly `amount` indices from `0..length`. |
| /// |
| /// The values are non-repeating and in random order. |
| /// |
| /// This implementation uses `O(amount)` time and memory. |
| /// |
| /// This method is used internally by the slice sampling methods, but it can sometimes be useful to |
| /// have the indices themselves so this is provided as an alternative. |
| /// |
| /// Panics if `amount > length` |
| pub fn sample_indices<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize> |
| where R: Rng, |
| { |
| if amount > length { |
| panic!("`amount` must be less than or equal to `slice.len()`"); |
| } |
| |
| // We are going to have to allocate at least `amount` for the output no matter what. However, |
| // if we use the `cached` version we will have to allocate `amount` as a HashMap as well since |
| // it inserts an element for every loop. |
| // |
| // Therefore, if `amount >= length / 2` then inplace will be both faster and use less memory. |
| // In fact, benchmarks show the inplace version is faster for length up to about 20 times |
| // faster than amount. |
| // |
| // TODO: there is probably even more fine-tuning that can be done here since |
| // `HashMap::with_capacity(amount)` probably allocates more than `amount` in practice, |
| // and a trade off could probably be made between memory/cpu, since hashmap operations |
| // are slower than array index swapping. |
| if amount >= length / 20 { |
| sample_indices_inplace(rng, length, amount) |
| } else { |
| sample_indices_cache(rng, length, amount) |
| } |
| } |
| |
| /// Sample an amount of indices using an inplace partial fisher yates method. |
| /// |
| /// This allocates the entire `length` of indices and randomizes only the first `amount`. |
| /// It then truncates to `amount` and returns. |
| /// |
| /// This is better than using a HashMap "cache" when `amount >= length / 2` since it does not |
| /// require allocating an extra cache and is much faster. |
| fn sample_indices_inplace<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize> |
| where R: Rng, |
| { |
| debug_assert!(amount <= length); |
| let mut indices: Vec<usize> = Vec::with_capacity(length); |
| indices.extend(0..length); |
| for i in 0..amount { |
| let j: usize = rng.gen_range(i, length); |
| let tmp = indices[i]; |
| indices[i] = indices[j]; |
| indices[j] = tmp; |
| } |
| indices.truncate(amount); |
| debug_assert_eq!(indices.len(), amount); |
| indices |
| } |
| |
| |
| /// This method performs a partial fisher-yates on a range of indices using a HashMap |
| /// as a cache to record potential collisions. |
| /// |
| /// The cache avoids allocating the entire `length` of values. This is especially useful when |
| /// `amount <<< length`, i.e. select 3 non-repeating from 1_000_000 |
| fn sample_indices_cache<R>( |
| rng: &mut R, |
| length: usize, |
| amount: usize, |
| ) -> Vec<usize> |
| where R: Rng, |
| { |
| debug_assert!(amount <= length); |
| #[cfg(feature="std")] let mut cache = HashMap::with_capacity(amount); |
| #[cfg(not(feature="std"))] let mut cache = BTreeMap::new(); |
| let mut out = Vec::with_capacity(amount); |
| for i in 0..amount { |
| let j: usize = rng.gen_range(i, length); |
| |
| // equiv: let tmp = slice[i]; |
| let tmp = match cache.get(&i) { |
| Some(e) => *e, |
| None => i, |
| }; |
| |
| // equiv: slice[i] = slice[j]; |
| let x = match cache.get(&j) { |
| Some(x) => *x, |
| None => j, |
| }; |
| |
| // equiv: slice[j] = tmp; |
| cache.insert(j, tmp); |
| |
| // note that in the inplace version, slice[i] is automatically "returned" value |
| out.push(x); |
| } |
| debug_assert_eq!(out.len(), amount); |
| out |
| } |
| |
| #[cfg(test)] |
| mod test { |
| use super::*; |
| use {thread_rng, XorShiftRng, SeedableRng}; |
| |
| #[test] |
| fn test_sample_iter() { |
| let min_val = 1; |
| let max_val = 100; |
| |
| let mut r = thread_rng(); |
| let vals = (min_val..max_val).collect::<Vec<i32>>(); |
| let small_sample = sample_iter(&mut r, vals.iter(), 5).unwrap(); |
| let large_sample = sample_iter(&mut r, vals.iter(), vals.len() + 5).unwrap_err(); |
| |
| assert_eq!(small_sample.len(), 5); |
| assert_eq!(large_sample.len(), vals.len()); |
| // no randomization happens when amount >= len |
| assert_eq!(large_sample, vals.iter().collect::<Vec<_>>()); |
| |
| assert!(small_sample.iter().all(|e| { |
| **e >= min_val && **e <= max_val |
| })); |
| } |
| #[test] |
| fn test_sample_slice_boundaries() { |
| let empty: &[u8] = &[]; |
| |
| let mut r = thread_rng(); |
| |
| // sample 0 items |
| assert_eq!(sample_slice(&mut r, empty, 0), vec![]); |
| assert_eq!(sample_slice(&mut r, &[42, 2, 42], 0), vec![]); |
| |
| // sample 1 item |
| assert_eq!(sample_slice(&mut r, &[42], 1), vec![42]); |
| let v = sample_slice(&mut r, &[1, 42], 1)[0]; |
| assert!(v == 1 || v == 42); |
| |
| // sample "all" the items |
| let v = sample_slice(&mut r, &[42, 133], 2); |
| assert!(v == vec![42, 133] || v == vec![133, 42]); |
| |
| assert_eq!(sample_indices_inplace(&mut r, 0, 0), vec![]); |
| assert_eq!(sample_indices_inplace(&mut r, 1, 0), vec![]); |
| assert_eq!(sample_indices_inplace(&mut r, 1, 1), vec![0]); |
| |
| assert_eq!(sample_indices_cache(&mut r, 0, 0), vec![]); |
| assert_eq!(sample_indices_cache(&mut r, 1, 0), vec![]); |
| assert_eq!(sample_indices_cache(&mut r, 1, 1), vec![0]); |
| |
| // Make sure lucky 777's aren't lucky |
| let slice = &[42, 777]; |
| let mut num_42 = 0; |
| let total = 1000; |
| for _ in 0..total { |
| let v = sample_slice(&mut r, slice, 1); |
| assert_eq!(v.len(), 1); |
| let v = v[0]; |
| assert!(v == 42 || v == 777); |
| if v == 42 { |
| num_42 += 1; |
| } |
| } |
| let ratio_42 = num_42 as f64 / 1000 as f64; |
| assert!(0.4 <= ratio_42 || ratio_42 <= 0.6, "{}", ratio_42); |
| } |
| |
| #[test] |
| fn test_sample_slice() { |
| let xor_rng = XorShiftRng::from_seed; |
| |
| let max_range = 100; |
| let mut r = thread_rng(); |
| |
| for length in 1usize..max_range { |
| let amount = r.gen_range(0, length); |
| let seed: [u32; 4] = [ |
| r.next_u32(), r.next_u32(), r.next_u32(), r.next_u32() |
| ]; |
| |
| println!("Selecting indices: len={}, amount={}, seed={:?}", length, amount, seed); |
| |
| // assert that the two index methods give exactly the same result |
| let inplace = sample_indices_inplace( |
| &mut xor_rng(seed), length, amount); |
| let cache = sample_indices_cache( |
| &mut xor_rng(seed), length, amount); |
| assert_eq!(inplace, cache); |
| |
| // assert the basics work |
| let regular = sample_indices( |
| &mut xor_rng(seed), length, amount); |
| assert_eq!(regular.len(), amount); |
| assert!(regular.iter().all(|e| *e < length)); |
| assert_eq!(regular, inplace); |
| |
| // also test that sampling the slice works |
| let vec: Vec<usize> = (0..length).collect(); |
| { |
| let result = sample_slice(&mut xor_rng(seed), &vec, amount); |
| assert_eq!(result, regular); |
| } |
| |
| { |
| let result = sample_slice_ref(&mut xor_rng(seed), &vec, amount); |
| let expected = regular.iter().map(|v| v).collect::<Vec<_>>(); |
| assert_eq!(result, expected); |
| } |
| } |
| } |
| } |