Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Lorenz system.

### 1. Generate data

As a general first step wee fix the random seed for reproducibilty
As a general first step we fix the random seed for reproducibilty

```julia
using Random
Expand Down
1 change: 1 addition & 0 deletions docs/src/api/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

```@docs
ESNCell
ES2NCell
```

## Reservoir computing with cellular automata
Expand Down
110 changes: 110 additions & 0 deletions docs/src/examples/model_es2n.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Building a model to add to ReservoirComputing.jl

This examples showcases how to build custom models such that they could
be also included in ReservoirComputing.jl. In this example we will be building a
edge of stability echo state network [`ES2N`](@ref). Since the model is
already available in the library, we will change the names of cells and
models, to not cause problems.

## Building an ES2NCell

Building a ReservoirComputing.jl model largely follows the Lux.jl model
approach.

```@example es2n_scratch
using ReservoirComputing
using ConcreteStructs
using Static

using ReservoirComputing: IntegerType, BoolType, InputType, has_bias, _wrap_layers
import ReservoirComputing: initialparameters

@concrete struct CustomES2NCell <: AbstractEchoStateNetworkCell
activation
in_dims <: IntegerType
out_dims <: IntegerType
init_bias
init_reservoir
init_input
init_orthogonal
init_state
proximity
use_bias <: StaticBool
end

function CustomES2NCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType},
activation = tanh; use_bias::BoolType = False(), init_bias = zeros32,
init_reservoir = rand_sparse, init_input = scaled_rand,
init_state = randn32, init_orthogonal = orthogonal,
proximity::AbstractFloat = 1.0)
return CustomES2NCell(activation, in_dims, out_dims, init_bias, init_reservoir,
init_input, init_orthogonal, init_state, proximity, static(use_bias))
end

function initialparameters(rng::AbstractRNG, esn::CustomES2NCell)
ps = (input_matrix = esn.init_input(rng, esn.out_dims, esn.in_dims),
reservoir_matrix = esn.init_reservoir(rng, esn.out_dims, esn.out_dims),
orthogonal_matrix = esn.init_orthogonal(rng, esn.out_dims, esn.out_dims))
if has_bias(esn)
ps = merge(ps, (bias = esn.init_bias(rng, esn.out_dims),))
end
return ps
end

function (esn::CustomES2NCell)((inp, (hidden_state,))::InputType, ps, st::NamedTuple)
T = eltype(inp)
if has_bias(esn)
candidate_h = esn.activation.(ps.input_matrix * inp .+
ps.reservoir_matrix * hidden_state .+ ps.bias)
else
candidate_h = esn.activation.(ps.input_matrix * inp .+
ps.reservoir_matrix * hidden_state)
end
h_new = (T(1.0) - esn.proximity) .* ps.orthogonal_matrix * hidden_state .+
esn.proximity .* candidate_h
return (h_new, (h_new,)), st
end
```

You will notice that some definitions are missing. For instance, we did not
dispatch over `initialstates`. This is because the `AbstractEchoStateNetworkCell`
subtyping takes care of a lot of these smaller functions already.

## Building the full ES2N model

Now you can build a full model in two different ways:
- Leveraging [`ReservoirComputer`](@ref)
- Building from scratch with a proper `CustomES2N` struct

```@example es2n_scratch
function CustomES2NApproach1(in_dims::IntegerType, res_dims::IntegerType,
out_dims::IntegerType, activation = tanh;
readout_activation = identity,
state_modifiers = (),
kwargs...)
return ReservoirComputer(StatefulLayer(CustomES2NCell(in_dims => res_dims, activation; kwargs...)),
state_modifiers, LinearReadout(res_dims => out_dims, readout_activation))
end
```

```@example es2n_scratch
@concrete struct CustomES2NApproach2 <:
AbstractEchoStateNetwork{(:reservoir, :states_modifiers, :readout)}
reservoir
states_modifiers
readout
end

function CustomES2NApproach2(in_dims::IntegerType, res_dims::IntegerType,
out_dims::IntegerType, activation = tanh;
readout_activation = identity,
state_modifiers = (),
kwargs...)
cell = StatefulLayer(CustomES2NCell(in_dims => res_dims, activation; kwargs...))
mods_tuple = state_modifiers isa Tuple || state_modifiers isa AbstractVector ?
Tuple(state_modifiers) : (state_modifiers,)
mods = _wrap_layers(mods_tuple)
ro = LinearReadout(res_dims => out_dims, readout_activation)
return CustomES2NApproach2(cell, mods, ro)
end
```
15 changes: 15 additions & 0 deletions docs/src/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,18 @@ @article{Gauthier2021
year = {2021},
month = sep
}

@article{Ceni2025,
title = {Edge of Stability Echo State Network},
volume = {36},
ISSN = {2162-2388},
url = {http://dx.doi.org/10.1109/TNNLS.2024.3400045},
DOI = {10.1109/tnnls.2024.3400045},
number = {4},
journal = {IEEE Transactions on Neural Networks and Learning Systems},
publisher = {Institute of Electrical and Electronics Engineers (IEEE)},
author = {Ceni, Andrea and Gallicchio, Claudio},
year = {2025},
month = apr,
pages = {7555–7564}
}
7 changes: 5 additions & 2 deletions src/ReservoirComputing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ include("reservoircomputer.jl")
include("layers/basic.jl")
include("layers/lux_layers.jl")
include("layers/esn_cell.jl")
include("layers/es2n_cell.jl")
include("layers/svmreadout.jl")
#general
include("states.jl")
Expand All @@ -36,6 +37,7 @@ include("inits/inits_input.jl")
include("inits/inits_reservoir.jl")
#full models
include("models/esn_generics.jl")
include("models/es2n.jl")
include("models/esn.jl")
include("models/esn_deep.jl")
include("models/esn_delay.jl")
Expand All @@ -45,7 +47,8 @@ include("models/ngrc.jl")
include("extensions/reca.jl")

export ReservoirComputer
export ESNCell, StatefulLayer, LinearReadout, ReservoirChain, Collect, collectstates,
export ESNCell, ES2NCell
export StatefulLayer, LinearReadout, ReservoirChain, Collect, collectstates,
DelayLayer, NonlinearFeaturesLayer
export SVMReadout
export Pad, Extend, NLAT1, NLAT2, NLAT3, PartialSquare, ExtendedSquare
Expand All @@ -59,7 +62,7 @@ export block_diagonal, chaotic_init, cycle_jumps, delay_line, delayline_backward
export add_jumps!, backward_connection!, delay_line!, reverse_simple_cycle!,
scale_radius!, self_loop!, simple_cycle!
export train, train!, predict, resetcarry!, polynomial_monomials
export ESN, DeepESN, DelayESN, HybridESN
export ES2N, ESN, DeepESN, DelayESN, HybridESN
export NGRC
#ext
export RECACell, RECA
Expand Down
123 changes: 123 additions & 0 deletions src/layers/es2n_cell.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
abstract type AbstractEchoStateNetworkCell <: AbstractReservoirRecurrentCell end

@doc raw"""
ES2NCell(in_dims => out_dims, [activation];
use_bias=False(), init_bias=zeros32,
init_reservoir=rand_sparse, init_input=scaled_rand,
init_state=randn32, init_orthogonal=orthogonal,
proximity=1.0))

Edge of Stability Echo State Network (ES2N) cell [Ceni2025](@cite).

## Equations

```math
\begin{aligned}
x(t) = \beta\, \phi\!\left( \rho\, \mathbf{W}_r x(t-1) + \omega\,
\mathbf{W}_{in} u(t) \right) + (1-\beta)\, \mathbf{O}\, x(t-1),
\end{aligned}
```
## Arguments

- `in_dims`: Input dimension.
- `out_dims`: Reservoir (hidden state) dimension.
- `activation`: Activation function. Default: `tanh`.

## Keyword arguments

- `use_bias`: Whether to include a bias term. Default: `false`.
- `init_bias`: Initializer for the bias. Used only if `use_bias=true`.
Default is `rand32`.
- `init_reservoir`: Initializer for the reservoir matrix `W_res`.
Default is [`rand_sparse`](@ref).
- `init_orthogonal`: Initializer for the orthogonal matrix `O`.
Default is [`orthogonal`](@ref).
- `init_input`: Initializer for the input matrix `W_in`.
Default is [`scaled_rand`](@ref).
- `init_state`: Initializer for the hidden state when an external
state is not provided. Default is `randn32`.
- `proximity`: Proximity coefficient `α ∈ (0,1]`. Default: `1.0`.

## Inputs

- **Case 1:** `x :: AbstractArray (in_dims, batch)`
A fresh state is created via `init_state`; the call is forwarded to Case 2.
- **Case 2:** `(x, (h,))` where `h :: AbstractArray (out_dims, batch)`
Computes the update and returns the new state.

In both cases, the forward returns `((h_new, (h_new,)), st_out)` where `st_out`
contains any updated internal state.

## Returns

- Output/hidden state `h_new :: out_dims` and state tuple `(h_new,)`.
- Updated layer state (NamedTuple).

## Parameters


- `input_matrix :: (out_dims × in_dims)` — `W_in`
- `reservoir_matrix :: (out_dims × out_dims)` — `W_res`
- `orthogonal_matrix :: (res_dims × res_dims)` — `O`
- `bias :: (out_dims,)` — present only if `use_bias=true`

## States

Created by `initialstates(rng, esn)`:

- `rng`: a replicated RNG used to sample initial hidden states when needed.
"""
@concrete struct ES2NCell <: AbstractEchoStateNetworkCell
activation
in_dims <: IntegerType
out_dims <: IntegerType
init_bias
init_reservoir
init_input
init_orthogonal
init_state
proximity
use_bias <: StaticBool
end

function ES2NCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType},
activation = tanh; use_bias::BoolType = False(), init_bias = zeros32,
init_reservoir = rand_sparse, init_input = scaled_rand,
init_state = randn32, init_orthogonal = orthogonal,
proximity::AbstractFloat = 1.0)
return ES2NCell(activation, in_dims, out_dims, init_bias, init_reservoir,
init_input, init_orthogonal, init_state, proximity, static(use_bias))
end

function initialparameters(rng::AbstractRNG, esn::ES2NCell)
ps = (input_matrix = esn.init_input(rng, esn.out_dims, esn.in_dims),
reservoir_matrix = esn.init_reservoir(rng, esn.out_dims, esn.out_dims),
orthogonal_matrix = esn.init_orthogonal(rng, esn.out_dims, esn.out_dims))
if has_bias(esn)
ps = merge(ps, (bias = esn.init_bias(rng, esn.out_dims),))
end
return ps
end

function (esn::ES2NCell)((inp, (hidden_state,))::InputType, ps, st::NamedTuple)
T = eltype(inp)
if has_bias(esn)
candidate_h = esn.activation.(ps.input_matrix * inp .+
ps.reservoir_matrix * hidden_state .+ ps.bias)
else
candidate_h = esn.activation.(ps.input_matrix * inp .+
ps.reservoir_matrix * hidden_state)
end
h_new = (one(T) - esn.proximity) .* ps.orthogonal_matrix * hidden_state .+
esn.proximity .* candidate_h
return (h_new, (h_new,)), st
end

function Base.show(io::IO, esn::ES2NCell)
print(io, "ES2NCell($(esn.in_dims) => $(esn.out_dims)")
if esn.proximity != eltype(esn.proximity)(1.0)
print(io, ", proximity=$(esn.proximity)")
end
has_bias(esn) || print(io, ", use_bias=false")
print(io, ")")
end
Loading
Loading