-
Notifications
You must be signed in to change notification settings - Fork 228
Description
In the current ('new') Gibbs implementation, there are these two branches of code:
Lines 179 to 185 in 23b92eb
elseif has_conditioned_gibbs(context, vn) | |
# Short-circuit the tilde assume if `vn` is present in `context`. | |
value, lp, _ = DynamicPPL.tilde_assume( | |
child_context, right, vn, get_global_varinfo(context) | |
) | |
value, lp, vi | |
else |
Lines 212 to 217 in 23b92eb
elseif has_conditioned_gibbs(context, vn) | |
value, lp, _ = DynamicPPL.tilde_assume( | |
child_context, right, vn, get_global_varinfo(context) | |
) | |
value, lp, vi | |
else |
(For additional context, this topic came up when attempting to fix the Gibbs sampler for [email protected]. It is not immediately clear how to fix the code above to work with accumulators, and my hunch was that the entire thing could basically be deleted. Testing this out was what led me down this path.)
When sampling with a component sampler, these branches are responsible for variables that are handled by other samplers. For example, in Gibbs(:a => Sampler1(), :b => Sampler2())
, this path would be triggered when Sampler1()
encounters b ~ dist
.
In Gibbs sampling the aim for Sampler1
is to sample new values of a
from the conditional distribution p(a | b)
. My reading of this is that, if the value of b
is really fixed, then adding the value of logpdf(b, dist)
here should not make a difference to the behaviour of the Gibbs sampler. In other words, we could simply ignore the lp
term in that branch. Instead of returning value, lp, vi
we could return value, 0.0, vi
.
I have tested this with a couple of models and so far this hypothesis has held up. For example:
using Turing, Random
@model function f(y)
sd ~ truncated(Cauchy(0, 5); lower = 0)
mean ~ Normal(0, 5)
y .~ Normal(mean, sd)
end
model = f(rand(Xoshiro(468), 10) .+ 2)
describe(sample(Xoshiro(468), model, Gibbs(:mean => MH(), :sd => HMC(0.1, 20)), 1000))
The above code block always returns the same values in the chain regardless of whether lp
is included or thrown away.
This makes sense because (e.g.) in the MH sampler, the extra log-prob associated with sd
is simply a constant term and it will cancel out when calculating the acceptance probability of the transition, which only depends on mean
. It's possible to insert some extra printing to prove that this is the case. And in the HMC sampler, the extra log-prob associated with mean
is again constant and so the gradients with respect to sd
will be unaffected.
Now, here is a mildly more complicated model:
J = 8
y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
@model function pdb_eight_schools_noncentered(J, y, sigma, ::Type{TV}=Vector{Float64}) where {TV}
mu ~ Normal(0, 5)
tau ~ truncated(Cauchy(0, 5); lower = 0)
theta_trans = TV(undef, J)
theta = TV(undef, J)
for i = 1:J
theta_trans[i] ~ Normal(0, 1)
theta[i] = theta_trans[i] * tau + mu
y[i] ~ Normal(theta[i], sigma[i])
end
end
model = pdb_eight_schools_noncentered(J, y, sigma)
chn = sample(Xoshiro(468), model, Gibbs((:mu, :tau) => HMC(0.1, 20), :theta_trans => MH()), 20000)
As it turns out, my hunch is still correct: ignoring the logp term doesn't affect sampling at all.
But, if we use the centred version of the model:
J = 8
y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
@model function pdb_eight_schools_centered(J, y, sigma)
mu ~ Normal(0, 5)
tau ~ truncated(Cauchy(0, 5); lower = 0)
theta = Vector{Float64}(undef, J)
for i = 1:J
theta[i] ~ Normal(mu, tau)
y[i] ~ Normal(theta[i], sigma[i])
end
end
model = pdb_eight_schools_centered(J, y, sigma)
chn = sample(Xoshiro(468), model, Gibbs((:mu, :tau) => HMC(0.1, 20), :theta => MH()), 20000)
this is what happens if you ignore the lp:
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
mu -1.1596 0.1706 0.0140 146.2745 506.1209 1.0051 14.1369
tau 0.2833 0.1349 0.0106 108.8022 44.8088 1.0147 10.5153
theta[1] -1.1828 0.2561 0.0219 139.9138 84.1313 1.0278 13.5222
theta[2] -1.1828 0.2593 0.0209 156.7183 110.0804 1.0026 15.1463
theta[3] -1.1403 0.2654 0.0255 108.3782 88.5591 1.0102 10.4744
theta[4] -1.1749 0.2591 0.0225 137.1817 123.8335 1.0181 13.2581
theta[5] -1.1389 0.2703 0.0262 130.3532 361.2658 1.0047 12.5982
theta[6] -1.1323 0.2470 0.0210 139.2951 128.0701 1.0145 13.4624
theta[7] -1.1378 0.2242 0.0148 239.9480 232.9650 1.0307 23.1901
theta[8] -1.1797 0.2700 0.0250 118.2959 72.6893 1.0019 11.4329
And this is what happens if you don't ignore the lp:
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
mu -0.0720 5.0446 0.1779 803.6753 1798.8552 1.0000 99.1824
tau 36.9890 915.2464 7.4628 10549.5014 10327.4102 1.0000 1301.9254
theta[1] -4.1584 2.1548 0.0693 950.5938 753.0534 1.0012 117.3138
theta[2] -4.3377 2.0876 0.0545 1462.1444 1024.8905 1.0007 180.4448
theta[3] -4.3305 2.1344 0.0584 1314.5257 1027.4664 1.0016 162.2270
theta[4] -4.3065 2.1558 0.0553 1509.2291 1129.3466 1.0005 186.2556
theta[5] -4.3102 2.1356 0.0621 1205.8376 537.2088 1.0000 148.8137
theta[6] -4.3503 2.1383 0.0596 1272.5214 975.1728 1.0030 157.0432
theta[7] -4.0175 2.0842 0.0548 1411.4961 739.3024 1.0000 174.1943
theta[8] -4.3535 2.1945 0.0760 877.4227 669.0904 1.0003 108.2837
The quality of the samples themselves is complete rubbish. That I'm not too worried about, because the centred model is known to be harder to sample from.
I'm more worried that the results are different at all, and I don't quite understand why this is the case for this model and not the others, and I'm mildly concerned that there is a correctness issue somewhere here.