From f687db008da9a5222c17b518d4cdb1151ff2fa61 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 6 Jun 2025 14:42:44 +0100 Subject: [PATCH 1/9] AdvancedPS v0.7 support, work in progress --- Project.toml | 4 ++-- src/mcmc/particle_mcmc.jl | 24 ++++++++++++------------ test/Project.toml | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 6fcb779e3..0cde5c600 100644 --- a/Project.toml +++ b/Project.toml @@ -52,7 +52,7 @@ AbstractMCMC = "5.5" Accessors = "0.1" AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7, 0.8" AdvancedMH = "0.8" -AdvancedPS = "0.6.0" +AdvancedPS = "0.7" AdvancedVI = "0.4" BangBang = "0.4.2" Bijectors = "0.14, 0.15" @@ -65,7 +65,7 @@ DynamicHMC = "3.4" DynamicPPL = "0.36.3" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" -Libtask = "0.8.8" +Libtask = "0.9.1" LinearAlgebra = "1" LogDensityProblems = "2" MCMCChains = "5, 6, 7" diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index ac5cd7648..2af01e08f 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -25,9 +25,8 @@ function TracedModel( "Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.", ) end - return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( - model, sampler, varinfo, (model.f, args...) - ) + evaluator = (model.f, args...) + return TracedModel(model, sampler, varinfo, evaluator) end function AdvancedPS.advance!( @@ -71,8 +70,10 @@ function AdvancedPS.update_rng!( return trace end -function Libtask.TapedTask(model::TracedModel, ::Random.AbstractRNG, args...; kwargs...) # RNG ? - return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...) +function Libtask.TapedTask(taped_globals, model::TracedModel, args...; kwargs...) # RNG ? + return Libtask.TapedTask( + taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs... + ) end abstract type ParticleInference <: InferenceAlgorithm end @@ -402,11 +403,11 @@ end function trace_local_varinfo_maybe(varinfo) try - trace = AdvancedPS.current_trace() - return trace.model.f.varinfo + trace = Libtask.get_taped_globals(Any).other + return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo catch e # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. - if e == KeyError(:__trace) || current_task().storage isa Nothing + if e == KeyError(:task_variable) return varinfo else rethrow(e) @@ -416,11 +417,10 @@ end function trace_local_rng_maybe(rng::Random.AbstractRNG) try - trace = AdvancedPS.current_trace() - return trace.rng + return Libtask.get_taped_globals(Any).rng catch e # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. - if e == KeyError(:__trace) || current_task().storage isa Nothing + if e == KeyError(:task_variable) return rng else rethrow(e) @@ -485,6 +485,6 @@ function AdvancedPS.Trace( tmodel = TracedModel(model, sampler, newvarinfo, rng) newtrace = AdvancedPS.Trace(tmodel, rng) - AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace) + AdvancedPS.addreference!(newtrace.model.ctask, newtrace) return newtrace end diff --git a/test/Project.toml b/test/Project.toml index 7cab77a01..303a5453e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ ADTypes = "1" AbstractMCMC = "5" AbstractPPL = "0.9, 0.10, 0.11" AdvancedMH = "0.6, 0.7, 0.8" -AdvancedPS = "=0.6.0" +AdvancedPS = "0.7" AdvancedVI = "0.4" Aqua = "0.8" BangBang = "0.4" From 2366bfa5d11017745f639eb4aefa71c586ff1337 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 19 Jun 2025 15:34:34 +0100 Subject: [PATCH 2/9] Fixing particle_mcmc.jl --- Project.toml | 2 +- src/mcmc/particle_mcmc.jl | 36 ++++++++++++++++++++++-------------- test/mcmc/particle_mcmc.jl | 1 + 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 0cde5c600..511797cf0 100644 --- a/Project.toml +++ b/Project.toml @@ -65,7 +65,7 @@ DynamicHMC = "3.4" DynamicPPL = "0.36.3" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" -Libtask = "0.9.1" +Libtask = "0.9.2" LinearAlgebra = "1" LogDensityProblems = "2" MCMCChains = "5, 6, 7" diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 2af01e08f..782595909 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -58,19 +58,7 @@ function AdvancedPS.reset_logprob!(trace::TracedModel) return trace end -function AdvancedPS.update_rng!( - trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}} -) - # Extract the `args`. - args = trace.model.ctask.args - # From `args`, extract the `SamplingContext`, which contains the RNG. - sampling_context = args[3] - rng = sampling_context.rng - trace.rng = rng - return trace -end - -function Libtask.TapedTask(taped_globals, model::TracedModel, args...; kwargs...) # RNG ? +function Libtask.TapedTask(taped_globals::Any, model::TracedModel, args...; kwargs...) # RNG ? return Libtask.TapedTask( taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs... ) @@ -485,6 +473,26 @@ function AdvancedPS.Trace( tmodel = TracedModel(model, sampler, newvarinfo, rng) newtrace = AdvancedPS.Trace(tmodel, rng) - AdvancedPS.addreference!(newtrace.model.ctask, newtrace) + AdvancedPS.addreference!(newtrace) return newtrace end + +# We need to tell Libtask which calls may have `produce` calls within them. In practice most +# of these won't be needed, because of inline and the fact that `might_produce` is only +# called on `:invoke` expressions rather than `:call`s, but since those are implementation +# details of the compiler we define a bunch of these here, starting with +# `acclogp_observe!!` which is what calls `produce`, and going up the call stack. +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true +function Libtask.might_produce( + ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}} +) + return true +end +function Libtask.might_produce( + ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}} +) + return true +end +Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true diff --git a/test/mcmc/particle_mcmc.jl b/test/mcmc/particle_mcmc.jl index 699ee6854..7a2f5fe1c 100644 --- a/test/mcmc/particle_mcmc.jl +++ b/test/mcmc/particle_mcmc.jl @@ -34,6 +34,7 @@ using Turing tested = sample(normal(), SMC(), 100) + # TODO(mhauru) This needs an explanation for why it fails. # failing test @model function fail_smc() a ~ Normal(4, 5) From b4823d968babe29060562743a9fc6be834cd5cce Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 20 Jun 2025 11:03:50 +0100 Subject: [PATCH 3/9] Remove use of AdvancedPS.addreference! --- src/mcmc/particle_mcmc.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 782595909..c1f3ab443 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -473,7 +473,6 @@ function AdvancedPS.Trace( tmodel = TracedModel(model, sampler, newvarinfo, rng) newtrace = AdvancedPS.Trace(tmodel, rng) - AdvancedPS.addreference!(newtrace) return newtrace end From d34dd3db9733d3b9ac509c4c7b8eaf8845bb8b4d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 20 Jun 2025 11:05:22 +0100 Subject: [PATCH 4/9] Improve a comment --- src/mcmc/particle_mcmc.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index c1f3ab443..c6a5fe7ca 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -477,10 +477,10 @@ function AdvancedPS.Trace( end # We need to tell Libtask which calls may have `produce` calls within them. In practice most -# of these won't be needed, because of inline and the fact that `might_produce` is only +# of these won't be needed, because of inlining and the fact that `might_produce` is only # called on `:invoke` expressions rather than `:call`s, but since those are implementation -# details of the compiler we define a bunch of these here, starting with -# `acclogp_observe!!` which is what calls `produce`, and going up the call stack. +# details of the compiler, we set a bunch of methods as might_produce = true. We start with +# `acclogp_observe!!` which is what calls `produce` and go up the call stack. Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true From 7cf8ee08d6fac64d8b8b8b71edf8f02a868805a1 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Mon, 23 Jun 2025 22:20:24 +0100 Subject: [PATCH 5/9] Update Project.toml (#2598) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index bfaa0d61a..b0a0652c9 100644 --- a/Project.toml +++ b/Project.toml @@ -85,7 +85,7 @@ Statistics = "1.6" StatsAPI = "1.6" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" -julia = "1.10.2" +julia = "1.10.8" [extras] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" From 1c6fad9b5e50adc6ad452e157ee04c29582b6bda Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 3 Jul 2025 15:11:13 +0100 Subject: [PATCH 6/9] Fix a bug and a test --- src/mcmc/particle_mcmc.jl | 2 +- test/essential/container.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index c73944b23..a81f436c8 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -58,7 +58,7 @@ function AdvancedPS.reset_logprob!(trace::TracedModel) return trace end -function Libtask.TapedTask(taped_globals::Any, model::TracedModel, args...; kwargs...) # RNG ? +function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...) return Libtask.TapedTask( taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs... ) diff --git a/test/essential/container.jl b/test/essential/container.jl index 1cb790d5a..cbd7a6fe2 100644 --- a/test/essential/container.jl +++ b/test/essential/container.jl @@ -23,8 +23,8 @@ using Turing model = test() trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG()) - # Make sure we link the traces - @test haskey(trace.model.ctask.task.storage, :__trace) + # Make sure the backreference from taped_globals to the trace is in place. + @test trace.model.ctask.taped_globals.other === trace res = AdvancedPS.advance!(trace, false) @test DynamicPPL.get_num_produce(trace.model.f.varinfo) == 1 From df31909982b588c504db9f08a7b0b7cf1b4611b1 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 10 Jul 2025 17:26:14 +0100 Subject: [PATCH 7/9] Bump Libtask to 0.9.3 Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index da6060f66..0f232dc90 100644 --- a/Project.toml +++ b/Project.toml @@ -67,7 +67,7 @@ DynamicHMC = "3.4" DynamicPPL = "0.36.3" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" -Libtask = "0.9.2" +Libtask = "0.9.3" LinearAlgebra = "1" LogDensityProblems = "2" MCMCChains = "5, 6, 7" From dd84c3304bd984d93adad35464e8f6aa27524e60 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 11 Jul 2025 15:41:26 +0100 Subject: [PATCH 8/9] Fix seed setting, increase iterations --- test/mcmc/Inference.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index a0d442186..cf528ce51 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -44,13 +44,13 @@ using Turing end # Should also be stable with an explicit RNG - seed = 5 - rng = Random.MersenneTwister(seed) + local_seed = 5 + rng = Random.MersenneTwister(local_seed) for sampler in samplers - Random.seed!(rng, seed) + Random.seed!(rng, local_seed) chain1 = sample(rng, model, sampler, MCMCThreads(), 10, 4) - Random.seed!(rng, seed) + Random.seed!(rng, local_seed) chain2 = sample(rng, model, sampler, MCMCThreads(), 10, 4) @test chain1.value == chain2.value @@ -256,9 +256,9 @@ using Turing pg = PG(10) gibbs = Gibbs(:p => HMC(0.2, 3), :x => PG(10)) - chn_s = sample(StableRNG(seed), testbb(obs), smc, 200) - chn_p = sample(StableRNG(seed), testbb(obs), pg, 200) - chn_g = sample(StableRNG(seed), testbb(obs), gibbs, 200) + chn_s = sample(StableRNG(seed), testbb(obs), smc, 2_000) + chn_p = sample(StableRNG(seed), testbb(obs), pg, 2_000) + chn_g = sample(StableRNG(seed), testbb(obs), gibbs, 2_000) check_numerical(chn_s, [:p], [meanp]; atol=0.05) check_numerical(chn_p, [:x], [meanp]; atol=0.1) @@ -647,7 +647,7 @@ using Turing @model function e(x=1.0) return x ~ Normal() end - # Can't test with HMC/NUTS because some AD backends error; see + # Can't test with HMC/NUTS because some AD backends error; see # https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/802 @test sample(e(), IS(), 100) isa MCMCChains.Chains end From c0fb9412c29c4aaf52a31f2218be44cd238b8e97 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 11 Jul 2025 16:08:38 +0100 Subject: [PATCH 9/9] Increate a test iteration count --- test/mcmc/ess.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index 4afeeb985..e918b3a51 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -64,7 +64,7 @@ using Turing @varname(mu1) => ESS(), @varname(mu2) => ESS(), ) - chain = sample(StableRNG(seed), MoGtest_default, alg, 2000) + chain = sample(StableRNG(seed), MoGtest_default, alg, 5000) check_MoGtest_default(chain; atol=0.1) end