1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import platform
1415
1516import numpy as np
1617import pymc as pm
1718
1819# general imports
1920import pytensor
21+
22+ pytensor .config .floatX = "float32"
2023import pytest
2124import scipy .stats .distributions as sp
2225
@@ -50,6 +53,12 @@ class TestGenExtremeClass:
5053 reason = "PyMC underflows earlier than scipy on float32" ,
5154 )
5255 def test_logp (self ):
56+ def ref_logp (value , mu , sigma , xi ):
57+ if 1 + xi * (value - mu ) / sigma > 0 :
58+ return sp .genextreme .logpdf (value , c = - xi , loc = mu , scale = sigma )
59+ else :
60+ return - np .inf
61+
5362 check_logp (
5463 GenExtreme ,
5564 R ,
@@ -58,15 +67,24 @@ def test_logp(self):
5867 "sigma" : Rplusbig ,
5968 "xi" : Domain ([- 1 , - 0.99 , - 0.5 , 0 , 0.5 , 0.99 , 1 ]),
6069 },
61- lambda value , mu , sigma , xi : sp .genextreme .logpdf (value , c = - xi , loc = mu , scale = sigma )
62- if 1 + xi * (value - mu ) / sigma > 0
63- else - np .inf ,
70+ ref_logp ,
71+ n_samples = - 1 ,
6472 )
6573
6674 if pytensor .config .floatX == "float32" :
6775 raise Exception ("Flaky test: It passed this time, but XPASS is not allowed." )
6876
77+ @pytest .mark .skipif (
78+ (pytensor .config .floatX == "float32" and platform .system () == "Windows" ),
79+ reason = "Scipy gives different results on Windows and does not match with desired accuracy" ,
80+ )
6981 def test_logcdf (self ):
82+ def ref_logcdf (value , mu , sigma , xi ):
83+ if 1 + xi * (value - mu ) / sigma > 0 :
84+ return sp .genextreme .logcdf (value , c = - xi , loc = mu , scale = sigma )
85+ else :
86+ return - np .inf
87+
7088 check_logcdf (
7189 GenExtreme ,
7290 R ,
@@ -75,9 +93,7 @@ def test_logcdf(self):
7593 "sigma" : Rplusbig ,
7694 "xi" : Domain ([- 1 , - 0.99 , - 0.5 , 0 , 0.5 , 0.99 , 1 ]),
7795 },
78- lambda value , mu , sigma , xi : sp .genextreme .logcdf (value , c = - xi , loc = mu , scale = sigma )
79- if 1 + xi * (value - mu ) / sigma > 0
80- else - np .inf ,
96+ ref_logcdf ,
8197 decimal = select_by_precision (float64 = 6 , float32 = 2 ),
8298 )
8399
0 commit comments