Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/GP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,12 @@ an instance of [`Node`](@ref). The main `kwargs` (all optional) are:
changepoints::Bool = true
# Observation noise level.
noise::Union{Nothing,Float64} = nothing
# Parameter customization.
prior::Dict{Any,Any} = Dict(
:gamma => Dict(:scale=>2., :mu=>0., :sigma=>1.),
:period => Dict(:mu=>-1.5, :sigma=>1.),
:wildcard => Dict(:mu=>-1.5, :sigma=>1.),
)
end

# Convert index of a node to its depth in tree.
Expand Down
2 changes: 1 addition & 1 deletion src/Greedy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ function greedy_search_initialize(
leaf_node_types = get_leaf_node_types(config)
observations = Gen.choicemap((:xs, xs))
if !isnothing(config.noise)
observations[:noise] = Model.untransform_param(:noise, config.noise)
observations[:noise] = Model.untransform_param(:noise, config.noise, config)
end
trace, = Gen.generate(Model.model, (ts, config,), observations)
# Optimize each base kernel.
Expand Down
38 changes: 29 additions & 9 deletions src/Model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,34 @@ function untransform_logit_normal(param::Real, scale::Real, mu::Real, sigma::Rea
return (log(param / (scale - param)) - mu) / sigma
end

transform_param(field::Symbol, z::Real) = @match field begin
:gamma => transform_logit_normal(z, 2, 0, 1)
_ => transform_log_normal(z, -1.5, 1.)
function transform_param(field::Symbol, z::Real, config::GP.GPConfig)
@match field begin
:gamma => transform_logit_normal(z,
config.prior[:gamma][:scale],
config.prior[:gamma][:mu],
config.prior[:gamma][:sigma])
:period => transform_log_normal(z,
config.prior[:period][:mu],
config.prior[:period][:sigma])
wildcard => transform_log_normal(z,
config.prior[:wildcard][:mu],
config.prior[:wildcard][:sigma])
end
end

untransform_param(field::Symbol, param::Real) = @match field begin
:gamma => untransform_logit_normal(param, 2, 0, 1)
_ => untransform_log_normal(param, -1.5, 1.)
function untransform_param(field::Symbol, param::Real, config::GP.GPConfig)
@match field begin
:gamma => untransform_logit_normal(param,
config.prior[:gamma][:scale],
config.prior[:gamma][:mu],
config.prior[:gamma][:sigma])
:period => untransform_log_normal(param,
config.prior[:period][:mu],
config.prior[:period][:sigma])
wildcard => untransform_log_normal(param,
config.prior[:wildcard][:mu],
config.prior[:wildcard][:sigma])
end
end

"""Return distribution over node types at a given index."""
Expand Down Expand Up @@ -71,7 +91,7 @@ end
params = []
for field in fieldnames(NodeType)
log_param = {(idx, field)} ~ normal(0, 1)
param = transform_param(field, log_param)
param = transform_param(field, log_param, config)
push!(params, param)
end
node = NodeType(params...)
Expand All @@ -93,7 +113,7 @@ end
# but we should allow such traces to have probability zero
# (for inference) rather than force an assertion error.
location = {(idx, :location)} ~ normal(0, 1)
param = transform_param(:location, location)
param = transform_param(:location, location, config)
child1 = Gen.get_child(idx, 1, config.max_branch)
child2 = Gen.get_child(idx, 2, config.max_branch)
left_node = {*} ~ covariance_prior(child1, config)
Expand All @@ -111,7 +131,7 @@ end
n = length(ts)
covariance_fn = {:tree} ~ covariance_prior(1, config)
noise ~ normal(0, 1)
noise = transform_param(:noise, noise) + JITTER
noise = transform_param(:noise, noise, config) + JITTER
cov_matrix = GP.compute_cov_matrix_vectorized(covariance_fn, noise, ts)
xs ~ mvnormal(zeros(n), cov_matrix)
return covariance_fn
Expand Down
11 changes: 6 additions & 5 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ function GPModel(
# Initialize the particle filter.
observations = Gen.choicemap((:xs, y_numeric))
if !isnothing(config.noise)
observations[:noise] = Model.untransform_param(:noise, config.noise)
observations[:noise] = Model.untransform_param(:noise, config.noise, config)
end
pf_state = Gen.initialize_particle_filter(
Model.model, (ds_numeric, config), observations, n_particles)
Expand Down Expand Up @@ -163,7 +163,7 @@ given in the transformed space over which parameter inference is performed
"""
function observation_noise_variances(model::GPModel; reparameterize::Bool=true)
noises = [t[:noise] for t in model.pf_state.traces]
noises = Model.transform_param.(:noise, noises) .+ AutoGP.Model.JITTER
noises = Model.transform_param.(:noise, noises, Ref(model.config)) .+ AutoGP.Model.JITTER
if reparameterize
noises = Transforms.unapply_var.([model.y_transform], noises)
end
Expand Down Expand Up @@ -436,7 +436,7 @@ function add_data!(model::GPModel, ds::IndexType, y::Vector{<:Real})
# Prepare observations.
observations = Gen.choicemap((:xs, y_numeric))
if !isnothing(model.config.noise)
observations[:noise] = Model.untransform_param(:noise, model.config.noise)
observations[:noise] = Model.untransform_param(:noise, model.config.noise, model.config)
end
# Run SMC step.
Inference.smc_step!(model.pf_state, (ds_numeric, model.config), observations)
Expand All @@ -461,7 +461,7 @@ function remove_data!(model::GPModel, ds::IndexType)
# Prepare observations.
observations = Gen.choicemap((:xs, y_numeric))
if !isnothing(model.config.noise)
observations[:noise] = Model.untransform_param(:noise, config.noise)
observations[:noise] = Model.untransform_param(:noise, config.noise, model.config)
end
# Run SMC step.
Inference.smc_step!(model.pf_state, (ds_numeric, model.config), observations)
Expand Down Expand Up @@ -736,7 +736,8 @@ function decompose(model::GPModel)
# ERROR: type GPConfig has no field WhiteNoise
# noises = Model.transform_param.(
# :noise,
# [trace[:noise] for trace in model.pf_state.traces],)
# [trace[:noise] for trace in model.pf_state.traces],
# model.config)
# .+ AutoGP.Model.JITTER
#
# TODO: Use GPModel(model, kernels) instead of duplicating here.
Expand Down
2 changes: 1 addition & 1 deletion src/inference_smc_anneal_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ function run_smc_anneal_data(
@timeit elapsed begin
observations = Gen.choicemap()
if !isnothing(config.noise)
observations[:noise] = Model.untransform_param(:noise, config.noise)
observations[:noise] = Model.untransform_param(:noise, config.noise, config.prior)
end
state = Gen.initialize_particle_filter(
model,
Expand Down
10 changes: 5 additions & 5 deletions src/inference_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ function predict_mvn(
trace::Gen.Trace,
ts::Vector{Float64};
noise_pred::Union{Nothing, Float64}=nothing)
ts_train = Gen.get_args(trace)[1]
ts_train, config = Gen.get_args(trace)
xs_train = trace[:xs]
cov_fn = trace[]
noise = Model.transform_param(:noise, trace[:noise]) + Model.JITTER
noise = Model.transform_param(:noise, trace[:noise], config) + Model.JITTER
return Distributions.MvNormal(cov_fn, noise, ts_train, xs_train, ts; noise_pred=noise_pred)
end

Expand Down Expand Up @@ -216,7 +216,7 @@ function node_to_choicemap(node::LeafNode, idx::Int, config::GPConfig; params=no
if isnothing(params) || params
for field in fieldnames(NodeType)
param = getfield(node, field)
choices[(idx, field)] = Model.untransform_param(field, param)
choices[(idx, field)] = Model.untransform_param(field, param, config)
end
end
return choices
Expand All @@ -236,7 +236,7 @@ function node_to_choicemap(node::ChangePoint, idx::Int, config::GPConfig; params
choices = Gen.choicemap()
choices[(idx, :node_type)] = node_to_integer(node, config)
if isnothing(params) || params
choices[(idx, :location)] = Model.untransform_param(:location, node.location)
choices[(idx, :location)] = Model.untransform_param(:location, node.location, config)
end
idx_l = Gen.get_child(idx, 1, config.max_branch)
idx_r = Gen.get_child(idx, 2, config.max_branch)
Expand Down Expand Up @@ -278,7 +278,7 @@ function node_to_trace(
choicemap_node = Gen.choicemap()
Gen.set_submap!(choicemap_node, :tree, node_to_choicemap(node, config))
constraints = merge(choicemap_node, choicemap_obs)
constraints[:noise] = Model.untransform_param(:noise, noise)
constraints[:noise] = Model.untransform_param(:noise, noise, config)
constraints[:xs] = xs
return Gen.generate(Model.model, (ts, config), constraints)[1]
end
12 changes: 7 additions & 5 deletions test/experiment_hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ Random.seed!(15)
"""Plot forecast of posterior predictive distribution."""
function plot_posterior_forecast(trace, ts_obs, ts_test, xs_obs, xs_test)
node = trace[]
noise = Model.transform_param(:noise, trace[:noise])
config = Gen.get_args(trace)[2]
noise = Model.transform_param(:noise, trace[:noise], config)
dist = Distributions.MvNormal(node, noise + Model.JITTER, ts_obs, xs_obs, ts_test)
mu = Distributions.mean(dist)
bounds = Distributions.quantile(dist, [[.1, .9]])
Expand Down Expand Up @@ -119,7 +120,7 @@ function test_predictive_likelihood_agrees(config, constraints, ts_obs, xs_obs,
constraints_obs[:xs] = xs_obs
trace_obs, weight_obs = Gen.generate(Model.model, (ts_obs, config), constraints_obs)
# Ensure agreement.
noise = Model.transform_param(:noise, trace_obs[:noise]) + Model.JITTER
noise = Model.transform_param(:noise, trace_obs[:noise], config) + Model.JITTER
dist = Distributions.MvNormal(trace_obs[], noise, ts_obs, xs_obs, ts_test)
lp_test_ll = Distributions.logpdf(dist, xs_test)
lp_test_bayes = weight_joint - weight_obs
Expand Down Expand Up @@ -166,7 +167,8 @@ end

"""Report metrics from the hmc trace."""
function compute_inference_metrics(trace_infer)
noise = Model.transform_param(:noise, trace_infer[:noise])
config = Gen.get_args(trace_infer)[2]
noise = Model.transform_param(:noise, trace_infer[:noise], config)
state = (GP.pretty(trace_infer[]), noise)
score = Gen.get_score(trace_infer)
dist = Distributions.MvNormal(trace_infer[], noise + Model.JITTER, ts_obs, xs_obs, ts_test)
Expand All @@ -188,8 +190,8 @@ test_constrain_structure(config)

# Select benchmark.
(node_true, noise_true) = BENCHMARKS[2]
xi_true = Model.untransform_param(:noise, noise_true)
@assert isapprox(noise_true, Model.transform_param(:noise, xi_true))
xi_true = Model.untransform_param(:noise, noise_true, config)
@assert isapprox(noise_true, Model.transform_param(:noise, xi_true, config))

# Simulate ground-truth trace.
(n, n_obs) = (1000, 200)
Expand Down