Skip to content

IndexKeyMap #192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@ authors = ["SciML"]
version = "0.1.10"

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GlobalSensitivity = "af5da776-676b-467e-8baf-acd8249e4f0f"
@@ -15,12 +16,14 @@ OptimizationBBO = "3e6eede4-6085-4f62-9a71-46d9bc1eb92b"
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
OptimizationNLopt = "4e6fcdb7-1186-4e1f-a706-475e75c168bb"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLExpectations = "afe9f18d-7609-4d0e-b7b7-af0cb72b8ea8"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
DiffEqBase = "6.127.0"
DifferentialEquations = "7"
Distributions = "0.25"
GlobalSensitivity = "2"
@@ -32,7 +35,7 @@ OptimizationMOI = "0.1"
OptimizationNLopt = "0.1"
Plots = "1"
Reexport = "1"
SciMLBase = "1.93.1"
SciMLBase = "1.93.3"
SciMLExpectations = "2"
Turing = "0.22, 0.23, 0.24"
julia = "1.6"
100 changes: 71 additions & 29 deletions docs/src/tutorials/ensemble_modeling.md
Original file line number Diff line number Diff line change
@@ -80,7 +80,7 @@ prototype problem, which we are effectively ignoring for our use case.
Thus a simple `EnsembleProblem` which ensembles the three models built above is as follows:

```@example ensemble
probs = [prob, prob2, prob3]
probs = [prob, prob2, prob3];
enprob = EnsembleProblem(probs)
```

@@ -95,7 +95,7 @@ We can access the 3 solutions as `sol[i]` respectively. Let's get the time serie
for `S` from each of the models:

```@example ensemble
sol[:,S]
sol[:, S]
```

## Building a Dataset
@@ -107,9 +107,9 @@ interface on the ensemble solution.
```@example ensemble
weights = [0.2, 0.5, 0.3]
data = [
S => vec(sum(stack(weights .* sol[:,S]), dims = 2)),
I => vec(sum(stack(weights .* sol[:,I]), dims = 2)),
R => vec(sum(stack(weights .* sol[:,R]), dims = 2)),
S => vec(sum(stack(weights .* sol[:, S]), dims = 2)),
I => vec(sum(stack(weights .* sol[:, I]), dims = 2)),
R => vec(sum(stack(weights .* sol[:, R]), dims = 2)),
]
```

@@ -131,27 +131,27 @@ scatter!(data[3][2])
Now let's split that into training, ensembling, and forecast sections:

```@example ensemble
fullS = vec(sum(stack(weights .* sol[:,S]),dims=2))
fullI = vec(sum(stack(weights .* sol[:,I]),dims=2))
fullR = vec(sum(stack(weights .* sol[:,R]),dims=2))
fullS = vec(sum(stack(weights .* sol[:, S]), dims = 2))
fullI = vec(sum(stack(weights .* sol[:, I]), dims = 2))
fullR = vec(sum(stack(weights .* sol[:, R]), dims = 2))

t_train = 0:14
data_train = [
S => (t_train,fullS[1:15]),
I => (t_train,fullI[1:15]),
R => (t_train,fullR[1:15]),
S => (t_train, fullS[1:15]),
I => (t_train, fullI[1:15]),
R => (t_train, fullR[1:15]),
]
t_ensem = 0:21
data_ensem = [
S => (t_ensem,fullS[1:22]),
I => (t_ensem,fullI[1:22]),
R => (t_ensem,fullR[1:22]),
S => (t_ensem, fullS[1:22]),
I => (t_ensem, fullI[1:22]),
R => (t_ensem, fullR[1:22]),
]
t_forecast = 0:30
data_forecast = [
S => (t_forecast,fullS),
I => (t_forecast,fullI),
R => (t_forecast,fullR),
S => (t_forecast, fullS),
I => (t_forecast, fullI),
R => (t_forecast, fullR),
]
```

@@ -160,10 +160,10 @@ data_forecast = [
Now let's perform a Bayesian calibration on each of the models. This gives us multiple parameterizations for each model, which then gives an ensemble which is `parameterizations x models` in size.

```@example ensemble
probs = [prob, prob2, prob3]
probs = [prob, prob2, prob3];
ps = [[β => Uniform(0.01, 10.0), γ => Uniform(0.01, 10.0)] for i in 1:3]
datas = [data_train,data_train,data_train]
enprobs = bayesian_ensemble(probs, ps, datas)
datas = [data_train, data_train, data_train]
enprobs = bayesian_ensemble(probs, ps, datas, nchains=2, niter=200)
```

Let's see how each of our models in the ensemble compare against the data when changed
@@ -192,8 +192,8 @@ Now let's train the ensemble model. We will do that by solving a bit further tha
calibration step. Let's build that solution data:

```@example ensemble
plot(sol;idxs = S)
scatter!(t_ensem,data_ensem[1][2][2])
plot(sol; idxs = S)
scatter!(t_ensem, data_ensem[1][2][2])
```

We can obtain the optimal weights for ensembling by solving a linear regression of
@@ -208,14 +208,14 @@ Now we can extrapolate forward with these ensemble weights as follows:

```@example ensemble
sol = solve(enprobs; saveat = t_ensem);
ensem_prediction = sum(stack(ensem_weights .* sol[:,S]), dims = 2)
ensem_prediction = sum(stack(ensem_weights .* sol[:, S]), dims = 2)
plot(sol; idxs = S, color = :blue)
plot!(t_ensem, ensem_prediction, lw = 5, color = :red)
scatter!(t_ensem, data_ensem[1][2][2])
```

```@example ensemble
ensem_prediction = sum(stack(ensem_weights .* sol[:,I]), dims = 2)
ensem_prediction = sum(stack(ensem_weights .* sol[:, I]), dims = 2)
plot(sol; idxs = I, color = :blue)
plot!(t_ensem, ensem_prediction, lw = 3, color = :red)
scatter!(t_ensem, data_ensem[2][2][2])
@@ -226,26 +226,68 @@ scatter!(t_ensem, data_ensem[2][2][2])
Once we have obtained the ensemble model, we can forecast ahead with it:

```@example ensemble
forecast_probs = [remake(enprobs.prob[i]; tspan = (t_train[1],t_forecast[end])) for i in 1:length(enprobs.prob)]
forecast_probs = [remake(enprobs.prob[i]; tspan = (t_train[1], t_forecast[end]))
for i in 1:length(enprobs.prob)];
fit_enprob = EnsembleProblem(forecast_probs)

sol = solve(fit_enprob; saveat = t_forecast);
ensem_prediction = sum(stack(ensem_weights .* sol[:,S]), dims = 2)
ensem_prediction = sum(stack(ensem_weights .* sol[:, S]), dims = 2)
plot(sol; idxs = S, color = :blue)
plot!(t_forecast, ensem_prediction, lw = 3, color = :red)
scatter!(t_forecast, data_forecast[1][2][2])
```

```@example ensemble
ensem_prediction = sum(stack(ensem_weights .* sol[:, I]), dims = 2)
plot(sol; idxs = I, color = :blue)
plot!(t_forecast, ensem_prediction, lw = 3, color = :red)
scatter!(t_forecast, data_forecast[2][2][2])
```

```@example ensemble
ensem_prediction = sum(stack(ensem_weights .* sol[:, R]), dims = 2)
plot(sol; idxs = R, color = :blue)
plot!(t_forecast, ensem_prediction, lw = 3, color = :red)
scatter!(t_forecast, data_forecast[3][2][2])
```

## Training the "Super Ensemble" Model

The standard ensemble model first calibrates each model in an ensemble and then uses the calibrated models
as the basis for a prediction via a linear combination. The super ensemble performs the Bayesian estimation
on the full combination of models, including the weights vector, as a single Bayesian posterior calculation.
While this has the downside that the prediction of any single model is not necessarily predictive of the
whole, in some cases this ensemble model may be more effective.

To train this model, simply use `bayesian_datafit` on the ensemble. This looks like:

```@example ensemble
probs = [prob, prob2, prob3];
ps = [[β => Uniform(0.01, 10.0), γ => Uniform(0.01, 10.0)] for i in 1:3]

super_enprob, ensem_weights = bayesian_datafit(probs, ps, data_ensem)
```

And now we can forecast with this model:

```@example ensemble
sol = solve(super_enprob; saveat = t_forecast);
ensem_prediction = sum(stack(ensem_weights .* sol[:, S]), dims = 2)
plot(sol; idxs = S, color = :blue)
plot!(t_forecast, ensem_prediction, lw = 3, color = :red)
scatter!(t_forecast, data_forecast[1][2][2])
```

```@example ensemble
ensem_prediction = sum(stack([ensem_weights[i] * sol[i][I] for i in 1:length(forecast_probs)]), dims = 2)
ensem_prediction = sum(stack(ensem_weights .* sol[:, I]), dims = 2)
plot(sol; idxs = I, color = :blue)
plot!(t_forecast, ensem_prediction, lw = 3, color = :red)
scatter!(t_forecast, data_forecast[2][2][2])
```

```@example ensemble
ensem_prediction = sum(stack([ensem_weights[i] * sol[i][R] for i in 1:length(forecast_probs)]), dims = 2)
ensem_prediction = sum(stack(ensem_weights .* sol[:, R]), dims = 2)
plot(sol; idxs = R, color = :blue)
plot!(t_forecast, ensem_prediction, lw = 3, color = :red)
scatter!(t_forecast, data_forecast[3][2][2])
```
```
2 changes: 2 additions & 0 deletions src/EasyModelAnalysis.jl
Original file line number Diff line number Diff line change
@@ -10,8 +10,10 @@ using GlobalSensitivity, Turing
using SciMLExpectations
@reexport using Plots
using SciMLBase.EnsembleAnalysis
using Random

include("basics.jl")
include("keyindexmap.jl")
include("datafit.jl")
include("sensitivity.jl")
include("threshold.jl")
339 changes: 304 additions & 35 deletions src/datafit.jl

Large diffs are not rendered by default.

38 changes: 25 additions & 13 deletions src/ensemble.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
function naivemap(f, ::EnsembleThreads, arg0, args...)
t = Vector{Task}(undef, length(arg0))
for (n, a) in enumerate(arg0)
t[n] = let an = map(Base.Fix2(Base.getindex, n), args)
Threads.@spawn f(a, an...)
end
end
return identity.(map(fetch, t))
end
function naivemap(f, ::EnsembleSerial, args...)
map(f, args...)
end


"""
ensemble_weights(sol::EnsembleSolution, data_ensem)
@@ -19,31 +33,29 @@ dataset on which the ensembler should be trained on.
"""
function ensemble_weights(sol::EnsembleSolution, data_ensem)
obs = first.(data_ensem)
predictions = reduce(vcat, reduce(hcat,[sol[i][s] for i in 1:length(sol)]) for s in obs)
data = reduce(vcat, [data_ensem[i][2] isa Tuple ? data_ensem[i][2][2] : data_ensem[i][2] for i in 1:length(data_ensem)])
weights = predictions \ data
predictions = reduce(vcat,
reduce(hcat, [sol[i][s] for i in 1:length(sol)]) for s in obs)
data = reduce(vcat,
[data_ensem[i][2] isa Tuple ? data_ensem[i][2][2] : data_ensem[i][2]
for i in 1:length(data_ensem)])
weights = predictions \ data
end

function bayesian_ensemble(probs, ps, datas;
noise_prior = InverseGamma(2, 3),
mcmcensemble::AbstractMCMC.AbstractMCMCEnsemble = Turing.MCMCSerial(),
ensemblealg::SciMLBase.BasicEnsembleAlgorithm = EnsembleThreads(),
nchains = 4,
niter = 1_000,
keep = 100)

fits = map(probs, ps, datas) do prob, p, data
bayesian_datafit(prob, p, data; noise_prior, mcmcensemble, nchains, niter)
end

models = map(probs, fits) do prob, fit
[remake(prob, p = Pair.(first.(fit), getindex.(last.(fit), i))) for i in length(fit[1][2])-keep:length(fit[1][2])]
models = naivemap(ensemblealg, probs, ps, datas) do prob, p, data
bayesian_datafit(prob, p, data; noise_prior, ensemblealg, nchains, niter)
end

@info "Calibrations are complete"

all_probs = reduce(vcat,models)
all_probs = reduce(vcat, map(prob -> prob.prob, models))

@info "$(length(all_probs)) total models"

enprob = EnsembleProblem(all_probs)
end
end
43 changes: 43 additions & 0 deletions src/keyindexmap.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

struct IndexKeyMap
indices::Vector{Int}
end

# probs support
function IndexKeyMap(prob, keys)
params = ModelingToolkit.parameters(prob.f.sys)
indices = Vector{Int}(undef, length(keys))
for i in eachindex(keys)
indices[i] = findfirst(Base.Fix1(isequal, keys[i]), params)
end
return IndexKeyMap(indices)
end

Base.@propagate_inbounds function (ikm::IndexKeyMap)(prob::SciMLBase.AbstractDEProblem,
v::AbstractVector)
@boundscheck checkbounds(v, length(ikm.indices))
def = prob.p
ret = Vector{Base.promote_eltype(v, def)}(undef, length(def))
copyto!(ret, def)
for (i, j) in enumerate(ikm.indices)
@inbounds ret[j] = v[i]
end
return ret
end
function _remake(prob, ikm::IndexKeyMap, pprior, tspan::Tuple{Number, Number} = prob.tspan)
p = ikm(prob, pprior)
remake(prob; tspan, p)
end

# data support
function IndexKeyMap(prob, data::AbstractVector{<:Pair})
states = ModelingToolkit.states(prob.f.sys)
indices = Vector{Int}(undef, length(data))
for i in eachindex(data)
indices[i] = findfirst(Base.Fix1(isequal, data[i].first), states)
end
return IndexKeyMap(indices)
end
function (ikm::IndexKeyMap)(sol::SciMLBase.AbstractTimeseriesSolution)
(@view(sol[i, :]) for i in ikm.indices)
end
11 changes: 6 additions & 5 deletions test/datafit.jl
Original file line number Diff line number Diff line change
@@ -88,15 +88,16 @@ tsave = collect(10.0:10.0:100.0)
sol_data = solve(prob, saveat = tsave)
data = [x => sol_data[x], z => sol_data[z]]
p_prior ==> Normal(26.8, 0.1), β => Normal(2.7, 0.1)]
p_posterior = @time bayesian_datafit(prob, p_prior, tsave, data)
@test var.(getfield.(p_prior, :second)) >= var.(getfield.(p_posterior, :second))
p_posterior = @time bayesian_datafit(prob, p_prior, tsave, data, nchains = 2, niter = 100)
solve(p_posterior)
# @test var.(getfield.(p_prior, :second)) >= var.(getfield.(p_posterior, :second))

tsave1 = collect(10.0:10.0:100.0)
sol_data1 = solve(prob, saveat = tsave1)
tsave2 = collect(10.0:13.5:100.0)
sol_data2 = solve(prob, saveat = tsave2)
data_with_t = [x => (tsave1, sol_data1[x]), z => (tsave2, sol_data2[z])]

p_posterior = @time bayesian_datafit(prob, p_prior, data_with_t)
@test var.(getfield.(p_prior, :second)) >= var.(getfield.(p_posterior, :second))

p_posterior = @time bayesian_datafit(prob, p_prior, data_with_t, nchains = 2, niter = 100)
solve(p_posterior)
# @test var.(getfield.(p_prior, :second)) >= var.(getfield.(p_posterior, :second))
47 changes: 30 additions & 17 deletions test/ensemble.jl
Original file line number Diff line number Diff line change
@@ -56,37 +56,50 @@ sol = solve(enprob; saveat = 1);

weights = [0.2, 0.5, 0.3]

fullS = vec(sum(stack(weights .* sol[:,S]),dims=2))
fullI = vec(sum(stack(weights .* sol[:,I]),dims=2))
fullR = vec(sum(stack(weights .* sol[:,R]),dims=2))
fullS = vec(sum(stack(weights .* sol[:, S]), dims = 2))
fullI = vec(sum(stack(weights .* sol[:, I]), dims = 2))
fullR = vec(sum(stack(weights .* sol[:, R]), dims = 2))

t_train = 0:14
data_train = [
S => (t_train,fullS[1:15]),
I => (t_train,fullI[1:15]),
R => (t_train,fullR[1:15]),
S => (t_train, fullS[1:15]),
I => (t_train, fullI[1:15]),
R => (t_train, fullR[1:15]),
]
t_ensem = 0:21
data_ensem = [
S => (t_ensem,fullS[1:22]),
I => (t_ensem,fullI[1:22]),
R => (t_ensem,fullR[1:22]),
S => (t_ensem, fullS[1:22]),
I => (t_ensem, fullI[1:22]),
R => (t_ensem, fullR[1:22]),
]
t_forecast = 0:30
data_forecast = [
S => (t_forecast,fullS),
I => (t_forecast,fullI),
R => (t_forecast,fullR),
S => (t_forecast, fullS),
I => (t_forecast, fullI),
R => (t_forecast, fullR),
]

sol = solve(enprob; saveat = t_ensem);

@test ensemble_weights(sol, data_ensem) [0.2, 0.5, 0.3]

probs = [prob, prob2, prob3]
ps = [[β => Uniform(0.01, 10.0), γ => Uniform(0.01, 10.0)] for i in 1:3]
datas = [data_train,data_train,data_train]
enprobs = bayesian_ensemble(probs, ps, datas)
probs = (prob, prob2, prob3)
ps = Tuple([β => Uniform(0.01, 10.0), γ => Uniform(0.01, 10.0)] for i in 1:3)
datas = (data_train, data_train, data_train)
enprobs = bayesian_ensemble(probs, ps, datas, nchains = 2, niter = 200)

sol = solve(enprobs; saveat = t_ensem);
ensemble_weights(sol, data_ensem)
ensemble_weights(sol, data_ensem)

# only supports one datas
ensembleofweightedensembles = bayesian_datafit(probs,
ps,
data_train,
nchains = 2,
niter = 200)

@test length(ensembleofweightedensembles.prob[1].prob) ==
length(ensembleofweightedensembles.prob[1].weights) == length(ps)
for prob in ensembleofweightedensembles.prob
@test sum(prob.weights) 1.0
end