From f8339433a710f965ade381ed75148bd011d70047 Mon Sep 17 00:00:00 2001 From: Anas Abdelrehim <73660335+AnasAbdelR@users.noreply.github.com> Date: Fri, 14 Jul 2023 15:28:58 -0400 Subject: [PATCH 1/4] add default regularization --- src/ensemble.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ensemble.jl b/src/ensemble.jl index 41c07a0c..41e6265c 100644 --- a/src/ensemble.jl +++ b/src/ensemble.jl @@ -17,11 +17,11 @@ dataset on which the ensembler should be trained on. This function currently assumes that `sol.t` matches the time points of all measurements in `data_ensem`! """ -function ensemble_weights(sol::EnsembleSolution, data_ensem) +function ensemble_weights(sol::EnsembleSolution, data_ensem; lambda = 1e-8) obs = first.(data_ensem) predictions = reduce(vcat, reduce(hcat,[sol[i][s] for i in 1:length(sol)]) for s in obs) data = reduce(vcat, [data_ensem[i][2] isa Tuple ? data_ensem[i][2][2] : data_ensem[i][2] for i in 1:length(data_ensem)]) - weights = predictions \ data + weights = (predictions*predictions' .+ lambda*I) \ (data*predictions') end function bayesian_ensemble(probs, ps, datas; @@ -46,4 +46,4 @@ function bayesian_ensemble(probs, ps, datas; @info "$(length(all_probs)) total models" enprob = EnsembleProblem(all_probs) -end \ No newline at end of file +end From 365d0540864a6100daaed2f5a26b7f1694bcecfc Mon Sep 17 00:00:00 2001 From: Anas Abdelrehim <73660335+AnasAbdelR@users.noreply.github.com> Date: Fri, 14 Jul 2023 16:48:34 -0400 Subject: [PATCH 2/4] add truncated SVD --- src/ensemble.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/ensemble.jl b/src/ensemble.jl index 41e6265c..2688490b 100644 --- a/src/ensemble.jl +++ b/src/ensemble.jl @@ -17,11 +17,16 @@ dataset on which the ensembler should be trained on. This function currently assumes that `sol.t` matches the time points of all measurements in `data_ensem`! """ -function ensemble_weights(sol::EnsembleSolution, data_ensem; lambda = 1e-8) +function ensemble_weights(sol::EnsembleSolution, data_ensem; rank = size(data_ensem,2)) obs = first.(data_ensem) predictions = reduce(vcat, reduce(hcat,[sol[i][s] for i in 1:length(sol)]) for s in obs) data = reduce(vcat, [data_ensem[i][2] isa Tuple ? data_ensem[i][2][2] : data_ensem[i][2] for i in 1:length(data_ensem)]) - weights = (predictions*predictions' .+ lambda*I) \ (data*predictions') + F = svd(data) + # Truncate SVD + U, S, V = F.U[:, 1:rank], F.S[1:rank], F.V[:, 1:rank] + # Compute pseudo-inverse of A from truncated SVD + pinv = (V * Diagonal(1 ./ S) * U') + weights = data*A_pinv end function bayesian_ensemble(probs, ps, datas; From 25145a148690e25eb64bb23a09b97db892815975 Mon Sep 17 00:00:00 2001 From: Anas Abdelrehim <73660335+AnasAbdelR@users.noreply.github.com> Date: Fri, 14 Jul 2023 16:56:09 -0400 Subject: [PATCH 3/4] Update ensemble.jl - order of operations --- src/ensemble.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ensemble.jl b/src/ensemble.jl index 2688490b..5958490a 100644 --- a/src/ensemble.jl +++ b/src/ensemble.jl @@ -24,9 +24,7 @@ function ensemble_weights(sol::EnsembleSolution, data_ensem; rank = size(data_en F = svd(data) # Truncate SVD U, S, V = F.U[:, 1:rank], F.S[1:rank], F.V[:, 1:rank] - # Compute pseudo-inverse of A from truncated SVD - pinv = (V * Diagonal(1 ./ S) * U') - weights = data*A_pinv + weights = (((data*V)*Diagonal(1 ./ S)) * U') end function bayesian_ensemble(probs, ps, datas; From 7bd19b9a6662288e4d59aea69ead3d4f451a0aa8 Mon Sep 17 00:00:00 2001 From: Anas Abdelrehim <73660335+AnasAbdelR@users.noreply.github.com> Date: Fri, 14 Jul 2023 17:10:54 -0400 Subject: [PATCH 4/4] Update ensemble.jl: fix default --- src/ensemble.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ensemble.jl b/src/ensemble.jl index 5958490a..f6bc761a 100644 --- a/src/ensemble.jl +++ b/src/ensemble.jl @@ -17,7 +17,7 @@ dataset on which the ensembler should be trained on. This function currently assumes that `sol.t` matches the time points of all measurements in `data_ensem`! """ -function ensemble_weights(sol::EnsembleSolution, data_ensem; rank = size(data_ensem,2)) +function ensemble_weights(sol::EnsembleSolution, data_ensem; rank = Int(round(length(last(first(data_ensem).second))/2))) obs = first.(data_ensem) predictions = reduce(vcat, reduce(hcat,[sol[i][s] for i in 1:length(sol)]) for s in obs) data = reduce(vcat, [data_ensem[i][2] isa Tuple ? data_ensem[i][2][2] : data_ensem[i][2] for i in 1:length(data_ensem)])