Skip to content

Conversation

@penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Oct 9, 2025

this PR:

  • makes predict(model, chn; include_all) work: previously it would error unless the rng argument was also specified
  • cuts out one model evaluation per MCMC iteration when predicting, which should make things nicely faster

benchmarks using the example from DPPL test suite:

using DynamicPPL, Distributions, MCMCChains, Chairmarks

@model function linear_reg(x, y, σ=0.1)
    β ~ Normal(0, 1)
    for i in eachindex(y)
        y[i] ~ Normal* x[i], σ)
    end
    # Insert a := block to test that it is not included in predictions
    return σ2 := σ^2
end

# Construct a chain with 'sampled values' of β
ground_truth_β = 2
β_chain = MCMCChains.Chains(
    rand(Normal(ground_truth_β, 0.002), 1000),
    [];
    info=(; varname_to_symbol=Dict(@varname(β) => )),
)

# Generate predictions from that chain
xs_test = [10 + 0.1, 10 + 2 * 0.1]
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
@be DynamicPPL.predict(m_lin_reg_test, β_chain)

# breaking
julia> @be DynamicPPL.predict(m_lin_reg_test, β_chain)
Benchmark: 5 samples with 1 evaluation
 min    19.908 ms (284965 allocs: 12.245 MiB)
 median 20.093 ms (284965 allocs: 12.245 MiB)
 mean   20.468 ms (285165 allocs: 12.453 MiB)
 max    22.045 ms (285965 allocs: 13.283 MiB)

# this PR
julia> @be DynamicPPL.predict(m_lin_reg_test, β_chain)
Benchmark: 8 samples with 1 evaluation
 min    12.631 ms (194998 allocs: 6.853 MiB)
 median 12.656 ms (194998 allocs: 6.853 MiB)
 mean   12.778 ms (194998 allocs: 6.853 MiB)
 max    13.553 ms (194998 allocs: 6.853 MiB)

@github-actions
Copy link
Contributor

github-actions bot commented Oct 9, 2025

DynamicPPL.jl documentation for PR #1068 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1068/

@codecov
Copy link

codecov bot commented Oct 9, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.47%. Comparing base (ec65b4f) to head (b0f1e96).
⚠️ Report is 1 commits behind head on breaking.

Additional details and impacted files
@@             Coverage Diff              @@
##           breaking    #1068      +/-   ##
============================================
+ Coverage     82.38%   82.47%   +0.09%     
============================================
  Files            42       42              
  Lines          3820     3824       +4     
============================================
+ Hits           3147     3154       +7     
+ Misses          673      670       -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@sunxd3 sunxd3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

@penelopeysm
Copy link
Member Author

penelopeysm commented Oct 9, 2025

If this is of any interest: Generally for small models MCMCChains does better than FlexiChains (I think this is mainly because constructing a FlexiChain from an array-of-dicts is quite slow because the internal storage is dict-of-array), but in the case where there are parameters that are very long vectors, MCMCChains scales very poorly because it has to split them up.

Note that this doesn't apply to the model in the top comment, since in that model, each element of y is a separate varname (so regardless of which chain backend is used, it gets split up).

### SETUP - using this PR
using DynamicPPL, Distributions, DistributionsAD, FlexiChains, MCMCChains, Chairmarks

@model function f(N, σ=0.1)
    m ~ Normal(0, 1)
    y ~ filldist(Normal(m, 1), N)
end

# Construct a chain with 'sampled values' of m
ground_truth_m = 2
randvals = rand(Normal(ground_truth_m, 0.002), 1000)
mchain = MCMCChains.Chains(
    randvals,
    [:m];
    info=(; varname_to_symbol=Dict(@varname(m) => :m)),
)
fchain = FlexiChains.FlexiChain{VarName}(
    1000, 1,
    Dict(FlexiChains.Parameter(@varname(m)) => randvals)
)

m1 = f(1)
@be DynamicPPL.predict(m1, fchain; include_all=false)
@be DynamicPPL.predict(m1, mchain; include_all=false)

m1000 = f(1000)
@be DynamicPPL.predict(m1000, fchain; include_all=false)
@be DynamicPPL.predict(m1000, mchain; include_all=false)

### RESULTS

julia> @be DynamicPPL.predict(m1, fchain; include_all=false)
Benchmark: 7 samples with 1 evaluation
 min    14.465 ms (193039 allocs: 8.451 MiB)
 median 15.053 ms (193039 allocs: 8.451 MiB)
 mean   16.036 ms (193039 allocs: 8.451 MiB, 5.14% gc time)
 max    22.861 ms (193039 allocs: 8.451 MiB, 35.95% gc time)

julia> @be DynamicPPL.predict(m1, mchain; include_all=false)
Benchmark: 9 samples with 1 evaluation
 min    11.015 ms (159456 allocs: 6.226 MiB)
 median 11.365 ms (159456 allocs: 6.226 MiB)
 mean   12.149 ms (159456 allocs: 6.226 MiB, 4.55% gc time)
 max    18.945 ms (159456 allocs: 6.226 MiB, 40.98% gc time)

julia> m1000 = f(1000)
Model{typeof(f), (:N, :σ), (), (), Tuple{Int64, Float64}, Tuple{}, DefaultContext}(f, (N = 1000, σ = 0.1), NamedTuple(), DefaultContext())

julia> @be DynamicPPL.predict(m1000, fchain; include_all=false)
Benchmark: 3 samples with 1 evaluation
        27.579 ms (202050 allocs: 63.224 MiB, 20.14% gc time)
        46.832 ms (202050 allocs: 63.224 MiB, 52.75% gc time)
        163.931 ms (202050 allocs: 63.224 MiB, 86.04% gc time)

julia> @be DynamicPPL.predict(m1000, mchain; include_all=false)
Benchmark: 1 sample with 1 evaluation
        2.648 s (25296094 allocs: 826.041 MiB, 9.72% gc time, without a warmup)

@penelopeysm penelopeysm merged commit 0fa5540 into breaking Oct 9, 2025
19 of 21 checks passed
@penelopeysm penelopeysm deleted the py/includeall branch October 9, 2025 18:42
penelopeysm added a commit that referenced this pull request Oct 21, 2025
* Bump minor version

* bump benchmarks compat

* add a skeletal changelog

* `InitContext`, part 3 - Introduce `InitContext` (#981)

* Implement InitContext

* Fix loading order of modules; move `prefix(::Model)` to model.jl

* Add tests for InitContext behaviour

* inline `rand(::Distributions.Uniform)`

Note that, apart from being simpler code, Distributions.Uniform also
doesn't allow the lower and upper bounds to be exactly equal (but we
might like to keep that option open in DynamicPPL, e.g. if the user
wants to initialise all values to the same value in linked space).

* Document

* Add a test to check that `init!!` doesn't change linking

* Fix `push!` for VarNamedVector

This should have been changed in #940, but slipped through as the file
wasn't listed as one of the changed files.

* Add some line breaks

Co-authored-by: Markus Hauru <[email protected]>

* Add the option of no fallback for ParamsInit

* Improve docstrings

* typo

* `p.default` -> `p.fallback`

* Rename `{Prior,Uniform,Params}Init` -> `InitFrom{Prior,Uniform,Params}`

---------

Co-authored-by: Markus Hauru <[email protected]>

* use `varname_leaves` from AbstractPPL instead (#1030)

* use `varname_leaves` from AbstractPPL instead

* add changelog entry

* fix import

* tidy occurrences of varname_leaves as well (#1031)

* `InitContext`, part 4 - Use `init!!` to replace `evaluate_and_sample!!`, `predict`, `returned`, and `initialize_values` (#984)

* Replace `evaluate_and_sample!!` -> `init!!`

* Use `ParamsInit` for `predict`; remove `setval_and_resample!` and friends

* Use `init!!` for initialisation

* Paper over the `Sampling->Init` context stack (pending removal of SamplingContext)

* Remove SamplingContext from JETExt to avoid triggering `Sampling->Init` pathway

* Remove `predict` on vector of VarInfo

* Fix some tests

* Remove duplicated test

* Simplify context testing

* Rename FooInit -> InitFromFoo

* Fix JETExt

* Fix JETExt properly

* Fix tests

* Improve comments

* Remove duplicated tests

* Docstring improvements

Co-authored-by: Markus Hauru <[email protected]>

* Concretise `chain_sample_to_varname_dict` using chain value type

* Clarify testset name

* Re-add comment that shouldn't have vanished

* Fix stale Requires dep

* Fix default_varinfo/initialisation for odd models

* Add comment to src/sampler.jl

Co-authored-by: Markus Hauru <[email protected]>

---------

Co-authored-by: Markus Hauru <[email protected]>

* `InitContext`, part 5 - Remove `SamplingContext`, `SampleFrom{Prior,Uniform}`, `{tilde_,}assume` (#985)

* Remove `SamplingContext` for good

* Remove `tilde_assume` as well

* Split up tilde_observe!! for Distribution / Submodel

* Tidy up tilde-pipeline methods and docstrings

* Fix tests

* fix ambiguity

* Add changelog

* Update HISTORY.md

Co-authored-by: Markus Hauru <[email protected]>

---------

Co-authored-by: Markus Hauru <[email protected]>

* fix missing import

* Shuffle context code around and remove dead code (#1050)

* Delete the `"del"` flag (#1058)

* Delete del

* Fix a typo

* Add HISTORY entry about del

* Fixes for Turing 0.41 (#1057)

* setleafcontext(model, ctx) and various other fixes

* fix a bug

* Add warning for `initial_parameters=...`

* Remove `resume_from` and `default_chain_type` (#1061)

* Remove resume_from

* Format

* Fix test

* remove initial_params warning

* Allow more flexible `initial_params` (#1064)

* Enable NamedTuple/Dict initialisation

* Add more tests

* fix include_all kwarg for predict, improve perf (#1068)

* Fix `include_all` for predict

* Fix include_all for predict, some perf improvements

* Replace `Metadata.flags` with `Metadata.trans` (#1060)

* Replace Medata.flags with Metadata.trans

* Fix a bug

* Fix a typo

* Fix two bugs

* Rename trans to is_transformed

* Rename islinked to is_transformed, remove duplication

* Change pointwise_logdensities default key type to VarName (#1071)

* Change pointwise_logdensities default key type to VarName

* Fix a doctest

* Fix DynamicPPL / MCMCChains methods (#1076)

* Reimplement pointwise_logdensities (almost completely)

* Move logjoint, logprior, ... as well

* Fix imports, etc

* Remove tests that are failing (yes I learnt this from Claude)

* Changelog

* logpdf

* fix docstrings

* allow dict output

* changelog

* fix some comments

* fix tests

* Fix more imports

* Remove stray n

Co-authored-by: Markus Hauru <[email protected]>

* Expand `logprior`, `loglikelihood`, and `logjoint` docstrings

---------

Co-authored-by: Markus Hauru <[email protected]>

* Remove `Sampler` and its interface (#1037)

* Remove `Sampler` and `initialstep`

* Actually just remove the entire file

* forgot one function

* Move sampling test utils to Turing as well

* Update changelog to correctly reflect changes

* [skip ci] Make changelog headings more consistent

---------

Co-authored-by: Markus Hauru <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants