77from pytensor .tensor .exceptions import NotScalarConstantError
88
99from pymc_experimental .utils .model_fgraph import (
10+ ModelDeterministic ,
1011 ModelFreeRV ,
12+ ModelNamed ,
13+ ModelObservedRV ,
14+ ModelPotential ,
1115 ModelVar ,
1216 fgraph_from_model ,
1317 model_deterministic ,
@@ -23,11 +27,17 @@ def test_basic():
2327 y = pm .Deterministic ("y" , x + 1 )
2428 w = pm .HalfNormal ("w" , pm .math .exp (y ))
2529 z = pm .Normal ("z" , y , w , observed = [0 , 1 , 2 ], dims = ("test_dim" ,))
26- pm .Potential ("pot" , x * 2 )
30+ pot = pm .Potential ("pot" , x * 2 )
2731
28- m_fgraph = fgraph_from_model (m_old )
32+ m_fgraph , memo = fgraph_from_model (m_old )
2933 assert isinstance (m_fgraph , FunctionGraph )
3034
35+ assert isinstance (memo [x ].owner .op , ModelFreeRV )
36+ assert isinstance (memo [y ].owner .op , ModelDeterministic )
37+ assert isinstance (memo [w ].owner .op , ModelFreeRV )
38+ assert isinstance (memo [z ].owner .op , ModelObservedRV )
39+ assert isinstance (memo [pot ].owner .op , ModelPotential )
40+
3141 m_new = model_from_fgraph (m_fgraph )
3242 assert isinstance (m_new , pm .Model )
3343
@@ -79,7 +89,12 @@ def test_data():
7989 mu = pm .Deterministic ("mu" , b0 + b1 * x , dims = ("test_dim" ,))
8090 obs = pm .Normal ("obs" , mu , sigma = 1e-5 , observed = y , dims = ("test_dim" ,))
8191
82- m_new = model_from_fgraph (fgraph_from_model (m_old ))
92+ m_fgraph , memo = fgraph_from_model (m_old )
93+ assert isinstance (memo [x ].owner .op , ModelNamed )
94+ assert isinstance (memo [y ].owner .op , ModelNamed )
95+ assert isinstance (memo [b0 ].owner .op , ModelNamed )
96+
97+ m_new = model_from_fgraph (m_fgraph )
8398
8499 # ConstantData is preserved
85100 assert m_new ["b0" ].data == m_old ["b0" ].data
@@ -125,7 +140,7 @@ def test_deterministics():
125140 assert m ["y" ].owner .inputs [3 ] is m ["mu" ]
126141 assert m ["y" ].owner .inputs [4 ] is not m ["sigma" ]
127142
128- fg = fgraph_from_model (m )
143+ fg , _ = fgraph_from_model (m )
129144
130145 # Check that no Deterministics are in graph of x to y and y to z
131146 x , y , z , det_mu , det_sigma , det_y_ , det_y__ = fg .outputs
@@ -173,7 +188,7 @@ def test_sub_model_error():
173188 with pm .Model () as sub_m :
174189 y = pm .Normal ("y" , x )
175190
176- nodes = [v for v in fgraph_from_model (m ).toposort () if not isinstance (v .op , ModelVar )]
191+ nodes = [v for v in fgraph_from_model (m )[ 0 ] .toposort () if not isinstance (v .op , ModelVar )]
177192 assert len (nodes ) == 2
178193 assert isinstance (nodes [0 ].op , pm .Beta )
179194 assert isinstance (nodes [1 ].op , pm .Normal )
@@ -234,7 +249,7 @@ def test_fgraph_rewrite(non_centered_rewrite):
234249 subject_mean = pm .Normal ("subject_mean" , group_mean , group_std , dims = ("subject" ,))
235250 obs = pm .Normal ("obs" , subject_mean , 1 , observed = np .zeros (10 ), dims = ("subject" ,))
236251
237- fg = fgraph_from_model (m_old )
252+ fg , _ = fgraph_from_model (m_old )
238253 non_centered_rewrite .apply (fg )
239254
240255 m_new = model_from_fgraph (fg )
0 commit comments