diff --git a/Project.toml b/Project.toml index 7504b246..ec323a1b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ReservoirComputing" uuid = "7c2d2b1e-3dd4-11ea-355a-8f6a8116e294" authors = ["Francesco Martinuzzi"] -version = "0.12.5" +version = "0.12.6" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/README.md b/README.md index 1ee4aa68..333b4afa 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ Lorenz system. ### 1. Generate data -As a general first step wee fix the random seed for reproducibilty +As a general first step we fix the random seed for reproducibilty ```julia using Random diff --git a/docs/Project.toml b/docs/Project.toml index 26799d84..3406a2c4 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,5 +1,6 @@ [deps] CellularAutomata = "878138dc-5b27-11ea-1a71-cb95d38d6b29" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" @@ -8,6 +9,7 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReservoirComputing = "7c2d2b1e-3dd4-11ea-355a-8f6a8116e294" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] diff --git a/docs/make.jl b/docs/make.jl index f2ea3988..648bf420 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,7 +1,7 @@ using Documenter, DocumenterCitations, DocumenterInterLinks, ReservoirComputing -#cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml"; force = true) -#cp("./docs/Project.toml", "./docs/src/assets/Project.toml"; force = true) +cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml"; force = true) +cp("./docs/Project.toml", "./docs/src/assets/Project.toml"; force = true) ENV["PLOTS_TEST"] = "true" ENV["GKSwstype"] = "100" diff --git a/docs/pages.jl b/docs/pages.jl index fc499114..13ca5498 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -9,6 +9,9 @@ pages = [ "Deep Echo State Networks" => "tutorials/deep_esn.md", #"Hybrid Echo State Networks" => "tutorials/hybrid.md", "Reservoir Computing with Cellular Automata" => "tutorials/reca.md"], + "Examples" => Any[ + "Building a model to add to ReservoirComputing.jl" => "examples/model_es2n.md", + ], "API Documentation" => Any[ "Layers" => "api/layers.md", "Models" => "api/models.md", diff --git a/docs/src/api/layers.md b/docs/src/api/layers.md index 50f9778d..ae9bfc1e 100644 --- a/docs/src/api/layers.md +++ b/docs/src/api/layers.md @@ -22,6 +22,7 @@ ```@docs ESNCell + ES2NCell ``` ## Reservoir computing with cellular automata diff --git a/docs/src/api/models.md b/docs/src/api/models.md index 7fadf777..78facf63 100644 --- a/docs/src/api/models.md +++ b/docs/src/api/models.md @@ -3,6 +3,7 @@ ## Echo State Networks ```@docs + ES2N ESN DeepESN DelayESN diff --git a/docs/src/examples/model_es2n.md b/docs/src/examples/model_es2n.md new file mode 100644 index 00000000..974cd9b5 --- /dev/null +++ b/docs/src/examples/model_es2n.md @@ -0,0 +1,150 @@ +# Building a model to add to ReservoirComputing.jl + +This examples showcases how to build custom models such that they could +be also included in ReservoirComputing.jl. In this example we will be building a +edge of stability echo state network [`ES2N`](@ref). Since the model is +already available in the library, we will change the names of cells and +models, to not cause problems. + +## Building an ES2NCell + +Building a ReservoirComputing.jl model largely follows the Lux.jl model +approach. + +```@example es2n_scratch +using ReservoirComputing +using ConcreteStructs +using Static +using Random + +using ReservoirComputing: IntegerType, BoolType, InputType, has_bias, _wrap_layers +import ReservoirComputing: initialparameters + +@concrete struct CustomES2NCell <: ReservoirComputing.AbstractEchoStateNetworkCell + activation + in_dims <: IntegerType + out_dims <: IntegerType + init_bias + init_reservoir + init_input + init_orthogonal + init_state + proximity + use_bias <: StaticBool +end + +function CustomES2NCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType}, + activation = tanh; use_bias::BoolType = False(), init_bias = zeros32, + init_reservoir = rand_sparse, init_input = scaled_rand, + init_state = randn32, init_orthogonal = orthogonal, + proximity::AbstractFloat = 1.0) + return CustomES2NCell(activation, in_dims, out_dims, init_bias, init_reservoir, + init_input, init_orthogonal, init_state, proximity, static(use_bias)) +end + +function initialparameters(rng::Random.AbstractRNG, esn::CustomES2NCell) + ps = (input_matrix = esn.init_input(rng, esn.out_dims, esn.in_dims), + reservoir_matrix = esn.init_reservoir(rng, esn.out_dims, esn.out_dims), + orthogonal_matrix = esn.init_orthogonal(rng, esn.out_dims, esn.out_dims)) + if has_bias(esn) + ps = merge(ps, (bias = esn.init_bias(rng, esn.out_dims),)) + end + return ps +end + +function (esn::CustomES2NCell)((inp, (hidden_state,))::InputType, ps, st::NamedTuple) + T = eltype(inp) + if has_bias(esn) + candidate_h = esn.activation.(ps.input_matrix * inp .+ + ps.reservoir_matrix * hidden_state .+ ps.bias) + else + candidate_h = esn.activation.(ps.input_matrix * inp .+ + ps.reservoir_matrix * hidden_state) + end + h_new = (T(1.0) - esn.proximity) .* ps.orthogonal_matrix * hidden_state .+ + esn.proximity .* candidate_h + return (h_new, (h_new,)), st +end +``` + +You will notice that some definitions are missing. For instance, we did not +dispatch over `initialstates`. This is because the `AbstractEchoStateNetworkCell` +subtyping takes care of a lot of these smaller functions already. + +## Building the full ES2N model + +Now you can build a full model in two different ways: + - Leveraging [`ReservoirComputer`](@ref) + - Building from scratch with a proper `CustomES2N` struct + +```@example es2n_scratch +function CustomES2NApproach1(in_dims, res_dims, + out_dims, activation = tanh; + readout_activation = identity, + state_modifiers = (), + kwargs...) + return ReservoirComputer(StatefulLayer(CustomES2NCell(in_dims => res_dims, activation; kwargs...)), + state_modifiers, LinearReadout(res_dims => out_dims, readout_activation)) +end +``` + +```@example es2n_scratch +@concrete struct CustomES2NApproach2 <: + ReservoirComputing.AbstractEchoStateNetwork{(:reservoir, :states_modifiers, :readout)} + reservoir + states_modifiers + readout +end + +function CustomES2NApproach2(in_dims::Int, res_dims::Int, + out_dims::Int, activation = tanh; + readout_activation = identity, + state_modifiers = (), + kwargs...) + cell = StatefulLayer(CustomES2NCell(in_dims => res_dims, activation; kwargs...)) + mods_tuple = state_modifiers isa Tuple || state_modifiers isa AbstractVector ? + Tuple(state_modifiers) : (state_modifiers,) + mods = _wrap_layers(mods_tuple) + ro = LinearReadout(res_dims => out_dims, readout_activation) + return CustomES2NApproach2(cell, mods, ro) +end +``` + +Now we can use the model like any other in ReservoirComputing.jl. +Following the example in the getting started page: + +```@example es2n_scratch +using OrdinaryDiffEq +using Plots + +Random.seed!(42) +rng = MersenneTwister(17) + +function lorenz(du, u, p, t) + du[1] = p[1] * (u[2] - u[1]) + du[2] = u[1] * (p[2] - u[3]) - u[2] + du[3] = u[1] * u[2] - p[3] * u[3] +end + +prob = ODEProblem(lorenz, [1.0f0, 0.0f0, 0.0f0], (0.0, 200.0), [10.0f0, 28.0f0, 8/3]) +data = Array(solve(prob, ABM54(); dt=0.02)) +shift = 300 +train_len = 5000 +predict_len = 1250 + +input_data = data[:, shift:(shift + train_len - 1)] +target_data = data[:, (shift + 1):(shift + train_len)] +test = data[:, (shift + train_len):(shift + train_len + predict_len - 1)] + +esn = CustomES2NApproach2(3, 300, 3; init_reservoir=rand_sparse(; radius=1.2, sparsity=6/300), + state_modifiers=NLAT2) + +ps, st = setup(rng, esn) +ps, st = train!(esn, input_data, target_data, ps, st) +output, st = predict(esn, predict_len, ps, st; initialdata=test[:, 1]) + +plot(transpose(output)[:, 1], transpose(output)[:, 2], transpose(output)[:, 3]; + label="predicted") +plot!(transpose(test)[:, 1], transpose(test)[:, 2], transpose(test)[:, 3]; + label="actual") +``` diff --git a/docs/src/refs.bib b/docs/src/refs.bib index 155f0246..b588d8f6 100644 --- a/docs/src/refs.bib +++ b/docs/src/refs.bib @@ -369,3 +369,18 @@ @article{Gauthier2021 year = {2021}, month = sep } + +@article{Ceni2025, + title = {Edge of Stability Echo State Network}, + volume = {36}, + ISSN = {2162-2388}, + url = {http://dx.doi.org/10.1109/TNNLS.2024.3400045}, + DOI = {10.1109/tnnls.2024.3400045}, + number = {4}, + journal = {IEEE Transactions on Neural Networks and Learning Systems}, + publisher = {Institute of Electrical and Electronics Engineers (IEEE)}, + author = {Ceni, Andrea and Gallicchio, Claudio}, + year = {2025}, + month = apr, + pages = {7555–7564} +} diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index ba92ea2d..af4702b4 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -25,6 +25,7 @@ include("reservoircomputer.jl") include("layers/basic.jl") include("layers/lux_layers.jl") include("layers/esn_cell.jl") +include("layers/es2n_cell.jl") include("layers/svmreadout.jl") #general include("states.jl") @@ -36,6 +37,7 @@ include("inits/inits_input.jl") include("inits/inits_reservoir.jl") #full models include("models/esn_generics.jl") +include("models/es2n.jl") include("models/esn.jl") include("models/esn_deep.jl") include("models/esn_delay.jl") @@ -45,7 +47,8 @@ include("models/ngrc.jl") include("extensions/reca.jl") export ReservoirComputer -export ESNCell, StatefulLayer, LinearReadout, ReservoirChain, Collect, collectstates, +export ESNCell, ES2NCell +export StatefulLayer, LinearReadout, ReservoirChain, Collect, collectstates, DelayLayer, NonlinearFeaturesLayer export SVMReadout export Pad, Extend, NLAT1, NLAT2, NLAT3, PartialSquare, ExtendedSquare @@ -59,7 +62,7 @@ export block_diagonal, chaotic_init, cycle_jumps, delay_line, delayline_backward export add_jumps!, backward_connection!, delay_line!, reverse_simple_cycle!, scale_radius!, self_loop!, simple_cycle! export train, train!, predict, resetcarry!, polynomial_monomials -export ESN, DeepESN, DelayESN, HybridESN +export ES2N, ESN, DeepESN, DelayESN, HybridESN export NGRC #ext export RECACell, RECA diff --git a/src/layers/es2n_cell.jl b/src/layers/es2n_cell.jl new file mode 100644 index 00000000..8fd3b221 --- /dev/null +++ b/src/layers/es2n_cell.jl @@ -0,0 +1,123 @@ +abstract type AbstractEchoStateNetworkCell <: AbstractReservoirRecurrentCell end + +@doc raw""" + ES2NCell(in_dims => out_dims, [activation]; + use_bias=False(), init_bias=zeros32, + init_reservoir=rand_sparse, init_input=scaled_rand, + init_state=randn32, init_orthogonal=orthogonal, + proximity=1.0)) + +Edge of Stability Echo State Network (ES2N) cell [Ceni2025](@cite). + +## Equations + +```math +\begin{aligned} +x(t) = \beta\, \phi\!\left( \rho\, \mathbf{W}_r x(t-1) + \omega\, + \mathbf{W}_{in} u(t) \right) + (1-\beta)\, \mathbf{O}\, x(t-1), +\end{aligned} +``` +## Arguments + + - `in_dims`: Input dimension. + - `out_dims`: Reservoir (hidden state) dimension. + - `activation`: Activation function. Default: `tanh`. + +## Keyword arguments + + - `use_bias`: Whether to include a bias term. Default: `false`. + - `init_bias`: Initializer for the bias. Used only if `use_bias=true`. + Default is `rand32`. + - `init_reservoir`: Initializer for the reservoir matrix `W_res`. + Default is [`rand_sparse`](@ref). + - `init_orthogonal`: Initializer for the orthogonal matrix `O`. + Default is [`orthogonal`](@ref). + - `init_input`: Initializer for the input matrix `W_in`. + Default is [`scaled_rand`](@ref). + - `init_state`: Initializer for the hidden state when an external + state is not provided. Default is `randn32`. + - `proximity`: Proximity coefficient `α ∈ (0,1]`. Default: `1.0`. + +## Inputs + + - **Case 1:** `x :: AbstractArray (in_dims, batch)` + A fresh state is created via `init_state`; the call is forwarded to Case 2. + - **Case 2:** `(x, (h,))` where `h :: AbstractArray (out_dims, batch)` + Computes the update and returns the new state. + +In both cases, the forward returns `((h_new, (h_new,)), st_out)` where `st_out` +contains any updated internal state. + +## Returns + + - Output/hidden state `h_new :: out_dims` and state tuple `(h_new,)`. + - Updated layer state (NamedTuple). + +## Parameters + + + - `input_matrix :: (out_dims × in_dims)` — `W_in` + - `reservoir_matrix :: (out_dims × out_dims)` — `W_res` + - `orthogonal_matrix :: (res_dims × res_dims)` — `O` + - `bias :: (out_dims,)` — present only if `use_bias=true` + +## States + +Created by `initialstates(rng, esn)`: + + - `rng`: a replicated RNG used to sample initial hidden states when needed. +""" +@concrete struct ES2NCell <: AbstractEchoStateNetworkCell + activation + in_dims <: IntegerType + out_dims <: IntegerType + init_bias + init_reservoir + init_input + init_orthogonal + init_state + proximity + use_bias <: StaticBool +end + +function ES2NCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType}, + activation = tanh; use_bias::BoolType = False(), init_bias = zeros32, + init_reservoir = rand_sparse, init_input = scaled_rand, + init_state = randn32, init_orthogonal = orthogonal, + proximity::AbstractFloat = 1.0) + return ES2NCell(activation, in_dims, out_dims, init_bias, init_reservoir, + init_input, init_orthogonal, init_state, proximity, static(use_bias)) +end + +function initialparameters(rng::AbstractRNG, esn::ES2NCell) + ps = (input_matrix = esn.init_input(rng, esn.out_dims, esn.in_dims), + reservoir_matrix = esn.init_reservoir(rng, esn.out_dims, esn.out_dims), + orthogonal_matrix = esn.init_orthogonal(rng, esn.out_dims, esn.out_dims)) + if has_bias(esn) + ps = merge(ps, (bias = esn.init_bias(rng, esn.out_dims),)) + end + return ps +end + +function (esn::ES2NCell)((inp, (hidden_state,))::InputType, ps, st::NamedTuple) + T = eltype(inp) + if has_bias(esn) + candidate_h = esn.activation.(ps.input_matrix * inp .+ + ps.reservoir_matrix * hidden_state .+ ps.bias) + else + candidate_h = esn.activation.(ps.input_matrix * inp .+ + ps.reservoir_matrix * hidden_state) + end + h_new = (one(T) - esn.proximity) .* ps.orthogonal_matrix * hidden_state .+ + esn.proximity .* candidate_h + return (h_new, (h_new,)), st +end + +function Base.show(io::IO, esn::ES2NCell) + print(io, "ES2NCell($(esn.in_dims) => $(esn.out_dims)") + if esn.proximity != eltype(esn.proximity)(1.0) + print(io, ", proximity=$(esn.proximity)") + end + has_bias(esn) || print(io, ", use_bias=false") + print(io, ")") +end diff --git a/src/models/es2n.jl b/src/models/es2n.jl new file mode 100644 index 00000000..266e3860 --- /dev/null +++ b/src/models/es2n.jl @@ -0,0 +1,119 @@ +@doc raw""" + ES2N(in_dims, res_dims, out_dims, activation=tanh; + proximity=1.0, init_reservoir=rand_sparse, init_input=scaled_rand, + init_bias=zeros32, init_state=randn32, use_bias=False(), + state_modifiers=(), readout_activation=identity, + init_orthogonal=orthogonal,) + +Edge of Stability Echo State Network (ES2N) [Ceni2025](@cite). + + +## Equations + +```math +\begin{aligned} +x(t) = \beta\, \phi\!\left( \rho\, \mathbf{W}_r x(t-1) + \omega\, + \mathbf{W}_{in} u(t) \right) + (1-\beta)\, \mathbf{O}\, x(t-1), +\end{aligned} +``` + +## Arguments + + - `in_dims`: Input dimension. + - `res_dims`: Reservoir (hidden state) dimension. + - `out_dims`: Output dimension. + - `activation`: Reservoir activation (for [`ESNCell`](@ref)). Default: `tanh`. + +## Keyword arguments + + - `proximity`: proximity `α ∈ (0,1]`. Default: `1.0`. + - `init_reservoir`: Initializer for `W_res`. Default: [`rand_sparse`](@ref). + - `init_input`: Initializer for `W_in`. Default: [`scaled_rand`](@ref). + - `init_orthogonal`: Initializer for `O`. Default: [`orthogonal`]. + - `init_bias`: Initializer for reservoir bias (used if `use_bias=true`). + Default: `zeros32`. + - `init_state`: Initializer used when an external state is not provided. + Default: `randn32`. + - `use_bias`: Whether the reservoir uses a bias term. Default: `false`. + - `state_modifiers`: A layer or collection of layers applied to the reservoir + state before the readout. Accepts a single layer, an `AbstractVector`, or a + `Tuple`. Default: empty `()`. + - `readout_activation`: Activation for the linear readout. Default: `identity`. + +## Inputs + + - `x :: AbstractArray (in_dims, batch)` + +## Returns + + - Output `y :: (out_dims, batch)`. + - Updated layer state (NamedTuple). + +## Parameters + + - `reservoir` — parameters of the internal [`ESNCell`](@ref), including: + - `input_matrix :: (res_dims × in_dims)` — `W_in` + - `reservoir_matrix :: (res_dims × res_dims)` — `W_res` + - `orthogonal_matrix :: (res_dims × res_dims)` — `O` + - `bias :: (res_dims,)` — present only if `use_bias=true` + - `states_modifiers` — a `Tuple` with parameters for each modifier layer (may be empty). + - `readout` — parameters of [`LinearReadout`](@ref), typically: + - `weight :: (out_dims × res_dims)` — `W_out` + - `bias :: (out_dims,)` — `b_out` (if the readout uses bias) + +> Exact field names for modifiers/readout follow their respective layer +> definitions. + +## States + + - `reservoir` — states for the internal [`ES2NCell`](@ref) (e.g. `rng` used to sample initial hidden states). + - `states_modifiers` — a `Tuple` with states for each modifier layer. + - `readout` — states for [`LinearReadout`](@ref). + +""" +@concrete struct ES2N <: + AbstractEchoStateNetwork{(:reservoir, :states_modifiers, :readout)} + reservoir + states_modifiers + readout +end + +function ES2N(in_dims::IntegerType, res_dims::IntegerType, + out_dims::IntegerType, activation = tanh; + readout_activation = identity, + state_modifiers = (), + kwargs...) + cell = StatefulLayer(ES2NCell(in_dims => res_dims, activation; kwargs...)) + mods_tuple = state_modifiers isa Tuple || state_modifiers isa AbstractVector ? + Tuple(state_modifiers) : (state_modifiers,) + mods = _wrap_layers(mods_tuple) + ro = LinearReadout(res_dims => out_dims, readout_activation) + return ES2N(cell, mods, ro) +end + +function Base.show(io::IO, esn::ES2N) + print(io, "ES2N(\n") + + print(io, " reservoir = ") + show(io, esn.reservoir) + print(io, ",\n") + + print(io, " state_modifiers = ") + if isempty(esn.states_modifiers) + print(io, "()") + else + print(io, "(") + for (i, m) in enumerate(esn.states_modifiers) + i > 1 && print(io, ", ") + show(io, m) + end + print(io, ")") + end + print(io, ",\n") + + print(io, " readout = ") + show(io, esn.readout) + print(io, "\n)") + + return +end diff --git a/src/models/esn.jl b/src/models/esn.jl index 1a70360c..b3460893 100644 --- a/src/models/esn.jl +++ b/src/models/esn.jl @@ -42,7 +42,7 @@ Reservoir (passed to [`ESNCell`](@ref)): - `leak_coefficient`: Leak rate `α ∈ (0,1]`. Default: `1.0`. - `init_reservoir`: Initializer for `W_res`. Default: [`rand_sparse`](@ref). - `init_input`: Initializer for `W_in`. Default: [`scaled_rand`](@ref). - - `init_bias`: Initializer for reservoir bias (used iff `use_bias=true`). + - `init_bias`: Initializer for reservoir bias (used if `use_bias=true`). Default: `zeros32`. - `init_state`: Initializer used when an external state is not provided. Default: `randn32`. diff --git a/src/reservoircomputer.jl b/src/reservoircomputer.jl index c5a70626..0fd36541 100644 --- a/src/reservoircomputer.jl +++ b/src/reservoircomputer.jl @@ -29,11 +29,23 @@ features, and install trained readout weights. - `(y, st′)` where `y` is the readout output and `st′` contains the updated states of the reservoir, modifiers, and readout. """ -@concrete struct ReservoirComputer <: - AbstractReservoirComputer{(:reservoir, :states_modifiers, :readout)} - reservoir - states_modifiers - readout +struct ReservoirComputer{R, S, L} <: + AbstractReservoirComputer{(:reservoir, :states_modifiers, :readout)} + reservoir::R + states_modifiers::S + readout::L + + function ReservoirComputer(reservoir::R, state_modifiers::S, readout::L) where {R, S, L} + mods_tuple = state_modifiers isa Tuple || state_modifiers isa AbstractVector ? + Tuple(state_modifiers) : (state_modifiers,) + mods = _wrap_layers(mods_tuple) + + return new{R, typeof(mods), L}(reservoir, mods, readout) + end +end + +function ReservoirComputer(reservoir, readout) + return ReservoirComputer(reservoir, (), readout) end function initialparameters(rng::AbstractRNG, rc::AbstractReservoirComputer) @@ -98,27 +110,30 @@ function addreadout!(::AbstractReservoirComputer, output_matrix::AbstractMatrix, end function Base.show(io::IO, rc::ReservoirComputer) - print(io, "ReservoirComputer(") + print(io, "ReservoirComputer(\n") - print(io, "reservoir = ") + print(io, " reservoir = ") show(io, rc.reservoir) + print(io, ",\n") - nmods = length(rc.states_modifiers) - if nmods == 0 - print(io, ", state_modifiers = ()") + print(io, " state_modifiers = ") + if isempty(rc.states_modifiers) + print(io, "()") else - print(io, ", state_modifiers = (") + print(io, "(") for (i, m) in enumerate(rc.states_modifiers) i > 1 && print(io, ", ") show(io, m) end print(io, ")") end + print(io, ",\n") - print(io, ", readout = ") + print(io, " readout = ") show(io, rc.readout) + print(io, "\n)") - print(io, ")") + return end @doc raw""" diff --git a/test/layers/test_esncell.jl b/test/layers/test_esncell.jl index 7ca90976..f7c89796 100644 --- a/test/layers/test_esncell.jl +++ b/test/layers/test_esncell.jl @@ -9,113 +9,190 @@ const _Z32 = m -> zeros(Float32, m) const _O32 = (rng, m) -> zeros(Float32, m) const _W_I = (rng, m, n) -> _I32(m, n) const _W_ZZ = (rng, m, n) -> zeros(Float32, m, n) + function init_state3(rng::AbstractRNG, m::Integer, B::Integer) B == 1 ? zeros(Float32, m) : zeros(Float32, m, B) end -@testset "ESNCell: constructor & show" begin - esn = ESNCell(3 => 5; leak_coefficient = 0.3, use_bias = False()) - io = IOBuffer() - show(io, esn) - shown = String(take!(io)) - @test occursin("ESNCell(3 => 5", shown) - @test occursin("leak_coefficient=0.3", shown) - @test occursin("use_bias=false", shown) -end +cell_name(::Type{C}) where {C} = string(nameof(C)) -@testset "ESNCell: initialparameters shapes & bias flag" begin - rng = MersenneTwister(1) +mix_kw(::Type{ESNCell}) = :leak_coefficient +mix_kw(::Type{ES2NCell}) = :proximity - esn_nobias = ESNCell(3 => 4; use_bias = False(), - init_input = _W_I, init_reservoir = _W_I, init_bias = _O32) +# Whatever show() actually prints: +mix_label(::Type{ESNCell}) = "leak_coefficient" +mix_label(::Type{ES2NCell}) = "proximity" - ps_nb = initialparameters(rng, esn_nobias) - @test haskey(ps_nb, :input_matrix) - @test haskey(ps_nb, :reservoir_matrix) - @test !haskey(ps_nb, :bias) - @test size(ps_nb.input_matrix) == (4, 3) - @test size(ps_nb.reservoir_matrix) == (4, 4) +default_extra_ctor_kwargs(::Type{ESNCell}) = NamedTuple() +default_extra_ctor_kwargs(::Type{ES2NCell}) = (init_orthogonal = _W_I,) - esn_bias = ESNCell(3 => 4; use_bias = True(), - init_input = _W_I, init_reservoir = _W_I, init_bias = _O32) +extra_param_keys(::Type{ESNCell}) = () +extra_param_keys(::Type{ES2NCell}) = (:orthogonal_matrix,) - ps_b = initialparameters(rng, esn_bias) - @test haskey(ps_b, :bias) - @test length(ps_b.bias) == 4 -end +function build_cell(::Type{C}, in_dims::Integer, out_dims::Integer; + activation = tanh, + mix::Real = 1.0, + use_bias = False(), + init_input = _W_I, + init_reservoir = _W_I, + init_bias = _O32, + init_state = _Z32, + extra::NamedTuple = NamedTuple() +) where {C} + base = (use_bias = use_bias, + init_input = init_input, + init_reservoir = init_reservoir, + init_bias = init_bias, + init_state = init_state) -@testset "ESNCell: initialstates carries RNG replica" begin - rng = MersenneTwister(2) - esn = ESNCell(2 => 2) - st = initialstates(rng, esn) - @test haskey(st, :rng) -end + mixnt = NamedTuple{(mix_kw(C),)}((mix,)) -@testset "ESNCell: forward (vector) — identity + leak=1 gives linear map" begin - esn = ESNCell(3 => 3, identity; use_bias = False(), - init_input = _W_I, init_reservoir = _W_I, init_bias = _O32, - init_state = _Z32, leak_coefficient = 1.0) - - ps = initialparameters(MersenneTwister(0), esn) - st = NamedTuple() - x = Float32[1, 2, 3] - h0 = zeros(Float32, 3) - - (y_tuple, st2) = esn((x, (h0,)), ps, st) - y, (hcarry,) = y_tuple - @test y ≈ x - @test hcarry ≈ y - @test st2 === st -end + kw = merge(base, default_extra_ctor_kwargs(C), mixnt, extra) -@testset "ESNCell: forward (vector) — leak extremes" begin - esn0 = ESNCell(3 => 3, identity; use_bias = False(), - init_input = _W_I, init_reservoir = _W_I, init_bias = _O32, - init_state = _Z32, leak_coefficient = 0.0) - - ps0 = initialparameters(MersenneTwister(0), esn0) - x = Float32[10, 20, 30] - h0 = Float32[4, 5, 6] - (y0_tuple, _) = esn0((x, (h0,)), ps0, NamedTuple()) - y0, _ = y0_tuple - @test y0 ≈ h0 - - esn1 = ESNCell(3 => 3, identity; use_bias = True(), - init_input = _W_I, init_reservoir = _W_ZZ, init_bias = (rng, m) -> ones(Float32, m), - init_state = _Z32, leak_coefficient = 1.0) - - ps1 = initialparameters(MersenneTwister(0), esn1) - (y1_tuple, _) = esn1((x, (zeros(Float32, 3),)), ps1, NamedTuple()) - y1, _ = y1_tuple - @test y1 ≈ x .+ 1.0f0 + return C(in_dims => out_dims, activation; kw...) end -@testset "ESNCell: forward (matrix batch)" begin - esn = ESNCell(3 => 3, identity; use_bias = False(), - init_input = _W_I, init_reservoir = _W_I, init_bias = _O32, - init_state = _Z32, leak_coefficient = 1.0) - - ps = initialparameters(MersenneTwister(0), esn) - X = Float32[1 2; 3 4; 5 6] # (3, 2) - H0 = zeros(Float32, 3, 2) - - (Y_tuple, _) = esn((X, (H0,)), ps, NamedTuple()) - Y, _ = Y_tuple - @test size(Y) == (3, 2) - @test Y ≈ X +function test_echo_state_cell_contract(::Type{C}) where {C} + @testset "$(cell_name(C)): constructor & show" begin + cell = build_cell(C, 3, 5; mix = 0.3, use_bias = False()) + io = IOBuffer() + show(io, cell) + shown = String(take!(io)) + + @test occursin("$(cell_name(C))(3 => 5", shown) + @test occursin(Regex("$(mix_label(C))=0\\.3(f0)?"), shown) + @test occursin("use_bias=false", shown) + end + + @testset "$(cell_name(C)): initialparameters shapes & bias flag" begin + rng = MersenneTwister(1) + + cell_nobias = build_cell(C, 3, 4; use_bias = False(), + init_input = _W_I, init_reservoir = _W_I, init_bias = _O32) + + ps_nb = initialparameters(rng, cell_nobias) + @test haskey(ps_nb, :input_matrix) + @test haskey(ps_nb, :reservoir_matrix) + @test size(ps_nb.input_matrix) == (4, 3) + @test size(ps_nb.reservoir_matrix) == (4, 4) + @test !haskey(ps_nb, :bias) + + for k in extra_param_keys(C) + @test haskey(ps_nb, k) + end + if C === ES2NCell + @test size(ps_nb.orthogonal_matrix) == (4, 4) + end + + cell_bias = build_cell(C, 3, 4; use_bias = True(), + init_input = _W_I, init_reservoir = _W_I, init_bias = _O32) + + ps_b = initialparameters(rng, cell_bias) + @test haskey(ps_b, :bias) + @test length(ps_b.bias) == 4 + end + + @testset "$(cell_name(C)): initialstates carries RNG replica" begin + rng = MersenneTwister(2) + cell = build_cell(C, 2, 2) + st = initialstates(rng, cell) + @test haskey(st, :rng) + end + + @testset "$(cell_name(C)): forward (vector) — identity + mix=1 gives linear map" begin + cell = build_cell(C, 3, 3; + activation = identity, + mix = 1.0, + use_bias = False(), + init_input = _W_I, + init_reservoir = _W_ZZ, + init_state = _Z32) + + ps = initialparameters(MersenneTwister(0), cell) + x = Float32[1, 2, 3] + h0 = zeros(Float32, 3) + + (y_tuple, st2) = cell((x, (h0,)), ps, NamedTuple()) + y, (hcarry,) = y_tuple + @test y ≈ x + @test hcarry ≈ y + @test st2 === NamedTuple() + end + + @testset "$(cell_name(C)): forward (vector) — mix extremes" begin + cell0 = build_cell(C, 3, 3; + activation = identity, + mix = 0.0, + use_bias = False(), + init_input = _W_I, + init_reservoir = _W_I, + init_state = _Z32) + + ps0 = initialparameters(MersenneTwister(0), cell0) + x = Float32[10, 20, 30] + h0 = Float32[4, 5, 6] + (y0_tuple, _) = cell0((x, (h0,)), ps0, NamedTuple()) + y0, _ = y0_tuple + @test y0 ≈ h0 + + cell1 = build_cell(C, 3, 3; + activation = identity, + mix = 1.0, + use_bias = True(), + init_input = _W_I, + init_reservoir = _W_ZZ, + init_bias = (rng, m) -> ones(Float32, m), + init_state = _Z32) + + ps1 = initialparameters(MersenneTwister(0), cell1) + (y1_tuple, _) = cell1((x, (zeros(Float32, 3),)), ps1, NamedTuple()) + y1, _ = y1_tuple + @test y1 ≈ x .+ 1.0f0 + end + + @testset "$(cell_name(C)): forward (matrix batch)" begin + cell = build_cell(C, 3, 3; + activation = identity, + mix = 1.0, + use_bias = False(), + init_input = _W_I, + init_reservoir = _W_ZZ, + init_state = _Z32) + + ps = initialparameters(MersenneTwister(0), cell) + X = Float32[1 2; 3 4; 5 6] # (3, 2) + H0 = zeros(Float32, 3, 2) + + (Y_tuple, _) = cell((X, (H0,)), ps, NamedTuple()) + Y, _ = Y_tuple + @test size(Y) == (3, 2) + @test Y ≈ X + end + + @testset "$(cell_name(C)): outer call computes its own initial hidden state" begin + rng = MersenneTwister(123) + cell = build_cell(C, 2, 2; + activation = identity, + mix = 1.0, + use_bias = False(), + init_input = _W_I, + init_reservoir = _W_ZZ, + init_state = init_state3) + + ps = initialparameters(rng, cell) + st = initialstates(rng, cell) + + x = Float32[7, 9] + (y_tuple, st2) = cell(x, ps, st) + y, _ = y_tuple + + @test y ≈ x + @test haskey(st2, :rng) + end end -@testset "ESNCell: outer call computes its own initial hidden state" begin - rng = MersenneTwister(123) - esn = ESNCell(2 => 2, identity; use_bias = False(), - init_input = _W_I, init_reservoir = _W_ZZ, - init_state = init_state3, leak_coefficient = 1.0) - - ps = initialparameters(rng, esn) - st = initialstates(rng, esn) - x = Float32[7, 9] - (y_tuple, st2) = esn(x, ps, st) - y, _ = y_tuple - @test y ≈ x - @test haskey(st2, :rng) +@testset "AbstractEchoStateNetworkCell contract" begin + for C in (ESNCell, ES2NCell) + test_echo_state_cell_contract(C) + end end diff --git a/test/models/test_esn.jl b/test/models/test_esn.jl index 76fdbb7e..1f58cb86 100644 --- a/test/models/test_esn.jl +++ b/test/models/test_esn.jl @@ -9,6 +9,7 @@ const _Z32 = m -> zeros(Float32, m) const _O32 = (rng, m) -> zeros(Float32, m) const _W_I = (rng, m, n) -> _I32(m, n) const _W_ZZ = (rng, m, n) -> zeros(Float32, m, n) + function init_state3(rng::AbstractRNG, m::Integer, B::Integer) B == 1 ? zeros(Float32, m) : zeros(Float32, m, B) end @@ -20,162 +21,217 @@ function _with_identity_readout(ps::NamedTuple; out_dims::Integer, in_dims::Inte return merge(ps, (readout = ro_ps,)) end -@testset "ESN: constructor & parameter/state shapes" begin - rng = MersenneTwister(42) - - in_dims, res_dims, out_dims = 3, 5, 4 - esn = ESN(in_dims, res_dims, out_dims, identity; - use_bias = False(), - init_input = _W_I, - init_reservoir = _W_ZZ, - init_bias = _O32, - init_state = init_state3, - leak_coefficient = 1.0) - - ps, st = setup(rng, esn) - - @test haskey(ps, :reservoir) - @test haskey(ps.reservoir, :input_matrix) - @test haskey(ps.reservoir, :reservoir_matrix) - @test !haskey(ps.reservoir, :bias) - @test size(ps.reservoir.input_matrix) == (res_dims, in_dims) - @test size(ps.reservoir.reservoir_matrix) == (res_dims, res_dims) - - @test haskey(ps, :readout) - @test haskey(ps.readout, :weight) - @test size(ps.readout.weight) == (out_dims, res_dims) - - @test haskey(st, :reservoir) - @test haskey(st, :states_modifiers) - @test haskey(st, :readout) - @test st.states_modifiers isa Tuple -end - -@testset "ESN: forward (vector) with identity pipeline -> y == x (dimensions matched)" begin - rng = MersenneTwister(0) - D = 3 - esn = ESN(D, D, D, identity; - use_bias = False(), - init_input = _W_I, - init_reservoir = _W_ZZ, - init_bias = _O32, - init_state = init_state3, - leak_coefficient = 1.0) - - ps, st = setup(rng, esn) - ps = _with_identity_readout(ps; out_dims = D, in_dims = D) - - x = Float32[1, 2, 3] - - X = reshape(x, :, 1) - Y, st2 = esn(X, ps, st) - - @test size(Y) == (D, 1) - @test vec(Y) ≈ x - @test haskey(st2, :reservoir) && haskey(st2, :states_modifiers) && haskey(st2, :readout) -end - -@testset "ESN: forward (batch matrix) with identity pipeline -> Y == X" begin - rng = MersenneTwister(1) - D, B = 3, 2 - esn = ESN(D, D, D, identity; - use_bias = False(), - init_input = _W_I, - init_reservoir = _W_ZZ, - init_bias = _O32, - init_state = init_state3, - leak_coefficient = 1.0) - - ps, st = setup(rng, esn) - ps = _with_identity_readout(ps; out_dims = D, in_dims = D) - - X = Float32[1 2; 3 4; 5 6] - Y, _ = esn(X, ps, st) - - @test size(Y) == (D, B) - @test Y ≈ X -end - -@testset "ESN: state_modifiers are applied (single modifier doubles features)" begin - rng = MersenneTwister(2) - D = 3 - esn = ESN(D, D, D, identity; - state_modifiers = (x -> 2.0f0 .* x,), - use_bias = False(), - init_input = _W_I, - init_reservoir = _W_ZZ, - init_bias = _O32, - init_state = init_state3, - leak_coefficient = 1.0) +model_name(::Type{M}) where {M} = string(nameof(M)) - ps, st = setup(rng, esn) - ps = _with_identity_readout(ps; out_dims = D, in_dims = D) +mix_kw(::Type{ESN}) = :leak_coefficient +mix_kw(::Type{ES2N}) = :proximity - x = Float32[1, 2, 3] - y, _ = esn(x, ps, st) - @test y ≈ 2.0f0 .* x -end +reservoir_param_keys(::Type{ESN}) = (:input_matrix, :reservoir_matrix) +reservoir_param_keys(::Type{ES2N}) = (:input_matrix, :reservoir_matrix, :orthogonal_matrix) -@testset "ESN: multiple state_modifiers apply in order" begin - rng = MersenneTwister(3) - D = 3 - mods = (x -> x .+ 1.0f0, x -> 3.0f0 .* x) +default_reservoir_kwargs(::Type{ESN}) = NamedTuple() +default_reservoir_kwargs(::Type{ES2N}) = (init_orthogonal = _W_I,) - esn = ESN(D, D, D, identity; - state_modifiers = mods, +function build_model(::Type{M}, in_dims::Int, res_dims::Int, out_dims::Int, activation; + state_modifiers = (), + readout_activation = identity, + mix::Real = 1.0, use_bias = False(), init_input = _W_I, init_reservoir = _W_ZZ, init_bias = _O32, init_state = init_state3, - leak_coefficient = 1.0) - - ps, st = setup(rng, esn) - ps = _with_identity_readout(ps; out_dims = D, in_dims = D) - - x = Float32[0, 1, 2] - y, _ = esn(x, ps, st) - @test y ≈ 3.0f0 .* (x .+ 1.0f0) + extra::NamedTuple = NamedTuple() +) where {M} + base = (use_bias = use_bias, + init_input = init_input, + init_reservoir = init_reservoir, + init_bias = init_bias, + init_state = init_state) + + mixnt = NamedTuple{(mix_kw(M),)}((mix,)) + kw = merge(base, default_reservoir_kwargs(M), mixnt, extra) + + return M(in_dims, res_dims, out_dims, activation; + state_modifiers = state_modifiers, + readout_activation = readout_activation, + kw...) end -@testset "ESN: outer call computes its own initial hidden state through ESNCell" begin - rng = MersenneTwister(123) - D = 2 - esn = ESN(D, D, D, identity; - use_bias = False(), - init_input = _W_I, - init_reservoir = _W_ZZ, - init_state = init_state3, - leak_coefficient = 1.0) - - ps, st = setup(rng, esn) - ps = _with_identity_readout(ps; out_dims = D, in_dims = D) - - x = Float32[7, 9] - y, st2 = esn(x, ps, st) - - @test y ≈ x - @test haskey(st2, :reservoir) - @test haskey(st2, :states_modifiers) - @test haskey(st2, :readout) +function test_esn_family_contract(::Type{M}) where {M} + @testset "$(model_name(M)): constructor & parameter/state shapes" begin + rng = MersenneTwister(42) + in_dims, res_dims, out_dims = 3, 5, 4 + + model = build_model(M, in_dims, res_dims, out_dims, identity; + use_bias = False(), + init_input = _W_I, + init_reservoir = _W_ZZ, + init_bias = _O32, + init_state = init_state3, + mix = 1.0) + + ps, st = setup(rng, model) + + @test haskey(ps, :reservoir) + for k in reservoir_param_keys(M) + @test haskey(ps.reservoir, k) + end + @test !haskey(ps.reservoir, :bias) + @test size(ps.reservoir.input_matrix) == (res_dims, in_dims) + @test size(ps.reservoir.reservoir_matrix) == (res_dims, res_dims) + if M === ES2N + @test size(ps.reservoir.orthogonal_matrix) == (res_dims, res_dims) + end + + @test haskey(ps, :readout) + @test haskey(ps.readout, :weight) + @test size(ps.readout.weight) == (out_dims, res_dims) + + @test haskey(st, :reservoir) + @test haskey(st, :states_modifiers) + @test haskey(st, :readout) + @test st.states_modifiers isa Tuple + end + + @testset "$(model_name(M)): forward (vector) with identity pipeline -> y == x (dimensions matched)" begin + rng = MersenneTwister(0) + D = 3 + + model = build_model(M, D, D, D, identity; + use_bias = False(), + init_input = _W_I, + init_reservoir = _W_ZZ, + init_bias = _O32, + init_state = init_state3, + mix = 1.0) + + ps, st = setup(rng, model) + ps = _with_identity_readout(ps; out_dims = D, in_dims = D) + + x = Float32[1, 2, 3] + X = reshape(x, :, 1) + + Y, st2 = model(X, ps, st) + + @test size(Y) == (D, 1) + @test vec(Y) ≈ x + @test haskey(st2, :reservoir) && haskey(st2, :states_modifiers) && + haskey(st2, :readout) + end + + @testset "$(model_name(M)): forward (batch matrix) with identity pipeline -> Y == X" begin + rng = MersenneTwister(1) + D, B = 3, 2 + + model = build_model(M, D, D, D, identity; + use_bias = False(), + init_input = _W_I, + init_reservoir = _W_ZZ, + init_bias = _O32, + init_state = init_state3, + mix = 1.0) + + ps, st = setup(rng, model) + ps = _with_identity_readout(ps; out_dims = D, in_dims = D) + + X = Float32[1 2; 3 4; 5 6] + Y, _ = model(X, ps, st) + + @test size(Y) == (D, B) + @test Y ≈ X + end + + @testset "$(model_name(M)): state_modifiers are applied (single modifier doubles features)" begin + rng = MersenneTwister(2) + D = 3 + + model = build_model(M, D, D, D, identity; + state_modifiers = (x -> 2.0f0 .* x,), + use_bias = False(), + init_input = _W_I, + init_reservoir = _W_ZZ, + init_bias = _O32, + init_state = init_state3, + mix = 1.0) + + ps, st = setup(rng, model) + ps = _with_identity_readout(ps; out_dims = D, in_dims = D) + + x = Float32[1, 2, 3] + y, _ = model(x, ps, st) + @test y ≈ 2.0f0 .* x + end + + @testset "$(model_name(M)): multiple state_modifiers apply in order" begin + rng = MersenneTwister(3) + D = 3 + mods = (x -> x .+ 1.0f0, x -> 3.0f0 .* x) + + model = build_model(M, D, D, D, identity; + state_modifiers = mods, + use_bias = False(), + init_input = _W_I, + init_reservoir = _W_ZZ, + init_bias = _O32, + init_state = init_state3, + mix = 1.0) + + ps, st = setup(rng, model) + ps = _with_identity_readout(ps; out_dims = D, in_dims = D) + + x = Float32[0, 1, 2] + y, _ = model(x, ps, st) + @test y ≈ 3.0f0 .* (x .+ 1.0f0) + end + + @testset "$(model_name(M)): outer call computes its own initial hidden state through reservoir cell" begin + rng = MersenneTwister(123) + D = 2 + + model = build_model(M, D, D, D, identity; + use_bias = False(), + init_input = _W_I, + init_reservoir = _W_ZZ, + init_state = init_state3, + mix = 1.0) + + ps, st = setup(rng, model) + ps = _with_identity_readout(ps; out_dims = D, in_dims = D) + + x = Float32[7, 9] + y, st2 = model(x, ps, st) + + @test y ≈ x + @test haskey(st2, :reservoir) + @test haskey(st2, :states_modifiers) + @test haskey(st2, :readout) + end + + @testset "$(model_name(M)): readout_activation is honored" begin + rng = MersenneTwister(4) + D = 3 + + model = build_model(M, D, D, D, identity; + readout_activation = x -> max.(x, 0.0f0), + use_bias = False(), + init_input = _W_I, + init_reservoir = _W_ZZ, + init_bias = _O32, + init_state = init_state3, + mix = 1.0) + + ps, st = setup(rng, model) + ps = _with_identity_readout(ps; out_dims = D, in_dims = D) + + x = Float32[-1, 0.5, -3] + y, _ = model(x, ps, st) + @test y ≈ max.(x, 0.0f0) + end end -@testset "ESN: readout_activation is honored" begin - rng = MersenneTwister(4) - D = 3 - esn = ESN(D, D, D, identity; - readout_activation = x -> max.(x, 0.0f0), - use_bias = False(), - init_input = _W_I, - init_reservoir = _W_ZZ, - init_bias = _O32, - init_state = init_state3, - leak_coefficient = 1.0) - - ps, st = setup(rng, esn) - ps = _with_identity_readout(ps; out_dims = D, in_dims = D) - - x = Float32[-1, 0.5, -3] - y, _ = esn(x, ps, st) - @test y ≈ max.(x, 0.0f0) +@testset "ESN-family model contract" begin + for M in (ESN, ES2N) + test_esn_family_contract(M) + end end