From 54446c17a50c948a63877de64170a64307a9aa92 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 23:20:10 +0100 Subject: [PATCH 1/7] Implement InitContext --- src/DynamicPPL.jl | 7 ++ src/contexts/init.jl | 180 +++++++++++++++++++++++++++++++++++++++++++ src/model.jl | 33 ++++++++ test/contexts.jl | 12 ++- 4 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 src/contexts/init.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 4c2f0bd00..6b20899d9 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -109,6 +109,12 @@ export AbstractVarInfo, ConditionContext, assume, tilde_assume, + # Initialisation + InitContext, + AbstractInitStrategy, + PriorInit, + UniformInit, + ParamsInit, # Pseudo distributions NamedDist, NoDist, @@ -175,6 +181,7 @@ include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") +include("contexts/init.jl") include("submodel.jl") include("varnamedvector.jl") include("accumulators.jl") diff --git a/src/contexts/init.jl b/src/contexts/init.jl new file mode 100644 index 000000000..580b1a666 --- /dev/null +++ b/src/contexts/init.jl @@ -0,0 +1,180 @@ +""" + AbstractInitStrategy + +Abstract type representing the possible ways of initialising new values for +the random variables in a model (e.g., when creating a new VarInfo). +""" +abstract type AbstractInitStrategy end + +""" + init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy) + +Generate a new value for a random variable with the given distribution. + +!!! warning "Values must be unlinked" + The values returned by `init` are always in the untransformed space, i.e., + they must be within the support of the original distribution. That means that, + for example, `init(rng, dist, u::UniformInit)` will in general return values that + are outside the range [u.lower, u.upper]. +""" +function init end + +""" + PriorInit() + +Obtain new values by sampling from the prior distribution. +""" +struct PriorInit <: AbstractInitStrategy end +init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::PriorInit) = rand(rng, dist) + +""" + UniformInit() + UniformInit(lower, upper) + +Obtain new values by first transforming the distribution of the random variable +to unconstrained space, and then sampling a value uniformly between `lower` and +`upper`. + +If unspecified, defaults to `(lower, upper) = (-2, 2)`, which mimics Stan's +default initialisation strategy. + +# References + +[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization) +""" +struct UniformInit{T<:AbstractFloat} <: AbstractInitStrategy + lower::T + upper::T + function UniformInit(lower::T, upper::T) where {T<:AbstractFloat} + lower > upper && + throw(ArgumentError("`lower` must be less than or equal to `upper`")) + return new{T}(lower, upper) + end + UniformInit() = UniformInit(-2.0, 2.0) +end +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::UniformInit) + b = Bijectors.bijector(dist) + sz = Bijectors.output_size(b, size(dist)) + y = rand(rng, Uniform(u.lower, u.upper), sz) + b_inv = Bijectors.inverse(b) + x = b_inv(y) + # 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398 + if x isa Array{<:Any,0} + x = x[] + end + return x +end + +""" + ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy=PriorInit()) + ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit()) + +Obtain new values by extracting them from the given dictionary or NamedTuple. +The parameter `default` specifies how new values are to be obtained if they +cannot be found in `params`, or they are specified as `missing`. The default +for `default` is `PriorInit()`. + +!!! note + These values must be provided in the space of the untransformed distribution. +""" +struct ParamsInit{P,S<:AbstractInitStrategy} <: AbstractInitStrategy + params::P + default::S + function ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy) + return new{typeof(params),typeof(default)}(params, default) + end + ParamsInit(params::AbstractDict{<:VarName}) = ParamsInit(params, PriorInit()) + function ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit()) + return ParamsInit(to_varname_dict(params), default) + end +end +function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit) + # TODO(penelopeysm): We should do a check to make sure that all of the + # parameters in `p.params` were actually used, and either warn or error if + # they aren't. This is non-trivial (we need to use something like + # varname_leaves), so I'm going to defer it to a later PR. + return if hasvalue(p.params, vn, dist) + x = getvalue(p.params, vn, dist) + if x === missing + init(rng, vn, dist, p.default) + else + # TODO(penelopeysm): We could also check that the type of x matches + # the dist? + x + end + else + init(rng, vn, dist, p.default) + end +end + +""" + InitContext( + [rng::Random.AbstractRNG=Random.default_rng()], + [strategy::AbstractInitStrategy=PriorInit()], + ) + +A leaf context that indicates that new values for random variables are +currently being obtained through sampling. Used e.g. when initialising a fresh +VarInfo. Note that, if `leafcontext(model.context) isa InitContext`, then +`evaluate!!(model, varinfo)` will override all values in the VarInfo. +""" +struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractContext + rng::R + strategy::S + function InitContext( + rng::Random.AbstractRNG, strategy::AbstractInitStrategy=PriorInit() + ) + return new{typeof(rng),typeof(strategy)}(rng, strategy) + end + function InitContext(strategy::AbstractInitStrategy=PriorInit()) + return InitContext(Random.default_rng(), strategy) + end +end +NodeTrait(::InitContext) = IsLeaf() + +function tilde_assume( + ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo +) + in_varinfo = haskey(vi, vn) + # `init()` always returns values in original space, i.e. possibly + # constrained + x = init(ctx.rng, vn, dist, ctx.strategy) + # Determine whether to insert a transformed value into the VarInfo. + # If the VarInfo alrady had a value for this variable, we will + # keep the same linked status as in the original VarInfo. If not, we + # check the rest of the VarInfo to see if other variables are linked. + # istrans(vi) returns true if vi is nonempty and all variables in vi + # are linked. + insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi) + f = if insert_transformed_value + to_linked_internal_transform(vi, vn, dist) + else + to_internal_transform(vi, vn, dist) + end + # TODO(penelopeysm): We would really like to do: + # y, logjac = with_logabsdet_jacobian(f, x) + # Unfortunately, `to_{linked_}internal_transform` returns a function that + # always converts x to a vector, i.e., if dist is univariate, f(x) will be + # a vector of length 1. It would be nice if we could unify these. + y = f(x) + logjac = logabsdetjac(insert_transformed_value ? link_transform(dist) : identity, x) + # Add the new value to the VarInfo. `push!!` errors if the value already + # exists, hence the need for setindex!!. + if in_varinfo + vi = setindex!!(vi, y, vn) + else + vi = push!!(vi, vn, y, dist) + end + # Neither of these set the `trans` flag so we have to do it manually if + # necessary. + insert_transformed_value && settrans!!(vi, true, vn) + # `accumulate_assume!!` wants untransformed values as the second argument. + vi = accumulate_assume!!(vi, x, -logjac, vn, dist) + # We always return the untransformed value here, as that will determine + # what the lhs of the tilde-statement is set to. + return x, vi +end + +function tilde_observe!!(::InitContext, right, left, vn, vi) + return tilde_observe!!(DefaultContext(), right, left, vn, vi) +end diff --git a/src/model.jl b/src/model.jl index 72a7ac294..f14744bf2 100644 --- a/src/model.jl +++ b/src/model.jl @@ -854,6 +854,39 @@ function evaluate_and_sample!!( return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) end +""" + init!!( + [rng::Random.AbstractRNG,] + model::Model, + varinfo::AbstractVarInfo, + [init_strategy::AbstractInitStrategy=PriorInit()] + ) + +Evaluate the `model` and replace the values of the model's random variables in +the given `varinfo` with new values using a specified initialisation strategy. +If the values in `varinfo` are not already present, they will be added using +that same strategy. + +If `init_strategy` is not provided, defaults to PriorInit(). + +Returns a tuple of the model's return value, plus the updated `varinfo` object. +""" +function init!!( + rng::Random.AbstractRNG, + model::Model, + varinfo::AbstractVarInfo, + init_strategy::AbstractInitStrategy=PriorInit(), +) + new_context = setleafcontext(model.context, InitContext(rng, init_strategy)) + new_model = contextualize(model, new_context) + return evaluate!!(new_model, varinfo) +end +function init!!( + model::Model, varinfo::AbstractVarInfo, init_strategy::AbstractInitStrategy=PriorInit() +) + return init!!(Random.default_rng(), model, varinfo, init_strategy) +end + """ evaluate!!(model::Model, varinfo) diff --git a/test/contexts.jl b/test/contexts.jl index 597ab736c..be976aad4 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,5 +1,5 @@ using Test, DynamicPPL, Accessors -using AbstractPPL: getoptic +using AbstractPPL: getoptic, hasvalue, getvalue using DynamicPPL: leafcontext, setleafcontext, @@ -431,4 +431,14 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test fixed(c6) == Dict(@varname(a.b.d) => 2) end end + + @testset "InitContext" begin + @testset "PriorInit" begin end + + @testset "UniformInit" begin end + + @testset "ParamsInit" begin end + + @testset "rng is respected (at least with PriorInit" begin end + end end From 1ef1a9285f62b9ca64b1ed8739516874aa1c8701 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 23:25:10 +0100 Subject: [PATCH 2/7] Fix loading order of modules; move `prefix(::Model)` to model.jl --- src/DynamicPPL.jl | 4 ++-- src/contexts.jl | 35 ----------------------------------- src/model.jl | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 37 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 6b20899d9..bb6af996e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -176,12 +176,12 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") include("chains.jl") +include("contexts.jl") +include("contexts/init.jl") include("model.jl") include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") -include("contexts.jl") -include("contexts/init.jl") include("submodel.jl") include("varnamedvector.jl") include("accumulators.jl") diff --git a/src/contexts.jl b/src/contexts.jl index addadfa1a..cd9876768 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -280,41 +280,6 @@ function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName return vn, setchildcontext(ctx, new_ctx) end -""" - prefix(model::Model, x::VarName) - prefix(model::Model, x::Val{sym}) - prefix(model::Model, x::Any) - -Return `model` but with all random variables prefixed by `x`, where `x` is either: -- a `VarName` (e.g. `@varname(a)`), -- a `Val{sym}` (e.g. `Val(:a)`), or -- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that - this will introduce runtime overheads so is not recommended unless absolutely - necessary. - -# Examples - -```jldoctest -julia> using DynamicPPL: prefix - -julia> @model demo() = x ~ Dirac(1) -demo (generic function with 2 methods) - -julia> rand(prefix(demo(), @varname(my_prefix))) -(var"my_prefix.x" = 1,) - -julia> rand(prefix(demo(), Val(:my_prefix))) -(var"my_prefix.x" = 1,) -``` -""" -prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) -function prefix(model::Model, x::Val{sym}) where {sym} - return contextualize(model, PrefixContext(VarName{sym}(), model.context)) -end -function prefix(model::Model, x) - return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) -end - """ ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext} diff --git a/src/model.jl b/src/model.jl index f14744bf2..6be4eb383 100644 --- a/src/model.jl +++ b/src/model.jl @@ -799,6 +799,41 @@ julia> # Now `a.x` will be sampled. """ fixed(model::Model) = fixed(model.context) +""" + prefix(model::Model, x::VarName) + prefix(model::Model, x::Val{sym}) + prefix(model::Model, x::Any) + +Return `model` but with all random variables prefixed by `x`, where `x` is either: +- a `VarName` (e.g. `@varname(a)`), +- a `Val{sym}` (e.g. `Val(:a)`), or +- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that + this will introduce runtime overheads so is not recommended unless absolutely + necessary. + +# Examples + +```jldoctest +julia> using DynamicPPL: prefix + +julia> @model demo() = x ~ Dirac(1) +demo (generic function with 2 methods) + +julia> rand(prefix(demo(), @varname(my_prefix))) +(var"my_prefix.x" = 1,) + +julia> rand(prefix(demo(), Val(:my_prefix))) +(var"my_prefix.x" = 1,) +``` +""" +prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) +function prefix(model::Model, x::Val{sym}) where {sym} + return contextualize(model, PrefixContext(VarName{sym}(), model.context)) +end +function prefix(model::Model, x) + return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) +end + """ (model::Model)([rng, varinfo]) From a90d95e8dadd76604a1ffe18155265ed5b8a2239 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 00:21:53 +0100 Subject: [PATCH 3/7] Add tests for InitContext behaviour --- src/contexts/init.jl | 12 +-- test/contexts.jl | 183 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 179 insertions(+), 16 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 580b1a666..6ff276d21 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -147,17 +147,11 @@ function tilde_assume( # are linked. insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi) f = if insert_transformed_value - to_linked_internal_transform(vi, vn, dist) + link_transform(dist) else - to_internal_transform(vi, vn, dist) + identity end - # TODO(penelopeysm): We would really like to do: - # y, logjac = with_logabsdet_jacobian(f, x) - # Unfortunately, `to_{linked_}internal_transform` returns a function that - # always converts x to a vector, i.e., if dist is univariate, f(x) will be - # a vector of length 1. It would be nice if we could unify these. - y = f(x) - logjac = logabsdetjac(insert_transformed_value ? link_transform(dist) : identity, x) + y, logjac = with_logabsdet_jacobian(f, x) # Add the new value to the VarInfo. `push!!` errors if the value already # exists, hence the need for setindex!!. if in_varinfo diff --git a/test/contexts.jl b/test/contexts.jl index be976aad4..5768757bb 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -20,8 +20,9 @@ using DynamicPPL: hasconditioned_nested, getconditioned_nested, collapse_prefix_stack, - prefix_cond_and_fixed_variables, - getvalue + prefix_cond_and_fixed_variables +using LinearAlgebra: I +using Random: Xoshiro using EnzymeCore @@ -103,7 +104,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # sometimes only the main symbol (e.g. it contains `x` when # `vn` is `x[1]`) for vn in conditioned_vns - val = DynamicPPL.getvalue(conditioned_values, vn) + val = getvalue(conditioned_values, vn) # These VarNames are present in the conditioning values, so # we should always be able to extract the value. @test hasconditioned_nested(context, vn) @@ -433,12 +434,180 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end @testset "InitContext" begin - @testset "PriorInit" begin end + empty_varinfos = [ + VarInfo(), + DynamicPPL.typed_varinfo(VarInfo()), + VarInfo(DynamicPPL.VarNamedVector()), + DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), + SimpleVarInfo(), + SimpleVarInfo(Dict{VarName,Any}()), + ] + + @model function test_init_model() + x ~ Normal() + y ~ MvNormal(fill(x, 2), I) + 1.0 ~ Normal() + return nothing + end + function test_generating_new_values(strategy::AbstractInitStrategy) + @testset "generating new values: $(typeof(strategy))" begin + # Check that init!! can generate values that weren't there + # previously. + model = test_init_model() + for empty_vi in empty_varinfos + this_vi = deepcopy(empty_vi) + _, vi = DynamicPPL.init!!(model, this_vi, strategy) + @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) + x, y = vi[@varname(x)], vi[@varname(y)] + @test x isa Real + @test y isa AbstractVector{<:Real} + @test length(y) == 2 + (; logprior, loglikelihood) = getlogp(vi) + @test logpdf(Normal(), x) + logpdf(MvNormal(fill(x, 2), I), y) == + logprior + @test logpdf(Normal(), 1.0) == loglikelihood + end + end + end + function test_replacing_values(strategy::AbstractInitStrategy) + @testset "replacing old values: $(typeof(strategy))" begin + # Check that init!! can overwrite values that were already there. + model = test_init_model() + for empty_vi in empty_varinfos + # start by generating some rubbish values + vi = deepcopy(empty_vi) + old_x, old_y = 100000.00, [300000.00, 500000.00] + push!!(vi, @varname(x), old_x, Normal()) + push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I)) + # then overwrite it + _, new_vi = DynamicPPL.init!!(model, vi, strategy) + new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] + # check that the values are (presumably) different + @test old_x != new_x + @test old_y != new_y + end + end + end + function test_rng_respected(strategy::AbstractInitStrategy) + @testset "check that RNG is respected: $(typeof(strategy))" begin + model = test_init_model() + for empty_vi in empty_varinfos + _, vi1 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi2 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi3 = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), strategy + ) + @test vi1[@varname(x)] == vi2[@varname(x)] + @test vi1[@varname(y)] == vi2[@varname(y)] + @test vi1[@varname(x)] != vi3[@varname(x)] + @test vi1[@varname(y)] != vi3[@varname(y)] + end + end + end - @testset "UniformInit" begin end + @testset "PriorInit" begin + test_generating_new_values(PriorInit()) + test_replacing_values(PriorInit()) + test_rng_respected(PriorInit()) + + @testset "check that values are within support" begin + # Not many other sensible checks we can do for priors. + @model just_unif() = x ~ Uniform(0.0, 1e-7) + for _ in 1:100 + _, vi = DynamicPPL.init!!(just_unif(), VarInfo(), PriorInit()) + @test vi[@varname(x)] isa Real + @test 0.0 <= vi[@varname(x)] <= 1e-7 + end + end + end - @testset "ParamsInit" begin end + @testset "UniformInit" begin + test_generating_new_values(UniformInit()) + test_replacing_values(UniformInit()) + test_rng_respected(UniformInit()) + + @testset "check that bounds are respected" begin + @testset "unconstrained" begin + umin, umax = -1.0, 1.0 + @model just_norm() = x ~ Normal() + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_norm(), VarInfo(), UniformInit(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test umin <= vi[@varname(x)] <= umax + end + end + @testset "constrained" begin + umin, umax = -1.0, 1.0 + @model just_beta() = x ~ Beta(2, 2) + inv_bijector = inverse(Bijectors.bijector(Beta(2, 2))) + tmin, tmax = inv_bijector(umin), inv_bijector(umax) + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_beta(), VarInfo(), UniformInit(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test tmin <= vi[@varname(x)] <= tmax + end + end + end + end - @testset "rng is respected (at least with PriorInit" begin end + @testset "ParamsInit" begin + @testset "given full set of parameters" begin + # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I) + my_x, my_y = 1.0, [2.0, 3.0] + params_nt = (; x=my_x, y=my_y) + params_dict = Dict(@varname(x) => my_x, @varname(y) => my_y) + model = test_init_model() + for empty_vi in empty_varinfos + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), ParamsInit(params_nt) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_nt = getlogp(vi) + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), ParamsInit(params_dict) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_dict = getlogp(vi) + @test logp_nt == logp_dict + end + end + + @testset "given only partial parameters" begin + # In this case, we expect `ParamsInit` to use the value of x, and + # generate a new value for y. + my_x = 1.0 + params_nt = (; x=my_x) + params_dict = Dict(@varname(x) => my_x) + model = test_init_model() + for empty_vi in empty_varinfos + _, vi = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), ParamsInit(params_nt) + ) + @test vi[@varname(x)] == my_x + nt_y = vi[@varname(y)] + @test nt_y isa AbstractVector{<:Real} + @test length(nt_y) == 2 + _, vi = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), ParamsInit(params_dict) + ) + @test vi[@varname(x)] == my_x + dict_y = vi[@varname(y)] + @test dict_y isa AbstractVector{<:Real} + @test length(dict_y) == 2 + # the values should be different since we used different seeds + @test dict_y != nt_y + end + end + end end end From 001a05aad080b18cc2d57187b3e0983a09f8b984 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 00:30:07 +0100 Subject: [PATCH 4/7] inline `rand(::Distributions.Uniform)` Note that, apart from being simpler code, Distributions.Uniform also doesn't allow the lower and upper bounds to be exactly equal (but we might like to keep that option open in DynamicPPL, e.g. if the user wants to initialise all values to the same value in linked space). --- src/contexts/init.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 6ff276d21..3b7007f51 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -38,6 +38,8 @@ to unconstrained space, and then sampling a value uniformly between `lower` and If unspecified, defaults to `(lower, upper) = (-2, 2)`, which mimics Stan's default initialisation strategy. +Requires that `lower <= upper`. + # References [Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization) @@ -55,7 +57,7 @@ end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::UniformInit) b = Bijectors.bijector(dist) sz = Bijectors.output_size(b, size(dist)) - y = rand(rng, Uniform(u.lower, u.upper), sz) + y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz...)) b_inv = Bijectors.inverse(b) x = b_inv(y) # 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398 From b55c1e17f97ae518d1d149122e1fb1055557183f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 00:46:43 +0100 Subject: [PATCH 5/7] Document --- docs/src/api.md | 21 +++++++++++++++++++++ src/contexts/init.jl | 20 ++++++++++---------- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index e918a095c..3d5c681cf 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -470,6 +470,27 @@ SamplingContext DefaultContext PrefixContext ConditionContext +InitContext +``` + +### VarInfo initialisation + +`InitContext` is used to initialise, or overwrite, values in a VarInfo. + +To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. +There are three concrete strategies provided in DynamicPPL: + +```@docs +PriorInit +UniformInit +ParamsInit +``` + +If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. + +```@docs +DynamicPPL.AbstractInitStrategy +DynamicPPL.init ``` ### Samplers diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 3b7007f51..2b87b533b 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -32,11 +32,11 @@ init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::PriorInit) = rand UniformInit(lower, upper) Obtain new values by first transforming the distribution of the random variable -to unconstrained space, and then sampling a value uniformly between `lower` and -`upper`. +to unconstrained space, then sampling a value uniformly between `lower` and +`upper`, and transforming that value back to the original space. -If unspecified, defaults to `(lower, upper) = (-2, 2)`, which mimics Stan's -default initialisation strategy. +If `lower` and `upper` are unspecified, they default to `(-2, 2)`, which mimics +Stan's default initialisation strategy. Requires that `lower <= upper`. @@ -91,17 +91,17 @@ struct ParamsInit{P,S<:AbstractInitStrategy} <: AbstractInitStrategy end end function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit) - # TODO(penelopeysm): We should do a check to make sure that all of the - # parameters in `p.params` were actually used, and either warn or error if - # they aren't. This is non-trivial (we need to use something like - # varname_leaves), so I'm going to defer it to a later PR. + # TODO(penelopeysm): It would be nice to do a check to make sure that all + # of the parameters in `p.params` were actually used, and either warn or + # error if they aren't. This is actually quite non-trivial though because + # the structure of Dicts in particular can have arbitrary nesting. return if hasvalue(p.params, vn, dist) x = getvalue(p.params, vn, dist) if x === missing init(rng, vn, dist, p.default) else - # TODO(penelopeysm): We could also check that the type of x matches - # the dist? + # TODO(penelopeysm): Since x is user-supplied, maybe we could also + # check here that the type / size of x matches the dist? x end else From 5da6d855024c70a0cfc09f27e4747973396fc353 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 19 Jul 2025 23:36:33 +0100 Subject: [PATCH 6/7] Add a test to check that `init!!` doesn't change linking --- test/contexts.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/contexts.jl b/test/contexts.jl index 5768757bb..3819c4564 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -508,11 +508,29 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end end + function test_link_status_respected(strategy::AbstractInitStrategy) + @testset "check that varinfo linking is preserved: $(typeof(strategy))" begin + @model logn() = a ~ LogNormal() + model = logn() + vi = VarInfo(model) + linked_vi = DynamicPPL.link!!(vi, model) + _, new_vi = DynamicPPL.init!!(model, linked_vi, strategy) + @test DynamicPPL.istrans(new_vi) + # this is the unlinked value, since it uses `getindex` + a = new_vi[@varname(a)] + # logp should correspond to the transformed value + @test isapprox(DynamicPPL.getlogjoint(new_vi), logpdf(Normal(), log(a))) + @test isapprox( + only(DynamicPPL.getindex_internal(new_vi, @varname(a))), log(a) + ) + end + end @testset "PriorInit" begin test_generating_new_values(PriorInit()) test_replacing_values(PriorInit()) test_rng_respected(PriorInit()) + test_link_status_respected(PriorInit()) @testset "check that values are within support" begin # Not many other sensible checks we can do for priors. @@ -529,6 +547,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() test_generating_new_values(UniformInit()) test_replacing_values(UniformInit()) test_rng_respected(UniformInit()) + test_link_status_respected(UniformInit()) @testset "check that bounds are respected" begin @testset "unconstrained" begin @@ -559,6 +578,9 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end @testset "ParamsInit" begin + test_link_status_respected(ParamsInit((; a=1.0))) + test_link_status_respected(ParamsInit(Dict(@varname(a) => 1.0))) + @testset "given full set of parameters" begin # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I) my_x, my_y = 1.0, [2.0, 3.0] From 2f7eba8cdd7003708d1d316280aecbb8b3739f0c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 01:58:18 +0100 Subject: [PATCH 7/7] Fix `push!` for VarNamedVector This should have been changed in #940, but slipped through as the file wasn't listed as one of the changed files. --- src/varnamedvector.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 5de0874c9..8095f4475 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -766,9 +766,7 @@ function update_internal!( return nothing end -# TODO(mhauru) The num_produce argument is used by Particle Gibbs. -# Remove this method as soon as possible. -function BangBang.push!(vnv::VarNamedVector, vn, val, dist, num_produce) +function BangBang.push!(vnv::VarNamedVector, vn, val, dist) f = from_vec_transform(dist) return setindex_internal!(vnv, tovec(val), vn, f) end