Skip to content

Commit ed19c34

Browse files
authored
Minor fixes to modelbuilder
1 parent 250e81a commit ed19c34

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

pymc_extras/model_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ def load(cls, fname: str):
446446
sampler_config=json.loads(idata.attrs["sampler_config"]),
447447
)
448448
model.idata = idata
449+
model.is_fitted_ = True
449450
dataset = idata.fit_data.to_dataframe()
450451
X = dataset.drop(columns=[model.output_var])
451452
y = dataset[model.output_var]
@@ -526,6 +527,8 @@ def fit(
526527
)
527528
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore
528529

530+
self.is_fitted_ = True
531+
529532
return self.idata # type: ignore
530533

531534
def predict(

0 commit comments

Comments
 (0)