Skip to content

Commit 0f4da2a

Browse files
committed
Allow inlining of Deterministics and Data in fgraph IR
1 parent faf003a commit 0f4da2a

File tree

2 files changed

+43
-17
lines changed

2 files changed

+43
-17
lines changed

pymc_experimental/tests/utils/test_model_fgraph.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,28 +76,37 @@ def test_basic():
7676
)
7777

7878

79-
def test_data():
79+
@pytest.mark.parametrize("inline_views", (False, True))
80+
def test_data(inline_views):
8081
"""Test shared RNGs, MutableData, ConstantData and Dim lengths are handled correctly.
8182
8283
Everything should be preserved across new and old models, except for shared RNGs
8384
"""
8485
with pm.Model(coords_mutable={"test_dim": range(3)}) as m_old:
8586
x = pm.MutableData("x", [0.0, 1.0, 2.0], dims=("test_dim",))
8687
y = pm.MutableData("y", [10.0, 11.0, 12.0], dims=("test_dim",))
87-
b0 = pm.ConstantData("b0", 0.0)
88+
b0 = pm.ConstantData("b0", np.zeros(3))
8889
b1 = pm.Normal("b1")
8990
mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",))
9091
obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",))
9192

92-
m_fgraph, memo = fgraph_from_model(m_old)
93+
m_fgraph, memo = fgraph_from_model(m_old, inlined_views=inline_views)
9394
assert isinstance(memo[x].owner.op, ModelNamed)
9495
assert isinstance(memo[y].owner.op, ModelNamed)
9596
assert isinstance(memo[b0].owner.op, ModelNamed)
97+
mu_val = memo[mu].owner.inputs[0]
98+
if not inline_views:
99+
# Add(b0, Mul(FreeRV(b1), x) not Add(Named(b0), Mul(FreeRV(b1), Named(x))
100+
assert mu_val.owner.inputs[0] is memo[b0].owner.inputs[0]
101+
assert mu_val.owner.inputs[1].owner.inputs[1] is memo[x].owner.inputs[0]
102+
else:
103+
assert mu_val.owner.inputs[0] is memo[b0]
104+
assert mu_val.owner.inputs[1].owner.inputs[1] is memo[x]
96105

97106
m_new = model_from_fgraph(m_fgraph)
98107

99108
# ConstantData is preserved
100-
assert m_new["b0"].data == m_old["b0"].data
109+
assert np.all(m_new["b0"].data == m_old["b0"].data)
101110

102111
# Shared non-rng shared variables are preserved
103112
assert m_new["x"].container is x.container
@@ -114,7 +123,8 @@ def test_data():
114123
np.testing.assert_array_almost_equal(pm.draw(m_new["x"]), [100.0, 200.0])
115124

116125

117-
def test_deterministics():
126+
@pytest.mark.parametrize("inline_views", (False, True))
127+
def test_deterministics(inline_views):
118128
"""Test handling of deterministics.
119129
120130
We don't want Deterministics in the middle of the FunctionGraph, as they would make rewrites cumbersome
@@ -140,22 +150,27 @@ def test_deterministics():
140150
assert m["y"].owner.inputs[3] is m["mu"]
141151
assert m["y"].owner.inputs[4] is not m["sigma"]
142152

143-
fg, _ = fgraph_from_model(m)
153+
fg, _ = fgraph_from_model(m, inlined_views=inline_views)
144154

145155
# Check that no Deterministics are in graph of x to y and y to z
146156
x, y, z, det_mu, det_sigma, det_y_, det_y__ = fg.outputs
147157
# [Det(mu), Det(sigma)]
148158
mu = det_mu.owner.inputs[0]
149159
sigma = det_sigma.owner.inputs[0]
150-
# [FreeRV(y(mu, sigma))] not [FreeRV(y(Det(mu), Det(sigma)))]
151-
assert y.owner.inputs[0].owner.inputs[3] is mu
152160
assert y.owner.inputs[0].owner.inputs[4] is sigma
153-
# [FreeRV(z(y))] not [FreeRV(z(Det(Det(y))))]
154-
assert z.owner.inputs[0].owner.inputs[3] is y
155-
# [Det(y), Det(y)], not [Det(y), Det(Det(y))]
156-
assert det_y_.owner.inputs[0] is y
157-
assert det_y__.owner.inputs[0] is y
158161
assert det_y_ is not det_y__
162+
assert det_y_.owner.inputs[0] is y
163+
if not inline_views:
164+
# FreeRV(y(mu, sigma)) not FreeRV(y(Det(mu), Det(sigma)))
165+
assert y.owner.inputs[0].owner.inputs[3] is mu
166+
# FreeRV(z(y)) not FreeRV(z(Det(Det(y))))
167+
assert z.owner.inputs[0].owner.inputs[3] is y
168+
# Det(y), not Det(Det(y))
169+
assert det_y__.owner.inputs[0] is y
170+
else:
171+
assert y.owner.inputs[0].owner.inputs[3] is det_mu
172+
assert z.owner.inputs[0].owner.inputs[3] is det_y__
173+
assert det_y__.owner.inputs[0] is det_y_
159174

160175
# Both mu and sigma deterministics are now in the graph of x to y
161176
m = model_from_fgraph(fg)

pymc_experimental/utils/model_fgraph.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,20 @@ def local_remove_identity(fgraph, node):
109109
remove_identity_rewrite = out2in(local_remove_identity)
110110

111111

112-
def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Variable]]:
112+
def fgraph_from_model(
113+
model: Model, inlined_views=False
114+
) -> Tuple[FunctionGraph, Dict[Variable, Variable]]:
113115
"""Convert Model to FunctionGraph.
114116
115117
See: model_from_fgraph
116118
119+
Parameters
120+
----------
121+
model: PyMC model
122+
inlined_views: bool, default False
123+
Whether "view" variables (Deterministics and Data) should be inlined among RVs in the fgraph,
124+
or show up as separate branches.
125+
117126
Returns
118127
-------
119128
fgraph: FunctionGraph
@@ -141,13 +150,15 @@ def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Varia
141150
# We copy Deterministics (Identity Op) so that they don't show in between "main" variables
142151
# We later remove these Identity Ops when we have a Deterministic ModelVar Op as a separator
143152
old_deterministics = model.deterministics
144-
deterministics = [det.copy(det.name) for det in old_deterministics]
153+
deterministics = [det if inlined_views else det.copy(det.name) for det in old_deterministics]
145154
# Other variables that are in model.named_vars but are not any of the categories above
146155
# E.g., MutableData, ConstantData, _dim_lengths
147156
# We use the same trick as deterministics!
148157
accounted_for = free_rvs + observed_rvs + potentials + old_deterministics
149158
old_other_named_vars = [var for var in model.named_vars.values() if var not in accounted_for]
150-
other_named_vars = [var.copy(var.name) for var in old_other_named_vars]
159+
other_named_vars = [
160+
var if inlined_views else var.copy(var.name) for var in old_other_named_vars
161+
]
151162
value_vars = [val for val in rvs_to_values.values() if val not in old_other_named_vars]
152163

153164
model_vars = rvs + potentials + deterministics + other_named_vars + value_vars
@@ -211,7 +222,7 @@ def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Varia
211222
# Reference model vars in memo
212223
inverse_memo = {v: k for k, v in memo.items()}
213224
for var, model_var in replacements:
214-
if isinstance(
225+
if not inlined_views and isinstance(
215226
model_var.owner is not None and model_var.owner.op, (ModelDeterministic, ModelNamed)
216227
):
217228
# Ignore extra identity that will be removed at the end

0 commit comments

Comments
 (0)