Skip to content

InitContext, part 4 - Use init!! to replace evaluate_and_sample!!, predict, returned, and initialize_values #984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: py/init-prior-uniform
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -456,11 +456,6 @@ AbstractPPL.evaluate!!

This method mutates the `varinfo` used for execution.
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method:

```@docs
DynamicPPL.evaluate_and_sample!!
```

The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
Contexts are subtypes of `AbstractPPL.AbstractContext`.
Expand All @@ -475,7 +470,12 @@ InitContext

### VarInfo initialisation

`InitContext` is used to initialise, or overwrite, values in a VarInfo.
The function `init!!` is used to initialise, or overwrite, values in a VarInfo.
It is really a thin wrapper around using `evaluate!!` with an `InitContext`.

```@docs
DynamicPPL.init!!
```

To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained.
There are three concrete strategies provided in DynamicPPL:
Expand Down Expand Up @@ -514,7 +514,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu
```@docs
DynamicPPL.initialstep
DynamicPPL.loadstate
DynamicPPL.initialsampler
DynamicPPL.init_strategy
```

Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.
Expand Down
15 changes: 5 additions & 10 deletions ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,17 @@ end
function DynamicPPL.Experimental._determine_varinfo_jet(
model::DynamicPPL.Model; only_ddpl::Bool=true
)
# Use SamplingContext to test type stability.
sampling_model = DynamicPPL.contextualize(
model, DynamicPPL.SamplingContext(model.context)
)

# First we try with the typed varinfo.
varinfo = DynamicPPL.typed_varinfo(sampling_model)
varinfo = DynamicPPL.typed_varinfo(model)

# Let's make sure that both evaluation and sampling doesn't result in type errors.
# Let's make sure that evaluation doesn't result in type errors.
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
sampling_model, varinfo; only_ddpl
model, varinfo; only_ddpl
)

if !issuccess
# Useful information for debugging.
@debug "Evaluaton with typed varinfo failed with the following issues:"
@debug "Evaluation with typed varinfo failed with the following issues:"
@debug result
end

Expand All @@ -46,7 +41,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
else
# Warn the user that we can't use the type stable one.
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
DynamicPPL.untyped_varinfo(sampling_model)
DynamicPPL.untyped_varinfo(model)
end
end

Expand Down
38 changes: 27 additions & 11 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ end

function _check_varname_indexing(c::MCMCChains.Chains)
return DynamicPPL.supports_varname_indexing(c) ||
error("Chains do not support indexing using `VarName`s.")
error("This `Chains` object does not support indexing using `VarName`s.")
end

function DynamicPPL.getindex_varname(
Expand All @@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
return keys(c.info.varname_to_symbol)
end

function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx)
_check_varname_indexing(c)
d = Dict{DynamicPPL.VarName,Any}()
for vn in DynamicPPL.varnames(c)
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
end
return d
end

"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)

Expand Down Expand Up @@ -114,9 +123,15 @@ function DynamicPPL.predict(

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
predictive_samples = map(iters) do (sample_idx, chain_idx)
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo))

# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`
_, varinfo = DynamicPPL.init!!(
rng,
model,
varinfo,
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
)
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
varname_vals = mapreduce(
collect,
Expand Down Expand Up @@ -248,13 +263,14 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
# Update the varinfo with the current sample and make variables not present in `chain`
# to be sampled.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to the `model`.
model(deepcopy(varinfo))
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`, and
# return the model's retval.
retval, _ = DynamicPPL.init!!(
model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit())
)
retval
end
end

Expand Down
7 changes: 7 additions & 0 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ end
function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi)
return assume(rng, sampler, right, vn, vi)
end
function tilde_assume(rng::Random.AbstractRNG, ::InitContext, sampler, right, vn, vi)
@warn(
"Encountered SamplingContext->InitContext. This method will be removed in the next PR.",
)
# just pretend the `InitContext` isn't there for now.
return assume(rng, sampler, right, vn, vi)
end
function tilde_assume(::DefaultContext, sampler, right, vn, vi)
# same as above but no rng
return assume(Random.default_rng(), sampler, right, vn, vi)
Expand Down
2 changes: 1 addition & 1 deletion src/extract_priors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ function extract_priors(rng::Random.AbstractRNG, model::Model)
# workaround for the fact that `order` is still hardcoded in VarInfo, and hence you
# can't push new variables without knowing the num_produce. Remove this when possible.
varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(), NumProduceAccumulator()))
varinfo = last(evaluate_and_sample!!(rng, model, varinfo))
varinfo = last(init!!(rng, model, varinfo))
return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors
end

Expand Down
58 changes: 17 additions & 41 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ end
# ^ Weird Documenter.jl bug means that we have to write the two above separately
# as it can only detect the `function`-less syntax.
function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo())
return first(evaluate_and_sample!!(rng, model, varinfo))
return first(init!!(rng, model, varinfo))
end

"""
Expand All @@ -863,46 +863,19 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
return Threads.nthreads() > 1
end

"""
evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler])

Evaluate the `model` with the given `varinfo`, but perform sampling during the
evaluation using the given `sampler` by wrapping the model's context in a
`SamplingContext`.

If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref).

Returns a tuple of the model's return value, plus the updated `varinfo` object.
"""
function evaluate_and_sample!!(
rng::Random.AbstractRNG,
model::Model,
varinfo::AbstractVarInfo,
sampler::AbstractSampler=SampleFromPrior(),
)
sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context))
return evaluate!!(sampling_model, varinfo)
end
function evaluate_and_sample!!(
model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior()
)
return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler)
end

"""
init!!(
[rng::Random.AbstractRNG,]
[rng::Random.AbstractRNG, ]
model::Model,
varinfo::AbstractVarInfo,
[init_strategy::AbstractInitStrategy=PriorInit()]
)

Evaluate the `model` and replace the values of the model's random variables in
the given `varinfo` with new values using a specified initialisation strategy.
If the values in `varinfo` are not already present, they will be added using
that same strategy.

If `init_strategy` is not provided, defaults to PriorInit().
Evaluate the `model` and replace the values of the model's random variables
in the given `varinfo` with new values, using a specified initialisation strategy.
If the values in `varinfo` are not set, they will be added.
using a specified initialisation strategy. If `init_strategy` is not provided,
defaults to PriorInit().

Returns a tuple of the model's return value, plus the updated `varinfo` object.
"""
Expand Down Expand Up @@ -1049,11 +1022,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f)
Generate a sample of type `T` from the prior distribution of the `model`.
"""
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
x = last(
evaluate_and_sample!!(
rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())
),
)
x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())))
return values_as(x, T)
end

Expand Down Expand Up @@ -1231,8 +1200,15 @@ function predict(
varinfo = DynamicPPL.VarInfo(model)
return map(chain) do params_varinfo
vi = deepcopy(varinfo)
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple))
model(rng, vi)
# TODO(penelopeysm): Requires two model evaluations, one to extract the
# parameters and one to set them. The reason why we need values_as_in_model
# is because `params_varinfo` may well have some weird combination of
# linked/unlinked, whereas `varinfo` is always unlinked since it is
# freshly constructed.
# This is quite inefficient. It would of course be alright if
# ValuesAsInModelAccumulator was a default acc.
values_nt = values_as_in_model(model, false, params_varinfo)
_, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_nt, PriorInit()))
return vi
end
end
Expand Down
Loading
Loading