@@ -3,6 +3,8 @@ import Base.getindex
33using SparseArrays
44using Setfield
55using Setfield: PropertyLens, get
6+ using DensityInterface
7+ using Random
68
79"""
810 GraphInfo
@@ -222,7 +224,7 @@ function Base.getindex(m::Model, vn::VarName)
222224end
223225
224226"""
225- set_node_value!(m::Model, ind::VarName, value::T) where Takes
227+ set_node_value!(m::Model, ind::VarName, value::T) where T
226228
227229Change the value of the node.
228230
@@ -231,7 +233,7 @@ Change the value of the node.
231233```jl-doctest
232234julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
233235 μ = (1.0, () -> 1.0, :Logical),
234- y = (0.0, (μ, s2) -> MvNormal (μ, sqrt(s2)), :Stochastic))
236+ y = (0.0, (μ, s2) -> Normal (μ, sqrt(s2)), :Stochastic))
235237Nodes:
236238μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#38#41"(), kind = :Logical)
237239s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#37#40"(), kind = :Stochastic)
@@ -271,31 +273,58 @@ julia> get_node_value(m, @varname s2)
271273"""
272274
273275function get_node_value (m:: Model , ind:: VarName )
274- v = getproperty (m[ind], : value )
276+ v = get (m[ind], @lens _ . value)
275277 v[]
276278end
277- # Base.get(m::Model, ind::VarName, field::Symbol) = field==:value ? getvalue(m, ind) : getproperty(m[ind],field)
279+
280+ function get_node_value (m:: Model , ind:: NTuple{N, Symbol} ) where N
281+ # [get_node_value(m, VarName{S}()) for S in ind]
282+ values = Vector {Union{Float64, Array{Float64}}} ()
283+ for i in ind
284+ push! (values, get_node_value (m, VarName {i} ()))
285+ end
286+ values
287+ end
288+
289+ """
290+ get_node_ref_value(m::Model, ind::VarName)
291+ get_node_ref_value(m::Model, ind::NTuple{N, Symbol})
292+
293+ Return the mutable Ref value associated with a node or tuple
294+ of nodes.
295+ """
296+ function get_node_ref_value (m:: Model , ind:: VarName )
297+ get (m[ind], @lens _. value)
298+ end
299+
300+ function get_node_ref_value (m:: Model , ind:: NTuple{N, Symbol} ) where N
301+ values = Vector{Union{Base. RefValue{Float64}, Base. RefValue{Vector{Float64}}}}()
302+ for i in ind
303+ push! (values, get_node_ref_value (m, VarName {i} ()))
304+ end
305+ values
306+ end
278307
279308"""
280309 get_node_input(m::Model, ind::VarName)
281310
282311Retrieve the inputs/parents of a node, as given by model DAG.
283312"""
284- get_node_input (m:: Model , ind:: VarName ) = getproperty (m[ind], : input )
313+ get_node_input (m:: Model , ind:: VarName ) = get (m[ind], @lens _ . input)
285314
286315"""
287316 get_node_input(m::Model, ind::VarName)
288317
289318Retrieve the evaluation function for a node.
290319"""
291- get_node_eval (m:: Model , ind:: VarName ) = getproperty (m[ind], : eval )
320+ get_node_eval (m:: Model , ind:: VarName ) = get (m[ind], @lens _ . eval)
292321
293322"""
294323 get_nodekind(m::Model, ind::VarName)
295324
296325Retrieve the type of the node, i.e. stochastic or logical.
297326"""
298- get_nodekind (m:: Model , ind:: VarName ) = getproperty (m[ind], : kind )
327+ get_nodekind (m:: Model , ind:: VarName ) = get (m[ind], @lens _ . kind)
299328
300329"""
301330 get_dag(m::Model)
@@ -310,16 +339,48 @@ get_dag(m::Model) = m.g.A
310339Returns a `Vector{Symbol}` containing the sorted vertices
311340of the DAG.
312341"""
313- get_sorted_vertices (m:: Model ) = getproperty (m. g, :sorted_vertices )
342+ get_sorted_vertices (m:: Model ) = get (m. g, @lens _. sorted_vertices)
343+
344+
345+ """
346+ get_model_values(m::Model)
347+
348+ Returns a Named Tuple of nodes and node values.
349+ """
350+ function get_model_values (m:: Model{T} ) where T
351+ NamedTuple {T} (get_node_value (m, T))
352+ end
353+
354+ """
355+ get_model_ref_values(m::Model)
356+
357+ Returns a Named Tuple of nodes and node Ref values.
358+ """
359+ function get_model_ref_values (m:: Model{T} ) where T
360+ NamedTuple {T} (get_node_ref_value (m, T))
361+ end
362+
363+ """
364+ set_model_values!(m::Model, values::NamedTuple)
314365
366+ Changes the values of the `Model` node values to those
367+ given by a Named Tuple of node symboles and new values.
368+ """
369+ function set_model_values! (m:: Model{T} , values:: NamedTuple{T} ) where T
370+ for vn in keys (m)
371+ if get_nodekind (m, vn) != :Observations
372+ set_node_value! (m, vn, get (values, vn))
373+ end
374+ end
375+ end
315376# iterators
316377
317378function Base. iterate (m:: Model , state= 1 )
318379 state > length (get_sorted_vertices (m)) ? nothing : (m[VarName {m.g.sorted_vertices[state]} ()], state+ 1 )
319380end
320381
321382Base. eltype (m:: Model ) = NamedTuple{fieldnames (GraphInfo)[1 : 4 ]}
322- Base. IteratorEltype (m:: Model ) = HasEltype ()
383+ Base. IteratorEltype (m:: Model ) = Base . HasEltype ()
323384
324385Base. keys (m:: Model ) = (VarName {n} () for n in m. g. sorted_vertices)
325386Base. values (m:: Model ) = Base. Generator (identity, m)
@@ -333,4 +394,156 @@ function Base.show(io::IO, m::Model)
333394 for node in get_sorted_vertices (m)
334395 print (io, " $node = " , m[VarName {node} ()], " \n " )
335396 end
397+ end
398+
399+ """
400+ rand!(rng::AbstractRNG, m::Model)
401+
402+ Draw random samples from the model and mutate the node values.
403+
404+ # Examples
405+
406+ ```jl-doctest
407+ julia> import AbstractPPL.GraphPPL: Model, rand!
408+ using Distributions
409+
410+ julia> using Random; Random.seed!(1234)
411+ TaskLocalRNG()
412+
413+ julia> m = Model(s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
414+ μ = (1.0, () -> 1.0, :Logical),
415+ y = (0.0, (μ, s2) -> Normal(μ, sqrt(s2)), :Stochastic))
416+ Nodes:
417+ μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#6#9"(), kind = :Logical)
418+ s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#5#8"(), kind = :Stochastic)
419+ y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#7#10"(), kind = :Stochastic)
420+
421+
422+ julia> rand!(m)
423+ Nodes:
424+ μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#6#9"(), kind = :Logical)
425+ s2 = (input = (), value = Base.RefValue{Float64}(2.7478186975593846), eval = var"#5#8"(), kind = :Stochastic)
426+ y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.3044653509044275), eval = var"#7#10"(), kind = :Stochastic)
427+ ```
428+ """
429+ function Random. rand! (rng:: AbstractRNG , m:: AbstractPPL.GraphPPL.Model{T} ) where T
430+ for vn in keys (m)
431+ input, _, f, kind = m[vn]
432+ input_values = get_node_value (m, input)
433+ if kind == :Stochastic || kind == :Observations
434+ set_node_value! (m, vn, rand (rng, f (input_values... )))
435+ else
436+ set_node_value! (m, vn, f (input_values... ))
437+ end
438+ end
439+ m
440+ end
441+
442+ function Random. rand! (m:: AbstractPPL.GraphPPL.Model{T} ) where T
443+ rand! (Random. GLOBAL_RNG, m)
444+ end
445+
446+ """
447+ rand!(rng::AbstractRNG, m::Model)
448+
449+ Draw random samples from the model and mutate the node values.
450+
451+ # Examples
452+
453+ ```jl-doctest
454+ julia> using Random; Random.seed!(1234)
455+
456+ julia> import AbstractPPL.GraphPPL: Model, rand
457+ [ Info: Precompiling AbstractPPL [7a57a42e-76ec-4ea3-a279-07e840d6d9cf]
458+
459+ julia> using Distributions
460+
461+ julia> m = Model(s2 = (1.0, () -> InverseGamma(2.0,3.0), :Stochastic),
462+ μ = (0.0, () -> 1.0, :Logical),
463+ y = (0.0, (μ, s2) -> Normal(μ, sqrt(s2)), :Stochastic))
464+ Nodes:
465+ μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#6#9"(), kind = :Logical)
466+ s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#5#8"(), kind = :Stochastic)
467+ y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#7#10"(), kind = :Stochastic)
468+
469+ julia> rand(m)
470+ (μ = 1.0, s2 = 1.0907695400401212, y = 0.05821954440386368)
471+ ```
472+ """
473+ function Random. rand (rng:: AbstractRNG , sm:: Random.SamplerTrivial{Model{Tnames, Tinput, Tvalue, Teval, Tkind}} ) where {Tnames, Tinput, Tvalue, Teval, Tkind}
474+ m = deepcopy (sm[])
475+ get_model_values (rand! (rng, m))
476+ end
477+
478+ """
479+ logdensityof(m::Model)
480+
481+ Evaluate the log-densinty of the model.
482+
483+ # Examples
484+
485+ ```jl-doctest
486+ julia> using Random; Random.seed!(1234)
487+ MersenneTwister(1234)
488+
489+ julia> import AbstractPPL.GraphPPL: Model, logdensityof
490+ [ Info: Precompiling AbstractPPL [7a57a42e-76ec-4ea3-a279-07e840d6d9cf]
491+
492+ julia> using Distributions
493+
494+ julia> m = Model(s2 = (1.0, () -> InverseGamma(2.0,3.0), :Stochastic),
495+ μ = (0.0, () -> 1.0, :Logical),
496+ y = (0.0, (μ, s2) -> Normal(μ, sqrt(s2)), :Stochastic))
497+ Nodes:
498+ μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#6#9"(), kind = :Logical)
499+ s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#5#8"(), kind = :Stochastic)
500+ y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#7#10"(), kind = :Stochastic)
501+
502+ julia> logdensityof(m)
503+ -1.721713955868453
504+ ```
505+ """
506+ function DensityInterface. logdensityof (m:: AbstractPPL.GraphPPL.Model )
507+ logdensityof (m, get_model_values (m))
508+ end
509+
510+ """
511+ logdensityof(m::Model{T}, v::NamedTuple{T})
512+
513+ Evaluate the log-densinty of the model.
514+
515+ # Examples
516+
517+ ```jl-doctest
518+ julia> using Random; Random.seed!(1234)
519+ MersenneTwister(1234)
520+
521+ julia> import AbstractPPL.GraphPPL: Model, logdensityof, get_model_values
522+ [ Info: Precompiling AbstractPPL [7a57a42e-76ec-4ea3-a279-07e840d6d9cf]
523+
524+ julia> using Distributions
525+
526+ julia> m = Model(s2 = (1.0, () -> InverseGamma(2.0,3.0), :Stochastic),
527+ μ = (0.0, () -> 1.0, :Logical),
528+ y = (0.0, (μ, s2) -> Normal(μ, sqrt(s2)), :Stochastic))
529+ Nodes:
530+ μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#6#9"(), kind = :Logical)
531+ s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#5#8"(), kind = :Stochastic)
532+ y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#7#10"(), kind = :Stochastic)
533+
534+ julia> logdensityof(m, get_model_values(m))
535+ -1.721713955868453
536+ """
537+ function DensityInterface. logdensityof (m:: AbstractPPL.GraphPPL.Model{T} , v:: NamedTuple{T, V} ) where {T, V}
538+ lp = 0.0
539+ for vn in keys (m)
540+ input, _, f, kind = m[vn]
541+ input_values = get_node_value (m, input)
542+ value = get (v, vn)
543+ if kind == :Stochastic || kind == :Observations
544+ # check whether this is a constrained variable #TODO use bijectors.jl
545+ lp += logdensityof (f (input_values... ), value)
546+ end
547+ end
548+ lp
336549end
0 commit comments