Skip to content

Correctness of Gibbs implementation (?) #2627

@penelopeysm

Description

@penelopeysm

In the current ('new') Gibbs implementation, there are these two branches of code:

Turing.jl/src/mcmc/gibbs.jl

Lines 179 to 185 in 23b92eb

elseif has_conditioned_gibbs(context, vn)
# Short-circuit the tilde assume if `vn` is present in `context`.
value, lp, _ = DynamicPPL.tilde_assume(
child_context, right, vn, get_global_varinfo(context)
)
value, lp, vi
else

Turing.jl/src/mcmc/gibbs.jl

Lines 212 to 217 in 23b92eb

elseif has_conditioned_gibbs(context, vn)
value, lp, _ = DynamicPPL.tilde_assume(
child_context, right, vn, get_global_varinfo(context)
)
value, lp, vi
else

(For additional context, this topic came up when attempting to fix the Gibbs sampler for [email protected]. It is not immediately clear how to fix the code above to work with accumulators, and my hunch was that the entire thing could basically be deleted. Testing this out was what led me down this path.)

When sampling with a component sampler, these branches are responsible for variables that are handled by other samplers. For example, in Gibbs(:a => Sampler1(), :b => Sampler2()), this path would be triggered when Sampler1() encounters b ~ dist.

In Gibbs sampling the aim for Sampler1 is to sample new values of a from the conditional distribution p(a | b). My reading of this is that, if the value of b is really fixed, then adding the value of logpdf(b, dist) here should not make a difference to the behaviour of the Gibbs sampler. In other words, we could simply ignore the lp term in that branch. Instead of returning value, lp, vi we could return value, 0.0, vi.

I have tested this with a couple of models and so far this hypothesis has held up. For example:

using Turing, Random

@model function f(y)
    sd ~ truncated(Cauchy(0, 5); lower = 0)
    mean ~ Normal(0, 5)
    y .~ Normal(mean, sd)
end

model = f(rand(Xoshiro(468), 10) .+ 2)
describe(sample(Xoshiro(468), model, Gibbs(:mean => MH(), :sd => HMC(0.1, 20)), 1000))

The above code block always returns the same values in the chain regardless of whether lp is included or thrown away.

This makes sense because (e.g.) in the MH sampler, the extra log-prob associated with sd is simply a constant term and it will cancel out when calculating the acceptance probability of the transition, which only depends on mean. It's possible to insert some extra printing to prove that this is the case. And in the HMC sampler, the extra log-prob associated with mean is again constant and so the gradients with respect to sd will be unaffected.

Now, here is a mildly more complicated model:

J = 8
y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
@model function pdb_eight_schools_noncentered(J, y, sigma, ::Type{TV}=Vector{Float64}) where {TV}
    mu ~ Normal(0, 5)
    tau ~ truncated(Cauchy(0, 5); lower = 0)
    theta_trans = TV(undef, J)
    theta = TV(undef, J)
    for i = 1:J
        theta_trans[i] ~ Normal(0, 1)
        theta[i] = theta_trans[i] * tau + mu
        y[i] ~ Normal(theta[i], sigma[i])
    end
end
model = pdb_eight_schools_noncentered(J, y, sigma)

chn = sample(Xoshiro(468), model, Gibbs((:mu, :tau) => HMC(0.1, 20), :theta_trans => MH()), 20000)

As it turns out, my hunch is still correct: ignoring the logp term doesn't affect sampling at all.

But, if we use the centred version of the model:

J = 8
y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
@model function pdb_eight_schools_centered(J, y, sigma)
    mu ~ Normal(0, 5)
    tau ~ truncated(Cauchy(0, 5); lower = 0)
    theta = Vector{Float64}(undef, J)
    for i = 1:J
        theta[i] ~ Normal(mu, tau)
        y[i] ~ Normal(theta[i], sigma[i])
    end
end
model = pdb_eight_schools_centered(J, y, sigma)

chn = sample(Xoshiro(468), model, Gibbs((:mu, :tau) => HMC(0.1, 20), :theta => MH()), 20000)

this is what happens if you ignore the lp:

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

          mu   -1.1596    0.1706    0.0140   146.2745   506.1209    1.0051       14.1369
         tau    0.2833    0.1349    0.0106   108.8022    44.8088    1.0147       10.5153
    theta[1]   -1.1828    0.2561    0.0219   139.9138    84.1313    1.0278       13.5222
    theta[2]   -1.1828    0.2593    0.0209   156.7183   110.0804    1.0026       15.1463
    theta[3]   -1.1403    0.2654    0.0255   108.3782    88.5591    1.0102       10.4744
    theta[4]   -1.1749    0.2591    0.0225   137.1817   123.8335    1.0181       13.2581
    theta[5]   -1.1389    0.2703    0.0262   130.3532   361.2658    1.0047       12.5982
    theta[6]   -1.1323    0.2470    0.0210   139.2951   128.0701    1.0145       13.4624
    theta[7]   -1.1378    0.2242    0.0148   239.9480   232.9650    1.0307       23.1901
    theta[8]   -1.1797    0.2700    0.0250   118.2959    72.6893    1.0019       11.4329

And this is what happens if you don't ignore the lp:

Summary Statistics
  parameters      mean        std      mcse     ess_bulk     ess_tail      rhat   ess_per_sec
      Symbol   Float64    Float64   Float64      Float64      Float64   Float64       Float64

          mu   -0.0720     5.0446    0.1779     803.6753    1798.8552    1.0000       99.1824
         tau   36.9890   915.2464    7.4628   10549.5014   10327.4102    1.0000     1301.9254
    theta[1]   -4.1584     2.1548    0.0693     950.5938     753.0534    1.0012      117.3138
    theta[2]   -4.3377     2.0876    0.0545    1462.1444    1024.8905    1.0007      180.4448
    theta[3]   -4.3305     2.1344    0.0584    1314.5257    1027.4664    1.0016      162.2270
    theta[4]   -4.3065     2.1558    0.0553    1509.2291    1129.3466    1.0005      186.2556
    theta[5]   -4.3102     2.1356    0.0621    1205.8376     537.2088    1.0000      148.8137
    theta[6]   -4.3503     2.1383    0.0596    1272.5214     975.1728    1.0030      157.0432
    theta[7]   -4.0175     2.0842    0.0548    1411.4961     739.3024    1.0000      174.1943
    theta[8]   -4.3535     2.1945    0.0760     877.4227     669.0904    1.0003      108.2837

The quality of the samples themselves is complete rubbish. That I'm not too worried about, because the centred model is known to be harder to sample from.

I'm more worried that the results are different at all, and I don't quite understand why this is the case for this model and not the others, and I'm mildly concerned that there is a correctness issue somewhere here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions