Skip to content

Adding support for Manifolds to Nelder-Mead #872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
* Improve handling of alternative number types in univariate optimization
* Add conditional likelihood example to docs
* Improve Fminbox trace printing.
* Support for manifolds in NelderMead.

# Optim v0.17.2 release notes
* Fix some typos
Expand Down
2 changes: 1 addition & 1 deletion docs/src/algo/manifolds.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ The following meta-manifolds construct manifolds out of pre-existing ones:

See `test/multivariate/manifolds.jl` for usage examples.

Implementing new manifolds is as simple as adding methods `project_tangent!(M::YourManifold,g,x)` and `retract!(M::YourManifold,x)`. If you implement another manifold or optimization method, please contribute a PR!
Implementing new manifolds is as simple as adding methods `project_tangent!(M::YourManifold,g,x)` and `retract!(M::YourManifold,x)`. Nedler-Mead only requires `retract!`. If you implement another manifold or optimization method, please contribute a PR!

## References
The Geometry of Algorithms with Orthogonality Constraints, Alan Edelman, Tomás A. Arias, Steven T. Smith, SIAM. J. Matrix Anal. & Appl., 20(2), 303–353
Expand Down
3 changes: 2 additions & 1 deletion docs/src/algo/nelder_mead.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ Nelder-Mead is currently the standard algorithm when no derivatives are provided
## Constructor
```julia
NelderMead(; parameters = AdaptiveParameters(),
initial_simplex = AffineSimplexer())
initial_simplex = AffineSimplexer(),
manifold = Flat())
```
The keywords in the constructor are used to control the following parts of the
solver:
Expand Down
32 changes: 25 additions & 7 deletions src/multivariate/solvers/zeroth_order/nelder_mead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,16 @@ parameters(P::FixedParameters, n::Integer) = (P.α, P.β, P.γ, P.δ)
struct NelderMead{Ts <: Simplexer, Tp <: NMParameters} <: ZerothOrderOptimizer
initial_simplex::Ts
parameters::Tp
manifold::Manifold
end

"""
# NelderMead
## Constructor
```julia
NelderMead(; parameters = AdaptiveParameters(),
initial_simplex = AffineSimplexer())
NelderMead(; manifold = Flat(),
parameters = AdaptiveParameters(),
initial_simplex = AffineSimplexer())
```

The constructor takes 2 keywords:
Expand All @@ -72,21 +74,22 @@ point with a better point. More information can be found in [1], [2] or [3].
- [2] Lagarias, Jeffrey C., et al. "Convergence properties of the Nelder–Mead simplex method in low dimensions." SIAM Journal on Optimization 9.1 (1998): 112-147
- [3] Gao, Fuchang and Lixing Han (2010). "Implementing the Nelder-Mead simplex algorithm with adaptive parameters". Computational Optimization and Applications. doi:10.1007/s10589-010-9329-3
"""
function NelderMead(; kwargs...)
function NelderMead(; manifold::Manifold=Flat(), kwargs...)
KW = Dict(kwargs)
if haskey(KW, :initial_simplex) || haskey(KW, :parameters)
initial_simplex, parameters = AffineSimplexer(), AdaptiveParameters()
haskey(KW, :initial_simplex) && (initial_simplex = KW[:initial_simplex])
haskey(KW, :parameters) && (parameters = KW[:parameters])
return NelderMead(initial_simplex, parameters)
return NelderMead(initial_simplex, parameters, manifold)
else
return NelderMead(AffineSimplexer(), AdaptiveParameters())
return NelderMead(AffineSimplexer(), AdaptiveParameters(), manifold)
end
end

Base.summary(::NelderMead) = "Nelder-Mead"

# centroid except h-th vertex
# Do not retract to manifold here, because function does not take method as input
function centroid!(c::AbstractArray{T}, simplex, h=0) where T
n = length(c)
fill!(c, zero(T))
Expand Down Expand Up @@ -149,10 +152,15 @@ mutable struct NelderMeadState{Tx, T, Tfs} <: ZerothOrderState
end

function initial_state(method::NelderMead, options, d, initial_x)
retract!(method.manifold, initial_x)

T = eltype(initial_x)
n = length(initial_x)
m = n + 1
simplex = simplexer(method.initial_simplex, initial_x)
simplex = simplexer(method.initial_simplex, initial_x)
@inbounds for i in 1:length(simplex)
retract!(method.manifold, simplex[i])
end
f_simplex = zeros(T, m)

value!!(d, first(simplex))
Expand All @@ -168,10 +176,13 @@ function initial_state(method::NelderMead, options, d, initial_x)

α, β, γ, δ = parameters(method.parameters, n)

init_centroid = centroid(simplex, i_order[m])
retract!(method.manifold, init_centroid)

NelderMeadState(copy(initial_x), # Variable to hold final minimizer value for MultivariateOptimizationResults
m, # Number of vertices in the simplex
simplex, # Maintain simplex in state.simplex
centroid(simplex, i_order[m]), # Maintain centroid in state.centroid
init_centroid, # Maintain centroid in state.centroid
copy(initial_x), # Store cache in state.x_lowest
copy(initial_x), # Store cache in state.x_second_highest
copy(initial_x), # Store cache in state.x_highest
Expand All @@ -194,6 +205,7 @@ function update_state!(f::F, state::NelderMeadState{T}, method::NelderMead) wher
n, m = length(state.x), state.m

centroid!(state.x_centroid, state.simplex, state.i_order[m])
retract!(method.manifold, state.x_centroid)
copyto!(state.x_lowest, state.simplex[state.i_order[1]])
copyto!(state.x_second_highest, state.simplex[state.i_order[n]])
copyto!(state.x_highest, state.simplex[state.i_order[m]])
Expand All @@ -205,13 +217,15 @@ function update_state!(f::F, state::NelderMeadState{T}, method::NelderMead) wher
@inbounds for j in 1:n
state.x_reflect[j] = state.x_centroid[j] + state.α * (state.x_centroid[j]-state.x_highest[j])
end
retract!(method.manifold, state.x_reflect)

f_reflect = value(f, state.x_reflect)
if f_reflect < state.f_lowest
# Compute an expansion
@inbounds for j in 1:n
state.x_cache[j] = state.x_centroid[j] + state.β *(state.x_reflect[j] - state.x_centroid[j])
end
retract!(method.manifold, state.x_cache)
f_expand = value(f, state.x_cache)

if f_expand < f_reflect
Expand Down Expand Up @@ -240,6 +254,7 @@ function update_state!(f::F, state::NelderMeadState{T}, method::NelderMead) wher
@simd for j in 1:n
@inbounds state.x_cache[j] = state.x_centroid[j] + state.γ * (state.x_reflect[j]-state.x_centroid[j])
end
retract!(method.manifold, state.x_cache)
f_outside_contraction = value(f, state.x_cache)
if f_outside_contraction < f_reflect
copyto!(state.simplex[state.i_order[m]], state.x_cache)
Expand All @@ -255,6 +270,7 @@ function update_state!(f::F, state::NelderMeadState{T}, method::NelderMead) wher
@simd for j in 1:n
@inbounds state.x_cache[j] = state.x_centroid[j] - state.γ *(state.x_reflect[j] - state.x_centroid[j])
end
retract!(method.manifold, state.x_cache)
f_inside_contraction = value(f, state.x_cache)
if f_inside_contraction < f_highest
copyto!(state.simplex[state.i_order[m]], state.x_cache)
Expand All @@ -271,6 +287,7 @@ function update_state!(f::F, state::NelderMeadState{T}, method::NelderMead) wher
for i = 2:m
ord = state.i_order[i]
copyto!(state.simplex[ord], state.x_lowest + state.δ*(state.simplex[ord]-state.x_lowest))
retract!(method.manifold, state.simplex[ord])
state.f_simplex[ord] = value(f, state.simplex[ord])
end
step_type = "shrink"
Expand All @@ -284,6 +301,7 @@ end
function after_while!(f, state, method::NelderMead, options)
sortperm!(state.i_order, state.f_simplex)
x_centroid_min = centroid(state.simplex, state.i_order[state.m])
retract!(method.manifold, x_centroid_min)
f_centroid_min = value(f, x_centroid_min)
f_min, i_f_min = findmin(state.f_simplex)
x_min = state.simplex[i_f_min]
Expand Down
8 changes: 8 additions & 0 deletions test/multivariate/manifolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,11 @@
@test minpow[:,2] == minprod[n+1:2n]
end
end

@testset "Manifolds zeroth_order" begin
A = ones(2,2)
fmanif(x) = dot(x,A*x)
res = Optim.optimize(fmanif, [1.0;0.0], NelderMead(manifold=Optim.Sphere()))
@test Optim.converged(res)
@test Optim.minimum(res) < 1e-3
end