diff --git a/Project.toml b/Project.toml index 201c4e7370..0f232dc90d 100644 --- a/Project.toml +++ b/Project.toml @@ -54,7 +54,7 @@ AbstractPPL = "0.11, 0.12" Accessors = "0.1" AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7, 0.8" AdvancedMH = "0.8" -AdvancedPS = "0.6.0" +AdvancedPS = "0.7" AdvancedVI = "0.4" BangBang = "0.4.2" Bijectors = "0.14, 0.15" @@ -67,7 +67,7 @@ DynamicHMC = "3.4" DynamicPPL = "0.36.3" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" -Libtask = "0.8.8" +Libtask = "0.9.3" LinearAlgebra = "1" LogDensityProblems = "2" MCMCChains = "5, 6, 7" @@ -85,7 +85,7 @@ Statistics = "1.6" StatsAPI = "1.6" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" -julia = "1.10.2" +julia = "1.10.8" [extras] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index ffc1019519..a81f436c87 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -25,9 +25,8 @@ function TracedModel( "Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.", ) end - return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( - model, sampler, varinfo, (model.f, args...) - ) + evaluator = (model.f, args...) + return TracedModel(model, sampler, varinfo, evaluator) end function AdvancedPS.advance!( @@ -59,20 +58,10 @@ function AdvancedPS.reset_logprob!(trace::TracedModel) return trace end -function AdvancedPS.update_rng!( - trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}} -) - # Extract the `args`. - args = trace.model.ctask.args - # From `args`, extract the `SamplingContext`, which contains the RNG. - sampling_context = args[3] - rng = sampling_context.rng - trace.rng = rng - return trace -end - -function Libtask.TapedTask(model::TracedModel, ::Random.AbstractRNG, args...; kwargs...) # RNG ? - return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...) +function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...) + return Libtask.TapedTask( + taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs... + ) end abstract type ParticleInference <: InferenceAlgorithm end @@ -402,11 +391,11 @@ end function trace_local_varinfo_maybe(varinfo) try - trace = AdvancedPS.current_trace() - return trace.model.f.varinfo + trace = Libtask.get_taped_globals(Any).other + return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo catch e # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. - if e == KeyError(:__trace) || current_task().storage isa Nothing + if e == KeyError(:task_variable) return varinfo else rethrow(e) @@ -416,11 +405,10 @@ end function trace_local_rng_maybe(rng::Random.AbstractRNG) try - trace = AdvancedPS.current_trace() - return trace.rng + return Libtask.get_taped_globals(Any).rng catch e # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. - if e == KeyError(:__trace) || current_task().storage isa Nothing + if e == KeyError(:task_variable) return rng else rethrow(e) @@ -481,6 +469,25 @@ function AdvancedPS.Trace( tmodel = TracedModel(model, sampler, newvarinfo, rng) newtrace = AdvancedPS.Trace(tmodel, rng) - AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace) return newtrace end + +# We need to tell Libtask which calls may have `produce` calls within them. In practice most +# of these won't be needed, because of inlining and the fact that `might_produce` is only +# called on `:invoke` expressions rather than `:call`s, but since those are implementation +# details of the compiler, we set a bunch of methods as might_produce = true. We start with +# `acclogp_observe!!` which is what calls `produce` and go up the call stack. +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true +function Libtask.might_produce( + ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}} +) + return true +end +function Libtask.might_produce( + ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}} +) + return true +end +Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true diff --git a/test/Project.toml b/test/Project.toml index ee964817c5..bb49654a10 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ ADTypes = "1" AbstractMCMC = "5" AbstractPPL = "0.11, 0.12" AdvancedMH = "0.6, 0.7, 0.8" -AdvancedPS = "=0.6.0" +AdvancedPS = "0.7" AdvancedVI = "0.4" Aqua = "0.8" BangBang = "0.4" diff --git a/test/essential/container.jl b/test/essential/container.jl index 1cb790d5ae..cbd7a6fe2b 100644 --- a/test/essential/container.jl +++ b/test/essential/container.jl @@ -23,8 +23,8 @@ using Turing model = test() trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG()) - # Make sure we link the traces - @test haskey(trace.model.ctask.task.storage, :__trace) + # Make sure the backreference from taped_globals to the trace is in place. + @test trace.model.ctask.taped_globals.other === trace res = AdvancedPS.advance!(trace, false) @test DynamicPPL.get_num_produce(trace.model.f.varinfo) == 1 diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index a0d4421869..cf528ce517 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -44,13 +44,13 @@ using Turing end # Should also be stable with an explicit RNG - seed = 5 - rng = Random.MersenneTwister(seed) + local_seed = 5 + rng = Random.MersenneTwister(local_seed) for sampler in samplers - Random.seed!(rng, seed) + Random.seed!(rng, local_seed) chain1 = sample(rng, model, sampler, MCMCThreads(), 10, 4) - Random.seed!(rng, seed) + Random.seed!(rng, local_seed) chain2 = sample(rng, model, sampler, MCMCThreads(), 10, 4) @test chain1.value == chain2.value @@ -256,9 +256,9 @@ using Turing pg = PG(10) gibbs = Gibbs(:p => HMC(0.2, 3), :x => PG(10)) - chn_s = sample(StableRNG(seed), testbb(obs), smc, 200) - chn_p = sample(StableRNG(seed), testbb(obs), pg, 200) - chn_g = sample(StableRNG(seed), testbb(obs), gibbs, 200) + chn_s = sample(StableRNG(seed), testbb(obs), smc, 2_000) + chn_p = sample(StableRNG(seed), testbb(obs), pg, 2_000) + chn_g = sample(StableRNG(seed), testbb(obs), gibbs, 2_000) check_numerical(chn_s, [:p], [meanp]; atol=0.05) check_numerical(chn_p, [:x], [meanp]; atol=0.1) @@ -647,7 +647,7 @@ using Turing @model function e(x=1.0) return x ~ Normal() end - # Can't test with HMC/NUTS because some AD backends error; see + # Can't test with HMC/NUTS because some AD backends error; see # https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/802 @test sample(e(), IS(), 100) isa MCMCChains.Chains end diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index 4afeeb9852..e918b3a512 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -64,7 +64,7 @@ using Turing @varname(mu1) => ESS(), @varname(mu2) => ESS(), ) - chain = sample(StableRNG(seed), MoGtest_default, alg, 2000) + chain = sample(StableRNG(seed), MoGtest_default, alg, 5000) check_MoGtest_default(chain; atol=0.1) end diff --git a/test/mcmc/particle_mcmc.jl b/test/mcmc/particle_mcmc.jl index 699ee68547..7a2f5fe1c7 100644 --- a/test/mcmc/particle_mcmc.jl +++ b/test/mcmc/particle_mcmc.jl @@ -34,6 +34,7 @@ using Turing tested = sample(normal(), SMC(), 100) + # TODO(mhauru) This needs an explanation for why it fails. # failing test @model function fail_smc() a ~ Normal(4, 5)