Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ julia = "1.8"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[targets]
test = ["Test"]
test = ["Test", "Serialization"]
6 changes: 6 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,9 @@ AutoGP.observation_noise_variances
AutoGP.decompose
AutoGP.extract_kernel
```

## [Serialization](@id model_serialization)

```@docs
Base.Dict(model::AutoGP.GPModel)
```
77 changes: 75 additions & 2 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,9 @@ function add_data!(model::GPModel, ds::IndexType, y::Vector{<:Real})
y_numeric = Transforms.apply(model.y_transform, model.y)
# Prepare observations.
observations = Gen.choicemap((:xs, y_numeric))
!isnothing(model.config.noise) && (observations[:noise] = trace[:noise])
if !isnothing(model.config.noise)
observations[:noise] = Model.untransform_param(:noise, model.config.noise)
end
# Run SMC step.
Inference.smc_step!(model.pf_state, (ds_numeric, model.config), observations)
end
Expand All @@ -456,7 +458,9 @@ function remove_data!(model::GPModel, ds::IndexType)
y_numeric = Transforms.apply(model.y_transform, model.y)
# Prepare observations.
observations = Gen.choicemap((:xs, y_numeric))
!isnothing(model.config.noise) && (observations[:noise] = trace[:noise])
if !isnothing(model.config.noise)
observations[:noise] = Model.untransform_param(:noise, config.noise)
end
# Run SMC step.
Inference.smc_step!(model.pf_state, (ds_numeric, model.config), observations)
end
Expand Down Expand Up @@ -800,3 +804,72 @@ function extract_kernel(model::GPModel, t::Type{T}; retain::Bool=true) where T <
new_model.pf_state.log_weights = model.pf_state.log_weights
return new_model
end

# Serialization

"""
Base.Dict(model::GPModel)

Convert a [`GPModel`](@ref) into a dictionary that can be saved and
loaded from disk, as shown in the following example.

# Example
```
using AutoGP, Dates, Serialization
model = AutoGP.GPModel([Date("2025-01-01"), Date("2025-01-02")], [1.0, 2.0])
serialize("model.autogp", Dict(model))
loaded_model = AutoGP.GPModel(deserialize("model.autogp"))
```
"""
function Base.Dict(model::GPModel)
kernels = covariance_kernels(model; reparameterize=false)
noises = observation_noise_variances(model; reparameterize=false)
m = Dict([
# pf_state
"pf_state" => Dict([
"log_weights" => model.pf_state.log_weights,
"log_ml_est" => model.pf_state.log_ml_est,
]),
# kernels and noise
"kernels" => kernels,
"noises" => noises,
# serialize other fields
"config" => model.config,
"ds" => model.ds,
"y" => model.y,
"ds_transform" => model.ds_transform,
"y_transform" => model.y_transform,
])
return m
end


function GPModel(m::Base.Dict{String, Any})
n_particles = length(m["kernels"])
ds_numeric = Transforms.apply(m["ds_transform"], to_numeric.(m["ds"]))
y_numeric = Transforms.apply(m["y_transform"], m["y"])
observations = Gen.choicemap((:xs, y_numeric))
pf_state = Gen.initialize_particle_filter(
Model.model, (ds_numeric, m["config"]), observations, n_particles)
# Set the pf_state.
pf_state.log_weights = m["pf_state"]["log_weights"]
pf_state.log_ml_est = m["pf_state"]["log_ml_est"]
for (i, (kernel, noise)) in enumerate(zip(m["kernels"], m["noises"]))
pf_state.traces[i] = Inference.node_to_trace(
kernel, m["config"], ds_numeric, y_numeric, noise)
end
# Return the GP model.
return GPModel(
pf_state,
m["config"],
collect(m["ds"]),
collect(m["y"]),
m["ds_transform"],
m["y_transform"])
end

# https://juliaio.github.io/JLD2.jl/stable/customserialization/
# struct GPModelSerialization x::Dict end
# JLD2.writeas(::Type{GPModel}) = GPModelSerialization
# JLD2.wconvert(::Type{GPModelSerialization}, model::GPModel) = GPModelSerialization(Dict(model))
# JLD2.rconvert(::Type{GPModel}, g::GPModelSerialization) = GPModel(g.x)
15 changes: 15 additions & 0 deletions src/inference_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,18 @@ function node_to_trace(node::Node, trace::Gen.Trace)
constraints,
)[1]
end

function node_to_trace(
node::Node,
config::GPConfig,
ts::Vector{Float64},
xs::Vector{Float64},
noise::Float64)
choicemap_obs = Gen.choicemap((:xs, xs))
choicemap_node = Gen.choicemap()
Gen.set_submap!(choicemap_node, :tree, node_to_choicemap(node, config))
constraints = merge(choicemap_node, choicemap_obs)
constraints[:noise] = Model.untransform_param(:noise, noise)
constraints[:xs] = xs
return Gen.generate(Model.model, (ts, config), constraints)[1]
end
5 changes: 3 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using Test
using AutoGP

@testset "AutoGP" begin
@testset "test_GP.jl" begin include("test_GP.jl") end
@testset "test_api.jl" begin include("test_api.jl") end
@testset "test_GP.jl" begin include("test_GP.jl") end
@testset "test_api.jl" begin include("test_api.jl") end
@testset "test_serialize.jl" begin include("test_serialize.jl") end
end
87 changes: 87 additions & 0 deletions test/test_serialize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2023 Google LLC
# Copyright 2025 CMU Probabilistic Computing Systems Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

using Test
using Dates
using Parameters
using Serialization

using AutoGP

function load_via_seralize(model::AutoGP.GPModel)
dict = mktemp() do path, io
serialize(path, Dict(model))
return deserialize(path)
end
return AutoGP.GPModel(dict)
end

# function load_via_jld2(model::AutoGP.GPModel)
# return mktempdir() do dir
# path = joinpath(dir, "model.jld2")
# save(path, "model", model)
# load(path, "model")
# end
# end

@testset "test_serialize" begin

function check_model_same(model1::AutoGP.GPModel, model2::AutoGP.GPModel)
@test model1.ds_transform == model2.ds_transform
@test model1.y_transform == model2.y_transform
@test model1.ds == model2.ds
@test model1.y == model2.y
@test type2dict(model1.config) == type2dict(model2.config)
kernels1 = AutoGP.covariance_kernels(model1)
kernels2 = AutoGP.covariance_kernels(model2)
@test all(kernels1 .≈ kernels2)
noises1 = AutoGP.observation_noise_variances(model1)
noises2 = AutoGP.observation_noise_variances(model2)
@test all(isapprox.(noises1, noises2, rtol=1e-3)) # Why precision loss?
weights1 = AutoGP.particle_weights(model1)
weights2 = AutoGP.particle_weights(model2)
@test all(isapprox.(weights1, weights2, atol=1e-4)) # Why precision loss?
end

# Initialize toy model.
model1 = AutoGP.GPModel([Date("2025-01-01"), Date("2025-01-02")], [1.0, 2.0])
AutoGP.fit_smc!(model1; n_mcmc=5, n_hmc=5, schedule=[2])

# Write and load from disk.
for model2 in [load_via_seralize(model1)]

# Check initial models agree
check_model_same(model1, model2)

# Add data.
AutoGP.add_data!(model1, [Date("2025-02-03")], [3.0]);
AutoGP.add_data!(model2, [Date("2025-02-03")], [3.0]);
check_model_same(model1, model2)

# Remove data.
AutoGP.remove_data!(model1, [Date("2025-01-01")])
AutoGP.remove_data!(model2, [Date("2025-01-01")])
check_model_same(model1, model2)

# Infer with same seed.
AutoGP.seed!(5)
AutoGP.fit_smc!(model1; n_mcmc=5, n_hmc=5, schedule=[2])
AutoGP.seed!(5)
AutoGP.fit_smc!(model2; n_mcmc=5, n_hmc=5, schedule=[2])
check_model_same(model1, model2)

end

end