|
25 | 25 |
|
26 | 26 | __all__ = ("compute_log_likelihood", "compute_log_prior") |
27 | 27 |
|
| 28 | +from pymc.model.transform.conditioning import remove_value_transforms |
| 29 | + |
28 | 30 |
|
29 | 31 | def compute_log_likelihood( |
30 | 32 | idata: InferenceData, |
@@ -126,46 +128,35 @@ def compute_log_density( |
126 | 128 | if kind not in ("likelihood", "prior"): |
127 | 129 | raise ValueError("kind must be either 'likelihood' or 'prior'") |
128 | 130 |
|
| 131 | + # We need to disable transforms, because the InferenceData only keeps the untransformed values |
| 132 | + umodel = remove_value_transforms(model) |
| 133 | + |
129 | 134 | if kind == "likelihood": |
130 | | - target_rvs = model.observed_RVs |
| 135 | + target_rvs = list(umodel.observed_RVs) |
131 | 136 | target_str = "observed_RVs" |
132 | 137 | else: |
133 | | - target_rvs = model.free_RVs |
| 138 | + target_rvs = list(umodel.free_RVs) |
134 | 139 | target_str = "free_RVs" |
135 | 140 |
|
136 | 141 | if var_names is None: |
137 | 142 | vars = target_rvs |
138 | 143 | var_names = tuple(rv.name for rv in vars) |
139 | 144 | else: |
140 | | - vars = [model.named_vars[name] for name in var_names] |
| 145 | + vars = [umodel.named_vars[name] for name in var_names] |
141 | 146 | if not set(vars).issubset(target_rvs): |
142 | 147 | raise ValueError(f"var_names must refer to {target_str} in the model. Got: {var_names}") |
143 | 148 |
|
144 | | - # We need to temporarily disable transforms, because the InferenceData only keeps the untransformed values |
145 | | - try: |
146 | | - original_rvs_to_values = model.rvs_to_values |
147 | | - original_rvs_to_transforms = model.rvs_to_transforms |
148 | | - |
149 | | - model.rvs_to_values = { |
150 | | - rv: rv.clone() if rv not in model.observed_RVs else value |
151 | | - for rv, value in model.rvs_to_values.items() |
152 | | - } |
153 | | - model.rvs_to_transforms = {rv: None for rv in model.basic_RVs} |
154 | | - |
155 | | - elemwise_logdens_fn = model.compile_fn( |
156 | | - inputs=model.value_vars, |
157 | | - outs=model.logp(vars=vars, sum=False), |
158 | | - on_unused_input="ignore", |
159 | | - ) |
160 | | - finally: |
161 | | - model.rvs_to_values = original_rvs_to_values |
162 | | - model.rvs_to_transforms = original_rvs_to_transforms |
| 149 | + elemwise_logdens_fn = umodel.compile_fn( |
| 150 | + inputs=umodel.value_vars, |
| 151 | + outs=umodel.logp(vars=vars, sum=False), |
| 152 | + on_unused_input="ignore", |
| 153 | + ) |
163 | 154 |
|
164 | | - coords, dims = coords_and_dims_for_inferencedata(model) |
| 155 | + coords, dims = coords_and_dims_for_inferencedata(umodel) |
165 | 156 |
|
166 | 157 | logdens_dataset = apply_function_over_dataset( |
167 | 158 | elemwise_logdens_fn, |
168 | | - posterior[[rv.name for rv in model.free_RVs]], |
| 159 | + posterior[[rv.name for rv in umodel.free_RVs]], |
169 | 160 | output_var_names=var_names, |
170 | 161 | sample_dims=sample_dims, |
171 | 162 | dims=dims, |
|
0 commit comments