Skip to content

Remove DistributionsAD #2613

@penelopeysm

Description

@penelopeysm

I decided to finally look into whether we still needed DistributionsAD. As far as I can tell, the only thing inside there that still sees use is filldist and arraydist.

My hypothesis is that we can completely replace them with the following:

# Current definitions
filldist = DistributionsAD.filldist
arraydist = DistributionsAD.arraydist

# Alternative, much simpler, definitions
filldist2(d::Distribution, n1::Int, ns::Int...) = product_distribution(Fill(d, n1, ns...))
arraydist2 = Distributions.product_distribution

To test this, I benchmarked rand, logpdf, and the gradient of logpdf (with our standard backends) using the existing implementation and the proposed implementation.

Benchmarking code (click to expand)
using Distributions
using DistributionsAD: DistributionsAD
using Chairmarks
using FillArrays: Fill
using DifferentiationInterface: AutoForwardDiff, AutoReverseDiff, AutoMooncake, prepare_gradient, gradient
import ForwardDiff
import ReverseDiff
import Mooncake

# Current definitions
filldist = DistributionsAD.filldist
arraydist = DistributionsAD.arraydist

# Alternative, much simpler, definitions
filldist2(d::Distribution, n1::Int, ns::Int...) = product_distribution(Fill(d, n1, ns...))
arraydist2 = Distributions.product_distribution

# AD backends to test
backends = Dict(
    "FD" => AutoForwardDiff(),
    "RD" => AutoReverseDiff(),
    "MC" => AutoMooncake(),
)

println("\n")
println("filldist")
println("========")

# Benchmark filldist
for dist in [Normal(), Beta(2, 2), MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0]), Wishart(7, [1.0 0.5; 0.5 1.0])]
    println("\n$(typeof(dist))")
    fd = filldist(dist, 2, 3)
    fd2 = filldist2(dist, 2, 3)
    fd_rand = @be rand($fd)
    fd2_rand = @be rand($fd2)
    println("       rand   filldist/filldist2: $(median(fd_rand).time/median(fd2_rand).time)")
    r = rand(fd)
    fd_logp = @be logpdf($fd, $r)
    fd2_logp = @be logpdf($fd2, $r)
    println("       logpdf filldist/filldist2: $(median(fd_logp).time/median(fd2_logp).time)")
    if !(dist isa Wishart)
        for (name, adtype) in backends
            f = Base.Fix1(logpdf, fd)
            f2 = Base.Fix1(logpdf, fd2)
            prep = prepare_gradient(f, adtype, r)
            prep2 = prepare_gradient(f2, adtype, r)
            fd_grad_logp = @be gradient($f, $prep, $adtype, $r)
            fd2_grad_logp = @be gradient($f2, $prep2, $adtype, $r)
            println("   $name ∇logpdf filldist/filldist2: $(median(fd_grad_logp).time/median(fd2_grad_logp).time)")
        end
    end
end

println("\n")
println("arraydist")
println("=========")

# Benchmark arraydist
for dists in [
    [Normal(0.0), Normal(1.0), Normal(2.0, 3.0), Normal(4.0, 0.2)],
    [Beta(2, 2), InverseGamma(2, 3), Normal()],
    [MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0]), MvNormal([2.0, 4.0], [1.0 0.5; 0.5 1.0])],
]
    println("\ndistributions starting with: $(typeof(dists[1]))")
    ad = arraydist(dists)
    ad2 = arraydist2(dists)
    ad_rand = @be rand($ad)
    ad2_rand = @be rand($ad2)
    println("       rand   arraydist/arraydist2: $(median(ad_rand).time/median(ad2_rand).time)")
    r = rand(ad)
    ad_logp = @be logpdf($ad, $r)
    ad2_logp = @be logpdf($ad2, $r)
    println("       logpdf arraydist/arraydist2: $(median(ad_logp).time/median(ad2_logp).time)")
    for (name, adtype) in backends
        f = Base.Fix1(logpdf, ad)
        f2 = Base.Fix1(logpdf, ad2)
        prep = prepare_gradient(f, adtype, r)
        prep2 = prepare_gradient(f2, adtype, r)
        ad_grad_logp = @be gradient($f, $prep, $adtype, $r)
        ad2_grad_logp = @be gradient($f2, $prep2, $adtype, $r)
        println("   $name ∇logpdf arraydist/arraydist2: $(median(ad_grad_logp).time/median(ad2_grad_logp).time)")
    end
end

Results

As can be verified from the benchmarking code above, the numbers reported are (existing / proposed), i.e., a number < 1 means that DistributionsAD is more efficient.

filldist
========

Normal{Float64}
       rand   filldist/filldist2: 1.0085484029484029
       logpdf filldist/filldist2: 1.289915841908632
   FD ∇logpdf filldist/filldist2: 1.1131279951870339
   RD ∇logpdf filldist/filldist2: 1.0306648131490719
   MC ∇logpdf filldist/filldist2: 2.3058095238095238

Beta{Float64}
       rand   filldist/filldist2: 1.069383434601855
       logpdf filldist/filldist2: 1.011111111111111
   FD ∇logpdf filldist/filldist2: 1.0779808895335319
   RD ∇logpdf filldist/filldist2: 1.0
   MC ∇logpdf filldist/filldist2: 1.2243054833003848

FullNormal
       rand   filldist/filldist2: 1.0057641122777223
       logpdf filldist/filldist2: 1.0158160093929138
   FD ∇logpdf filldist/filldist2: 1.014606050887321
   RD ∇logpdf filldist/filldist2: 0.9965654224944137
   MC ∇logpdf filldist/filldist2: 1.0277829296338885

Wishart{Float64, PDMats.PDMat{Float64, Matrix{Float64}}, Int64}
       rand   filldist/filldist2: 1.0972668192219681
       logpdf filldist/filldist2: 1.0186094420600857


arraydist
=========

distributions starting with: Normal{Float64}
       rand   arraydist/arraydist2: 0.8166216049657906
       logpdf arraydist/arraydist2: 0.3643514465925137
   FD ∇logpdf arraydist/arraydist2: 0.7815570133385658
   RD ∇logpdf arraydist/arraydist2: 0.8234206667941607
   MC ∇logpdf arraydist/arraydist2: 0.46417765631907404

distributions starting with: Beta{Float64}
       rand   arraydist/arraydist2: 1.0020691994572593
       logpdf arraydist/arraydist2: 1.0045024289620155
   FD ∇logpdf arraydist/arraydist2: 1.020335676643539
   RD ∇logpdf arraydist/arraydist2: 1.0028416872089838
   MC ∇logpdf arraydist/arraydist2: 0.9901037037037037

distributions starting with: FullNormal
       rand   arraydist/arraydist2: 1.7908701400560223
       logpdf arraydist/arraydist2: 1.137092822555559
   FD ∇logpdf arraydist/arraydist2: 0.8979361936193618
   RD ∇logpdf arraydist/arraydist2: 1.0244971751412428
   MC ∇logpdf arraydist/arraydist2: 1.041990500863558

The only case where DistributionsAD provides an obvious benefit is the first arraydist() case, i.e. a vector of Normals.

In fact, DistributionsAD.arraydist(Normal...) doesn't actually return a custom type. It returns Distributions.Product (which is actually deprecated in Distributions.jl).

In contrast, arraydist2 i.e. Distributions.product_distribution goes one step further and returns a MvNormal:

julia> arraydist([Normal(0.0), Normal(1.0), Normal(2.0, 3.0), Normal(4.0, 0.2)])
Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}(v=Normal{Float64}[Normal{Float64}=0.0, σ=1.0), Normal{Float64}=1.0, σ=1.0), Normal{Float64}=2.0, σ=3.0), Normal{Float64}=4.0, σ=0.2)])

julia> arraydist2([Normal(0.0), Normal(1.0), Normal(2.0, 3.0), Normal(4.0, 0.2)])
DiagNormal(
dim: 4
μ: [0.0, 1.0, 2.0, 4.0]
Σ: [1.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0; 0.0 0.0 9.0 0.0; 0.0 0.0 0.0 0.04000000000000001]
)

One could argue then that in this case, Distributions.product_distribution has an inefficient implementation for Normals and should be fixed upstream in Distributions.jl. I reported this upstream: JuliaStats/Distributions.jl#1989

Even in the current state, without any upstream fixes, I think we can:

  1. deprecate / remove DistributionsAD
  2. add my suggested definitions of filldist and arraydist to DynamicPPL (or maybe AbstractPPLDistributionsExt)
  3. add an arraydist specialisation for Normals

If the performance is fixed upstream in Distributions, then we can further remove the arraydist specialisation.

(ppl) pkg> st
Status `~/ppl/Project.toml`
  [0ca39b1e] Chairmarks v1.3.1
  [a0c0ee7d] DifferentiationInterface v0.7.1
  [31c24e10] Distributions v0.25.120
  [ced4e74d] DistributionsAD v0.6.58
  [1a297f60] FillArrays v1.13.0
  [f6369f11] ForwardDiff v1.0.1
  [da2b9cff] Mooncake v0.4.137
  [37e2e3b7] ReverseDiff v1.16.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions