-
Notifications
You must be signed in to change notification settings - Fork 228
Description
I decided to finally look into whether we still needed DistributionsAD. As far as I can tell, the only thing inside there that still sees use is filldist
and arraydist
.
My hypothesis is that we can completely replace them with the following:
# Current definitions
filldist = DistributionsAD.filldist
arraydist = DistributionsAD.arraydist
# Alternative, much simpler, definitions
filldist2(d::Distribution, n1::Int, ns::Int...) = product_distribution(Fill(d, n1, ns...))
arraydist2 = Distributions.product_distribution
To test this, I benchmarked rand
, logpdf
, and the gradient of logpdf
(with our standard backends) using the existing implementation and the proposed implementation.
Benchmarking code (click to expand)
using Distributions
using DistributionsAD: DistributionsAD
using Chairmarks
using FillArrays: Fill
using DifferentiationInterface: AutoForwardDiff, AutoReverseDiff, AutoMooncake, prepare_gradient, gradient
import ForwardDiff
import ReverseDiff
import Mooncake
# Current definitions
filldist = DistributionsAD.filldist
arraydist = DistributionsAD.arraydist
# Alternative, much simpler, definitions
filldist2(d::Distribution, n1::Int, ns::Int...) = product_distribution(Fill(d, n1, ns...))
arraydist2 = Distributions.product_distribution
# AD backends to test
backends = Dict(
"FD" => AutoForwardDiff(),
"RD" => AutoReverseDiff(),
"MC" => AutoMooncake(),
)
println("\n")
println("filldist")
println("========")
# Benchmark filldist
for dist in [Normal(), Beta(2, 2), MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0]), Wishart(7, [1.0 0.5; 0.5 1.0])]
println("\n$(typeof(dist))")
fd = filldist(dist, 2, 3)
fd2 = filldist2(dist, 2, 3)
fd_rand = @be rand($fd)
fd2_rand = @be rand($fd2)
println(" rand filldist/filldist2: $(median(fd_rand).time/median(fd2_rand).time)")
r = rand(fd)
fd_logp = @be logpdf($fd, $r)
fd2_logp = @be logpdf($fd2, $r)
println(" logpdf filldist/filldist2: $(median(fd_logp).time/median(fd2_logp).time)")
if !(dist isa Wishart)
for (name, adtype) in backends
f = Base.Fix1(logpdf, fd)
f2 = Base.Fix1(logpdf, fd2)
prep = prepare_gradient(f, adtype, r)
prep2 = prepare_gradient(f2, adtype, r)
fd_grad_logp = @be gradient($f, $prep, $adtype, $r)
fd2_grad_logp = @be gradient($f2, $prep2, $adtype, $r)
println(" $name ∇logpdf filldist/filldist2: $(median(fd_grad_logp).time/median(fd2_grad_logp).time)")
end
end
end
println("\n")
println("arraydist")
println("=========")
# Benchmark arraydist
for dists in [
[Normal(0.0), Normal(1.0), Normal(2.0, 3.0), Normal(4.0, 0.2)],
[Beta(2, 2), InverseGamma(2, 3), Normal()],
[MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0]), MvNormal([2.0, 4.0], [1.0 0.5; 0.5 1.0])],
]
println("\ndistributions starting with: $(typeof(dists[1]))")
ad = arraydist(dists)
ad2 = arraydist2(dists)
ad_rand = @be rand($ad)
ad2_rand = @be rand($ad2)
println(" rand arraydist/arraydist2: $(median(ad_rand).time/median(ad2_rand).time)")
r = rand(ad)
ad_logp = @be logpdf($ad, $r)
ad2_logp = @be logpdf($ad2, $r)
println(" logpdf arraydist/arraydist2: $(median(ad_logp).time/median(ad2_logp).time)")
for (name, adtype) in backends
f = Base.Fix1(logpdf, ad)
f2 = Base.Fix1(logpdf, ad2)
prep = prepare_gradient(f, adtype, r)
prep2 = prepare_gradient(f2, adtype, r)
ad_grad_logp = @be gradient($f, $prep, $adtype, $r)
ad2_grad_logp = @be gradient($f2, $prep2, $adtype, $r)
println(" $name ∇logpdf arraydist/arraydist2: $(median(ad_grad_logp).time/median(ad2_grad_logp).time)")
end
end
Results
As can be verified from the benchmarking code above, the numbers reported are (existing / proposed), i.e., a number < 1 means that DistributionsAD is more efficient.
filldist
========
Normal{Float64}
rand filldist/filldist2: 1.0085484029484029
logpdf filldist/filldist2: 1.289915841908632
FD ∇logpdf filldist/filldist2: 1.1131279951870339
RD ∇logpdf filldist/filldist2: 1.0306648131490719
MC ∇logpdf filldist/filldist2: 2.3058095238095238
Beta{Float64}
rand filldist/filldist2: 1.069383434601855
logpdf filldist/filldist2: 1.011111111111111
FD ∇logpdf filldist/filldist2: 1.0779808895335319
RD ∇logpdf filldist/filldist2: 1.0
MC ∇logpdf filldist/filldist2: 1.2243054833003848
FullNormal
rand filldist/filldist2: 1.0057641122777223
logpdf filldist/filldist2: 1.0158160093929138
FD ∇logpdf filldist/filldist2: 1.014606050887321
RD ∇logpdf filldist/filldist2: 0.9965654224944137
MC ∇logpdf filldist/filldist2: 1.0277829296338885
Wishart{Float64, PDMats.PDMat{Float64, Matrix{Float64}}, Int64}
rand filldist/filldist2: 1.0972668192219681
logpdf filldist/filldist2: 1.0186094420600857
arraydist
=========
distributions starting with: Normal{Float64}
rand arraydist/arraydist2: 0.8166216049657906
logpdf arraydist/arraydist2: 0.3643514465925137
FD ∇logpdf arraydist/arraydist2: 0.7815570133385658
RD ∇logpdf arraydist/arraydist2: 0.8234206667941607
MC ∇logpdf arraydist/arraydist2: 0.46417765631907404
distributions starting with: Beta{Float64}
rand arraydist/arraydist2: 1.0020691994572593
logpdf arraydist/arraydist2: 1.0045024289620155
FD ∇logpdf arraydist/arraydist2: 1.020335676643539
RD ∇logpdf arraydist/arraydist2: 1.0028416872089838
MC ∇logpdf arraydist/arraydist2: 0.9901037037037037
distributions starting with: FullNormal
rand arraydist/arraydist2: 1.7908701400560223
logpdf arraydist/arraydist2: 1.137092822555559
FD ∇logpdf arraydist/arraydist2: 0.8979361936193618
RD ∇logpdf arraydist/arraydist2: 1.0244971751412428
MC ∇logpdf arraydist/arraydist2: 1.041990500863558
The only case where DistributionsAD provides an obvious benefit is the first arraydist()
case, i.e. a vector of Normals.
In fact, DistributionsAD.arraydist(Normal...)
doesn't actually return a custom type. It returns Distributions.Product
(which is actually deprecated in Distributions.jl).
In contrast, arraydist2
i.e. Distributions.product_distribution
goes one step further and returns a MvNormal:
julia> arraydist([Normal(0.0), Normal(1.0), Normal(2.0, 3.0), Normal(4.0, 0.2)])
Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}(v=Normal{Float64}[Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=1.0, σ=1.0), Normal{Float64}(μ=2.0, σ=3.0), Normal{Float64}(μ=4.0, σ=0.2)])
julia> arraydist2([Normal(0.0), Normal(1.0), Normal(2.0, 3.0), Normal(4.0, 0.2)])
DiagNormal(
dim: 4
μ: [0.0, 1.0, 2.0, 4.0]
Σ: [1.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0; 0.0 0.0 9.0 0.0; 0.0 0.0 0.0 0.04000000000000001]
)
One could argue then that in this case, Distributions.product_distribution
has an inefficient implementation for Normals and should be fixed upstream in Distributions.jl. I reported this upstream: JuliaStats/Distributions.jl#1989
Even in the current state, without any upstream fixes, I think we can:
- deprecate / remove DistributionsAD
- add my suggested definitions of
filldist
andarraydist
to DynamicPPL (or maybe AbstractPPLDistributionsExt) - add an
arraydist
specialisation for Normals
If the performance is fixed upstream in Distributions, then we can further remove the arraydist
specialisation.
(ppl) pkg> st
Status `~/ppl/Project.toml`
[0ca39b1e] Chairmarks v1.3.1
[a0c0ee7d] DifferentiationInterface v0.7.1
[31c24e10] Distributions v0.25.120
[ced4e74d] DistributionsAD v0.6.58
[1a297f60] FillArrays v1.13.0
[f6369f11] ForwardDiff v1.0.1
[da2b9cff] Mooncake v0.4.137
[37e2e3b7] ReverseDiff v1.16.1