From a3e3ede4071b7dea7981707a0da48c2f90dbf327 Mon Sep 17 00:00:00 2001 From: Feras Saad Date: Mon, 26 Jan 2026 18:28:21 -0500 Subject: [PATCH] Make Model.transform_param accept config for priors --- src/GP.jl | 6 +++++ src/Greedy.jl | 2 +- src/Model.jl | 38 ++++++++++++++++++++++++-------- src/api.jl | 11 ++++----- src/inference_smc_anneal_data.jl | 2 +- src/inference_utils.jl | 10 ++++----- test/experiment_hmc.jl | 12 +++++----- 7 files changed, 55 insertions(+), 26 deletions(-) diff --git a/src/GP.jl b/src/GP.jl index 2fe2aaf..8e2b3b1 100644 --- a/src/GP.jl +++ b/src/GP.jl @@ -1129,6 +1129,12 @@ an instance of [`Node`](@ref). The main `kwargs` (all optional) are: changepoints::Bool = true # Observation noise level. noise::Union{Nothing,Float64} = nothing + # Parameter customization. + prior::Dict{Any,Any} = Dict( + :gamma => Dict(:scale=>2., :mu=>0., :sigma=>1.), + :period => Dict(:mu=>-1.5, :sigma=>1.), + :wildcard => Dict(:mu=>-1.5, :sigma=>1.), + ) end # Convert index of a node to its depth in tree. diff --git a/src/Greedy.jl b/src/Greedy.jl index d65acdf..195f6aa 100644 --- a/src/Greedy.jl +++ b/src/Greedy.jl @@ -394,7 +394,7 @@ function greedy_search_initialize( leaf_node_types = get_leaf_node_types(config) observations = Gen.choicemap((:xs, xs)) if !isnothing(config.noise) - observations[:noise] = Model.untransform_param(:noise, config.noise) + observations[:noise] = Model.untransform_param(:noise, config.noise, config) end trace, = Gen.generate(Model.model, (ts, config,), observations) # Optimize each base kernel. diff --git a/src/Model.jl b/src/Model.jl index f635d5d..25b9380 100644 --- a/src/Model.jl +++ b/src/Model.jl @@ -32,14 +32,34 @@ function untransform_logit_normal(param::Real, scale::Real, mu::Real, sigma::Rea return (log(param / (scale - param)) - mu) / sigma end -transform_param(field::Symbol, z::Real) = @match field begin - :gamma => transform_logit_normal(z, 2, 0, 1) - _ => transform_log_normal(z, -1.5, 1.) +function transform_param(field::Symbol, z::Real, config::GP.GPConfig) + @match field begin + :gamma => transform_logit_normal(z, + config.prior[:gamma][:scale], + config.prior[:gamma][:mu], + config.prior[:gamma][:sigma]) + :period => transform_log_normal(z, + config.prior[:period][:mu], + config.prior[:period][:sigma]) + wildcard => transform_log_normal(z, + config.prior[:wildcard][:mu], + config.prior[:wildcard][:sigma]) + end end -untransform_param(field::Symbol, param::Real) = @match field begin - :gamma => untransform_logit_normal(param, 2, 0, 1) - _ => untransform_log_normal(param, -1.5, 1.) +function untransform_param(field::Symbol, param::Real, config::GP.GPConfig) + @match field begin + :gamma => untransform_logit_normal(param, + config.prior[:gamma][:scale], + config.prior[:gamma][:mu], + config.prior[:gamma][:sigma]) + :period => untransform_log_normal(param, + config.prior[:period][:mu], + config.prior[:period][:sigma]) + wildcard => untransform_log_normal(param, + config.prior[:wildcard][:mu], + config.prior[:wildcard][:sigma]) + end end """Return distribution over node types at a given index.""" @@ -71,7 +91,7 @@ end params = [] for field in fieldnames(NodeType) log_param = {(idx, field)} ~ normal(0, 1) - param = transform_param(field, log_param) + param = transform_param(field, log_param, config) push!(params, param) end node = NodeType(params...) @@ -93,7 +113,7 @@ end # but we should allow such traces to have probability zero # (for inference) rather than force an assertion error. location = {(idx, :location)} ~ normal(0, 1) - param = transform_param(:location, location) + param = transform_param(:location, location, config) child1 = Gen.get_child(idx, 1, config.max_branch) child2 = Gen.get_child(idx, 2, config.max_branch) left_node = {*} ~ covariance_prior(child1, config) @@ -111,7 +131,7 @@ end n = length(ts) covariance_fn = {:tree} ~ covariance_prior(1, config) noise ~ normal(0, 1) - noise = transform_param(:noise, noise) + JITTER + noise = transform_param(:noise, noise, config) + JITTER cov_matrix = GP.compute_cov_matrix_vectorized(covariance_fn, noise, ts) xs ~ mvnormal(zeros(n), cov_matrix) return covariance_fn diff --git a/src/api.jl b/src/api.jl index a998da4..94cd48e 100644 --- a/src/api.jl +++ b/src/api.jl @@ -103,7 +103,7 @@ function GPModel( # Initialize the particle filter. observations = Gen.choicemap((:xs, y_numeric)) if !isnothing(config.noise) - observations[:noise] = Model.untransform_param(:noise, config.noise) + observations[:noise] = Model.untransform_param(:noise, config.noise, config) end pf_state = Gen.initialize_particle_filter( Model.model, (ds_numeric, config), observations, n_particles) @@ -163,7 +163,7 @@ given in the transformed space over which parameter inference is performed """ function observation_noise_variances(model::GPModel; reparameterize::Bool=true) noises = [t[:noise] for t in model.pf_state.traces] - noises = Model.transform_param.(:noise, noises) .+ AutoGP.Model.JITTER + noises = Model.transform_param.(:noise, noises, Ref(model.config)) .+ AutoGP.Model.JITTER if reparameterize noises = Transforms.unapply_var.([model.y_transform], noises) end @@ -436,7 +436,7 @@ function add_data!(model::GPModel, ds::IndexType, y::Vector{<:Real}) # Prepare observations. observations = Gen.choicemap((:xs, y_numeric)) if !isnothing(model.config.noise) - observations[:noise] = Model.untransform_param(:noise, model.config.noise) + observations[:noise] = Model.untransform_param(:noise, model.config.noise, model.config) end # Run SMC step. Inference.smc_step!(model.pf_state, (ds_numeric, model.config), observations) @@ -461,7 +461,7 @@ function remove_data!(model::GPModel, ds::IndexType) # Prepare observations. observations = Gen.choicemap((:xs, y_numeric)) if !isnothing(model.config.noise) - observations[:noise] = Model.untransform_param(:noise, config.noise) + observations[:noise] = Model.untransform_param(:noise, config.noise, model.config) end # Run SMC step. Inference.smc_step!(model.pf_state, (ds_numeric, model.config), observations) @@ -736,7 +736,8 @@ function decompose(model::GPModel) # ERROR: type GPConfig has no field WhiteNoise # noises = Model.transform_param.( # :noise, - # [trace[:noise] for trace in model.pf_state.traces],) + # [trace[:noise] for trace in model.pf_state.traces], + # model.config) # .+ AutoGP.Model.JITTER # # TODO: Use GPModel(model, kernels) instead of duplicating here. diff --git a/src/inference_smc_anneal_data.jl b/src/inference_smc_anneal_data.jl index 548be5f..d4bd144 100644 --- a/src/inference_smc_anneal_data.jl +++ b/src/inference_smc_anneal_data.jl @@ -180,7 +180,7 @@ function run_smc_anneal_data( @timeit elapsed begin observations = Gen.choicemap() if !isnothing(config.noise) - observations[:noise] = Model.untransform_param(:noise, config.noise) + observations[:noise] = Model.untransform_param(:noise, config.noise, config.prior) end state = Gen.initialize_particle_filter( model, diff --git a/src/inference_utils.jl b/src/inference_utils.jl index 8528fb8..df0e74a 100644 --- a/src/inference_utils.jl +++ b/src/inference_utils.jl @@ -175,10 +175,10 @@ function predict_mvn( trace::Gen.Trace, ts::Vector{Float64}; noise_pred::Union{Nothing, Float64}=nothing) - ts_train = Gen.get_args(trace)[1] + ts_train, config = Gen.get_args(trace) xs_train = trace[:xs] cov_fn = trace[] - noise = Model.transform_param(:noise, trace[:noise]) + Model.JITTER + noise = Model.transform_param(:noise, trace[:noise], config) + Model.JITTER return Distributions.MvNormal(cov_fn, noise, ts_train, xs_train, ts; noise_pred=noise_pred) end @@ -216,7 +216,7 @@ function node_to_choicemap(node::LeafNode, idx::Int, config::GPConfig; params=no if isnothing(params) || params for field in fieldnames(NodeType) param = getfield(node, field) - choices[(idx, field)] = Model.untransform_param(field, param) + choices[(idx, field)] = Model.untransform_param(field, param, config) end end return choices @@ -236,7 +236,7 @@ function node_to_choicemap(node::ChangePoint, idx::Int, config::GPConfig; params choices = Gen.choicemap() choices[(idx, :node_type)] = node_to_integer(node, config) if isnothing(params) || params - choices[(idx, :location)] = Model.untransform_param(:location, node.location) + choices[(idx, :location)] = Model.untransform_param(:location, node.location, config) end idx_l = Gen.get_child(idx, 1, config.max_branch) idx_r = Gen.get_child(idx, 2, config.max_branch) @@ -278,7 +278,7 @@ function node_to_trace( choicemap_node = Gen.choicemap() Gen.set_submap!(choicemap_node, :tree, node_to_choicemap(node, config)) constraints = merge(choicemap_node, choicemap_obs) - constraints[:noise] = Model.untransform_param(:noise, noise) + constraints[:noise] = Model.untransform_param(:noise, noise, config) constraints[:xs] = xs return Gen.generate(Model.model, (ts, config), constraints)[1] end diff --git a/test/experiment_hmc.jl b/test/experiment_hmc.jl index 13969e2..5c6d244 100644 --- a/test/experiment_hmc.jl +++ b/test/experiment_hmc.jl @@ -32,7 +32,8 @@ Random.seed!(15) """Plot forecast of posterior predictive distribution.""" function plot_posterior_forecast(trace, ts_obs, ts_test, xs_obs, xs_test) node = trace[] - noise = Model.transform_param(:noise, trace[:noise]) + config = Gen.get_args(trace)[2] + noise = Model.transform_param(:noise, trace[:noise], config) dist = Distributions.MvNormal(node, noise + Model.JITTER, ts_obs, xs_obs, ts_test) mu = Distributions.mean(dist) bounds = Distributions.quantile(dist, [[.1, .9]]) @@ -119,7 +120,7 @@ function test_predictive_likelihood_agrees(config, constraints, ts_obs, xs_obs, constraints_obs[:xs] = xs_obs trace_obs, weight_obs = Gen.generate(Model.model, (ts_obs, config), constraints_obs) # Ensure agreement. - noise = Model.transform_param(:noise, trace_obs[:noise]) + Model.JITTER + noise = Model.transform_param(:noise, trace_obs[:noise], config) + Model.JITTER dist = Distributions.MvNormal(trace_obs[], noise, ts_obs, xs_obs, ts_test) lp_test_ll = Distributions.logpdf(dist, xs_test) lp_test_bayes = weight_joint - weight_obs @@ -166,7 +167,8 @@ end """Report metrics from the hmc trace.""" function compute_inference_metrics(trace_infer) - noise = Model.transform_param(:noise, trace_infer[:noise]) + config = Gen.get_args(trace_infer)[2] + noise = Model.transform_param(:noise, trace_infer[:noise], config) state = (GP.pretty(trace_infer[]), noise) score = Gen.get_score(trace_infer) dist = Distributions.MvNormal(trace_infer[], noise + Model.JITTER, ts_obs, xs_obs, ts_test) @@ -188,8 +190,8 @@ test_constrain_structure(config) # Select benchmark. (node_true, noise_true) = BENCHMARKS[2] -xi_true = Model.untransform_param(:noise, noise_true) -@assert isapprox(noise_true, Model.transform_param(:noise, xi_true)) +xi_true = Model.untransform_param(:noise, noise_true, config) +@assert isapprox(noise_true, Model.transform_param(:noise, xi_true, config)) # Simulate ground-truth trace. (n, n_obs) = (1000, 200)