Skip to content

The big InitContext PR. Do not merge! #967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,6 @@ Requires = "1"
Statistics = "1"
Test = "1.6"
julia = "1.10.8"

[sources]
AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "py/hasgetvalue"}
28 changes: 20 additions & 8 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -466,23 +466,35 @@ The behaviour of a model execution can be changed with evaluation contexts, whic
Contexts are subtypes of `AbstractPPL.AbstractContext`.

```@docs
SamplingContext
DefaultContext
PrefixContext
ConditionContext
InitContext
```

### Samplers
### VarInfo initialisation

`InitContext` is used to initialise, or overwrite, values in a VarInfo.

To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained.
There are three concrete strategies provided in DynamicPPL:

```@docs
PriorInit
UniformInit
ParamsInit
```

In DynamicPPL two samplers are defined that are used to initialize unobserved random variables:
[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution.
If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method.

```@docs
SampleFromPrior
SampleFromUniform
DynamicPPL.AbstractInitStrategy
DynamicPPL.init
```

Additionally, a generic sampler for inference is implemented.
### Samplers

In DynamicPPL a generic sampler for inference is implemented.

```@docs
Sampler
Expand All @@ -493,7 +505,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu
```@docs
DynamicPPL.initialstep
DynamicPPL.loadstate
DynamicPPL.initialsampler
DynamicPPL.init_strategy
```

Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.
Expand Down
2 changes: 0 additions & 2 deletions ext/DynamicPPLEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ else
using ..EnzymeCore
end

@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true

# Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme
# only checks whether such a method exists, and never runs it.
@inline EnzymeCore.EnzymeRules.inactive_noinl(::typeof(DynamicPPL.istrans), args...) =
Expand Down
15 changes: 5 additions & 10 deletions ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,17 @@ end
function DynamicPPL.Experimental._determine_varinfo_jet(
model::DynamicPPL.Model; only_ddpl::Bool=true
)
# Use SamplingContext to test type stability.
sampling_model = DynamicPPL.contextualize(
model, DynamicPPL.SamplingContext(model.context)
)

# First we try with the typed varinfo.
varinfo = DynamicPPL.typed_varinfo(sampling_model)
varinfo = DynamicPPL.typed_varinfo(model)

# Let's make sure that both evaluation and sampling doesn't result in type errors.
# Let's make sure that evaluation doesn't result in type errors.
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
sampling_model, varinfo; only_ddpl
model, varinfo; only_ddpl
)

if !issuccess
# Useful information for debugging.
@debug "Evaluaton with typed varinfo failed with the following issues:"
@debug "Evaluation with typed varinfo failed with the following issues:"
@debug result
end

Expand All @@ -46,7 +41,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
else
# Warn the user that we can't use the type stable one.
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
DynamicPPL.untyped_varinfo(sampling_model)
DynamicPPL.untyped_varinfo(model)
end
end

Expand Down
38 changes: 27 additions & 11 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ end

function _check_varname_indexing(c::MCMCChains.Chains)
return DynamicPPL.supports_varname_indexing(c) ||
error("Chains do not support indexing using `VarName`s.")
error("This `Chains` object does not support indexing using `VarName`s.")
end

function DynamicPPL.getindex_varname(
Expand All @@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
return keys(c.info.varname_to_symbol)
end

function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx)
_check_varname_indexing(c)
d = Dict{DynamicPPL.VarName,Any}()
for vn in DynamicPPL.varnames(c)
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
end
return d
end

"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)

Expand Down Expand Up @@ -114,9 +123,15 @@ function DynamicPPL.predict(

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
predictive_samples = map(iters) do (sample_idx, chain_idx)
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo))

# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`
_, varinfo = DynamicPPL.init!!(
rng,
model,
varinfo,
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
)
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
varname_vals = mapreduce(
collect,
Expand Down Expand Up @@ -248,13 +263,14 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
# Update the varinfo with the current sample and make variables not present in `chain`
# to be sampled.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to the `model`.
model(deepcopy(varinfo))
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`, and
# return the model's retval.
retval, _ = DynamicPPL.init!!(
model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit())
)
retval
end
end

Expand Down
14 changes: 9 additions & 5 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using DocStringExtensions
using Random: Random

# For extending
import AbstractPPL: predict
import AbstractPPL: predict, hasvalue, getvalue

# TODO: Remove these when it's possible.
import Bijectors: link, invlink
Expand Down Expand Up @@ -97,18 +97,21 @@ export AbstractVarInfo,
values_as_in_model,
# Samplers
Sampler,
SampleFromPrior,
SampleFromUniform,
# LogDensityFunction
LogDensityFunction,
# Contexts
contextualize,
SamplingContext,
DefaultContext,
PrefixContext,
ConditionContext,
assume,
tilde_assume,
# Initialisation
InitContext,
AbstractInitStrategy,
PriorInit,
UniformInit,
ParamsInit,
# Pseudo distributions
NamedDist,
NoDist,
Expand Down Expand Up @@ -170,11 +173,12 @@ abstract type AbstractVarInfo <: AbstractModelTrace end
# Necessary forward declarations
include("utils.jl")
include("chains.jl")
include("contexts.jl")
include("contexts/init.jl")
include("model.jl")
include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("submodel.jl")
include("varnamedvector.jl")
include("accumulators.jl")
Expand Down
107 changes: 5 additions & 102 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,14 @@
# assume
"""
tilde_assume(context::SamplingContext, right, vn, vi)

Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
accumulate the log probability, and return the sampled value with a context associated
with a sampler.

Falls back to
```julia
tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
```
"""
function tilde_assume(context::SamplingContext, right, vn, vi)
return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
end

function tilde_assume(context::AbstractContext, args...)
return tilde_assume(childcontext(context), args...)
end
function tilde_assume(::DefaultContext, right, vn, vi)
return assume(right, vn, vi)
end

function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...)
return tilde_assume(rng, childcontext(context), args...)
end
function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi)
return assume(rng, sampler, right, vn, vi)
end
function tilde_assume(::DefaultContext, sampler, right, vn, vi)
# same as above but no rng
return assume(Random.default_rng(), sampler, right, vn, vi)
y = getindex_internal(vi, vn)
f = from_maybe_linked_internal_transform(vi, vn, right)
x, logjac = with_logabsdet_jacobian(f, y)
vi = accumulate_assume!!(vi, x, logjac, vn, right)
return x, vi
end

function tilde_assume(context::PrefixContext, right, vn, vi)
# Note that we can't use something like this here:
# new_vn = prefix(context, vn)
Expand All @@ -46,12 +22,6 @@ function tilde_assume(context::PrefixContext, right, vn, vi)
new_vn, new_context = prefix_and_strip_contexts(context, vn)
return tilde_assume(new_context, right, new_vn, vi)
end
function tilde_assume(
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi
)
new_vn, new_context = prefix_and_strip_contexts(context, vn)
return tilde_assume(rng, new_context, sampler, right, new_vn, vi)
end

"""
tilde_assume!!(context, right, vn, vi)
Expand All @@ -71,17 +41,6 @@ function tilde_assume!!(context, right, vn, vi)
end

# observe
"""
tilde_observe!!(context::SamplingContext, right, left, vi)

Handle observed constants with a `context` associated with a sampler.

Falls back to `tilde_observe!!(context.context, right, left, vi)`.
"""
function tilde_observe!!(context::SamplingContext, right, left, vn, vi)
return tilde_observe!!(context.context, right, left, vn, vi)
end

function tilde_observe!!(context::AbstractContext, right, left, vn, vi)
return tilde_observe!!(childcontext(context), right, left, vn, vi)
end
Expand Down Expand Up @@ -114,59 +73,3 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi)
vi = accumulate_observe!!(vi, right, left, vn)
return left, vi
end

function assume(::Random.AbstractRNG, spl::Sampler, dist)
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
end

# fallback without sampler
function assume(dist::Distribution, vn::VarName, vi)
y = getindex_internal(vi, vn)
f = from_maybe_linked_internal_transform(vi, vn, dist)
x, logjac = with_logabsdet_jacobian(f, y)
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
return x, vi
end

# TODO: Remove this thing.
# SampleFromPrior and SampleFromUniform
function assume(
rng::Random.AbstractRNG,
sampler::Union{SampleFromPrior,SampleFromUniform},
dist::Distribution,
vn::VarName,
vi::VarInfoOrThreadSafeVarInfo,
)
if haskey(vi, vn)
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
# if that's okay.
unset_flag!(vi, vn, "del", true)
r = init(rng, dist, sampler)
f = to_maybe_linked_internal_transform(vi, vn, dist)
# TODO(mhauru) This should probably be call a function called setindex_internal!
vi = BangBang.setindex!!(vi, f(r), vn)
setorder!(vi, vn, get_num_produce(vi))
else
# Otherwise we just extract it.
r = vi[vn, dist]
end
else
r = init(rng, dist, sampler)
if istrans(vi)
f = to_linked_internal_transform(vi, vn, dist)
vi = push!!(vi, vn, f(r), dist)
# By default `push!!` sets the transformed flag to `false`.
vi = settrans!!(vi, true, vn)
else
vi = push!!(vi, vn, r, dist)
end
end

# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
vi = accumulate_assume!!(vi, r, -logjac, vn, dist)
return r, vi
end
Loading