@@ -76,28 +76,37 @@ def test_basic():
7676 )
7777
7878
79- def test_data ():
79+ @pytest .mark .parametrize ("inline_views" , (False , True ))
80+ def test_data (inline_views ):
8081 """Test shared RNGs, MutableData, ConstantData and Dim lengths are handled correctly.
8182
8283 Everything should be preserved across new and old models, except for shared RNGs
8384 """
8485 with pm .Model (coords_mutable = {"test_dim" : range (3 )}) as m_old :
8586 x = pm .MutableData ("x" , [0.0 , 1.0 , 2.0 ], dims = ("test_dim" ,))
8687 y = pm .MutableData ("y" , [10.0 , 11.0 , 12.0 ], dims = ("test_dim" ,))
87- b0 = pm .ConstantData ("b0" , 0.0 )
88+ b0 = pm .ConstantData ("b0" , np . zeros ( 3 ) )
8889 b1 = pm .Normal ("b1" )
8990 mu = pm .Deterministic ("mu" , b0 + b1 * x , dims = ("test_dim" ,))
9091 obs = pm .Normal ("obs" , mu , sigma = 1e-5 , observed = y , dims = ("test_dim" ,))
9192
92- m_fgraph , memo = fgraph_from_model (m_old )
93+ m_fgraph , memo = fgraph_from_model (m_old , inlined_views = inline_views )
9394 assert isinstance (memo [x ].owner .op , ModelNamed )
9495 assert isinstance (memo [y ].owner .op , ModelNamed )
9596 assert isinstance (memo [b0 ].owner .op , ModelNamed )
97+ mu_val = memo [mu ].owner .inputs [0 ]
98+ if not inline_views :
99+ # Add(b0, Mul(FreeRV(b1), x) not Add(Named(b0), Mul(FreeRV(b1), Named(x))
100+ assert mu_val .owner .inputs [0 ] is memo [b0 ].owner .inputs [0 ]
101+ assert mu_val .owner .inputs [1 ].owner .inputs [1 ] is memo [x ].owner .inputs [0 ]
102+ else :
103+ assert mu_val .owner .inputs [0 ] is memo [b0 ]
104+ assert mu_val .owner .inputs [1 ].owner .inputs [1 ] is memo [x ]
96105
97106 m_new = model_from_fgraph (m_fgraph )
98107
99108 # ConstantData is preserved
100- assert m_new ["b0" ].data == m_old ["b0" ].data
109+ assert np . all ( m_new ["b0" ].data == m_old ["b0" ].data )
101110
102111 # Shared non-rng shared variables are preserved
103112 assert m_new ["x" ].container is x .container
@@ -114,7 +123,8 @@ def test_data():
114123 np .testing .assert_array_almost_equal (pm .draw (m_new ["x" ]), [100.0 , 200.0 ])
115124
116125
117- def test_deterministics ():
126+ @pytest .mark .parametrize ("inline_views" , (False , True ))
127+ def test_deterministics (inline_views ):
118128 """Test handling of deterministics.
119129
120130 We don't want Deterministics in the middle of the FunctionGraph, as they would make rewrites cumbersome
@@ -140,22 +150,27 @@ def test_deterministics():
140150 assert m ["y" ].owner .inputs [3 ] is m ["mu" ]
141151 assert m ["y" ].owner .inputs [4 ] is not m ["sigma" ]
142152
143- fg , _ = fgraph_from_model (m )
153+ fg , _ = fgraph_from_model (m , inlined_views = inline_views )
144154
145155 # Check that no Deterministics are in graph of x to y and y to z
146156 x , y , z , det_mu , det_sigma , det_y_ , det_y__ = fg .outputs
147157 # [Det(mu), Det(sigma)]
148158 mu = det_mu .owner .inputs [0 ]
149159 sigma = det_sigma .owner .inputs [0 ]
150- # [FreeRV(y(mu, sigma))] not [FreeRV(y(Det(mu), Det(sigma)))]
151- assert y .owner .inputs [0 ].owner .inputs [3 ] is mu
152160 assert y .owner .inputs [0 ].owner .inputs [4 ] is sigma
153- # [FreeRV(z(y))] not [FreeRV(z(Det(Det(y))))]
154- assert z .owner .inputs [0 ].owner .inputs [3 ] is y
155- # [Det(y), Det(y)], not [Det(y), Det(Det(y))]
156- assert det_y_ .owner .inputs [0 ] is y
157- assert det_y__ .owner .inputs [0 ] is y
158161 assert det_y_ is not det_y__
162+ assert det_y_ .owner .inputs [0 ] is y
163+ if not inline_views :
164+ # FreeRV(y(mu, sigma)) not FreeRV(y(Det(mu), Det(sigma)))
165+ assert y .owner .inputs [0 ].owner .inputs [3 ] is mu
166+ # FreeRV(z(y)) not FreeRV(z(Det(Det(y))))
167+ assert z .owner .inputs [0 ].owner .inputs [3 ] is y
168+ # Det(y), not Det(Det(y))
169+ assert det_y__ .owner .inputs [0 ] is y
170+ else :
171+ assert y .owner .inputs [0 ].owner .inputs [3 ] is det_mu
172+ assert z .owner .inputs [0 ].owner .inputs [3 ] is det_y__
173+ assert det_y__ .owner .inputs [0 ] is det_y_
159174
160175 # Both mu and sigma deterministics are now in the graph of x to y
161176 m = model_from_fgraph (fg )
0 commit comments