|
1 | | -# Copyright 2022 The PyMC Developers |
| 1 | +# Copyright 2023 The PyMC Developers |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 |
|
| 16 | +import sys |
| 17 | +import tempfile |
| 18 | + |
16 | 19 | import numpy as np |
17 | 20 | import pandas as pd |
18 | 21 | import pymc as pm |
| 22 | +import pytest |
19 | 23 |
|
20 | 24 | from pymc_experimental.model_builder import ModelBuilder |
21 | 25 |
|
@@ -77,93 +81,35 @@ def create_sample_input(cls): |
77 | 81 |
|
78 | 82 |
|
79 | 83 | def test_fit(): |
80 | | - with pm.Model() as model: |
81 | | - x = np.linspace(start=0, stop=1, num=100) |
82 | | - y = 5 * x + 3 |
83 | | - x = pm.MutableData("x", x) |
84 | | - y_data = pm.MutableData("y_data", y) |
85 | | - |
86 | | - a_loc = 7 |
87 | | - a_scale = 3 |
88 | | - b_loc = 5 |
89 | | - b_scale = 3 |
90 | | - obs_error = 2 |
91 | | - |
92 | | - a = pm.Normal("a", a_loc, sigma=a_scale) |
93 | | - b = pm.Normal("b", b_loc, sigma=b_scale) |
94 | | - obs_error = pm.HalfNormal("σ_model_fmc", obs_error) |
95 | | - |
96 | | - y_model = pm.Normal("y_model", a + b * x, obs_error, observed=y_data) |
97 | | - |
98 | | - idata = pm.sample(tune=100, draws=200, chains=1, cores=1, target_accept=0.5) |
99 | | - idata.extend(pm.sample_prior_predictive()) |
100 | | - idata.extend(pm.sample_posterior_predictive(idata)) |
101 | | - |
102 | 84 | data, model_config, sampler_config = test_ModelBuilder.create_sample_input() |
103 | | - model_2 = test_ModelBuilder(model_config, sampler_config, data) |
104 | | - model_2.idata = model_2.fit() |
105 | | - assert str(model_2.idata.groups) == str(idata.groups) |
106 | | - |
| 85 | + model = test_ModelBuilder(model_config, sampler_config, data) |
| 86 | + model.fit() |
| 87 | + assert model.idata is not None |
| 88 | + assert "posterior" in model.idata.groups() |
107 | 89 |
|
108 | | -def test_predict(): |
109 | 90 | x_pred = np.random.uniform(low=0, high=1, size=100) |
110 | 91 | prediction_data = pd.DataFrame({"input": x_pred}) |
111 | | - data, model_config, sampler_config = test_ModelBuilder.create_sample_input() |
112 | | - model_2 = test_ModelBuilder(model_config, sampler_config, data) |
113 | | - model_2.idata = model_2.fit() |
114 | | - model_2.predict(prediction_data) |
115 | | - with pm.Model() as model: |
116 | | - x = np.linspace(start=0, stop=1, num=100) |
117 | | - y = 5 * x + 3 |
118 | | - x = pm.MutableData("x", x) |
119 | | - y_data = pm.MutableData("y_data", y) |
120 | | - a_loc = 7 |
121 | | - a_scale = 3 |
122 | | - b_loc = 5 |
123 | | - b_scale = 3 |
124 | | - obs_error = 2 |
125 | | - |
126 | | - a = pm.Normal("a", a_loc, sigma=a_scale) |
127 | | - b = pm.Normal("b", b_loc, sigma=b_scale) |
128 | | - obs_error = pm.HalfNormal("σ_model_fmc", obs_error) |
129 | | - |
130 | | - y_model = pm.Normal("y_model", a + b * x, obs_error, observed=y_data) |
| 92 | + pred = model.predict(prediction_data) |
| 93 | + assert "y_model" in pred.keys() |
| 94 | + post_pred = model.predict_posterior(prediction_data) |
| 95 | + assert "y_model" in post_pred.keys() |
131 | 96 |
|
132 | | - idata = pm.sample(tune=10, draws=20, chains=3, cores=1) |
133 | | - idata.extend(pm.sample_prior_predictive()) |
134 | | - idata.extend(pm.sample_posterior_predictive(idata)) |
135 | | - y_test = pm.sample_posterior_predictive(idata) |
136 | | - |
137 | | - assert str(model_2.idata.groups) == str(idata.groups) |
138 | 97 |
|
| 98 | +@pytest.mark.skipif( |
| 99 | + sys.platform == "win32", reason="Permissions for temp files not granted on windows CI." |
| 100 | +) |
| 101 | +def test_save_load(): |
| 102 | + data, model_config, sampler_config = test_ModelBuilder.create_sample_input() |
| 103 | + model = test_ModelBuilder(model_config, sampler_config, data) |
| 104 | + model.fit() |
| 105 | + temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) |
| 106 | + model.save(temp.name) |
| 107 | + model2 = test_ModelBuilder.load(temp.name) |
| 108 | + assert model.idata.groups() == model2.idata.groups() |
139 | 109 |
|
140 | | -def test_predict_posterior(): |
141 | 110 | x_pred = np.random.uniform(low=0, high=1, size=100) |
142 | 111 | prediction_data = pd.DataFrame({"input": x_pred}) |
143 | | - data, model_config, sampler_config = test_ModelBuilder.create_sample_input() |
144 | | - model_2 = test_ModelBuilder(model_config, sampler_config, data) |
145 | | - model_2.idata = model_2.fit() |
146 | | - model_2.predict_posterior(prediction_data) |
147 | | - with pm.Model() as model: |
148 | | - x = np.linspace(start=0, stop=1, num=100) |
149 | | - y = 5 * x + 3 |
150 | | - x = pm.MutableData("x", x) |
151 | | - y_data = pm.MutableData("y_data", y) |
152 | | - a_loc = 7 |
153 | | - a_scale = 3 |
154 | | - b_loc = 5 |
155 | | - b_scale = 3 |
156 | | - obs_error = 2 |
157 | | - |
158 | | - a = pm.Normal("a", a_loc, sigma=a_scale) |
159 | | - b = pm.Normal("b", b_loc, sigma=b_scale) |
160 | | - obs_error = pm.HalfNormal("σ_model_fmc", obs_error) |
161 | | - |
162 | | - y_model = pm.Normal("y_model", a + b * x, obs_error, observed=y_data) |
163 | | - |
164 | | - idata = pm.sample(tune=10, draws=20, chains=3, cores=1) |
165 | | - idata.extend(pm.sample_prior_predictive()) |
166 | | - idata.extend(pm.sample_posterior_predictive(idata)) |
167 | | - y_test = pm.sample_posterior_predictive(idata) |
168 | | - |
169 | | - assert str(model_2.idata.groups) == str(idata.groups) |
| 112 | + pred1 = model.predict(prediction_data) |
| 113 | + pred2 = model2.predict(prediction_data) |
| 114 | + assert pred1["y_model"].shape == pred2["y_model"].shape |
| 115 | + temp.close() |
0 commit comments