@@ -31,7 +31,7 @@ use rand::Rng;
3131/// ```rust
3232/// use rand_distr::{Pert, Distribution};
3333///
34- /// let d = Pert::new(0., 5., 2.5).unwrap();
34+ /// let d = Pert::new(0., 5.).with_mode( 2.5).unwrap();
3535/// let v = d.sample(&mut rand::thread_rng());
3636/// println!("{} is from a PERT distribution", v);
3737/// ```
@@ -82,35 +82,75 @@ where
8282 Exp1 : Distribution < F > ,
8383 Open01 : Distribution < F > ,
8484{
85- /// Set up the PERT distribution with defined `min`, `max` and `mode`.
85+ /// Construct a PERT distribution with defined `min`, `max`
8686 ///
87- /// This is equivalent to calling `Pert::new_with_shape` with `shape == 4.0`.
87+ /// # Example
88+ ///
89+ /// ```
90+ /// use rand_distr::Pert;
91+ /// let pert_dist = Pert::new(0.0, 10.0)
92+ /// .with_shape(3.5)
93+ /// .with_mean(3.0)
94+ /// .unwrap();
95+ /// # let _unused: Pert<f64> = pert_dist;
96+ /// ```
97+ #[ allow( clippy:: new_ret_no_self) ]
98+ #[ inline]
99+ pub fn new ( min : F , max : F ) -> PertBuilder < F > {
100+ let shape = F :: from ( 4.0 ) . unwrap ( ) ;
101+ PertBuilder { min, max, shape }
102+ }
103+ }
104+
105+ /// Struct used to build a [`Pert`]
106+ #[ derive( Debug ) ]
107+ pub struct PertBuilder < F > {
108+ min : F ,
109+ max : F ,
110+ shape : F ,
111+ }
112+
113+ impl < F > PertBuilder < F >
114+ where
115+ F : Float ,
116+ StandardNormal : Distribution < F > ,
117+ Exp1 : Distribution < F > ,
118+ Open01 : Distribution < F > ,
119+ {
120+ /// Set the shape parameter
121+ ///
122+ /// If not specified, this defaults to 4.
123+ #[ inline]
124+ pub fn with_shape ( mut self , shape : F ) -> PertBuilder < F > {
125+ self . shape = shape;
126+ self
127+ }
128+
129+ /// Specify the mean
88130 #[ inline]
89- pub fn new ( min : F , max : F , mode : F ) -> Result < Pert < F > , PertError > {
90- Pert :: new_with_shape ( min, max, mode, F :: from ( 4. ) . unwrap ( ) )
131+ pub fn with_mean ( self , mean : F ) -> Result < Pert < F > , PertError > {
132+ let two = F :: from ( 2.0 ) . unwrap ( ) ;
133+ let mode = ( ( self . shape + two) * mean - self . min - self . max ) / self . shape ;
134+ self . with_mode ( mode)
91135 }
92136
93- /// Set up the PERT distribution with defined `min`, `max`, ` mode` and
94- /// `shape`.
95- pub fn new_with_shape ( min : F , max : F , mode : F , shape : F ) -> Result < Pert < F > , PertError > {
96- if !( max > min) {
137+ /// Specify the mode
138+ # [ inline ]
139+ pub fn with_mode ( self , mode : F ) -> Result < Pert < F > , PertError > {
140+ if !( self . max > self . min ) {
97141 return Err ( PertError :: RangeTooSmall ) ;
98142 }
99- if !( mode >= min && max >= mode) {
143+ if !( mode >= self . min && self . max >= mode) {
100144 return Err ( PertError :: ModeRange ) ;
101145 }
102- if !( shape >= F :: from ( 0. ) . unwrap ( ) ) {
146+ if !( self . shape >= F :: from ( 0. ) . unwrap ( ) ) {
103147 return Err ( PertError :: ShapeTooSmall ) ;
104148 }
105149
150+ let ( min, max, shape) = ( self . min , self . max , self . shape ) ;
106151 let range = max - min;
107- let mu = ( min + max + shape * mode) / ( shape + F :: from ( 2. ) . unwrap ( ) ) ;
108- let v = if mu == mode {
109- shape * F :: from ( 0.5 ) . unwrap ( ) + F :: from ( 1. ) . unwrap ( )
110- } else {
111- ( mu - min) * ( F :: from ( 2. ) . unwrap ( ) * mode - min - max) / ( ( mode - mu) * ( max - min) )
112- } ;
113- let w = v * ( max - mu) / ( mu - min) ;
152+ let v = F :: from ( 1.0 ) . unwrap ( ) + shape * ( mode - min) / range;
153+ let w = F :: from ( 1.0 ) . unwrap ( ) + shape * ( max - mode) / range;
114154 let beta = Beta :: new ( v, w) . map_err ( |_| PertError :: RangeTooSmall ) ?;
115155 Ok ( Pert { min, range, beta } )
116156 }
@@ -136,17 +176,38 @@ mod test {
136176 #[ test]
137177 fn test_pert ( ) {
138178 for & ( min, max, mode) in & [ ( -1. , 1. , 0. ) , ( 1. , 2. , 1. ) , ( 5. , 25. , 25. ) ] {
139- let _distr = Pert :: new ( min, max, mode) . unwrap ( ) ;
179+ let _distr = Pert :: new ( min, max) . with_mode ( mode) . unwrap ( ) ;
140180 // TODO: test correctness
141181 }
142182
143183 for & ( min, max, mode) in & [ ( -1. , 1. , 2. ) , ( -1. , 1. , -2. ) , ( 2. , 1. , 1. ) ] {
144- assert ! ( Pert :: new( min, max, mode) . is_err( ) ) ;
184+ assert ! ( Pert :: new( min, max) . with_mode ( mode) . is_err( ) ) ;
145185 }
146186 }
147187
148188 #[ test]
149- fn pert_distributions_can_be_compared ( ) {
150- assert_eq ! ( Pert :: new( 1.0 , 3.0 , 2.0 ) , Pert :: new( 1.0 , 3.0 , 2.0 ) ) ;
189+ fn distributions_can_be_compared ( ) {
190+ let ( min, mode, max, shape) = ( 1.0 , 2.0 , 3.0 , 4.0 ) ;
191+ let p1 = Pert :: new ( min, max) . with_mode ( mode) . unwrap ( ) ;
192+ let mean = ( min + shape * mode + max) / ( shape + 2.0 ) ;
193+ let p2 = Pert :: new ( min, max) . with_mean ( mean) . unwrap ( ) ;
194+ assert_eq ! ( p1, p2) ;
195+ }
196+
197+ #[ test]
198+ fn mode_almost_half_range ( ) {
199+ assert ! ( Pert :: new( 0.0f32 , 0.48258883 ) . with_mode( 0.24129441 ) . is_ok( ) ) ;
200+ }
201+
202+ #[ test]
203+ fn almost_symmetric_about_zero ( ) {
204+ let distr = Pert :: new ( -10f32 , 10f32 ) . with_mode ( f32:: EPSILON ) ;
205+ assert ! ( distr. is_ok( ) ) ;
206+ }
207+
208+ #[ test]
209+ fn almost_symmetric ( ) {
210+ let distr = Pert :: new ( 0f32 , 2f32 ) . with_mode ( 1f32 + f32:: EPSILON ) ;
211+ assert ! ( distr. is_ok( ) ) ;
151212 }
152213}
0 commit comments