11# assume
2- """
3- tilde_assume(context::SamplingContext, right, vn, vi)
4-
5- Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
6- accumulate the log probability, and return the sampled value with a context associated
7- with a sampler.
8-
9- Falls back to
10- ```julia
11- tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
12- ```
13- """
14- function tilde_assume (context:: SamplingContext , right, vn, vi)
15- return tilde_assume (context. rng, context. context, context. sampler, right, vn, vi)
16- end
17-
182function tilde_assume (context:: AbstractContext , args... )
193 return tilde_assume (childcontext (context), args... )
204end
215function tilde_assume (:: DefaultContext , right, vn, vi)
22- return assume (right, vn, vi)
23- end
24-
25- function tilde_assume (rng:: Random.AbstractRNG , context:: AbstractContext , args... )
26- return tilde_assume (rng, childcontext (context), args... )
27- end
28- function tilde_assume (rng:: Random.AbstractRNG , :: DefaultContext , sampler, right, vn, vi)
29- return assume (rng, sampler, right, vn, vi)
30- end
31- function tilde_assume (rng:: Random.AbstractRNG , :: InitContext , sampler, right, vn, vi)
32- @warn (
33- " Encountered SamplingContext->InitContext. This method will be removed in the next PR." ,
34- )
35- # just pretend the `InitContext` isn't there for now.
36- return assume (rng, sampler, right, vn, vi)
37- end
38- function tilde_assume (:: DefaultContext , sampler, right, vn, vi)
39- # same as above but no rng
40- return assume (Random. default_rng (), sampler, right, vn, vi)
6+ y = getindex_internal (vi, vn)
7+ f = from_maybe_linked_internal_transform (vi, vn, right)
8+ x, logjac = with_logabsdet_jacobian (f, y)
9+ vi = accumulate_assume!! (vi, x, logjac, vn, right)
10+ return x, vi
4111end
42-
4312function tilde_assume (context:: PrefixContext , right, vn, vi)
4413 # Note that we can't use something like this here:
4514 # new_vn = prefix(context, vn)
@@ -53,12 +22,6 @@ function tilde_assume(context::PrefixContext, right, vn, vi)
5322 new_vn, new_context = prefix_and_strip_contexts (context, vn)
5423 return tilde_assume (new_context, right, new_vn, vi)
5524end
56- function tilde_assume (
57- rng:: Random.AbstractRNG , context:: PrefixContext , sampler, right, vn, vi
58- )
59- new_vn, new_context = prefix_and_strip_contexts (context, vn)
60- return tilde_assume (rng, new_context, sampler, right, new_vn, vi)
61- end
6225
6326"""
6427 tilde_assume!!(context, right, vn, vi)
@@ -78,17 +41,6 @@ function tilde_assume!!(context, right, vn, vi)
7841end
7942
8043# observe
81- """
82- tilde_observe!!(context::SamplingContext, right, left, vi)
83-
84- Handle observed constants with a `context` associated with a sampler.
85-
86- Falls back to `tilde_observe!!(context.context, right, left, vi)`.
87- """
88- function tilde_observe!! (context:: SamplingContext , right, left, vn, vi)
89- return tilde_observe!! (context. context, right, left, vn, vi)
90- end
91-
9244function tilde_observe!! (context:: AbstractContext , right, left, vn, vi)
9345 return tilde_observe!! (childcontext (context), right, left, vn, vi)
9446end
@@ -121,59 +73,3 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi)
12173 vi = accumulate_observe!! (vi, right, left, vn)
12274 return left, vi
12375end
124-
125- function assume (:: Random.AbstractRNG , spl:: Sampler , dist)
126- return error (" DynamicPPL.assume: unmanaged inference algorithm: $(typeof (spl)) " )
127- end
128-
129- # fallback without sampler
130- function assume (dist:: Distribution , vn:: VarName , vi)
131- y = getindex_internal (vi, vn)
132- f = from_maybe_linked_internal_transform (vi, vn, dist)
133- x, logjac = with_logabsdet_jacobian (f, y)
134- vi = accumulate_assume!! (vi, x, logjac, vn, dist)
135- return x, vi
136- end
137-
138- # TODO : Remove this thing.
139- # SampleFromPrior and SampleFromUniform
140- function assume (
141- rng:: Random.AbstractRNG ,
142- sampler:: Union{SampleFromPrior,SampleFromUniform} ,
143- dist:: Distribution ,
144- vn:: VarName ,
145- vi:: VarInfoOrThreadSafeVarInfo ,
146- )
147- if haskey (vi, vn)
148- # Always overwrite the parameters with new ones for `SampleFromUniform`.
149- if sampler isa SampleFromUniform || is_flagged (vi, vn, " del" )
150- # TODO (mhauru) Is it important to unset the flag here? The `true` allows us
151- # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
152- # if that's okay.
153- unset_flag! (vi, vn, " del" , true )
154- r = init (rng, dist, sampler)
155- f = to_maybe_linked_internal_transform (vi, vn, dist)
156- # TODO (mhauru) This should probably be call a function called setindex_internal!
157- vi = BangBang. setindex!! (vi, f (r), vn)
158- setorder! (vi, vn, get_num_produce (vi))
159- else
160- # Otherwise we just extract it.
161- r = vi[vn, dist]
162- end
163- else
164- r = init (rng, dist, sampler)
165- if istrans (vi)
166- f = to_linked_internal_transform (vi, vn, dist)
167- vi = push!! (vi, vn, f (r), dist)
168- # By default `push!!` sets the transformed flag to `false`.
169- vi = settrans!! (vi, true , vn)
170- else
171- vi = push!! (vi, vn, r, dist)
172- end
173- end
174-
175- # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
176- logjac = logabsdetjac (istrans (vi, vn) ? link_transform (dist) : identity, r)
177- vi = accumulate_assume!! (vi, r, - logjac, vn, dist)
178- return r, vi
179- end
0 commit comments