Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/monty-hall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ fn simulate<R: Rng>(random_door: &Uniform<u32>, rng: &mut R) -> SimulationResult
// Returns the door the game host opens given our choice and knowledge of
// where the car is. The game host will never open the door with the car.
fn game_host_open<R: Rng>(car: u32, choice: u32, rng: &mut R) -> u32 {
use rand::seq::SliceRandom;
use rand::seq::IndexedRandom;
*free_doors(&[car, choice]).choose(rng).unwrap()
}

Expand Down
2 changes: 1 addition & 1 deletion rand_distr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ pub use self::weibull::{Error as WeibullError, Weibull};
pub use self::zipf::{Zeta, ZetaError, Zipf, ZipfError};
#[cfg(feature = "alloc")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub use rand::distributions::{WeightedError, WeightedIndex};
pub use rand::distributions::{WeightError, WeightedIndex};
#[cfg(feature = "alloc")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub use weighted_alias::WeightedAliasIndex;
Expand Down
49 changes: 23 additions & 26 deletions rand_distr/src/weighted_alias.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//! This module contains an implementation of alias method for sampling random
//! indices with probabilities proportional to a collection of weights.

use super::WeightedError;
use super::WeightError;
use crate::{uniform::SampleUniform, Distribution, Uniform};
use core::fmt;
use core::iter::Sum;
Expand Down Expand Up @@ -79,18 +79,15 @@ pub struct WeightedAliasIndex<W: AliasableWeight> {
impl<W: AliasableWeight> WeightedAliasIndex<W> {
/// Creates a new [`WeightedAliasIndex`].
///
/// Returns an error if:
/// - The vector is empty.
/// - The vector is longer than `u32::MAX`.
/// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX /
/// weights.len()`.
/// - The sum of weights is zero.
pub fn new(weights: Vec<W>) -> Result<Self, WeightedError> {
/// Error cases:
/// - [`WeightError::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`.
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number,
/// negative or greater than `max = W::MAX / weights.len()`.
/// - [`WeightError::InsufficientNonZero`] when the sum of all weights is zero.
pub fn new(weights: Vec<W>) -> Result<Self, WeightError> {
let n = weights.len();
if n == 0 {
return Err(WeightedError::NoItem);
} else if n > ::core::u32::MAX as usize {
return Err(WeightedError::TooMany);
if n == 0 || n > ::core::u32::MAX as usize {
return Err(WeightError::InvalidInput);
}
let n = n as u32;

Expand All @@ -101,7 +98,7 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
.iter()
.all(|&w| W::ZERO <= w && w <= max_weight_size)
{
return Err(WeightedError::InvalidWeight);
return Err(WeightError::InvalidWeight);
}

// The sum of weights will represent 100% of no alias odds.
Expand All @@ -113,7 +110,7 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
weight_sum
};
if weight_sum == W::ZERO {
return Err(WeightedError::AllWeightsZero);
return Err(WeightError::InsufficientNonZero);
}

// `weight_sum` would have been zero if `try_from_lossy` causes an error here.
Expand Down Expand Up @@ -382,23 +379,23 @@ mod test {
// Floating point special cases
assert_eq!(
WeightedAliasIndex::new(vec![::core::f32::INFINITY]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(),
WeightedError::AllWeightsZero
WeightError::InsufficientNonZero
);
assert_eq!(
WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![::core::f32::NAN]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
}

Expand All @@ -416,11 +413,11 @@ mod test {
// Signed integer special cases
assert_eq!(
WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![::core::i128::MIN]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
}

Expand All @@ -438,11 +435,11 @@ mod test {
// Signed integer special cases
assert_eq!(
WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![::core::i8::MIN]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
}

Expand Down Expand Up @@ -486,15 +483,15 @@ mod test {

assert_eq!(
WeightedAliasIndex::<W>::new(vec![]).unwrap_err(),
WeightedError::NoItem
WeightError::InvalidInput
);
assert_eq!(
WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(),
WeightedError::AllWeightsZero
WeightError::InsufficientNonZero
);
assert_eq!(
WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
}

Expand Down
63 changes: 36 additions & 27 deletions rand_distr/src/weighted_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

use core::ops::SubAssign;

use super::WeightedError;
use super::WeightError;
use crate::Distribution;
use alloc::vec::Vec;
use rand::distributions::uniform::{SampleBorrow, SampleUniform};
Expand Down Expand Up @@ -98,15 +98,19 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
WeightedTreeIndex<W>
{
/// Creates a new [`WeightedTreeIndex`] from a slice of weights.
pub fn new<I>(weights: I) -> Result<Self, WeightedError>
///
/// Error cases:
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative.
/// - [`WeightError::Overflow`] when the sum of all weights overflows.
pub fn new<I>(weights: I) -> Result<Self, WeightError>
where
I: IntoIterator,
I::Item: SampleBorrow<W>,
{
let mut subtotals: Vec<W> = weights.into_iter().map(|x| x.borrow().clone()).collect();
for weight in subtotals.iter() {
if *weight < W::ZERO {
return Err(WeightedError::InvalidWeight);
if !(*weight >= W::ZERO) {
return Err(WeightError::InvalidWeight);
}
}
let n = subtotals.len();
Expand All @@ -115,7 +119,7 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
let parent = (i - 1) / 2;
subtotals[parent]
.checked_add_assign(&w)
.map_err(|()| WeightedError::Overflow)?;
.map_err(|()| WeightError::Overflow)?;
}
Ok(Self { subtotals })
}
Expand Down Expand Up @@ -164,14 +168,18 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
}

/// Appends a new weight at the end.
pub fn push(&mut self, weight: W) -> Result<(), WeightedError> {
if weight < W::ZERO {
return Err(WeightedError::InvalidWeight);
///
/// Error cases:
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative.
/// - [`WeightError::Overflow`] when the sum of all weights overflows.
pub fn push(&mut self, weight: W) -> Result<(), WeightError> {
if !(weight >= W::ZERO) {
return Err(WeightError::InvalidWeight);
}
if let Some(total) = self.subtotals.first() {
let mut total = total.clone();
if total.checked_add_assign(&weight).is_err() {
return Err(WeightedError::Overflow);
return Err(WeightError::Overflow);
}
}
let mut index = self.len();
Expand All @@ -184,9 +192,13 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
}

/// Updates the weight at an index.
pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError> {
if weight < W::ZERO {
return Err(WeightedError::InvalidWeight);
///
/// Error cases:
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative.
/// - [`WeightError::Overflow`] when the sum of all weights overflows.
pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightError> {
if !(weight >= W::ZERO) {
return Err(WeightError::InvalidWeight);
}
let old_weight = self.get(index);
if weight > old_weight {
Expand All @@ -195,7 +207,7 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
if let Some(total) = self.subtotals.first() {
let mut total = total.clone();
if total.checked_add_assign(&difference).is_err() {
return Err(WeightedError::Overflow);
return Err(WeightError::Overflow);
}
}
self.subtotals[index]
Expand Down Expand Up @@ -235,13 +247,10 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
///
/// Returns an error if there are no elements or all weights are zero. This
/// is unlike [`Distribution::sample`], which panics in those cases.
fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, WeightedError> {
if self.subtotals.is_empty() {
return Err(WeightedError::NoItem);
}
let total_weight = self.subtotals[0].clone();
fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, WeightError> {
let total_weight = self.subtotals.first().cloned().unwrap_or(W::ZERO);
if total_weight == W::ZERO {
return Err(WeightedError::AllWeightsZero);
return Err(WeightError::InsufficientNonZero);
}
let mut target_weight = rng.gen_range(W::ZERO..total_weight);
let mut index = 0;
Expand Down Expand Up @@ -296,19 +305,19 @@ mod test {
let tree = WeightedTreeIndex::<f64>::new(&[]).unwrap();
assert_eq!(
tree.try_sample(&mut rng).unwrap_err(),
WeightedError::NoItem
WeightError::InsufficientNonZero
);
}

#[test]
fn test_overflow_error() {
assert_eq!(
WeightedTreeIndex::new(&[i32::MAX, 2]),
Err(WeightedError::Overflow)
Err(WeightError::Overflow)
);
let mut tree = WeightedTreeIndex::new(&[i32::MAX - 2, 1]).unwrap();
assert_eq!(tree.push(3), Err(WeightedError::Overflow));
assert_eq!(tree.update(1, 4), Err(WeightedError::Overflow));
assert_eq!(tree.push(3), Err(WeightError::Overflow));
assert_eq!(tree.update(1, 4), Err(WeightError::Overflow));
tree.update(1, 2).unwrap();
}

Expand All @@ -318,22 +327,22 @@ mod test {
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
assert_eq!(
tree.try_sample(&mut rng).unwrap_err(),
WeightedError::AllWeightsZero
WeightError::InsufficientNonZero
);
}

#[test]
fn test_invalid_weight_error() {
assert_eq!(
WeightedTreeIndex::<i32>::new(&[1, -1]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
let mut tree = WeightedTreeIndex::<i32>::new(&[]).unwrap();
assert_eq!(tree.push(-1).unwrap_err(), WeightedError::InvalidWeight);
assert_eq!(tree.push(-1).unwrap_err(), WeightError::InvalidWeight);
tree.push(1).unwrap();
assert_eq!(
tree.update(0, -1).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
}

Expand Down
2 changes: 1 addition & 1 deletion rand_distr/tests/pdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#![allow(clippy::float_cmp)]

use average::Histogram;
use rand::{Rng, SeedableRng};
use rand::Rng;
use rand_distr::{Normal, SkewNormal};

const HIST_LEN: usize = 100;
Expand Down
9 changes: 1 addition & 8 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,6 @@ pub mod hidden_export {
pub use super::float::IntoFloat; // used by rand_distr
}
pub mod uniform;
#[deprecated(
since = "0.8.0",
note = "use rand::distributions::{WeightedIndex, WeightedError} instead"
)]
#[cfg(feature = "alloc")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub mod weighted;

pub use self::bernoulli::{Bernoulli, BernoulliError};
pub use self::distribution::{Distribution, DistIter, DistMap};
Expand All @@ -126,7 +119,7 @@ pub use self::slice::Slice;
#[doc(inline)]
pub use self::uniform::Uniform;
#[cfg(feature = "alloc")]
pub use self::weighted_index::{Weight, WeightedError, WeightedIndex};
pub use self::weighted_index::{Weight, WeightError, WeightedIndex};

#[allow(unused)]
use crate::Rng;
Expand Down
14 changes: 7 additions & 7 deletions src/distributions/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use alloc::string::String;
/// [`Slice::new`] constructs a distribution referencing a slice and uniformly
/// samples references from the items in the slice. It may do extra work up
/// front to make sampling of multiple values faster; if only one sample from
/// the slice is required, [`SliceRandom::choose`] can be more efficient.
/// the slice is required, [`IndexedRandom::choose`] can be more efficient.
///
/// Steps are taken to avoid bias which might be present in naive
/// implementations; for example `slice[rng.gen() % slice.len()]` samples from
Expand All @@ -25,7 +25,7 @@ use alloc::string::String;
/// This distribution samples with replacement; each sample is independent.
/// Sampling without replacement requires state to be retained, and therefore
/// cannot be handled by a distribution; you should instead consider methods
/// on [`SliceRandom`], such as [`SliceRandom::choose_multiple`].
/// on [`IndexedRandom`], such as [`IndexedRandom::choose_multiple`].
///
/// # Example
///
Expand All @@ -48,21 +48,21 @@ use alloc::string::String;
/// assert!(vowel_string.chars().all(|c| vowels.contains(&c)));
/// ```
///
/// For a single sample, [`SliceRandom::choose`][crate::seq::SliceRandom::choose]
/// For a single sample, [`IndexedRandom::choose`][crate::seq::IndexedRandom::choose]
/// may be preferred:
///
/// ```
/// use rand::seq::SliceRandom;
/// use rand::seq::IndexedRandom;
///
/// let vowels = ['a', 'e', 'i', 'o', 'u'];
/// let mut rng = rand::thread_rng();
///
/// println!("{}", vowels.choose(&mut rng).unwrap())
/// ```
///
/// [`SliceRandom`]: crate::seq::SliceRandom
/// [`SliceRandom::choose`]: crate::seq::SliceRandom::choose
/// [`SliceRandom::choose_multiple`]: crate::seq::SliceRandom::choose_multiple
/// [`IndexedRandom`]: crate::seq::IndexedRandom
/// [`IndexedRandom::choose`]: crate::seq::IndexedRandom::choose
/// [`IndexedRandom::choose_multiple`]: crate::seq::IndexedRandom::choose_multiple
#[derive(Debug, Clone, Copy)]
pub struct Slice<'a, T> {
slice: &'a [T],
Expand Down
Loading