Skip to content

InitContext, part 4 - Use init!! to replace evaluate_and_sample!!, predict, returned, and initialize_values #984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: py/init-prior-uniform
Choose a base branch
from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Jul 10, 2025

Part 1: Adding hasvalue and getvalue to AbstractPPL
Part 2: Removing hasvalue and getvalue from DynamicPPL
Part 3: Introducing InitContext and init!!

This is part 4/N of #967.


In Part 3 we introduced InitContext. This PR makes use of the functionality in there to replace a bunch of code that no longer needs to exist:

  • setval_and_resample! followed by model evaluation: This process was used for predict and returned, to manually store certain values in the VarInfo, which would be used in the subsequent model evaluation. We can now do this in a single step using ParamsInit.
  • initialize_values!!: very similar to the above. It would manually set values inside the varinfo, and then it would trigger an extra model evaluation to update the logp field. Again, this is directly replaced with ParamsInit.
  • evaluate_and_sample!!: direct one-to-one replacement with init!!.

There is one API change associated with this: the initial_params kwarg to sample must now be an AbstractInitStrategy. It's still optional (it will usually default to PriorInit). However, there are two implications:

  • initial_params cannot be a vector of parameters anymore. It must be ParamsInit(::NamedTuple) OR ParamsInit(::AbstractDict{VarName}).
  • Because ParamsInit expects values in unlinked space, initial_params must always be specified in unlinked space. Previously, initial_params would have to be specified in a way that matched the linking status of the underlying varinfo.

I consider both of these to be a major win for clarity. (One might argue that vectors are more convenient. Sure, you can get a vector with vi[:], but now you can just do values_as(vi, Dict{VarName,Any}) instead.)

Test failures

The remaining test failures in CI relate to an unfortunate interaction where SamplingContext can contain InitContext. This horrible monstrosity will be removed in the next PR.

Closes

Closes #774
Closes #797
Closes #983
Closes TuringLang/Turing.jl#2476
Closes TuringLang/Turing.jl#1775

Copy link
Contributor

github-actions bot commented Jul 10, 2025

Benchmark Report for Commit 7a8e7e3

Computer Information

Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                  8.9 |                 1.5 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                627.0 |                38.3 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                394.9 |                50.3 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |                980.1 |                32.4 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               6111.7 |                27.4 |
|           Smorgasbord |       201 | reversediff |             typed |   true |                997.2 |                38.9 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |                970.4 |                 4.2 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5545.5 |                 3.9 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |                931.7 |                 9.0 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              62422.5 |                 3.5 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               8054.8 |                10.0 |
|               Dynamic |        10 |    mooncake |             typed |   true |                124.4 |                12.9 |
|              Submodel |         1 |    mooncake |             typed |   true |                 12.4 |                 6.2 |
|                   LDA |        12 | reversediff |             typed |   true |               1276.6 |                 1.7 |

Comment on lines 126 to 134
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`
_, varinfo = last(
DynamicPPL.init!!(
rng,
model,
varinfo,
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
),
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that, if the chain does not store varnames inside its info field, chain_sample_to_varname_dict will fail.

I don't consider this to be a problem right now because every chain obtained via Turing's sample() will contain varnames:

https://github.com/TuringLang/Turing.jl/blob/1aa95ac91a115569c742bab74f7b751ed1450309/src/mcmc/Inference.jl#L288-L290

So this is only a problem if you manually construct a chain and try to call predict on it, which I think is a highly unlikely workflow (and I'm happy to wait for people to complain if it fails).

However, it's obviously ugly. The only good way around this is, as I suggested before, to rework MCMCChains.jl.

@penelopeysm penelopeysm changed the title Use init!! to replace evaluate_and_sample!!, predict, returned, and initialize_values InitContext, part 4 - Use init!! to replace evaluate_and_sample!!, predict, returned, and initialize_values Jul 10, 2025
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch 2 times, most recently from 025aa8b to b55c1e1 Compare July 10, 2025 14:24
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch 5 times, most recently from b72c3bf to 92d3542 Compare July 10, 2025 15:57
@penelopeysm penelopeysm mentioned this pull request Jul 10, 2025
20 tasks
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch 4 times, most recently from 7438b23 to d55d378 Compare July 10, 2025 16:56
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch from d55d378 to a392451 Compare July 10, 2025 17:33
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch from a392451 to 12d93e5 Compare July 10, 2025 17:44
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch from 12d93e5 to 7a8e7e3 Compare July 10, 2025 17:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant