diff --git a/src/multivariate/solvers/zeroth_order/particle_swarm.jl b/src/multivariate/solvers/zeroth_order/particle_swarm.jl index e2fa3146a..bc33bf973 100644 --- a/src/multivariate/solvers/zeroth_order/particle_swarm.jl +++ b/src/multivariate/solvers/zeroth_order/particle_swarm.jl @@ -104,9 +104,9 @@ function initial_state(method::ParticleSwarm, options, d, initial_x::AbstractArr c2 = T(2) w = T(1) - X = Array{T,2}(undef, n, n_particles) - V = Array{T,2}(undef, n, n_particles) - X_best = Array{T,2}(undef, n, n_particles) + X = similar_axis(initial_x, n_particles) + V = similar_axis(initial_x, n_particles) + X_best = similar_axis(initial_x, n_particles) dx = zeros(T, n) score = zeros(T, n_particles) x = copy(initial_x) @@ -465,7 +465,7 @@ end function compute_cost!(f, n_particles::Int, - X::Matrix, + X::AbstractMatrix, score::Vector) for i in 1:n_particles @@ -473,3 +473,12 @@ function compute_cost!(f, end nothing end + +""" + similar_axis(x, n_particles) + +Equivalent to `Base.similar(x, length(x), n_particles)` for `x::Array`. Provide a method of this function to preserve a special axis for subtypes of `AbstractArray`. +""" +function similar_axis(x, n_particles) + similar(x, length(x), n_particles) +end \ No newline at end of file