diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
new file mode 100644
index 00000000..a176bd3e
--- /dev/null
+++ b/.github/workflows/docs.yml
@@ -0,0 +1,65 @@
+name: Documentation
+
+on:
+ push:
+ branches:
+ - main
+ tags:
+ - 'v*'
+ pull_request:
+ branches:
+ - dev
+ workflow_dispatch:
+
+permissions:
+ contents: read
+ pages: write
+ id-token: write
+
+# Allow only one concurrent deployment
+concurrency:
+ group: "pages"
+ cancel-in-progress: false
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # Needed for hatch-vcs to determine version
+
+ - name: Install uv
+ uses: astral-sh/setup-uv@v6
+ with:
+ enable-cache: true
+ python-version: "3.12"
+
+ - name: Install the project
+ run: uv sync --only-group docs
+
+ - name: Build documentation
+ run: |
+ cd docs
+ uv run make html
+
+ - name: Add .nojekyll file
+ run: touch docs/build/html/.nojekyll
+
+ - name: Upload artifact
+ uses: actions/upload-pages-artifact@v3
+ with:
+ path: 'docs/build/html'
+
+ deploy:
+ # Only deploy on push to main or release tags
+ if: github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/v'))
+ environment:
+ name: github-pages
+ url: ${{ steps.deployment.outputs.page_url }}
+ runs-on: ubuntu-latest
+ needs: build
+ steps:
+ - name: Deploy to GitHub Pages
+ id: deployment
+ uses: actions/deploy-pages@v4
diff --git a/.gitignore b/.gitignore
index 2212bfef..b78272ee 100644
--- a/.gitignore
+++ b/.gitignore
@@ -70,8 +70,10 @@ instance/
# Sphinx documentation
docs/_build/
+docs/build/
docs/source/_build
docs/source/generated
+docs/source/api/generated
# PyBuilder
.pybuilder/
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 00000000..d0c3cbf1
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/img/HybridBufferBasic.svg b/docs/img/HybridBufferBasic.svg
new file mode 100644
index 00000000..710963de
--- /dev/null
+++ b/docs/img/HybridBufferBasic.svg
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/docs/img/HybridBufferOverflow.svg b/docs/img/HybridBufferOverflow.svg
new file mode 100644
index 00000000..e5de0a83
--- /dev/null
+++ b/docs/img/HybridBufferOverflow.svg
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/docs/make.bat b/docs/make.bat
new file mode 100644
index 00000000..9534b018
--- /dev/null
+++ b/docs/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=source
+set BUILDDIR=build
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.http://sphinx-doc.org/
+ exit /b 1
+)
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/docs/source/_templates/autosummary/module.rst b/docs/source/_templates/autosummary/module.rst
new file mode 100644
index 00000000..fd632b42
--- /dev/null
+++ b/docs/source/_templates/autosummary/module.rst
@@ -0,0 +1,64 @@
+{{ fullname | escape | underline}}
+
+.. automodule:: {{ fullname }}
+
+ {% block attributes %}
+ {% if attributes %}
+ .. rubric:: Module Attributes
+
+ .. autosummary::
+ :toctree:
+ {% for item in attributes %}
+ {{ item }}
+ {%- endfor %}
+ {% endif %}
+ {% endblock %}
+
+ {% block functions %}
+ {% if functions %}
+ .. rubric:: Functions
+
+ {% for item in functions %}
+ .. autofunction:: {{ item }}
+ {%- endfor %}
+ {% endif %}
+ {% endblock %}
+
+ {% block classes %}
+ {% if classes %}
+ .. rubric:: Classes
+
+ {% for item in classes %}
+ .. autoclass:: {{ item }}
+ :members:
+ :undoc-members:
+ :show-inheritance:
+ :special-members: __init__
+ {%- endfor %}
+ {% endif %}
+ {% endblock %}
+
+ {% block exceptions %}
+ {% if exceptions %}
+ .. rubric:: Exceptions
+
+ {% for item in exceptions %}
+ .. autoexception:: {{ item }}
+ :members:
+ :show-inheritance:
+ {%- endfor %}
+ {% endif %}
+ {% endblock %}
+
+{% block modules %}
+{% if modules %}
+.. rubric:: Modules
+
+.. autosummary::
+ :toctree:
+ :recursive:
+{% for item in modules %}
+ {{ item }}
+{%- endfor %}
+{% endif %}
+{% endblock %}
diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst
new file mode 100644
index 00000000..1be2c645
--- /dev/null
+++ b/docs/source/api/index.rst
@@ -0,0 +1,157 @@
+API Reference
+=============
+
+This page contains the complete API reference for ``ezmsg.sigproc``.
+
+.. contents:: Modules
+ :local:
+ :depth: 1
+
+Base Processors
+---------------
+
+Core processor protocols and base classes.
+
+.. autosummary::
+ :toctree: generated
+ :recursive:
+
+ ezmsg.sigproc.base
+
+Filtering
+---------
+
+Various filter implementations for signal processing.
+
+.. autosummary::
+ :toctree: generated
+ :recursive:
+
+ ezmsg.sigproc.filter
+ ezmsg.sigproc.butterworthfilter
+ ezmsg.sigproc.cheby
+ ezmsg.sigproc.combfilter
+ ezmsg.sigproc.adaptive_lattice_notch
+ ezmsg.sigproc.firfilter
+ ezmsg.sigproc.kaiser
+ ezmsg.sigproc.ewmfilter
+ ezmsg.sigproc.filterbank
+ ezmsg.sigproc.filterbankdesign
+ ezmsg.sigproc.gaussiansmoothing
+
+Spectral Analysis
+-----------------
+
+Spectral and frequency domain analysis tools.
+
+.. autosummary::
+ :toctree: generated
+ :recursive:
+
+ ezmsg.sigproc.spectral
+ ezmsg.sigproc.spectrogram
+ ezmsg.sigproc.spectrum
+ ezmsg.sigproc.wavelets
+ ezmsg.sigproc.bandpower
+ ezmsg.sigproc.fbcca
+
+Sampling & Resampling
+---------------------
+
+Signal sampling, windowing, and resampling operations.
+
+.. autosummary::
+ :toctree: generated
+ :recursive:
+
+ ezmsg.sigproc.sampler
+ ezmsg.sigproc.window
+ ezmsg.sigproc.resample
+ ezmsg.sigproc.downsample
+ ezmsg.sigproc.decimate
+
+Signal Conditioning
+-------------------
+
+Signal preprocessing and conditioning operations.
+
+.. autosummary::
+ :toctree: generated
+ :recursive:
+
+ ezmsg.sigproc.scaler
+ ezmsg.sigproc.detrend
+ ezmsg.sigproc.activation
+ ezmsg.sigproc.quantize
+ ezmsg.sigproc.ewma
+
+Transformations
+---------------
+
+Geometric and structural transformations.
+
+.. autosummary::
+ :toctree: generated
+ :recursive:
+
+ ezmsg.sigproc.affinetransform
+ ezmsg.sigproc.transpose
+ ezmsg.sigproc.extract_axis
+ ezmsg.sigproc.slicer
+
+Signal Operations
+-----------------
+
+Aggregation, difference, and other signal operations.
+
+.. autosummary::
+ :toctree: generated
+ :recursive:
+
+ ezmsg.sigproc.aggregate
+ ezmsg.sigproc.diff
+
+Signal Generation
+-----------------
+
+Synthetic signal generators and injectors.
+
+.. autosummary::
+ :toctree: generated
+ :recursive:
+
+ ezmsg.sigproc.synth
+ ezmsg.sigproc.signalinjector
+
+Messages & Data Structures
+---------------------------
+
+Message types and data structures.
+
+.. autosummary::
+ :toctree: generated
+ :recursive:
+
+ ezmsg.sigproc.messages
+
+Math Utilities
+--------------
+
+Mathematical operations on signals.
+
+.. autosummary::
+ :toctree: generated
+ :recursive:
+
+ ezmsg.sigproc.math
+
+Utilities
+---------
+
+Helper utilities for signal processing.
+
+.. autosummary::
+ :toctree: generated
+ :recursive:
+
+ ezmsg.sigproc.util
diff --git a/docs/source/conf.py b/docs/source/conf.py
new file mode 100644
index 00000000..b2b43cdd
--- /dev/null
+++ b/docs/source/conf.py
@@ -0,0 +1,123 @@
+# Configuration file for the Sphinx documentation builder.
+
+import os
+import sys
+
+# Add the source directory to the path
+sys.path.insert(0, os.path.abspath("../../src"))
+
+# -- Project information --------------------------
+
+project = "ezmsg.sigproc"
+copyright = "2024, ezmsg Contributors"
+author = "ezmsg Contributors"
+
+# The version is managed by hatch-vcs and stored in __version__.py
+try:
+ from ezmsg.sigproc.__version__ import version as release
+except ImportError:
+ release = "unknown"
+
+# For display purposes, extract the base version without git commit info
+version = release.split("+")[0] if release != "unknown" else release
+
+# -- General configuration --------------------------
+
+extensions = [
+ "sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.napoleon",
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.viewcode",
+ "sphinx.ext.duration",
+ # "sphinx_autodoc_typehints", # Disabled due to compatibility issue
+ "sphinx_copybutton",
+ "myst_parser", # For markdown files
+]
+
+templates_path = ["_templates"]
+source_suffix = {
+ ".rst": "restructuredtext",
+ ".md": "markdown",
+}
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
+
+# The toctree master document
+master_doc = "index"
+
+# -- Autodoc configuration ------------------------------
+
+# Auto-generate API docs
+autosummary_generate = True
+autosummary_imported_members = False
+autodoc_typehints = "description"
+autodoc_member_order = "bysource"
+autodoc_typehints_format = "short"
+python_use_unqualified_type_names = True
+autodoc_default_options = {
+ "members": True,
+ "member-order": "bysource",
+ "special-members": "__init__",
+ "undoc-members": True,
+ "show-inheritance": True,
+}
+
+# Don't show the full module path in the docs
+add_module_names = False
+
+# -- Intersphinx configuration --------------------------
+
+intersphinx_mapping = {
+ "python": ("https://docs.python.org/3/", None),
+ "numpy": ("https://numpy.org/doc/stable/", None),
+ "scipy": ("https://scipy.org/doc/scipy/", None),
+ "ezmsg": ("https://www.ezmsg.org/ezmsg/", None),
+}
+intersphinx_disabled_domains = ["std"]
+
+# -- Options for HTML output -----------------------------
+
+html_theme = "pydata_sphinx_theme"
+html_static_path = ["_static"]
+
+# Set the base URL for the documentation
+html_baseurl = "https://www.ezmsg.org/ezmsg-sigproc/"
+
+html_theme_options = {
+ "logo": {
+ "text": f"ezmsg.sigproc {version}",
+ "link": "https://ezmsg.org", # Link back to main site
+ },
+ "header_links_before_dropdown": 4,
+ "navbar_start": ["navbar-logo"],
+ "navbar_end": ["theme-switcher", "navbar-icon-links"],
+ "icon_links": [
+ {
+ "name": "GitHub",
+ "url": "https://github.com/ezmsg-org/ezmsg-sigproc",
+ "icon": "fa-brands fa-github",
+ },
+ {
+ "name": "ezmsg.org",
+ "url": "https://www.ezmsg.org",
+ "icon": "fa-solid fa-house",
+ },
+ ],
+}
+
+# Timestamp is inserted at every page bottom in this strftime format.
+html_last_updated_fmt = "%Y-%m-%d"
+
+# -- Options for linkcode -----------------------------
+
+branch = "main"
+code_url = f"https://github.com/ezmsg-org/ezmsg-sigproc/blob/{branch}/"
+
+
+def linkcode_resolve(domain, info):
+ if domain != "py":
+ return None
+ if not info["module"]:
+ return None
+ filename = info["module"].replace(".", "/")
+ return f"{code_url}src/{filename}.py"
diff --git a/docs/source/guides/HybridBuffer.md b/docs/source/guides/HybridBuffer.md
new file mode 100644
index 00000000..2352a559
--- /dev/null
+++ b/docs/source/guides/HybridBuffer.md
@@ -0,0 +1,87 @@
+## HybridBuffer
+
+The HybridBuffer is a stateful, FIFO buffer that combines a deque for fast appends with a contiguous circular buffer for efficient, advancing reads. The synchronization between the deque and the circular buffer can be immediate, upon threshold reaching, or on demand, allowing for flexible data management strategies.
+
+This buffer is designed to be agnostic to the array library used (e.g., NumPy, CuPy, PyTorch) via the Python Array API standard.
+
+### Basic Reading and Writing Behaviour
+
+The following diagram illustrates the states of the HybridBuffer across data writes and reads when `update_strategy="on_demand"`:
+
+
+
+**Figure 1**
+
+A. In the initial state, the buffer is empty, with no data in either the deque or the circular buffer.
+ * deq_len=0; available=0, tell=0
+
+B. After we `write()` 4 samples, the deque contains the new data, but the circular buffer is still empty.
+ * deq_len=4; available=4, tell=0
+
+C. After we `write()` 4 more samples, the deque now has 2 messages, each with 4 samples, and the circular buffer remains untouched.
+ * deq_len=8; available=8, tell=0
+
+D. Panels D-F depict a single call to `read(4)` which is implemented as calls to other methods. If we don't have 4 unread samples in the circular buffer, but we do have >= 4 samples 'available' (i.e., including the deque), then a `flush()` is performed: the entirety of the data in the deque are copied to the circular buffer and the deque is cleared.
+ * deq_len=0; available=8, tell=0
+ * TODO: Currently `flush()` copies the data twice, once from the deque to a contiguous array, and then from that contiguous array to the circular buffer. This should be optimized to copy directly from the deque to the circular buffer.
+
+E. Next we `peek(4)` which returns the first 4 samples from the circular buffer; the return value may be a view on the data if the data are contiguous in the circular buffer, or a copy if the data are not contiguous. Note that the tail (read pointer) does not advance with `peek()`.
+ * deq_len=0; available=8, tell=0
+
+F. Finally, we `seek(4)` to advance the tail.
+ * deq_len=0; available=4, tell=4
+
+G. We `write()` 4 more samples, which are appended to the deque, leaving the circular buffer unchanged from the previous step.
+ * deq_len=4; available=8, tell=4
+
+H. We then `read(4)` again. This time, a `flush()` is not triggered because we have enough unread samples in the circular buffer, but `peek(4)` and `seek(4)` are still called. The read pointer advances by 4, leaving 0 unread samples in the circular buffer and 4 in the deque.
+ * deq_len=4; available=0, tell=8
+
+Note: `peek(n)` and `seek(n)`, where `n` > `n_available` will raise an error. However, `peek(None)` will return all available samples without error, and `seek(None)` will advance the tail to the end of the available data.
+
+### Overflow Behaviour
+
+The criteria to trigger an overflow are as follows:
+* the deque has more data than there is space in the circular buffer, where space is the combination of previously read samples and unwritten samples in the circular buffer.
+* the caller triggers a flush either manually (`flush()`) or by requesting (via `read`, `peek`, or `seek`) more samples than are available in the circular buffer but not more than the total size of the available samples in the buffer + available samples in the deque.
+
+
+
+**Figure 2**
+
+A. We start with a circular buffer that has been running for a while (it has wrapped around several times). At this particular moment, we have more data in the deque (12) than we have room in the buffer (8). The remaining figures describe what happens when `flush()` is called with different overflow strategies. The samples are labeled to make it easier to follow the flow of data.
+ * deq_len=12; available=20, tell=1
+
+B. "warn-overwrite": If the overflow_strategy is set to 'warn-overwrite', the HybridBuffer will log a warning and overwrite the oldest data in the circular buffer with the new data from the deque. Here, samples 'a-d' are lost.
+ * deq_len=0; available=16, tell=0
+
+C. "drop": As much as possible of the data from the deque are copied into the circular buffer, but remaining data are dropped. In this case, samples 'q-t' are lost.
+ * deq_len=0; available=16, tell=0
+
+D. "grow": The HybridBuffer will attempt to grow the circular buffer to the lesser of double its current size or the size required to accommodate all read + unread + deque data. If the buffer cannot grow (e.g., due to memory constraints; default max_size is 1GB), it will raise an error.
+ * deq_len=0; available=20, tell=8
+
+Additionally, one can configure the HybridBuffer overflow_strategy to 'raise', which will raise an error if there is insufficient space (empty or read samples) in the buffer to perform the flush.
+
+There are a few mitigations to defer flushing to help prevent overflows:
+
+* If the requested number of samples to read, peek, or seek is less than the number of unread samples in the circular buffer, then no flush is performed.
+* Helper methods `peek_at(k, allow_flush=False)` (False is default), and `peek_last()` will retrieve the target sample from the buffer-OR-deque without flushing.
+ * Be cautious relying on repeated calls to `peek_at(k, allow_flush=False)` as it scans over the items in the deque which can be slow.
+* When calling `read(n)`, if a flush is necessary, and it will cause an overflow, and the overflow could be prevented with a pre-emptive read up to `n`, then it will do the read in 2 parts. First it will call `peek(n_unread_in_buffer)` and `seek(n_unread_in_buffer)` to read the unread samples in the circular buffer. Second, it will call `peek(n_remaining)` and `seek(n_remaining)` to trigger a flush -- which should no longer cause an overflow -- then read the remaining requested samples and stitch them together.
+
+### Advanced Pointer Manipulation
+
+The previous section describes how `read`, `peek`, `seek`, and `peek_at` function in normal use cases. It is also possible to call `seek` with a negative value, which will attempt to move the tail pointer backwards over previously-read (or previously sought-over) data by that many samples. `seek` returns the number of samples that were actually moved, which may be less than the requested value if there was insufficient room. Negative seeks can only rewind into previously read data, and positive seeks can only advance into unread data, possibly including data that gets flushed from the deque.
+
+## HybridAxisBuffer
+
+The `HybridAxisBuffer` carries the semantics of the `HybridBuffer` but it is designed to handle either a `LinearAxis` or a `CoordinateAxis`. Its `write` method expects an axis object and its `peek` and `read` methods return an axis, not just the data.
+
+For a `LinearAxis`, the `HybridAxisBuffer` simply maintains the `gain`, the `offset`, and the 'number of samples available'. Since this does not store actual data, it has no capacity. If this object is intended to be synchronized with another `HybridBuffer`-using object that does have a capacity, then the other object should be manipulated first and then the number of samples actually moved should be used to call the `HybridAxisBuffer`'s methods, otherwise these objects will be out of sync.
+
+For a `CoordinateAxis`, the `HybridAxisBuffer` maintains the `data` in a `HybridBuffer` and thus behaves like a `HybridBuffer` with respect to the capacity. The returned `CoordinateAxis` object might have its `.data` field as a view on the data in the buffer, so it should not be modified in place.
+
+## HybridAxisArrayBuffer
+
+This is a convenience class that combines the `HybridAxisBuffer` and `HybridBuffer` into a single object that can be used to manage both axis and data in a single object. This class is particularly useful when you need to manage both the axis information and the data samples together, as is the case for an `AxisArray` object. Its `write` method expects an `AxisArary` object and its `peek` and `read` methods return an `AxisArray` object. Note that the return object's `.data` field might be a view on the data in the buffer so it should not be modified in place. Similarly so for the `CoordinateAxis` data.
diff --git a/docs/ProcessorsBase.md b/docs/source/guides/ProcessorsBase.md
similarity index 100%
rename from docs/ProcessorsBase.md
rename to docs/source/guides/ProcessorsBase.md
diff --git a/docs/source/guides/explanations/sigproc.rst b/docs/source/guides/explanations/sigproc.rst
new file mode 100644
index 00000000..e4dd5f9e
--- /dev/null
+++ b/docs/source/guides/explanations/sigproc.rst
@@ -0,0 +1,365 @@
+About ezmsg-sigproc Signal Processors
+###########################################
+
+
+`ezmsg-sigproc` is an `ezmsg` extension that provides a template for building signal processing classes as well a way to easily convert to ezmsg ``Unit``\ s.
+
+It also comes with a collection of pre-built signal processing classes and relevant ezmsg Units that accomplish standard signal processing tasks and are designed to work seamlessly within the ``ezmsg`` framework.
+
+A list of available signal processors and ezmsg Units can be found in the (TBD) `ezmsg-sigproc reference `_.
+
+
+|ezmsg_logo_small| Rationale For Implementation
+********************************************************
+
+Providing a flexible and extensible framework for signal processing tasks makes it
+
+- easier for users to create custom signal processors
+- easier for users to integrate with ezmsg and create ezmsg Units
+- easier to create processing pipelines in the ``ezmsg`` ecosystem
+- allows standalone use outside of an ezmsg context
+
+
+|ezmsg_logo_small| How to decide which processor template to use?
+******************************************************************
+
+We use the term "processor" to refer to any class that processes signals. We then separate processors into types based on whether or not they receive input messages (typically signal data), send output messages, or both:
+
+- A producer sends output messages, but does not receive input
+- A consumer receives input, but does not output
+- A transformer receives input and sends output.
+
+Furthermore, if a processor of any type must maintain state between processing calls (e.g., filtering, modulation, etc.), it is considered a stateful processor. For example, a producer that is stateful is called a stateful producer.
+
+Additionally, we also consider adaptive stateful transformers, which are stateful transformers that adapt their internal state based on the input signal characteristics (e.g., adaptive filters). If we would like a transformer to be asynchronous in all calls, we would use an asynchronous transformer.
+
+The decision tree for this classification is as follows:
+
+.. graphviz::
+ :align: center
+
+ digraph signal_processor_decision_tree {
+ node [shape=box, style="rounded,filled", fillcolor="#f0f0f0", fontname="Arial"];
+ edge [fontname="Arial"];
+
+ AMP [label="Multiple Processors?"];
+ ARI [label="Receives Input?"];
+ ACB [label="Single Chain / Branching?"];
+ P [label="Producer", shape=diamond, fillcolor="#27f21cff"];
+ APO [label="Produces Output?"];
+ NBC [label="no base class", style="none"];
+ ACRI [label="Receives Input?"];
+ C [label="Consumer", shape=diamond, fillcolor="#27f21cff"];
+ T [label="Transformer", shape=diamond, fillcolor="#27f21cff"];
+ PS [label="Stateful?"];
+ CS [label="Stateful?"];
+ TS [label="Stateful?"];
+ TSA [label="Adaptive?"];
+ TSAF [label="Async First?"];
+ CompositeProducer [style="none, filled", fillcolor="#effb1aff"];
+ CompositeProcessor [style="none, filled", fillcolor="#effb1aff"];
+ BaseProducer [style="none, filled", fillcolor="#effb1aff"];
+ BaseStatefulProducer [style="none, filled", fillcolor="#effb1aff"];
+ BaseConsumer [style="none, filled", fillcolor="#effb1aff"];
+ BaseStatefulConsumer [style="none, filled", fillcolor="#effb1aff"];
+ BaseTransformer [style="none, filled", fillcolor="#effb1aff"];
+ BaseAdaptiveTransformer [style="none, filled", fillcolor="#effb1aff"];
+ BaseStatefulTransformer [style="none, filled", fillcolor="#effb1aff"];
+ BaseAsyncTransformer [style="none, filled", fillcolor="#effb1aff"];
+
+ AMP -> ARI [label="no"];
+ AMP -> ACB [label="yes"];
+ ARI -> P [label="no"];
+ ARI -> APO [label="yes"];
+ ACB -> NBC [label="branching"];
+ ACB -> ACRI [label="single chain"];
+ P -> PS;
+ APO -> C [label="no"];
+ APO -> T [label="yes"];
+ ACRI -> CompositeProducer [label="no"];
+ ACRI -> CompositeProcessor [label="yes"];
+ PS -> BaseProducer [label="no"];
+ PS -> BaseStatefulProducer [label="yes"];
+ C -> CS;
+ T -> TS;
+ CS -> BaseConsumer [label="no"];
+ CS -> BaseStatefulConsumer [label="yes"];
+ TS -> BaseTransformer [label="no"];
+ TS -> TSA [label="yes"];
+ TSA -> TSAF [label="no"];
+ TSA -> BaseAdaptiveTransformer [label="yes"];
+ TSAF -> BaseStatefulTransformer [label="no"];
+ TSAF -> BaseAsyncTransformer [label="yes"];
+ }
+
+The leaf nodes in yellow are abstract base classes provided in `ezmsg.sigproc.base` for implementing standalone processors. The table below summarizes these base classes.
+
+|ezmsg_logo_small| Abstract implementations (Base Classes) for standalone processors
+***************************************************************************************
+
+
+Generic TypeVars
+===================
+
+In this table, we summarize the generic TypeVars used in the processor class protocols and abstract base classes provided in `ezmsg.sigproc.base`.
+
+.. list-table::
+ :widths: 5 20 30
+ :header-rows: 1
+
+ * - Idx
+ - Class
+ - Description
+ * - 1
+ - `MessageInType` (Mi)
+ - for messages passed to a consumer, processor, or transformer
+ * - 2
+ - `MessageOutType` (Mo)
+ - for messages returned by a producer, processor, or transformer
+ * - 3
+ - `SettingsType`
+ - bound to ``ez.Settings``
+ * - 4
+ - `StateType` (St)
+ - bound to ``ProcessorState`` which is simply ``ez.State`` with a ``hash: int`` field.
+
+
+Processor Class Protocols
+===========================
+In this table, we summarize the processor class protocols used to define the abstract base classes provided in `ezmsg.sigproc.base`. Each protocol corresponds to a specific processor type and characteristics as outlined in the decision tree above.
+
++-----+-----------------------+--------+-------+------------------------+--------+-----------------+
+| Idx | Class | Parent | State | ``__call__`` signature | @state | ``partial_fit`` |
++=====+=======================+========+=======+========================+========+=================+
+| 1 | `Processor` | \- | No | Any -> Any | \- | \- |
++-----+-----------------------+--------+-------+------------------------+--------+-----------------+
+| 2 | `Producer` | \- | No | None -> Mo | \- | \- |
++-----+-----------------------+--------+-------+------------------------+--------+-----------------+
+| 3 | `Consumer` | 1 | No | Mi -> None | \- | \- |
++-----+-----------------------+--------+-------+------------------------+--------+-----------------+
+| 4 | `Transformer` | 1 | No | Mi -> Mo | \- | \- |
++-----+-----------------------+--------+-------+------------------------+--------+-----------------+
+| 5 | `StatefulProcessor` | \- | Yes | Any -> Any | Y | \- |
++-----+-----------------------+--------+-------+------------------------+--------+-----------------+
+| 6 | `StatefulProducer` | \- | Yes | None -> Mo | Y | \- |
++-----+-----------------------+--------+-------+------------------------+--------+-----------------+
+| 7 | `StatefulConsumer` | 5 | Yes | Mi -> None | Y | \- |
++-----+-----------------------+--------+-------+------------------------+--------+-----------------+
+| 8 | `StatefulTransformer` | 5 | Yes | Mi -> Mo | Y | \- |
++-----+-----------------------+--------+-------+------------------------+--------+-----------------+
+| 9 | `AdaptiveTransformer` | 8 | Yes | Mi -> Mo | Y | Y |
++-----+-----------------------+--------+-------+------------------------+--------+-----------------+
+
+Note: ``__call__`` and ``partial_fit`` both have asynchronous alternatives: ``__acall__`` and ``apartial_fit`` respectively.
+
+
+Processor Base Classes
+========================
+
+In this table, we summarize the abstract base classes provided in `ezmsg.sigproc.base` for implementing standalone signal processors. Each base class corresponds to a specific processor type and protocol, as outlined in the decision tree above.
+
+.. list-table::
+ :widths: 5 20 5 5 30
+ :header-rows: 1
+
+ * - Idx
+ - Class
+ - Parent
+ - Protocol
+ - Features
+ * - 1
+ - ``BaseProcessor``
+ - \-
+ - 1
+ - ``__init__`` for settings; ``__call__`` (alias: ``send``) wraps abstract ``_process``.
+ * - 2
+ - ``BaseProducer``
+ - \-
+ - 2
+ - Similar to ``BaseProcessor``; ``next``/``anext`` instead of ``send``/``asend`` aliases. async first!
+ * - 3
+ - ``BaseConsumer``
+ - 1
+ - 3
+ - Overrides return type to None.
+ * - 4
+ - ``BaseTransformer``
+ - 1
+ - 4
+ - Overrides input and return types.
+ * - 5
+ - ``BaseStatefulProcessor``
+ - 1
+ - 5
+ - ``state`` setter unpickles arg; ``stateful_op`` wraps ``__call__``.
+ * - 6
+ - ``BaseStatefulProducer``
+ - 2
+ - 6
+ - ``state`` setter and getter; ``stateful_op`` wraps ``__call__`` which runs ``__acall__``.
+ * - 7
+ - ``BaseStatefulConsumer``
+ - 5
+ - 7
+ - Overrides return type to None.
+ * - 8
+ - ``BaseStatefulTransformer``
+ - 5
+ - 8
+ - Overrides input and return types.
+ * - 9
+ - ``BaseAdaptiveTransformer``
+ - 8
+ - 9
+ - Implements ``partial_fit``. ``__call__`` may call ``partial_fit`` if message has ``.trigger``.
+ * - 10
+ - ``BaseAsyncTransformer``
+ - 8
+ - 8
+ - ``__acall__`` wraps abstract ``_aprocess``; ``__call__`` runs ``__acall__``.
+ * - 11
+ - ``CompositeProcessor``
+ - 1
+ - 5
+ - Methods iterate over sequence of processors created in ``_initialize_processors``.
+ * - 12
+ - ``CompositeProducer``
+ - 2
+ - 6
+ - Similar to ``CompositeProcessor``, but first processor must be a producer.
+
+NOTES:
+
+1. Producers do not inherit from ``BaseProcessor``, so concrete implementations should subclass ``BaseProducer`` or ``BaseStatefulProducer``.
+2. For concrete implementations of non-producer processors, inherit from the base subclasses of ``BaseProcessor`` (eg. ``BaseConsumer``, ``BaseTransformer``) and from base subclasses of ``BaseStatefulProcessor``. These two processor classes are primarily used for efficient abstract base class construction.
+3. For most base classes, the async methods simply call the synchronous methods where the processor logic is expected. Exceptions are ``BaseProducer`` (and its children) and ``BaseAsyncTransformer`` which are async-first and should be strongly considered for operations that are I/O bound.
+4. For async-first classes, the logic is implemented in the async methods and the sync methods are thin wrappers around the async methods. The wrapper uses a helper method called ``run_coroutine_sync`` to run the async method in a synchronous context, but this adds some noticeable processing overhead.
+5. If you need to call your processor outside ezmsg (which uses async), and you cannot easily add an async context* in your processing, then you might want to consider duplicating the processor logic in the sync methods.
+
+ .. note:: Jupyter notebooks are async by default, so you can await async code in a notebook without any extra setup.
+
+6. ``CompositeProcessor`` and ``CompositeProducer`` are stateful, and structurally subclass the ``StatefulProcessor`` and ``StatefulProducer`` protocols, but they
+do not inherit from ``BaseStatefulProcessor`` and ``BaseStatefulProducer``. They accomplish statefulness by inheriting from the mixin abstract base class ``CompositeStateful``, which implements the state related methods: ``get_state_type``, ``state.setter``, ``state.getter``, ``_hash_message``, ``_reset_state``, and ``stateful_op`` (as well as composite processor chain related methods). However, ``BaseStatefulProcessor``, ``BaseStatefulProducer`` implement ``stateful_op`` method for a single processor in an incompatible way to what is required for composite chains of processors.
+
+
+|ezmsg_logo_small| Implementing a custom standalone processor
+****************************************************************
+
+1. Create a new settings dataclass: ``class MySettings(ez.Settings):``
+2. Create a new state dataclass:
+
+.. code-block:: python
+
+ @processor_state
+ class MyState:
+
+3. Decide on your base processor class, considering the protocol, whether it should be async-first, and other factors using the decision tree above.
+
+4. Implement the child class.
+ * The minimum implementation is ``_process`` for sync processors, ``_aprocess`` for async processors, and ``_produce`` for producers.
+ * For any stateful processor, implement ``_reset_state``.
+ * For stateful processors that need to respond to a change in the incoming data, implement ``_hash_message``.
+ * For adaptive transformers, implement ``partial_fit``.
+ * For chains of processors (``CompositeProcessor``/ ``CompositeProducer``), need to implement ``_initialize_processors``.
+ * See processors in `ezmsg.sigproc` for examples.
+5. Override non-abstract methods if you need special behaviour. For example:
+ * ``WindowTransformer`` overrides ``__init__`` to do some sanity checks on the provided settings.
+ * ``TransposeTransformer`` and ``WindowTransformer`` override ``__call__`` to provide a passthrough shortcut when the settings allow for it.
+ * ``ClockProducer`` overrides ``__call__`` in order to provide a synchronous call bypassing the default async behaviour.
+
+
+|ezmsg_logo_small| Abstract implementations (Base Classes) for ezmsg Units using processors
+**********************************************************************************************
+
+Generic TypeVars for ezmsg Units
+==================================
+
+.. list-table::
+ :widths: 5 20 30
+ :header-rows: 1
+
+ * - Idx
+ - Class
+ - Description
+ * - 5
+ - ``ProducerType``
+ - bound to ``BaseProducer`` (hence, also ``BaseStatefulProducer``, ``CompositeProducer``)
+ * - 6
+ - ``ConsumerType``
+ - bound to ``BaseConsumer``, ``BaseStatefulConsumer``
+ * - 7
+ - ``TransformerType``
+ - bound to ``BaseTransformer``, ``BaseStatefulTransformer``, ``CompositeProcessor`` (hence, also ``BaseAsyncTransformer``)
+ * - 8
+ - ``AdaptiveTransformerType``
+ - bound to ``BaseAdaptiveTransformer``
+
+
+Base Classes for ezmsg processor Units:
+==============================================================================
+
++-----+---------------------------------+---------+-----------------------------+
+| Idx | Class | Parents | Expected TypeVars |
++=====+=================================+=========+=============================+
+| 1 | ``BaseProcessorUnit`` | \- | \- |
++-----+---------------------------------+---------+-----------------------------+
+| 2 | ``BaseProducerUnit`` | \- | ``ProducerType`` |
++-----+---------------------------------+---------+-----------------------------+
+| 3 | ``BaseConsumerUnit`` | 1 | ``ConsumerType`` |
++-----+---------------------------------+---------+-----------------------------+
+| 4 | ``BaseTransformerUnit`` | 1 | ``TransformerType`` |
++-----+---------------------------------+---------+-----------------------------+
+| 5 | ``BaseAdaptiveTransformerUnit`` | 1 | ``AdaptiveTransformerType`` |
++-----+---------------------------------+---------+-----------------------------+
+
+Note, it is strongly recommended to use `BaseConsumerUnit`, `BaseTransformerUnit`, or `BaseAdaptiveTransformerUnit` for implementing concrete subclasses rather than `BaseProcessorUnit`.
+
+|ezmsg_logo_small| How to implement a custom ezmsg Unit from a standalone processor
+=====================================================================================
+
+1. Create and test custom standalone processor as above.
+2. Decide which base unit to implement.
+ * Use the "Generic TypeVars for ezmsg Units" table above to determine the expected TypeVar.
+ * Find the Expected TypeVar in the "ezmsg Units" table.
+3. Create the derived class.
+
+Often, all that is required is the following (e.g., for a custom transformer):
+
+.. code-block:: python
+
+ import ezmsg.core as ez
+ from ezmsg.util.messages.axisarray import AxisArray
+ from ezmsg.sigproc.base import BaseTransformer, BaseTransformerUnit
+
+
+ class CustomTransformerSettings(ez.Settings):
+ ...
+
+
+ class CustomTransformer(BaseTransformer[CustomTransformerSettings, AxisArray, AxisArray]):
+ def _process(self, message: AxisArray) -> AxisArray:
+ # Your processing code here...
+ return message
+
+
+ class CustomUnit(BaseTransformerUnit[
+ CustomTransformerSettings, # SettingsType
+ AxisArray, # MessageInType
+ AxisArray, # MessageOutType
+ CustomTransformer, # TransformerType
+ ]):
+ SETTINGS = CustomTransformerSettings
+
+
+.. note:: The type of ProcessorUnit is based on the internal processor and not the input or output of the unit. Input streams are allowed in ProducerUnits and output streams in ConsumerUnits. For an example of such a use case, see ``BaseCounterFirstProducerUnit`` and its subclasses. ``BaseCounterFirstProducerUnit`` has an input stream that receives a flag signal from a clock that triggers a call to the internal producer.
+
+|ezmsg_logo_small| See Also
+********************************
+
+1. `Signal Processor Documentation `_
+#. `Signal Processing Tutorial <../../tutorials/signalprocessing.html>`_
+#. `Signal Processing HOW TOs <../../how-tos/signalprocessing/main.html>`_
+
+.. |ezmsg_logo_small| image:: ../_static/_images/ezmsg_logo.png
+ :width: 40
+ :alt: ezmsg logo
\ No newline at end of file
diff --git a/docs/source/guides/how-tos/signalprocessing/adaptive.rst b/docs/source/guides/how-tos/signalprocessing/adaptive.rst
new file mode 100644
index 00000000..3dbaf7d3
--- /dev/null
+++ b/docs/source/guides/how-tos/signalprocessing/adaptive.rst
@@ -0,0 +1,4 @@
+How to implement adaptive signal processing in ezmsg?
+#######################################################
+
+(under construction)
\ No newline at end of file
diff --git a/docs/source/guides/how-tos/signalprocessing/checkpoint.rst b/docs/source/guides/how-tos/signalprocessing/checkpoint.rst
new file mode 100644
index 00000000..79fcdc46
--- /dev/null
+++ b/docs/source/guides/how-tos/signalprocessing/checkpoint.rst
@@ -0,0 +1,4 @@
+How to use checkpoints for ezmsg signal processing Units that leverage ML models?
+######################################################################################
+
+(under construction)
\ No newline at end of file
diff --git a/docs/source/guides/how-tos/signalprocessing/composite.rst b/docs/source/guides/how-tos/signalprocessing/composite.rst
new file mode 100644
index 00000000..ae05d7c3
--- /dev/null
+++ b/docs/source/guides/how-tos/signalprocessing/composite.rst
@@ -0,0 +1,4 @@
+How to efficiently chain multiple signal processors in ezmsg?
+###############################################################
+
+(under construction)
\ No newline at end of file
diff --git a/docs/source/guides/how-tos/signalprocessing/content-signalprocessing.rst b/docs/source/guides/how-tos/signalprocessing/content-signalprocessing.rst
new file mode 100644
index 00000000..4aac0758
--- /dev/null
+++ b/docs/source/guides/how-tos/signalprocessing/content-signalprocessing.rst
@@ -0,0 +1,13 @@
+Signal Processing HOW TOs
+##########################
+
+.. toctree::
+ :maxdepth: 1
+
+ processor
+ stateful
+ standalone
+ adaptive
+ composite
+ unit
+ checkpoint
diff --git a/docs/source/guides/how-tos/signalprocessing/processor.rst b/docs/source/guides/how-tos/signalprocessing/processor.rst
new file mode 100644
index 00000000..ffc6c2cc
--- /dev/null
+++ b/docs/source/guides/how-tos/signalprocessing/processor.rst
@@ -0,0 +1,4 @@
+How to write a signal processor in ezmsg?
+###############################################
+
+(under construction)
\ No newline at end of file
diff --git a/docs/source/guides/how-tos/signalprocessing/standalone.rst b/docs/source/guides/how-tos/signalprocessing/standalone.rst
new file mode 100644
index 00000000..2e8eefa8
--- /dev/null
+++ b/docs/source/guides/how-tos/signalprocessing/standalone.rst
@@ -0,0 +1,4 @@
+How to use ezmsg-sigproc signal processors outside of an ezmsg context?
+###############################################################################
+
+(under construction)
\ No newline at end of file
diff --git a/docs/source/guides/how-tos/signalprocessing/stateful.rst b/docs/source/guides/how-tos/signalprocessing/stateful.rst
new file mode 100644
index 00000000..54937c0e
--- /dev/null
+++ b/docs/source/guides/how-tos/signalprocessing/stateful.rst
@@ -0,0 +1,4 @@
+How to implement a stateful signal processor in ezmsg?
+###############################################################
+
+(under construction)
\ No newline at end of file
diff --git a/docs/source/guides/how-tos/signalprocessing/unit.rst b/docs/source/guides/how-tos/signalprocessing/unit.rst
new file mode 100644
index 00000000..a2a1b2e5
--- /dev/null
+++ b/docs/source/guides/how-tos/signalprocessing/unit.rst
@@ -0,0 +1,11 @@
+How to turn a signal processor into an ``ezmsg`` Unit?
+#######################################################
+
+To convert a signal processor to an ``ezmsg`` Unit, you can follow these steps:
+
+1. **Define the Processor**: Create a class that inherits from the appropriate signal processor template (e.g., `SignalProcessor`, `Filter`, etc.).
+2. **Implement the Processing Logic**: Override the necessary methods to implement the signal processing logic.
+3. **Define Input and Output Ports**: Use the `ezmsg` port system to define input and output ports for the signal processor.
+4. **Register the Unit**: Use the `ezmsg` registration system to register the signal processor as an `ezmsg` Unit.
+
+(under construction)
\ No newline at end of file
diff --git a/docs/source/guides/sigproc/base.rst b/docs/source/guides/sigproc/base.rst
new file mode 100644
index 00000000..7d86330a
--- /dev/null
+++ b/docs/source/guides/sigproc/base.rst
@@ -0,0 +1,65 @@
+Base Processors
+========================================
+
+Here is the API for the base processors included in the `ezmsg-sigproc` extension. For more detailed information on the design decisions behind these base processors, please refer to the :doc:`ezmsg-sigproc explainer <../../explanations/sigproc>`.
+
+
+.. autoclass:: ezmsg.sigproc.base.BaseProcessor
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
+.. autoclass:: ezmsg.sigproc.base.BaseProducer
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
+.. autoclass:: ezmsg.sigproc.base.BaseConsumer
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
+.. autoclass:: ezmsg.sigproc.base.BaseTransformer
+ :members:
+ :show-inheritance:
+:inherited-members:
+
+.. autoclass:: ezmsg.sigproc.base.BaseStatefulProcessor
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
+.. autoclass:: ezmsg.sigproc.base.BaseStatefulProducer
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
+.. autoclass:: ezmsg.sigproc.base.BaseStatefulConsumer
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
+.. autoclass:: ezmsg.sigproc.base.BaseStatefulTransformer
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
+.. autoclass:: ezmsg.sigproc.base.BaseAdaptiveTransformer
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
+.. autoclass:: ezmsg.sigproc.base.BaseAsyncTransformer
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
+.. autoclass:: ezmsg.sigproc.base.CompositeProcessor
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
+.. autoclass:: ezmsg.sigproc.base.CompositeProducer
+ :members:
+ :show-inheritance:
+ :inherited-members:
diff --git a/docs/source/guides/sigproc/content-sigproc.rst b/docs/source/guides/sigproc/content-sigproc.rst
new file mode 100644
index 00000000..79c9d5eb
--- /dev/null
+++ b/docs/source/guides/sigproc/content-sigproc.rst
@@ -0,0 +1,22 @@
+ezmsg-sigproc
+===============
+
+Timeseries signal processing implementations in ezmsg, leveraging numpy and scipy.
+Most of the methods and classes in this extension are intended to be used in building signal processing pipelines.
+They use :class:`ezmsg.util.messages.axisarray.AxisArray` as the primary data structure for passing signals between components.
+The message's data are expected to be a numpy array.
+
+.. note:: Some generators might yield valid :class:`AxisArray` messages with ``.data`` size of 0.
+This may occur when the generator receives inadequate data to produce a valid output, such as when windowing or buffering.
+
+`ezmsg-sigproc` contains two types of modules:
+
+- base processors and units that provide fundamental building blocks for signal processing pipelines
+- in-built signal processing modules that implement common signal processing techniques
+
+.. toctree::
+ :maxdepth: 1
+
+ base
+ units
+ processors
diff --git a/docs/source/guides/sigproc/processors.rst b/docs/source/guides/sigproc/processors.rst
new file mode 100644
index 00000000..dff5ca35
--- /dev/null
+++ b/docs/source/guides/sigproc/processors.rst
@@ -0,0 +1,142 @@
+In-Built Signal Processing Modules
+======================================================
+
+Here is the API reference for the in-built signal processing modules included in the `ezmsg-sigproc` extension.
+
+ezmsg.sigproc.activation
+--------------------------
+
+.. automodule:: ezmsg.sigproc.activation
+ :members:
+
+
+ezmsg.sigproc.affinetransform
+-------------------------------
+
+.. automodule:: ezmsg.sigproc.affinetransform
+ :members:
+
+
+ezmsg.sigproc.aggregate
+-------------------------
+
+.. automodule:: ezmsg.sigproc.aggregate
+ :members:
+ :undoc-members:
+
+
+ezmsg.sigproc.bandpower
+-------------------------
+
+.. automodule:: ezmsg.sigproc.bandpower
+ :members:
+
+
+ezmsg.sigproc.filter
+----------------------
+
+.. automodule:: ezmsg.sigproc.filter
+ :members:
+
+ezmsg.sigproc.butterworthfilter
+---------------------------------
+
+.. automodule:: ezmsg.sigproc.butterworthfilter
+ :members:
+
+
+ezmsg.sigproc.decimate
+------------------------
+
+.. automodule:: ezmsg.sigproc.decimate
+ :members:
+
+
+ezmsg.sigproc.downsample
+--------------------------
+
+.. automodule:: ezmsg.sigproc.downsample
+ :members:
+
+
+ezmsg.sigproc.ewmfilter
+-------------------------
+
+.. automodule:: ezmsg.sigproc.ewmfilter
+ :members:
+
+
+ezmsg.sigproc.math
+-----------------------
+
+.. automodule:: ezmsg.sigproc.math.clip
+ :members:
+
+.. automodule:: ezmsg.sigproc.math.difference
+ :members:
+
+.. automodule:: ezmsg.sigproc.math.invert
+ :members:
+
+.. automodule:: ezmsg.sigproc.math.log
+ :members:
+
+.. automodule:: ezmsg.sigproc.math.scale
+ :members:
+
+
+ezmsg.sigproc.sampler
+-----------------------
+
+.. automodule:: ezmsg.sigproc.sampler
+ :members:
+
+
+ezmsg.sigproc.scaler
+----------------------
+
+.. automodule:: ezmsg.sigproc.scaler
+ :members:
+
+
+ezmsg.sigproc.signalinjector
+------------------------------
+
+.. automodule:: ezmsg.sigproc.signalinjector
+ :members:
+
+
+ezmsg.sigproc.slicer
+-----------------------
+
+.. automodule:: ezmsg.sigproc.slicer
+ :members:
+
+
+ezmsg.sigproc.spectrum
+------------------------
+
+.. automodule:: ezmsg.sigproc.spectrum
+ :members:
+ :undoc-members:
+
+
+ezmsg.sigproc.spectrogram
+---------------------------
+
+.. automodule:: ezmsg.sigproc.spectrogram
+ :members:
+
+
+ezmsg.sigproc.synth
+---------------------
+
+.. automodule:: ezmsg.sigproc.synth
+ :members:
+
+
+ezmsg.sigproc.window
+----------------------
+
+.. automodule:: ezmsg.sigproc.window
+ :members:
\ No newline at end of file
diff --git a/docs/source/guides/sigproc/units.rst b/docs/source/guides/sigproc/units.rst
new file mode 100644
index 00000000..322e8927
--- /dev/null
+++ b/docs/source/guides/sigproc/units.rst
@@ -0,0 +1,29 @@
+Base Processor Units
+=============================
+
+Here is the API for the base processor ezmsg ``Unit``\ s included in the `ezmsg-sigproc` extension. For more detailed information on the design decisions behind these base units, please refer to the :doc:`ezmsg-sigproc explainer <../../explanations/sigproc>`.
+
+.. autoclass:: ezmsg.sigproc.base.BaseProducerUnit
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
+.. autoclass:: ezmsg.sigproc.base.BaseProcessorUnit
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
+.. autoclass:: ezmsg.sigproc.base.BaseConsumerUnit
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
+.. autoclass:: ezmsg.sigproc.base.BaseTransformerUnit
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
+.. autoclass:: ezmsg.sigproc.base.BaseAdaptiveTransformerUnit
+ :members:
+ :show-inheritance:
+ :inherited-members:
diff --git a/docs/source/guides/tutorials/signalprocessing.rst b/docs/source/guides/tutorials/signalprocessing.rst
new file mode 100644
index 00000000..0156b995
--- /dev/null
+++ b/docs/source/guides/tutorials/signalprocessing.rst
@@ -0,0 +1,615 @@
+Leveraging ezmsg For Signal Processing
+###############################################
+
+`ezmsg` is a powerful framework for building signal processing applications. It provides a flexible and extensible architecture that allows users to create custom signal processors, integrate with ezmsg Units, and build complex processing pipelines.
+
+We will explore how to do this by recreating the `Downsample` signal processor unit. It will demonstrate how to create a signal processor, convert it to an ezmsg Unit, and use it in a processing pipeline. Additionally, it will provide a mini primer on the `AxisArray` class, which is the preferred ezmsg message format.
+
+.. tip:: Downsampling is a common signal processing operation that reduces the sampling rate of a signal by keeping only every nth sample. This is useful for reducing the amount of data to be processed, especially in real-time applications.
+
+
+|ezmsg_logo_small| Choosing your signal processing class
+**********************************************************
+
+We make use of the following decision tree to choose the appropriate signal processing class:
+
+.. graphviz::
+ :align: center
+
+ digraph signal_processor_decision_tree {
+ node [shape=box, style="rounded,filled", fillcolor="#f0f0f0", fontname="Arial"];
+ edge [fontname="Arial"];
+
+ AMP [label="Multiple Processors?", fontcolor="#ff0000"];
+ ARI [label="Receives Input?", fontcolor="#ff0000"];
+ ACB [label="Single Chain / Branching?"];
+ P [label="Producer", shape=diamond, fillcolor="#27f21cff"];
+ APO [label="Produces Output?", fontcolor="#ff0000"];
+ NBC [label="no base class", style="none"];
+ ACRI [label="Receives Input?"];
+ C [label="Consumer", shape=diamond, fillcolor="#27f21cff"];
+ T [label="Transformer", shape=diamond, fillcolor="#27f21cff", fontcolor="#ff0000"];
+ PS [label="Stateful?"];
+ CS [label="Stateful?"];
+ TS [label="Stateful?", fontcolor="#ff0000"];
+ TSA [label="Adaptive?", fontcolor="#ff0000"];
+ TSAF [label="Async First?", fontcolor="#ff0000"];
+ CompositeProducer [style="none, filled", fillcolor="#effb1aff"];
+ CompositeProcessor [style="none, filled", fillcolor="#effb1aff"];
+ BaseProducer [style="none, filled", fillcolor="#effb1aff"];
+ BaseStatefulProducer [style="none, filled", fillcolor="#effb1aff"];
+ BaseConsumer [style="none, filled", fillcolor="#effb1aff"];
+ BaseStatefulConsumer [style="none, filled", fillcolor="#effb1aff"];
+ BaseTransformer [style="none, filled", fillcolor="#effb1aff"];
+ BaseAdaptiveTransformer [style="none, filled", fillcolor="#effb1aff"];
+ BaseStatefulTransformer [style="none, filled", fillcolor="#effb1aff", fontcolor="#ff0000"];
+ BaseAsyncTransformer [style="none, filled", fillcolor="#effb1aff"];
+
+ AMP -> ARI [label="no", color="#ff0000", fontcolor="#ff0000"];
+ AMP -> ACB [label="yes"];
+ ARI -> P [label="no"];
+ ARI -> APO [label="yes", color="#ff0000", fontcolor="#ff0000"];
+ ACB -> NBC [label="branching"];
+ ACB -> ACRI [label="single chain"];
+ P -> PS;
+ APO -> C [label="no"];
+ APO -> T [label="yes", color="#ff0000", fontcolor="#ff0000"];
+ ACRI -> CompositeProducer [label="no"];
+ ACRI -> CompositeProcessor [label="yes"];
+ PS -> BaseProducer [label="no"];
+ PS -> BaseStatefulProducer [label="yes"];
+ C -> CS;
+ T -> TS [color="#ff0000", fontcolor="#ff0000"];
+ CS -> BaseConsumer [label="no"];
+ CS -> BaseStatefulConsumer [label="yes"];
+ TS -> BaseTransformer [label="no"];
+ TS -> TSA [label="yes", color="#ff0000", fontcolor="#ff0000"];
+ TSA -> TSAF [label="no", color="#ff0000", fontcolor="#ff0000"];
+ TSA -> BaseAdaptiveTransformer [label="yes"];
+ TSAF -> BaseStatefulTransformer [label="no", color="#ff0000", fontcolor="#ff0000"];
+ TSAF -> BaseAsyncTransformer [label="yes"];
+ }
+.. flowchart TD
+.. AMP{Multiple Processors?};
+.. AMP -->|no| ARI{Receives Input?};
+.. AMP -->|yes| ACB{Single Chain / Branching?}
+.. ARI -->|no| P(Producer);
+.. ARI -->|yes| APO{Produces Output?};
+.. ACB -->|branching| NBC[no base class];
+.. ACB -->|single chain| ACRI{Receives Input?};
+.. P --> PS{Stateful?};
+.. APO -->|no| C(Consumer);
+.. APO -->|yes| T(Transformer);
+.. ACRI -->|no| CompositeProducer;
+.. ACRI -->|yes| CompositeProcessor;
+.. PS -->|no| BaseProducer;
+.. PS -->|yes| BaseStatefulProducer;
+.. C --> CS{Stateful?};
+.. T --> TS{Stateful?};
+.. CS -->|no| BaseConsumer;
+.. CS -->|yes| BaseStatefulConsumer;
+.. TS -->|no| BaseTransformer;
+.. TS -->|yes| TSA{Adaptive?};
+.. TSA -->|no| TSAF{Async First?};
+.. TSA -->|yes| BaseAdaptiveTransformer;
+.. TSAF -->|no| BaseStatefulTransformer;
+.. TSAF -->|yes| BaseAsyncTransformer;
+
+In our case, we are creating a **single** signal processor that **receives input** and **produces output**. The decision tree indicates that we will be using a **transformer**-type base class. To continue, we need to determine if the processor is *stateful*, *adaptive* and *async first* or not.
+
+A stateful processor maintains internal state information that can affect its processing behavior, while a stateless processor does not maintain any internal state and processes each input independently. Adaptive transformers are a subtype of transformer that can adjust its settings based on trigger messages, whereas all other transformers are non-adaptive. Async first transformers prioritise asynchronous processing, meaning they can handle incoming messages without blocking, while non-async first transformers may block while processing messages.
+
+To answer whether our `Downsample` transformer is any of these types, we need to identify what we consider the settings (configuration) for the transformer and what we consider the state.
+
+A good rule of thumb is that settings are parameters used to configure the processor and are typically set once during initialization and remain constant. On the other hand, the processor state is internal data that the processor needs to maintain during its operation and can change dynamically as the processor processes data.
+
+We will see that `Downsample` is stateful, not adaptive and not async first, so we will inherit from the `BaseStatefulTransformer` class. This will become clearer as we implement the processor in the following sections.
+
+First, we need to install the `ezmsg-sigproc` package if we haven't already. This package contains the base classes for signal processing in ezmsg. You can install it using pip:
+
+.. code-block:: bash
+
+ pip install "ezmsg[sigproc]"
+
+
+|ezmsg_logo_small| Creating the `Downsample` signal processor
+*************************************************************
+
+We begin by identifying the components needed to create the `Downsample` signal processor. This includes defining the settings, state, and the main processing class itself.
+
+First create a new Python file named `downsample.py` in your root directory. In this file we will implement the `Downsample` signal processor.
+
+Add the following import statements to the top of the `downsample.py` file:
+
+.. code-block:: python
+
+ # downsample.py
+ import numpy as np
+ from ezmsg.util.messages.axisarray import (
+ AxisArray,
+ slice_along_axis,
+ replace,
+ )
+ import ezmsg.core as ez
+
+ from ezmsg.sigproc.base import (
+ BaseStatefulTransformer,
+ BaseTransformerUnit,
+ processor_state,
+ )
+
+.. note:: These are modules we will need in the implementation and will be explained as we go along. You will notice that we import `numpy` (for numerical operations), `AxisArray` (this is our class for handling multi-dimensional arrays with named axes), and from `ezmsg-sigproc`, we import the `BaseStatefulTransformer` class and the `BaseTransformerUnit` (for wrapping our processor into an ezmsg unit).
+
+
+DownsampleSettings class
+====================================
+
+To create a `Downsample` signal processor, we first define the settings for the processor. The parameters that we need to know for the transformer to operate include:
+
+- the axis along which to downsample.
+- desired rate after downsampling has occurred, or
+- the desired factor by which to downsample.
+
+Thus, your settings class will look like this:
+
+.. code-block:: python
+
+ class DownsampleSettings(ez.Settings):
+ """
+ Settings for :obj:`Downsample` node.
+ """
+
+ axis: str = "time"
+ """The name of the axis along which to downsample."""
+
+ target_rate: float | None = None
+ """Desired rate after downsampling. The actual rate will be the nearest integer factor of the input rate that is the same or higher than the target rate."""
+
+ factor: int | None = None
+ """Explicitly specify downsample factor. If specified, target_rate is ignored."""
+
+There are no ``__init__`` methods that you might expect because we are inheriting from ``ez.Settings``, which uses Python's dataclass functionality to automatically generate the ``__init__`` method based on the class attributes.
+
+.. tip:: It is very good practice to name your settings class with the name of your processor followed by `Settings`. This makes it easy to identify the settings class for a given processor.
+
+The fact that we will not ever need to change these settings implies we do not need use of an adaptive transformer.
+
+DownsampleState class
+========================
+
+For the general operation of the `Downsample` processor, we need to keep track of the downsampling factor (since this could change per message) and the index of the next message's first sample (for maintaining continuity in the downsampled output), especially when processing a stream of data.
+
+The fact that we need to maintain state information implies that we will need to use a stateful transformer.
+
+Your state class will look like this:
+
+.. code-block:: python
+
+ @processor_state
+ class DownsampleState:
+ q: int = 0
+ """The integer downsampling factor. It will be determined based on the target rate."""
+
+ s_idx: int = 0
+ """Index of the next msg's first sample into the virtual rotating ds_factor counter."""
+
+Again, our class seems to be missing an ``__init__`` method, but this is because we are using the ``@processor_state`` decorator from `ezmsg-sigproc`, which automatically generates the ``__init__`` method for us. Just another way to make our code cleaner and more maintainable.
+
+.. note:: It is very good practice to name your state class with the name of your processor followed by `State`. This makes it easy to identify the state class for a given processor.
+
+.. note:: Finally, our transformer is **not async first** as we do not need to prioritise asynchronous processing, which is usually more relevant for processors that interface with IO operations whose timing is unpredictable.
+
+|ezmsg_logo_small| DownsampleTransformer Class
+*******************************************************
+
+We have already identified that we will be using a stateful transformer, so we will inherit from the ``BaseStatefulTransformer`` class. Create the class definition as follows:
+
+.. code-block:: python
+
+ class DownsampleTransformer(
+ BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]
+ ):
+ """
+ Downsampled data simply comprise every `factor`th sample.
+ This should only be used following appropriate lowpass filtering.
+ If your pipeline does not already have lowpass filtering then consider
+ using the :obj:`Decimate` collection instead.
+ """
+
+ def _hash_message(self, message: AxisArray) -> int: ...
+
+ def _reset_state(self, message: AxisArray) -> None: ...
+
+ def _process(self, message: AxisArray) -> AxisArray: ...
+
+.. note:: The `BaseStatefulTransformer` class is a generic class that takes four type parameters: the settings type, the input message type, the output message type, and the state type. In our case, the settings type is `DownsampleSettings`, the input and output message types are both `AxisArray`, and the state type is `DownsampleState`.
+
+
+As can be seen above we must implement the following methods:
+
+- ``_hash_message``: This method is used to generate a hash for the input message. This is useful for caching and avoiding redundant processing.
+- ``_reset_state``: This method is used to reset the internal state of the processor. This is useful when starting a new processing session or when the input data changes significantly.
+- ``_process``: This is the main processing method where the downsampling logic will be implemented.
+
+The first two methods deal with the state of the processor (and are only required for stateful processors), while the third method is where the actual downsampling logic will be implemented.
+
+.. important:: ``_process`` is a necessary method for all transformers and consumers. The equivalent method for producers is called ``_produce``. For non-stateful processors, this will be the only method you need to implement if you inherit from the relevant base class. All other methods are preimplemented for you, but you can override them if needed.
+
+In order to implement these methods, we need to understand our preferred message format: `AxisArray`. This is a flexible and powerful class for handling multi-dimensional arrays with named axes, which is particularly useful for signal processing applications. I have already used `AxisArray` in our code as the input message and output message types.
+
+A detailed explanation of the `AxisArray` class is beyond the scope of this tutorial, but you can refer to the :doc:`AxisArray explainer <../explanations/axisarray>` as well as the :doc:`API reference <../reference/API/axisarray>` for more information.
+
+Brief Aside on AxisArray
+=================================
+
+An ``AxisArray`` is a multi-dimensional array with named axes. Each axis can have a name and a set of labels for its elements. This allows for more intuitive indexing and manipulation of the data.
+
+An `AxisArray` has the following attributes:
+
+- ``data``: a numpy ndarray containing the actual data.
+- ``dims``: a list of axis names.
+- ``axes``: a dictionary mapping axis names to their label information.
+- ``attrs``: a dictionary for storing additional metadata.
+- ``key``: a unique identifier for the array.
+
+Unsurprisingly, all of this must be self-consistent: the number of axis names in ``dims`` must match the number of dimensions in ``data``, and the axis names in ``axes`` should match the ones in ``dims``. The label information in ``axes`` refers to the 'value' of each axis index, e.g., for a time axis, the labels might be timestamps. We provide three commonly used axes type objects:
+
+- A ``LinearAxis``: represents a linear axis with evenly spaced values - you just need the ``offset`` (start value) and the ``gain`` (step size). An example of this would be simple numerical index (offset=0, gain=1) or regularly spaced time samples (offset=start time, gain=1/sampling rate).
+- A ``TimeAxis``: this is a `LinearAxis` that represents a time axis. Its ``unit`` attribute is by default set to seconds (s).
+- A ``CoordinateAxis``: this is our continuous/dense axis, which can represent any continuous variable, such as frequency or spatial coordinates. You provide the actual values for each index in a ``data`` array of values.
+
+The `AxisArray` class provides several methods for manipulating and accessing the data, and the one we will be using in our `Downsample` processor is ``slice_along_axis``. This method allows us to slice the array along a specified axis, which is essential for downsampling.
+
+Hashing the State
+===========================
+
+We can generate a unique hash for the input message using the `key` attribute of the `AxisArray` which we tend to use for identifying what device our data has come from as well as an identifier of the message structure (in this case, the `gain` of the axes containing the data). Since downsampling requires messages to come with linearly spaced data, our axes will either be a `LinearAxis` or a `TimeAxis`, so this attribute will exist.
+
+Our implementation of the ``_hash_message`` method will look like this:
+
+.. code-block:: python
+
+ def _hash_message(self, message: AxisArray) -> int:
+ return hash((message.axes[self.settings.axis].gain, message.key))
+
+.. note:: The idea here is that if either the gain of the axis or the key of the message changes, we are dealing with different data, so we need to reevaluate our state. Importantly, the `DownsampleTransformer` *can* be implemented in a stateless way, but this would require computing the downsampling factor and first sample index every time, and hence a much less efficient implementation.
+
+
+Resetting the State
+=================================
+
+The ``_reset_state`` method is used to reset the internal state of the processor when a message is received with a hash different than that stored by the `DownsampleTransformer`. We need to reset the downsampling factor and the index of the next message's first sample. This is important when starting a new processing session or when the input data changes shape (like a different sampling rate).
+
+We set the downsampling factor either to the one in `DownsampleSettings` if specified, else we compute it based on the target rate and the input message rate. If target rate is not specified, we default to a downsampling factor of 1 (no downsampling). If a target rate is specified, we compute the downsampling factor as the nearest integer that is the same or higher than the ratio of the input rate to the target rate. If the final downsampling factor is less than 1 (not a valid value), we set it to 1 (no downsampling).
+
+Finally, we reset the index of the next message's first sample to 0.
+
+.. code-block:: python
+
+ def _reset_state(self, message: AxisArray) -> None:
+ axis = message.get_axis(self.settings.axis)
+
+ if self.settings.factor is not None:
+ q = self.settings.factor
+ elif self.settings.target_rate is None:
+ q = 1
+ else:
+ q = int(1 / (axis.gain * self.settings.target_rate))
+ if q < 1:
+ ez.logger.warning(
+ f"Target rate {self.settings.target_rate} cannot be achieved with input rate of {1 / axis.gain}."
+ "Setting factor to 1."
+ )
+ q = 1
+ self._state.q = q
+ self._state.s_idx = 0
+
+
+.. _processing_data_tutorial:
+
+|ezmsg_logo_small| Processing the Data
+***********************************************
+
+To finish the `DownsampleTransformer` class, we need to actually process the data by downsampling.
+This is done in the ``_process`` method. We will use some of the methods provided by the `AxisArray` class to help us with this.
+
+Step 1: Getting the indices to slice the data
+=========================================================
+
+We first get the index of the axis (`axis_idx`) and the axis itself (`axis`) along which we want to downsample. We then determine the number of samples in the input message along that axis:
+
+.. code-block:: python
+
+ downsample_axis = self.settings.axis
+ axis = message.get_axis(downsample_axis)
+ axis_idx = message.get_axis_idx(downsample_axis)
+ n_samples = message.data.shape[axis_idx]
+
+Next, create a linear range of indices starting from the current index of the next message's first sample (`self._state.s_idx`) to the current index plus the number of samples in the input message. We use modulo operation with the downsampling factor (`self._state.q`) to create a virtual rotating counter. If the number of samples is greater than 0, we update the index of the next message's first sample for the next iteration. Our slice object is the indices where the virtual counter is 0, which corresponds to the samples we want to keep after downsampling:
+
+.. code-block:: python
+
+ samples = (
+ np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
+ )
+ if n_samples > 0:
+ # Update state for next iteration.
+ self._state.s_idx = samples[-1] + 1
+
+ pub_samples = np.where(samples == 0)[0]
+ if len(pub_samples) > 0:
+ n_step = pub_samples[0].item()
+ data_slice = pub_samples
+ else:
+ n_step = 0
+ data_slice = slice(None, 0, None)
+
+Here `pub_samples` corresponds to the samples we want to keep after downsampling - they are the zeros in our virtual counter. If there are any samples to publish, we set `n_step` to the first index in `pub_samples` (ie. the first zero) and `data_slice` to `pub_samples`. If there are no samples to publish, we set `n_step` to 0 and `data_slice` to an empty slice.
+
+Step 2: Slicing the data and updating the axis
+=========================================================
+
+We will create the output message by first creating our new numpy ndarray by slicing the input message's data along the specified axis using the `slice_along_axis` function from the `AxisArray` class. Then we will update the axis information to reflect the downsampling. Finally, we create a new `AxisArray` message with the downsampled data and updated axes using the ``replace`` function from the `AxisArray` class.
+
+The slicing of the data is done as follows:
+
+.. code-block:: python
+
+ slice_along_axis(message.data, sl=data_slice, axis=axis_idx)
+
+
+We also need to update the axis information to reflect the downsampling. All other axes stay as before, but the one we downsampled on (`downsample_axis`) needs to be updated. The gain of the axis is multiplied by the downsampling factor, and the offset is updated based on the number of steps taken in the virtual counter:
+
+.. code-block:: python
+
+ from ezmsg.util.messages.axisarray import replace
+
+ new_axes={
+ **message.axes,
+ downsample_axis: replace(
+ axis,
+ gain=axis.gain * self._state.q,
+ offset=axis.offset + axis.gain * n_step,
+ ),
+ }
+
+.. important:: The ``replace`` function is a utility function provided by the `AxisArray` class that allows us to create a new object with updated attributes while keeping the other attributes unchanged. It is very fast by avoiding deep copies of the entire object and safety checks that usually occur at object creation time. Its signature is ``replace(obj: T, **changes) -> T``, where `obj` is the object to be updated and `**changes` are the attributes to be updated with their new values. For performance reasons, we **strongly suggest** using the ``replace`` function whenever you are transforming an `AxisArray` message and do not need its previous state.
+
+.. tip:: If, on the contrary, you would prefer a safer (but slower) implementation, you can set the environment variable ``EZMSG_DISABLE_FAST_REPLACE=1`` before running your code. It will then use the Python `dataclasses` implementation of ``replace`` with consistency checks.
+
+
+Step 3: Creating the output message
+=========================================================
+
+Finally, we create the output message:
+
+.. code-block:: python
+
+ msg_out = replace(
+ message,
+ data=slice_along_axis(message.data, data_slice, axis=axis_idx),
+ axes={
+ **message.axes,
+ downsample_axis: replace(
+ axis,
+ gain=axis.gain * self._state.q,
+ offset=axis.offset + axis.gain * n_step,
+ ),
+ },
+ )
+
+.. note:: We used ``replace`` to create the output message, updating only the `data` and `axes` attributes while keeping the other attributes (like `dims`, `attrs`, and `key`) unchanged.
+
+Step 4: Putting it all together
+=========================================================
+
+The final implementation of the ``_process`` method looks like this:
+
+.. code-block:: python
+
+ def _process(self, message: AxisArray) -> AxisArray:
+ downsample_axis = self.settings.axis
+ axis = message.get_axis(downsample_axis)
+ axis_idx = message.get_axis_idx(downsample_axis)
+
+ n_samples = message.data.shape[axis_idx]
+ samples = (
+ np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
+ )
+ if n_samples > 0:
+ # Update state for next iteration.
+ self._state.s_idx = samples[-1] + 1
+
+ pub_samples = np.where(samples == 0)[0]
+ if len(pub_samples) > 0:
+ n_step = pub_samples[0].item()
+ data_slice = pub_samples
+ else:
+ n_step = 0
+ data_slice = slice(None, 0, None)
+ msg_out = replace(
+ message,
+ data=slice_along_axis(message.data, data_slice, axis=axis_idx),
+ axes={
+ **message.axes,
+ downsample_axis: replace(
+ axis,
+ gain=axis.gain * self._state.q,
+ offset=axis.offset + axis.gain * n_step,
+ ),
+ },
+ )
+ return msg_out
+
+
+|ezmsg_logo_small| Final DownsampleTransformer Class
+*******************************************************
+
+Confirm that your final `DownsampleTransformer` class looks like this:
+
+.. code-block:: python
+
+ class DownsampleTransformer(
+ BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]
+ ):
+ """
+ Downsampled data simply comprise every `factor`th sample.
+ This should only be used following appropriate lowpass filtering.
+ If your pipeline does not already have lowpass filtering then consider
+ using the :obj:`Decimate` collection instead.
+ """
+
+ def _hash_message(self, message: AxisArray) -> int:
+ return hash((message.axes[self.settings.axis].gain, message.key))
+
+ def _reset_state(self, message: AxisArray) -> None:
+ axis = message.get_axis(self.settings.axis)
+
+ if self.settings.factor is not None:
+ q = self.settings.factor
+ elif self.settings.target_rate is None:
+ q = 1
+ else:
+ q = int(1 / (axis.gain * self.settings.target_rate))
+ if q < 1:
+ ez.logger.warning(
+ f"Target rate {self.settings.target_rate} cannot be achieved with input rate of {1 / axis.gain}."
+ "Setting factor to 1."
+ )
+ q = 1
+ self._state.q = q
+ self._state.s_idx = 0
+
+ def _process(self, message: AxisArray) -> AxisArray:
+ downsample_axis = self.settings.axis
+ axis = message.get_axis(downsample_axis)
+ axis_idx = message.get_axis_idx(downsample_axis)
+
+ n_samples = message.data.shape[axis_idx]
+ samples = (
+ np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
+ )
+ if n_samples > 0:
+ # Update state for next iteration.
+ self._state.s_idx = samples[-1] + 1
+
+ pub_samples = np.where(samples == 0)[0]
+ if len(pub_samples) > 0:
+ n_step = pub_samples[0].item()
+ data_slice = pub_samples
+ else:
+ n_step = 0
+ data_slice = slice(None, 0, None)
+ msg_out = replace(
+ message,
+ data=slice_along_axis(message.data, data_slice, axis=axis_idx),
+ axes={
+ **message.axes,
+ downsample_axis: replace(
+ axis,
+ gain=axis.gain * self._state.q,
+ offset=axis.offset + axis.gain * n_step,
+ ),
+ },
+ )
+ return msg_out
+
+
+|ezmsg_logo_small| Using the DownsampleTransformer
+**********************************************************
+
+The `Downsample` class is now fully implemented and ready for use in signal processing pipelines.
+You can even use it outside of an ezmsg context by instantiating it directly and calling its ``_process`` method with an `AxisArray` message.
+
+.. important:: The preferred way to call the ``_process`` method is to call the instance directly; below you will see that in the line: ``msg_out = downsampler(msg_in)``. This is possible because all of the processor base classes implement the ``__call__`` method, to call the ``_process`` method internally (or ``_produce`` in the case of `Producers`).
+
+In a separate Python file in the same directory, you can test the `DownsampleTransformer` class as follows:
+
+.. code-block:: python
+
+ # test_downsample.py
+ from downsample import DownsampleTransformer, DownsampleSettings
+ import ezmsg.core as ez
+ from ezmsg.util.messages.axisarray import AxisArray, LinearAxis
+ import numpy as np
+
+ # Create a DownsampleTransformer instance with desired settings.
+ settings = DownsampleSettings(axis="time", target_rate=50) # Target rate of 50 Hz.
+ downsampler = DownsampleTransformer(settings)
+
+ # Create a sample AxisArray message with a time axis and some data.
+ time_axis = LinearAxis(offset=0.0, gain=0.01) # 100 Hz sampling rate.
+ data = np.random.rand(1000) # 1000 samples of random data.
+ msg_in = AxisArray(
+ data=data,
+ dims=["time"],
+ axes={"time": time_axis},
+ key="example_device",
+ )
+
+ # Process the message to downsample it.
+ msg_out = downsampler(msg_in)
+
+ print(f"Input shape: {msg_in.data.shape}, Output shape: {msg_out.data.shape}")
+ print(f"Input time axis gain: {msg_in.axes['time'].gain}, Output time axis gain: {msg_out.axes['time'].gain}")
+
+Doing the above is very handy for unit testing your processor as well as for offline processing of data.
+
+.. note:: The `downsample` module in `ezmsg-sigproc` has a utility function for creating a `DownsampleTransformer` instance with the desired settings:
+
+ .. code-block:: python
+
+ def downsample(
+ axis: str = "time",
+ target_rate: float | None = None,
+ factor: int | None = None,
+ ) -> DownsampleTransformer:
+ return DownsampleTransformer(
+ DownsampleSettings(axis=axis, target_rate=target_rate, factor=factor)
+ )
+
+ After importing this utility function, lines 8 and 9 in our code above could now read:
+
+ .. code-block:: python
+
+ downsampler = downsample(axis="time", target_rate=50)
+
+Of course, the real power of `ezmsg` comes from integrating your processor into an `ezmsg` Unit and using it in a processing pipeline. We will see how to do this next.
+
+
+|ezmsg_logo_small| Creating the `Downsample ezmsg` Unit
+***********************************************************
+
+`ezmsg-sigproc` provides convenient ezmsg `Unit` wrappers for all the signal processor base classes. To do this inherit from the appropriate `ezmsg-sigproc` unit class. These are:
+
+- `BaseProducerUnit`
+- `BaseConsumerUnit`
+- `BaseTransformerUnit`
+
+The names correspond to the type of base processor class you are using. Importantly, these unit classes are agnostic to whether your processor is stateful/adaptive/async first - they will work with any of the processor base classes.
+
+Our `Downsample` processor is a stateful transformer, so we will inherit from the `BaseTransformerUnit` class.
+
+A lot of the behind-the-scenes work is done for you by the `BaseTransformerUnit` class, so we only need to write the following:
+
+.. code-block:: python
+
+ class DownsampleUnit(
+ BaseTransformerUnit[DownsampleSettings, AxisArray, AxisArray, DownsampleTransformer]
+ ):
+ SETTINGS = DownsampleSettings
+
+
+Connecting it to other `Component`\ s and initialising the transformer are accomplished in the same way that we did in the :doc:`pipeline tutorial `.
+
+
+|ezmsg_logo_small| See Also
+************************************
+
+- `Further examples `_ can be found in the examples directory in `ezmsg`. These are examples of creating and using `ezmsg` Units and pipelines.
+- `ezmsg-sigproc` has a large number of already implemented signal processors. More information can be found at the :doc:`ezmsg-sigproc reference <../extensions/sigproc/content-sigproc>`.
+- `Downsample` class reference
+
+.. |ezmsg_logo_small| image:: ../_static/_images/ezmsg_logo.png
+ :width: 40
+ :alt: ezmsg logo
\ No newline at end of file
diff --git a/docs/source/index.rst b/docs/source/index.rst
new file mode 100644
index 00000000..5f3c32ea
--- /dev/null
+++ b/docs/source/index.rst
@@ -0,0 +1,82 @@
+ezmsg.sigproc
+==============
+
+Timeseries signal processing modules for the `ezmsg `_ framework.
+
+Overview
+--------
+
+``ezmsg-sigproc`` provides signal processing primitives built on ezmsg, leveraging numpy, scipy, pywavelets, and sparse. The package offers both standalone processors for offline analysis and Unit wrappers for streaming pipelines.
+
+Key features:
+
+* **Filtering** - Various filter implementations (Chebyshev, comb filters, etc.)
+* **Spectral analysis** - Spectrogram, spectrum, and wavelet transforms
+* **Resampling** - Downsample, decimate, and resample operations
+* **Windowing** - Sliding windows and buffering utilities
+* **Math operations** - Arithmetic, log, abs, difference, and more
+* **Signal generation** - Synthetic signal generators
+
+All modules use :class:`ezmsg.util.messages.axisarray.AxisArray` as the primary data structure for passing signals between components.
+
+.. note::
+ Processors can be used standalone for offline analysis or integrated into ezmsg pipelines for real-time streaming applications.
+
+Installation
+------------
+
+Install from PyPI:
+
+.. code-block:: bash
+
+ pip install ezmsg-sigproc
+
+Or install the latest development version:
+
+.. code-block:: bash
+
+ pip install git+https://github.com/ezmsg-org/ezmsg-sigproc@dev
+
+Dependencies
+^^^^^^^^^^^^
+
+Core dependencies:
+
+* ``ezmsg`` - Core messaging framework
+* ``numpy`` - Numerical computing
+* ``scipy`` - Scientific computing and signal processing
+* ``pywavelets`` - Wavelet transforms
+* ``sparse`` - Sparse array operations
+* ``numba`` - JIT compilation for performance
+
+Quick Start
+-----------
+
+For general ezmsg tutorials and guides, visit `ezmsg.org `_.
+
+For package-specific documentation:
+
+* **Processor Architecture** - See :doc:`guides/ProcessorsBase` for details on the processor hierarchy
+* **How-To Guides** - See :doc:`guides/how-tos/signalprocessing/content-signalprocessing` for usage patterns
+* **API Reference** - See :doc:`api/index` for complete API documentation
+
+Documentation
+-------------
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Contents:
+
+ guides/ProcessorsBase
+ guides/HybridBuffer
+ guides/how-tos/signalprocessing/content-signalprocessing
+ guides/tutorials/signalprocessing
+ guides/sigproc/content-sigproc
+ api/index
+
+
+Indices and tables
+------------------
+
+* :ref:`genindex`
+* :ref:`modindex`
diff --git a/pyproject.toml b/pyproject.toml
index 20ccbaad..ab4bea9a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -27,6 +27,7 @@ dev = [
"jupyter>=1.1.1",
{include-group = "lint"},
{include-group = "test"},
+ {include-group = "docs"},
]
lint = [
"ruff"
@@ -37,6 +38,16 @@ test = [
"pytest-cov>=5.0.0",
"pytest>=8.3.3",
]
+docs = [
+ "sphinx>=8.1.3",
+ "pydata-sphinx-theme",
+ "sphinx_autodoc_typehints>=3.0.0",
+ "sphinx_copybutton",
+ "myst_parser",
+]
+profile = [
+ "snakeviz>=2.2.2",
+]
[build-system]
requires = ["hatchling", "hatch-vcs"]
diff --git a/src/ezmsg/sigproc/fbcca.py b/src/ezmsg/sigproc/fbcca.py
new file mode 100644
index 00000000..361d7a5f
--- /dev/null
+++ b/src/ezmsg/sigproc/fbcca.py
@@ -0,0 +1,332 @@
+import typing
+import math
+from dataclasses import field
+
+import numpy as np
+
+import ezmsg.core as ez
+from ezmsg.util.messages.axisarray import AxisArray
+from ezmsg.util.messages.util import replace
+
+from .sampler import SampleTriggerMessage
+from .window import WindowTransformer, WindowSettings
+
+from .base import (
+ BaseTransformer,
+ BaseTransformerUnit,
+ CompositeProcessor,
+ BaseProcessor,
+ BaseStatefulProcessor,
+)
+
+from .kaiser import KaiserFilterSettings
+from .filterbankdesign import (
+ FilterbankDesignSettings,
+ FilterbankDesignTransformer,
+)
+
+
+class FBCCASettings(ez.Settings):
+ """
+ Settings for :obj:`FBCCATransformer`
+ """
+
+ time_dim: str
+ """
+ The time dim in the data array.
+ """
+
+ ch_dim: str
+ """
+ The channels dim in the data array.
+ """
+
+ filterbank_dim: str | None = None
+ """
+ The filter bank subband dim in the data array. If unspecified, method falls back to CCA
+ None (default): the input has no subbands; just use CCA
+ """
+
+ harmonics: int = 5
+ """
+ The number of additional harmonics beyond the fundamental to use for the 'design' matrix.
+ 5 (default): Evaluate 5 harmonics of the base frequency.
+ Many periodic signals are not pure sinusoids, and inclusion of higher harmonics can help evaluate the
+ presence of signals with higher frequency harmonic content
+ """
+
+ freqs: typing.List[float] = field(default_factory=list)
+ """
+ Frequencies (in hz) to evaluate the presence of within the input signal.
+ [] (default): an empty list; frequencies will be found within the input SampleMessages.
+ AxisArrays have no good place to put this metadata, so specify frequencies here if only AxisArrays
+ will be passed as input to the generator. If the input has a `trigger` attr of type :obj:`SampleTriggerMessage`,
+ the processor looks for the `freqs` attribute within that trigger for a list of frequencies to evaluate.
+ This field is present in the :obj:`SSVEPSampleTriggerMessage` defined in ezmsg.tasks.ssvep from the ezmsg-tasks package.
+ NOTE: Avoid frequencies that have line-noise (60 Hz/50 Hz) as a harmonic.
+ """
+
+ softmax_beta: float = 1.0
+ """
+ Beta parameter for softmax on output --> "probabilities".
+ 1.0 (default): Use the shifted softmax transformation to output 0-1 probabilities.
+ If 0.0, the maximum singular value of the SVD for each design matrix is output
+ """
+
+ target_freq_dim: str = "target_freq"
+ """
+ Name for dim to put target frequency outputs on.
+ 'target_freq' (default)
+ """
+
+ max_int_time: float = 0.0
+ """
+ Maximum integration time (in seconds) to use for calculation.
+ 0 (default): Use all time provided for the calculation.
+ Useful for artificially limiting the amount of data used for the CCA method to evaluate
+ the necessary integration time for good decoding performance
+ """
+
+
+class FBCCATransformer(BaseTransformer[FBCCASettings, AxisArray, AxisArray]):
+ """
+ A canonical-correlation (CCA) signal decoder for detection of periodic activity in multi-channel timeseries
+ recordings. It is particularly useful for detecting the presence of steady-state evoked responses in multi-channel
+ EEG data. Please see Lin et. al. 2007 for a description on the use of CCA to detect the presence of SSVEP in EEG
+ data.
+ This implementation also includes the "Filterbank" extension of the CCA decoding approach which utilizes a
+ filterbank to decompose input multi-channel EEG data into several frequency sub-bands; each of which is analyzed
+ with CCA, then combined using a weighted sum; allowing CCA to more readily identify harmonic content in EEG data.
+ Read more about this approach in Chen et. al. 2015.
+
+ ## Further reading:
+ * [Lin et. al. 2007](https://ieeexplore.ieee.org/document/4015614)
+ * [Nakanishi et. al. 2015](https://doi.org/10.1371%2Fjournal.pone.0140703)
+ * [Chen et. al. 2015](http://dx.doi.org/10.1088/1741-2560/12/4/046008)
+ """
+
+ def _process(self, message: AxisArray) -> AxisArray:
+ """
+ Input: AxisArray with at least a time_dim, and ch_dim
+ Output: AxisArray with time_dim, ch_dim, (and filterbank_dim if specified)
+ collapsed, with a new 'target_freq' dim of length 'freqs'
+ """
+
+ test_freqs: list[float] = self.settings.freqs
+ trigger = message.attrs.get("trigger", None)
+ if isinstance(trigger, SampleTriggerMessage):
+ if len(test_freqs) == 0:
+ test_freqs = getattr(trigger, "freqs", [])
+
+ if len(test_freqs) == 0:
+ raise ValueError("no frequencies to test")
+
+ time_dim_idx = message.get_axis_idx(self.settings.time_dim)
+ ch_dim_idx = message.get_axis_idx(self.settings.ch_dim)
+
+ filterbank_dim_idx = None
+ if self.settings.filterbank_dim is not None:
+ filterbank_dim_idx = message.get_axis_idx(self.settings.filterbank_dim)
+
+ # Move (filterbank_dim), time, ch to end of array
+ rm_dims = [self.settings.time_dim, self.settings.ch_dim]
+ if self.settings.filterbank_dim is not None:
+ rm_dims = [self.settings.filterbank_dim] + rm_dims
+ new_order = [i for i, dim in enumerate(message.dims) if dim not in rm_dims]
+ if filterbank_dim_idx is not None:
+ new_order.append(filterbank_dim_idx)
+ new_order.extend([time_dim_idx, ch_dim_idx])
+ out_dims = [
+ message.dims[i] for i in new_order if message.dims[i] not in rm_dims
+ ]
+ data_arr = message.data.transpose(new_order)
+
+ # Add a singleton dim for filterbank dim if we don't have one
+ if filterbank_dim_idx is None:
+ data_arr = data_arr[..., None, :, :]
+ filterbank_dim_idx = data_arr.ndim - 3
+
+ # data_arr is now (..., filterbank, time, ch)
+ # Get output shape for remaining dims and reshape data_arr for iterative processing
+ out_shape = list(data_arr.shape[:-3])
+ data_arr = data_arr.reshape([math.prod(out_shape), *data_arr.shape[-3:]])
+
+ # Create output dims and axes with added target_freq_dim
+ out_shape.append(len(test_freqs))
+ out_dims.append(self.settings.target_freq_dim)
+ out_axes = {
+ axis_name: axis
+ for axis_name, axis in message.axes.items()
+ if axis_name not in rm_dims
+ and not (
+ isinstance(axis, AxisArray.CoordinateAxis)
+ and any(d in rm_dims for d in axis.dims)
+ )
+ }
+ out_axes[self.settings.target_freq_dim] = AxisArray.CoordinateAxis(
+ np.array(test_freqs), [self.settings.target_freq_dim]
+ )
+
+ if message.data.size == 0:
+ out_data = message.data.reshape(out_shape)
+ output = replace(message, data=out_data, dims=out_dims, axes=out_axes)
+ return output
+
+ # Get time axis
+ t_ax_info = message.ax(self.settings.time_dim)
+ t = t_ax_info.values
+ t -= t[0]
+ max_samp = len(t)
+ if self.settings.max_int_time > 0:
+ max_samp = int(abs(t_ax_info.values - self.settings.max_int_time).argmin())
+ t = t[:max_samp]
+
+ calc_output = np.zeros((*data_arr.shape[:-2], len(test_freqs)))
+
+ for test_freq_idx, test_freq in enumerate(test_freqs):
+ # Create the design matrix of base frequency and requested harmonics
+ Y = np.column_stack(
+ [
+ fn(2.0 * np.pi * k * test_freq * t)
+ for k in range(1, self.settings.harmonics + 1)
+ for fn in (np.sin, np.cos)
+ ]
+ )
+
+ for test_idx, arr in enumerate(
+ data_arr
+ ): # iterate over first dim; arr is (filterbank x time x ch)
+ for band_idx, band in enumerate(
+ arr
+ ): # iterate over second dim: arr is (time x ch)
+ calc_output[test_idx, band_idx, test_freq_idx] = cca_rho_max(
+ band[:max_samp, ...], Y
+ )
+
+ # Combine per-subband canonical correlations using a weighted sum
+ # https://iopscience.iop.org/article/10.1088/1741-2560/12/4/046008
+ freq_weights = (np.arange(1, calc_output.shape[1] + 1) ** -1.25) + 0.25
+ calc_output = ((calc_output**2) * freq_weights[None, :, None]).sum(axis=1)
+
+ if self.settings.softmax_beta != 0:
+ calc_output = calc_softmax(
+ calc_output, axis=-1, beta=self.settings.softmax_beta
+ )
+
+ output = replace(
+ message,
+ data=calc_output.reshape(out_shape),
+ dims=out_dims,
+ axes=out_axes,
+ )
+
+ return output
+
+
+class FBCCA(BaseTransformerUnit[FBCCASettings, AxisArray, AxisArray, FBCCATransformer]):
+ SETTINGS = FBCCASettings
+
+
+class StreamingFBCCASettings(FBCCASettings):
+ """
+ Perform rolling/streaming FBCCA on incoming EEG.
+ Decomposes the input multi-channel timeseries data into multiple sub-bands using a FilterbankDesign Transformer,
+ then accumulates data using Window into short-time observations for analysis using an FBCCA Transformer.
+ """
+
+ window_dur: float = 4.0 # sec
+ window_shift: float = 0.5 # sec
+ window_dim: str = "fbcca_window"
+ filter_bw: float = 7.0 # Hz
+ filter_low: float = 7.0 # Hz
+ trans_bw: float = 2.0 # Hz
+ ripple_db: float = 20.0 # dB
+ subbands: int = 12
+
+
+class StreamingFBCCATransformer(
+ CompositeProcessor[StreamingFBCCASettings, AxisArray, AxisArray]
+):
+ @staticmethod
+ def _initialize_processors(
+ settings: StreamingFBCCASettings,
+ ) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
+ pipeline = {}
+
+ if settings.filterbank_dim is not None:
+ cut_freqs = (
+ np.arange(settings.subbands + 1) * settings.filter_bw
+ ) + settings.filter_low
+ filters = [
+ KaiserFilterSettings(
+ axis=settings.time_dim,
+ cutoff=(c - settings.trans_bw, cut_freqs[-1]),
+ ripple=settings.ripple_db,
+ width=settings.trans_bw,
+ pass_zero=False,
+ )
+ for c in cut_freqs[:-1]
+ ]
+
+ pipeline["filterbank"] = FilterbankDesignTransformer(
+ FilterbankDesignSettings(
+ filters=filters, new_axis=settings.filterbank_dim
+ )
+ )
+
+ pipeline["window"] = WindowTransformer(
+ WindowSettings(
+ axis=settings.time_dim,
+ newaxis=settings.window_dim,
+ window_dur=settings.window_dur,
+ window_shift=settings.window_shift,
+ zero_pad_until="shift",
+ )
+ )
+
+ pipeline["fbcca"] = FBCCATransformer(settings)
+
+ return pipeline
+
+
+class StreamingFBCCA(
+ BaseTransformerUnit[
+ StreamingFBCCASettings, AxisArray, AxisArray, StreamingFBCCATransformer
+ ]
+):
+ SETTINGS = StreamingFBCCASettings
+
+
+def cca_rho_max(X: np.ndarray, Y: np.ndarray) -> float:
+ """
+ X: (n_time, n_ch)
+ Y: (n_time, n_ref) # design matrix for one frequency
+ returns: largest canonical correlation in [0,1]
+ """
+ # Center columns
+ Xc = X - X.mean(axis=0, keepdims=True)
+ Yc = Y - Y.mean(axis=0, keepdims=True)
+
+ # Drop any zero-variance columns to avoid rank issues
+ Xc = Xc[:, Xc.std(axis=0) > 1e-12]
+ Yc = Yc[:, Yc.std(axis=0) > 1e-12]
+ if Xc.size == 0 or Yc.size == 0:
+ return 0.0
+
+ # Orthonormal bases
+ Qx, _ = np.linalg.qr(Xc, mode="reduced") # (n_time, r_x)
+ Qy, _ = np.linalg.qr(Yc, mode="reduced") # (n_time, r_y)
+
+ # Canonical correlations are the singular values of Qx^T Qy
+ with np.errstate(divide="ignore", over="ignore", invalid="ignore"):
+ s = np.linalg.svd(Qx.T @ Qy, compute_uv=False)
+ return float(s[0]) if s.size else 0.0
+
+
+def calc_softmax(cv: np.ndarray, axis: int, beta: float = 1.0):
+ # Calculate softmax with shifting to avoid overflow
+ # (https://doi.org/10.1093/imanum/draa038)
+ cv = cv - cv.max(axis=axis, keepdims=True)
+ cv = np.exp(beta * cv)
+ cv = cv / np.sum(cv, axis=axis, keepdims=True)
+ return cv
diff --git a/src/ezmsg/sigproc/filter.py b/src/ezmsg/sigproc/filter.py
index 9a3a4c54..51bd9541 100644
--- a/src/ezmsg/sigproc/filter.py
+++ b/src/ezmsg/sigproc/filter.py
@@ -263,6 +263,14 @@ def __call__(self, message: AxisArray) -> AxisArray:
axis = self.state.filter.settings.axis
fs = 1 / message.axes[axis].gain
coefs = design_fun(fs)
+
+ # Convert BA to SOS if requested
+ if coefs is not None and self.settings.coef_type == "sos":
+ if isinstance(coefs, tuple) and len(coefs) == 2:
+ # It's BA format, convert to SOS
+ b, a = coefs
+ coefs = scipy.signal.tf2sos(b, a)
+
self.state.filter.update_coefficients(
coefs, coef_type=self.settings.coef_type
)
@@ -282,6 +290,14 @@ def _reset_state(self, message: AxisArray) -> None:
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
fs = 1 / message.axes[axis].gain
coefs = design_fun(fs)
+
+ # Convert BA to SOS if requested
+ if coefs is not None and self.settings.coef_type == "sos":
+ if isinstance(coefs, tuple) and len(coefs) == 2:
+ # It's BA format, convert to SOS
+ b, a = coefs
+ coefs = scipy.signal.tf2sos(b, a)
+
new_settings = FilterSettings(
axis=axis, coef_type=self.settings.coef_type, coefs=coefs
)
diff --git a/src/ezmsg/sigproc/filterbankdesign.py b/src/ezmsg/sigproc/filterbankdesign.py
new file mode 100644
index 00000000..01f1b060
--- /dev/null
+++ b/src/ezmsg/sigproc/filterbankdesign.py
@@ -0,0 +1,136 @@
+import typing
+
+import ezmsg.core as ez
+import numpy as np
+import numpy.typing as npt
+
+from ezmsg.util.messages.util import replace
+from ezmsg.util.messages.axisarray import AxisArray
+
+from .base import (
+ BaseStatefulTransformer,
+ processor_state,
+)
+
+from .filterbank import (
+ FilterbankTransformer,
+ FilterbankSettings,
+ FilterbankMode,
+ MinPhaseMode,
+)
+
+from .kaiser import KaiserFilterSettings, kaiser_design_fun
+
+
+class FilterbankDesignSettings(ez.Settings):
+ filters: typing.Iterable[KaiserFilterSettings]
+
+ mode: FilterbankMode = FilterbankMode.CONV
+ """
+ "conv", "fft", or "auto". If "auto", the mode is determined by the size of the input data.
+ fft mode is more efficient for long kernels. However, fft mode uses non-overlapping windows and will
+ incur a delay equal to the window length, which is larger than the largest kernel.
+ conv mode is less efficient but will return data for every incoming chunk regardless of how small it is
+ and thus can provide shorter latency updates.
+ """
+
+ min_phase: MinPhaseMode = MinPhaseMode.NONE
+ """
+ If not None, convert the kernels to minimum-phase equivalents. Valid options are
+ 'hilbert', 'homomorphic', and 'homomorphic-full'. Complex filters not supported.
+ See `scipy.signal.minimum_phase` for details.
+ """
+
+ axis: str = "time"
+ """The name of the axis to operate on. This should usually be "time"."""
+
+ new_axis: str = "kernel"
+ """The name of the new axis corresponding to the kernel index."""
+
+
+@processor_state
+class FilterbankDesignState:
+ filterbank: FilterbankTransformer | None = None
+ needs_redesign: bool = False
+
+
+class FilterbankDesignTransformer(
+ BaseStatefulTransformer[
+ FilterbankDesignSettings, AxisArray, AxisArray, FilterbankDesignState
+ ],
+):
+ """
+ Transformer that designs and applies a filterbank based on Kaiser windowed FIR filters.
+ """
+
+ @classmethod
+ def get_message_type(cls, dir: str) -> type[AxisArray]:
+ if dir in ("in", "out"):
+ return AxisArray
+ else:
+ raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
+
+ def update_settings(
+ self, new_settings: typing.Optional[FilterbankDesignSettings] = None, **kwargs
+ ) -> None:
+ """
+ Update settings and mark that filter coefficients need to be recalculated.
+
+ Args:
+ new_settings: Complete new settings object to replace current settings
+ **kwargs: Individual settings to update
+ """
+ # Update settings
+ if new_settings is not None:
+ self.settings = new_settings
+ else:
+ self.settings = replace(self.settings, **kwargs)
+
+ # Set flag to trigger recalculation on next message
+ if self.state.filterbank is not None:
+ self.state.needs_redesign = True
+
+ def _calculate_kernels(self, fs: float) -> list[npt.NDArray]:
+ kernels = []
+ for filter in self.settings.filters:
+ output = kaiser_design_fun(
+ fs,
+ cutoff=filter.cutoff,
+ ripple=filter.ripple,
+ width=filter.width,
+ pass_zero=filter.pass_zero,
+ wn_hz=filter.wn_hz,
+ )
+
+ kernels.append(np.array([1.0]) if output is None else output[0])
+ return kernels
+
+ def __call__(self, message: AxisArray) -> AxisArray:
+ if self.state.filterbank is not None and self.state.needs_redesign:
+ self._reset_state(message)
+ self.state.needs_redesign = False
+ return super().__call__(message)
+
+ def _hash_message(self, message: AxisArray) -> int:
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
+ gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
+ axis_idx = message.get_axis_idx(axis)
+ samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
+ return hash((message.key, samp_shape, gain))
+
+ def _reset_state(self, message: AxisArray) -> None:
+ axis_obj = message.axes[self.settings.axis]
+ assert isinstance(axis_obj, AxisArray.LinearAxis)
+ fs = 1 / axis_obj.gain
+ kernels = self._calculate_kernels(fs)
+ new_settings = FilterbankSettings(
+ kernels=kernels,
+ mode=self.settings.mode,
+ min_phase=self.settings.min_phase,
+ axis=self.settings.axis,
+ new_axis=self.settings.new_axis,
+ )
+ self.state.filterbank = FilterbankTransformer(settings=new_settings)
+
+ def _process(self, message: AxisArray) -> AxisArray:
+ return self.state.filterbank(message)
diff --git a/src/ezmsg/sigproc/firfilter.py b/src/ezmsg/sigproc/firfilter.py
new file mode 100644
index 00000000..18b3630c
--- /dev/null
+++ b/src/ezmsg/sigproc/firfilter.py
@@ -0,0 +1,119 @@
+import functools
+import typing
+
+import numpy as np
+import numpy.typing as npt
+import scipy.signal
+
+from .filter import (
+ FilterBaseSettings,
+ FilterByDesignTransformer,
+ BACoeffs,
+ BaseFilterByDesignTransformerUnit,
+)
+
+
+class FIRFilterSettings(FilterBaseSettings):
+ """Settings for :obj:`FIRFilter`. See scipy.signal.firwin for more details"""
+
+ # axis and coef_type are inherited from FilterBaseSettings
+
+ order: int = 0
+ """
+ Filter order/number of taps
+ """
+
+ cutoff: float | npt.ArrayLike | None = None
+ """
+ Cutoff frequency of filter (expressed in the same units as fs) OR an array of cutoff frequencies
+ (that is, band edges). In the former case, as a float, the cutoff frequency should correspond with
+ the half-amplitude point, where the attenuation will be -6dB. In the latter case, the frequencies in
+ cutoff should be positive and monotonically increasing between 0 and fs/2. The values 0 and fs/2 must
+ not be included in cutoff.
+ """
+
+ width: float | None = None
+ """
+ If width is not None, then assume it is the approximate width of the transition region (expressed in
+ the same units as fs) for use in Kaiser FIR filter design. In this case, the window argument is ignored.
+ """
+
+ window: str | None = "hamming"
+ """
+ Desired window to use. See scipy.signal.get_window for a list of windows and required parameters.
+ """
+
+ pass_zero: bool | str = True
+ """
+ If True, the gain at the frequency 0 (i.e., the “DC gain”) is 1. If False, the DC gain is 0. Can also
+ be a string argument for the desired filter type (equivalent to btype in IIR design functions).
+ {‘lowpass’, ‘highpass’, ‘bandpass’, ‘bandstop’}
+ """
+
+ scale: bool = True
+ """
+ Set to True to scale the coefficients so that the frequency response is exactly unity at a certain
+ frequency. That frequency is either:
+ * 0 (DC) if the first passband starts at 0 (i.e. pass_zero is True)
+ * fs/2 (the Nyquist frequency) if the first passband ends at fs/2
+ (i.e the filter is a single band highpass filter);
+ center of first passband otherwise
+ """
+
+ wn_hz: bool = True
+ """
+ Set False if provided Wn are normalized from 0 to 1, where 1 is the Nyquist frequency
+ """
+
+
+def firwin_design_fun(
+ fs: float,
+ order: int = 0,
+ cutoff: float | npt.ArrayLike | None = None,
+ width: float | None = None,
+ window: str | None = "hamming",
+ pass_zero: bool | str = True,
+ scale: bool = True,
+ wn_hz: bool = True,
+) -> BACoeffs | None:
+ """
+ Design an `order`th-order FIR filter and return the filter coefficients.
+ See :obj:`FIRFilterSettings` for argument description.
+
+ Returns:
+ The filter taps as designed by firwin
+ """
+ if order > 0:
+ taps = scipy.signal.firwin(
+ numtaps=order,
+ cutoff=cutoff,
+ width=width,
+ window=window,
+ pass_zero=pass_zero,
+ scale=scale,
+ fs=fs if wn_hz else None,
+ )
+ return (taps, np.array([1.0]))
+ return None
+
+
+class FIRFilterTransformer(FilterByDesignTransformer[FIRFilterSettings, BACoeffs]):
+ def get_design_function(
+ self,
+ ) -> typing.Callable[[float], BACoeffs | None]:
+ return functools.partial(
+ firwin_design_fun,
+ order=self.settings.order,
+ cutoff=self.settings.cutoff,
+ width=self.settings.width,
+ window=self.settings.window,
+ pass_zero=self.settings.pass_zero,
+ scale=self.settings.scale,
+ wn_hz=self.settings.wn_hz,
+ )
+
+
+class FIRFilter(
+ BaseFilterByDesignTransformerUnit[FIRFilterSettings, FIRFilterTransformer]
+):
+ SETTINGS = FIRFilterSettings
diff --git a/src/ezmsg/sigproc/kaiser.py b/src/ezmsg/sigproc/kaiser.py
new file mode 100644
index 00000000..1f8176a5
--- /dev/null
+++ b/src/ezmsg/sigproc/kaiser.py
@@ -0,0 +1,110 @@
+import functools
+import typing
+
+import numpy as np
+import numpy.typing as npt
+import scipy.signal
+
+from .filter import (
+ FilterBaseSettings,
+ FilterByDesignTransformer,
+ BACoeffs,
+ BaseFilterByDesignTransformerUnit,
+)
+
+
+class KaiserFilterSettings(FilterBaseSettings):
+ """Settings for :obj:`KaiserFilter`"""
+
+ # axis and coef_type are inherited from FilterBaseSettings
+
+ cutoff: float | npt.ArrayLike | None = None
+ """
+ Cutoff frequency of filter (expressed in the same units as fs) OR an array of cutoff frequencies
+ (that is, band edges). In the former case, as a float, the cutoff frequency should correspond with
+ the half-amplitude point, where the attenuation will be -6dB. In the latter case, the frequencies in
+ cutoff should be positive and monotonically increasing between 0 and fs/2. The values 0 and fs/2 must
+ not be included in cutoff.
+ """
+
+ ripple: float | None = None
+ """
+ Upper bound for the deviation (in dB) of the magnitude of the filter's frequency response from that of
+ the desired filter (not including frequencies in any transition intervals).
+ See scipy.signal.kaiserord for more information.
+ """
+
+ width: float | None = None
+ """
+ If width is not None, then assume it is the approximate width of the transition region (expressed in
+ the same units as fs) for use in Kaiser FIR filter design.
+ See scipy.signal.kaiserord for more information.
+ """
+
+ pass_zero: bool | str = True
+ """
+ If True, the gain at the frequency 0 (i.e., the “DC gain”) is 1. If False, the DC gain is 0. Can also
+ be a string argument for the desired filter type (equivalent to btype in IIR design functions).
+ {‘lowpass’, ‘highpass’, ‘bandpass’, ‘bandstop’}
+ """
+
+ wn_hz: bool = True
+ """
+ Set False if cutoff and width are normalized from 0 to 1, where 1 is the Nyquist frequency
+ """
+
+
+def kaiser_design_fun(
+ fs: float,
+ cutoff: float | npt.ArrayLike | None = None,
+ ripple: float | None = None,
+ width: float | None = None,
+ pass_zero: bool | str = True,
+ wn_hz: bool = True,
+) -> BACoeffs | None:
+ """
+ Design an `order`th-order FIR Kaiser filter and return the filter coefficients.
+ See :obj:`FIRFilterSettings` for argument description.
+
+ Returns:
+ The filter taps as designed by firwin
+ """
+ if ripple is None or width is None or cutoff is None:
+ return None
+
+ width = width / (0.5 * fs) if wn_hz else width
+ n_taps, beta = scipy.signal.kaiserord(ripple, width)
+ if n_taps % 2 == 0:
+ n_taps += 1
+ taps = scipy.signal.firwin(
+ numtaps=n_taps,
+ cutoff=cutoff,
+ window=("kaiser", beta), # type: ignore
+ pass_zero=pass_zero, # type: ignore
+ scale=False,
+ fs=fs if wn_hz else None,
+ )
+
+ return (taps, np.array([1.0]))
+
+
+class KaiserFilterTransformer(
+ FilterByDesignTransformer[KaiserFilterSettings, BACoeffs]
+):
+ def get_design_function(
+ self,
+ ) -> typing.Callable[[float], BACoeffs | None]:
+ return functools.partial(
+ kaiser_design_fun,
+ cutoff=self.settings.cutoff,
+ ripple=self.settings.ripple,
+ width=self.settings.width,
+ pass_zero=self.settings.pass_zero,
+ wn_hz=self.settings.wn_hz,
+ )
+
+
+class KaiserFilter(
+ BaseFilterByDesignTransformerUnit[KaiserFilterSettings, KaiserFilterTransformer]
+):
+ SETTINGS = KaiserFilterSettings
diff --git a/src/ezmsg/sigproc/resample.py b/src/ezmsg/sigproc/resample.py
index 5775d4b8..764e3bfe 100644
--- a/src/ezmsg/sigproc/resample.py
+++ b/src/ezmsg/sigproc/resample.py
@@ -1,13 +1,11 @@
import asyncio
-import dataclasses
+import math
import time
-import typing
import numpy as np
-import numpy.typing as npt
import scipy.interpolate
import ezmsg.core as ez
-from ezmsg.util.messages.axisarray import AxisArray
+from ezmsg.util.messages.axisarray import AxisArray, LinearAxis
from ezmsg.util.messages.util import replace
from .base import (
@@ -15,6 +13,8 @@
BaseConsumerUnit,
processor_state,
)
+from .util.axisarray_buffer import HybridAxisArrayBuffer, HybridAxisBuffer
+from .util.buffer import UpdateStrategy
class ResampleSettings(ez.Settings):
@@ -23,7 +23,7 @@ class ResampleSettings(ez.Settings):
resample_rate: float | None = None
"""target resample rate in Hz. If None, the resample rate will be determined by the reference signal."""
- max_chunk_delay: float = 0.0
+ max_chunk_delay: float = np.inf
"""Maximum delay between outputs in seconds. If the delay exceeds this value, the transformer will extrapolate."""
fill_value: str = "extrapolate"
@@ -34,23 +34,49 @@ class ResampleSettings(ez.Settings):
See scipy.interpolate.interp1d for more options.
"""
+ buffer_duration: float = 2.0
-@dataclasses.dataclass
-class ResampleBuffer:
- data: npt.NDArray
- tvec: npt.NDArray
- template: AxisArray
- last_update: float
+ buffer_update_strategy: UpdateStrategy = "immediate"
+ """
+ The buffer update strategy. See :obj:`ezmsg.sigproc.util.buffer.UpdateStrategy`.
+ If you expect to push data much more frequently than it is resampled, then "on_demand"
+ might be more efficient. For most other scenarios, "immediate" is best.
+ """
@processor_state
class ResampleState:
- signal_buffer: ResampleBuffer | None = None
- ref_axis: tuple[typing.Union[AxisArray.TimeAxis, AxisArray.CoordinateAxis], int] = (
- AxisArray.TimeAxis(fs=1.0),
- 0,
- )
- last_t_out: float | None = None
+ src_buffer: HybridAxisArrayBuffer | None = None
+ """
+ Buffer for the incoming signal data. This is the source for training the interpolation function.
+ Its contents are rarely empty because we usually hold back some data to allow for accurate
+ interpolation and optionally extrapolation.
+ """
+
+ ref_axis_buffer: HybridAxisBuffer | None = None
+ """
+ The buffer for the reference axis (usually a time axis). The interpolation function
+ will be evaluated at the reference axis values.
+ When resample_rate is None, this buffer will be filled with the axis from incoming
+ _reference_ messages.
+ When resample_rate is not None (i.e., prescribed float resample_rate), this buffer
+ is filled with a synthetic axis that is generated from the incoming signal messages.
+ """
+
+ last_ref_ax_val: float | None = None
+ """
+ The last value of the reference axis that was returned. This helps us to know
+ what the _next_ returned value should be, and to avoid returning the same value.
+ TODO: We can eliminate this variable if we maintain "by convention" that the
+ reference axis always has 1 value at its start that we exclude from the resampling.
+ """
+
+ last_write_time: float = -np.inf
+ """
+ Wall clock time of the last write to the signal buffer.
+ This is used to determine if we need to extrapolate the reference axis
+ if we have not received an update within max_chunk_delay.
+ """
class ResampleProcessor(
@@ -60,169 +86,149 @@ def _hash_message(self, message: AxisArray) -> int:
ax_idx: int = message.get_axis_idx(self.settings.axis)
sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
ax = message.axes[self.settings.axis]
- in_fs = (1 / ax.gain) if hasattr(ax, "gain") else None
- return hash((message.key, in_fs) + sample_shape)
+ gain = ax.gain if hasattr(ax, "gain") else None
+ return hash((message.key, gain) + sample_shape)
def _reset_state(self, message: AxisArray) -> None:
"""
Reset the internal state based on the incoming message.
- If resample_rate is None, the output is driven by the reference signal.
- The input will still determine the template (except the primary axis) and the buffer.
"""
- ax_idx: int = message.get_axis_idx(self.settings.axis)
- ax = message.axes[self.settings.axis]
- in_dat = message.data
- in_tvec = (
- ax.data
- if hasattr(ax, "data")
- else ax.value(np.arange(in_dat.shape[ax_idx]))
- )
- if ax_idx != 0:
- in_dat = np.moveaxis(in_dat, ax_idx, 0)
-
- if self.settings.resample_rate is None:
- # Output is driven by input.
- # We cannot include the resampled axis until we see reference data.
- out_axes = {
- k: v for k, v in message.axes.items() if k != self.settings.axis
- }
- # last_t_out also driven by reference data.
- # self.state.last_t_out = None
- else:
- out_axes = {
- **message.axes,
- self.settings.axis: AxisArray.TimeAxis(
- fs=self.settings.resample_rate, offset=in_tvec[0]
- ),
- }
- self.state.last_t_out = in_tvec[0] - 1 / self.settings.resample_rate
- template = replace(message, data=in_dat[:0], axes=out_axes)
- self.state.signal_buffer = ResampleBuffer(
- data=in_dat[:0],
- tvec=in_tvec[:0],
- template=template,
- last_update=time.time(),
- )
-
- def _process(self, message: AxisArray) -> None:
- # The incoming message will be added to the buffer.
- buf = self.state.signal_buffer
-
- # If our outputs are driven by reference signal, create the template's output axis if not already created.
- if (
- self.settings.resample_rate is None
- and self.settings.axis not in self.state.signal_buffer.template.axes
- ):
- buf = self.state.signal_buffer
- buf.template.axes[self.settings.axis] = self.state.ref_axis[0]
- if hasattr(buf.template.axes[self.settings.axis], "gain"):
- buf.template = replace(
- buf.template,
- axes={
- **buf.template.axes,
- self.settings.axis: replace(
- buf.template.axes[self.settings.axis],
- offset=self.state.last_t_out,
- ),
- },
- )
- # Note: last_t_out was set on the first call to push_reference.
-
- # Append the new data to the buffer
- ax_idx: int = message.get_axis_idx(self.settings.axis)
- in_dat: npt.NDArray = message.data
- if ax_idx != 0:
- in_dat = np.moveaxis(in_dat, ax_idx, 0)
- ax = message.axes[self.settings.axis]
- in_tvec = (
- ax.data if hasattr(ax, "data") else ax.value(np.arange(in_dat.shape[0]))
+ self.state.src_buffer = HybridAxisArrayBuffer(
+ duration=self.settings.buffer_duration,
+ axis=self.settings.axis,
+ update_strategy=self.settings.buffer_update_strategy,
+ overflow_strategy="grow",
)
- buf.data = np.concatenate((buf.data, in_dat), axis=0)
- buf.tvec = np.hstack((buf.tvec, in_tvec))
- buf.last_update = time.time()
+ if self.settings.resample_rate is not None:
+ # If we are resampling at a prescribed rate, then we synthesize a reference axis
+ self.state.ref_axis_buffer = HybridAxisBuffer(
+ duration=self.settings.buffer_duration,
+ )
+ in_ax = message.axes[self.settings.axis]
+ out_gain = 1 / self.settings.resample_rate
+ t0 = in_ax.data[0] if hasattr(in_ax, "data") else in_ax.value(0)
+ self.state.last_ref_ax_val = t0 - out_gain
+ self.state.last_write_time = -np.inf
def push_reference(self, message: AxisArray) -> None:
ax = message.axes[self.settings.axis]
ax_idx = message.get_axis_idx(self.settings.axis)
- n_new = message.data.shape[ax_idx]
- if self.state.ref_axis[1] == 0:
- self.state.ref_axis = (ax, n_new)
- else:
- if hasattr(ax, "gain"):
- # Rate and offset don't need to change; we simply increment our sample counter.
- self.state.ref_axis = (
- self.state.ref_axis[0],
- self.state.ref_axis[1] + n_new,
- )
- else:
- # Extend our time axis with the new data.
- new_tvec = np.concatenate(
- (self.state.ref_axis[0].data, ax.data), axis=0
- )
- self.state.ref_axis = (
- replace(self.state.ref_axis[0], data=new_tvec),
- self.state.ref_axis[1] + n_new,
- )
-
- if self.settings.resample_rate is None and self.state.last_t_out is None:
- # This reference axis will become THE output axis.
- # If last_t_out has not previously been set, we set it to the sample before this reference data.
- if hasattr(self.state.ref_axis[0], "gain"):
- ref_tvec = self.state.ref_axis[0].value(np.arange(2))
- else:
- ref_tvec = self.state.ref_axis[0].data[:2]
- self.state.last_t_out = 2 * ref_tvec[0] - ref_tvec[1]
+ if self.state.ref_axis_buffer is None:
+ self.state.ref_axis_buffer = HybridAxisBuffer(
+ duration=self.settings.buffer_duration,
+ update_strategy=self.settings.buffer_update_strategy,
+ overflow_strategy="grow",
+ )
+ t0 = ax.data[0] if hasattr(ax, "data") else ax.value(0)
+ self.state.last_ref_ax_val = t0 - ax.gain
+ self.state.ref_axis_buffer.write(ax, n_samples=message.data.shape[ax_idx])
- def __next__(self) -> AxisArray:
- buf = self.state.signal_buffer
+ def _process(self, message: AxisArray) -> None:
+ """
+ Add a new data message to the buffer and update the reference axis if needed.
+ """
+ # Note: The src_buffer will copy and permute message if ax_idx != 0
+ self.state.src_buffer.write(message)
- if buf is None:
+ # If we are resampling at a prescribed rate (i.e., not by reference msgs),
+ # then we use this opportunity to extend our synthetic reference axis.
+ ax_idx = message.get_axis_idx(self.settings.axis)
+ if self.settings.resample_rate is not None and message.data.shape[ax_idx] > 0:
+ in_ax = message.axes[self.settings.axis]
+ in_t_end = (
+ in_ax.data[-1]
+ if hasattr(in_ax, "data")
+ else in_ax.value(message.data.shape[ax_idx] - 1)
+ )
+ out_gain = 1 / self.settings.resample_rate
+ prev_t_end = self.state.last_ref_ax_val
+ n_synth = math.ceil((in_t_end - prev_t_end) * self.settings.resample_rate)
+ synth_ref_axis = LinearAxis(
+ unit="s", gain=out_gain, offset=prev_t_end + out_gain
+ )
+ self.state.ref_axis_buffer.write(synth_ref_axis, n_samples=n_synth)
+
+ self.state.last_write_time = time.time()
+
+ def __next__(self) -> AxisArray:
+ if self.state.src_buffer is None or self.state.ref_axis_buffer is None:
+ # If we have not received any data, or we require reference data
+ # that we do not yet have, then return an empty template.
return AxisArray(data=np.array([]), dims=[""], axes={}, key="null")
- # buffer is empty or ref-driven && empty-reference; return the empty template
- if (buf.tvec.size == 0) or (
- self.settings.resample_rate is None and self.state.ref_axis[1] < 3
- ):
- # Note: empty template's primary axis' offset might be meaningless.
- return buf.template
-
- # Identify the output timestamps at which we will resample the buffer
- b_project = False
- if self.settings.resample_rate is None:
- # Rely on reference signal to determine output timestamps
- if hasattr(self.state.ref_axis[0], "data"):
- ref_tvec = self.state.ref_axis[0].data
- else:
- n_avail = self.state.ref_axis[1]
- ref_tvec = self.state.ref_axis[0].value(np.arange(n_avail))
+ src = self.state.src_buffer
+ ref = self.state.ref_axis_buffer
+
+ # If we have no reference or the source is insufficient for interpolation
+ # then return the empty template
+ if ref.is_empty() or src.available() < 3:
+ src_axarr = src.peek(0)
+ return replace(
+ src_axarr,
+ axes={
+ **src_axarr.axes,
+ self.settings.axis: ref.peek(0),
+ },
+ )
+
+ # Build the reference xvec.
+ # Note: The reference axis buffer may grow upon `.peek()`
+ # as it flushes data from its deque to its buffer.
+ ref_ax = ref.peek()
+ if hasattr(ref_ax, "data"):
+ ref_xvec = ref_ax.data
else:
- # Get output timestamps from resample_rate and what we've collected so far
- t_begin = self.state.last_t_out + 1 / self.settings.resample_rate
- t_end = buf.tvec[-1]
- if self.settings.max_chunk_delay > 0 and time.time() > (
- buf.last_update + self.settings.max_chunk_delay
- ):
- # We've waiting too long between pushes. We will have to extrapolate.
- b_project = True
- t_end += self.settings.max_chunk_delay
- ref_tvec = np.arange(t_begin, t_end, 1 / self.settings.resample_rate)
-
- # Which samples can we resample?
- b_ref = ref_tvec > self.state.last_t_out
+ ref_xvec = ref_ax.value(np.arange(ref.available()))
+
+ # If we do not rely on an external reference, and we have not received new data in a while,
+ # then extrapolate our reference vector out beyond the delay limit.
+ b_project = self.settings.resample_rate is not None and time.time() > (
+ self.state.last_write_time + self.settings.max_chunk_delay
+ )
+ if b_project:
+ n_append = math.ceil(self.settings.max_chunk_delay / ref_ax.gain)
+ xvec_append = ref_xvec[-1] + np.arange(1, n_append + 1) * ref_ax.gain
+ ref_xvec = np.hstack((ref_xvec, xvec_append))
+
+ # Get source to train interpolation
+ src_axarr = src.peek()
+ src_axis = src_axarr.axes[self.settings.axis]
+ x = (
+ src_axis.data
+ if hasattr(src_axis, "data")
+ else src_axis.value(np.arange(src_axarr.data.shape[0]))
+ )
+
+ # Only resample at reference values that have not been interpolated over previously.
+ b_ref = ref_xvec > self.state.last_ref_ax_val
if not b_project:
- b_ref = np.logical_and(b_ref, ref_tvec <= buf.tvec[-1])
+ # Not extrapolating -- Do not resample beyond the end of the source buffer.
+ b_ref = np.logical_and(b_ref, ref_xvec <= x[-1])
ref_idx = np.where(b_ref)[0]
- if len(ref_idx) < 2:
- # Not enough data to resample; return the empty template.
- return buf.template
+ if len(ref_idx) == 0:
+ # Nothing to interpolate over; return empty data
+ null_ref = (
+ replace(ref_ax, data=ref_ax.data[:0])
+ if hasattr(ref_ax, "data")
+ else ref_ax
+ )
+ return replace(
+ src_axarr,
+ data=src_axarr.data[:0, ...],
+ axes={**src_axarr.axes, self.settings.axis: null_ref},
+ )
+
+ xnew = ref_xvec[ref_idx]
+
+ # Identify source data indices around ref tvec with some padding for better interpolation.
+ src_start_ix = max(
+ 0, np.where(x > xnew[0])[0][0] - 2 if np.any(x > xnew[0]) else 0
+ )
+
+ x = x[src_start_ix:]
+ y = src_axarr.data[src_start_ix:]
- tnew = ref_tvec[ref_idx]
- # Slice buf to minimal range around tnew with some padding for better interpolation.
- buf_start_ix = max(0, np.searchsorted(buf.tvec, tnew[0]) - 2)
- buf_stop_ix = np.searchsorted(buf.tvec, tnew[-1], side="right") + 2
- x = buf.tvec[buf_start_ix:buf_stop_ix]
- y = buf.data[buf_start_ix:buf_stop_ix]
if (
isinstance(self.settings.fill_value, str)
and self.settings.fill_value == "last"
@@ -240,37 +246,32 @@ def __next__(self) -> AxisArray:
fill_value=fill_value,
assume_sorted=True,
)
- resampled_data = f(tnew)
- if hasattr(buf.template.axes[self.settings.axis], "data"):
- repl_axis = replace(buf.template.axes[self.settings.axis], data=tnew)
+
+ # Calculate output
+ resampled_data = f(xnew)
+
+ # Create output message
+ if hasattr(ref_ax, "data"):
+ out_ax = replace(ref_ax, data=xnew)
else:
- repl_axis = replace(buf.template.axes[self.settings.axis], offset=tnew[0])
+ out_ax = replace(ref_ax, offset=xnew[0])
result = replace(
- buf.template,
+ src_axarr,
data=resampled_data,
axes={
- **buf.template.axes,
- self.settings.axis: repl_axis,
+ **src_axarr.axes,
+ self.settings.axis: out_ax,
},
)
- # Update state to move past samples that are no longer be needed
- self.state.last_t_out = tnew[-1]
- buf.data = buf.data[max(0, buf_stop_ix - 3) :]
- buf.tvec = buf.tvec[max(0, buf_stop_ix - 3) :]
- buf.last_update = time.time()
-
- if self.settings.resample_rate is None:
- # Update self.state.ref_axis to remove samples that have been used in the output
- if hasattr(self.state.ref_axis[0], "data"):
- new_ref_ax = replace(
- self.state.ref_axis[0],
- data=self.state.ref_axis[0].data[ref_idx[-1] + 1 :],
- )
- else:
- next_offset = self.state.ref_axis[0].value(ref_idx[-1] + 1)
- new_ref_ax = replace(self.state.ref_axis[0], offset=next_offset)
- self.state.ref_axis = (new_ref_ax, self.state.ref_axis[1] - len(ref_idx))
+ # Update the state. For state buffers, seek beyond samples that are no longer needed.
+ # src: keep at least 1 sample before the final resampled value
+ seek_ix = np.where(x >= xnew[-1])[0]
+ if len(seek_ix) > 0:
+ self.state.src_buffer.seek(max(0, src_start_ix + seek_ix[0] - 1))
+ # ref: remove samples that have been sent to output
+ self.state.ref_axis_buffer.seek(ref_idx[-1] + 1)
+ self.state.last_ref_ax_val = xnew[-1]
return result
diff --git a/src/ezmsg/sigproc/sampler.py b/src/ezmsg/sigproc/sampler.py
index 0c3da1f4..d45366bd 100644
--- a/src/ezmsg/sigproc/sampler.py
+++ b/src/ezmsg/sigproc/sampler.py
@@ -1,18 +1,19 @@
import asyncio
from collections import deque
+import copy
import traceback
import typing
import numpy as np
-import numpy.typing as npt
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import (
AxisArray,
- slice_along_axis,
)
from ezmsg.util.messages.util import replace
from .util.profile import profile_subpub
+from .util.axisarray_buffer import HybridAxisArrayBuffer
+from .util.buffer import UpdateStrategy
from .util.message import SampleMessage, SampleTriggerMessage
from .base import (
BaseStatefulTransformer,
@@ -43,6 +44,7 @@ class SamplerSettings(ez.Settings):
None (default) will choose the first axis in the first input.
Note: (for now) the axis must exist in the msg .axes and be of type AxisArray.LinearAxis
"""
+
period: tuple[float, float] | None = None
"""Optional default period (in seconds) if unspecified in SampleTriggerMessage."""
@@ -51,20 +53,25 @@ class SamplerSettings(ez.Settings):
estimate_alignment: bool = True
"""
- If true, use message timestamp fields and reported sampling rate to estimate sample-accurate alignment for samples.
+ If true, use message timestamp fields and reported sampling rate to estimate
+ sample-accurate alignment for samples.
If false, sampling will be limited to incoming message rate -- "Block timing"
NOTE: For faster-than-realtime playback -- Incoming timestamps must reflect
"realtime" operation for estimate_alignment to operate correctly.
"""
+ buffer_update_strategy: UpdateStrategy = "immediate"
+ """
+ The buffer update strategy. See :obj:`ezmsg.sigproc.util.buffer.UpdateStrategy`.
+ If you expect to push data much more frequently than triggers, then "on_demand"
+ might be more efficient. For most other scenarios, "immediate" is best.
+ """
+
@processor_state
class SamplerState:
- fs: float = 0.0
- offset: float | None = None
- buffer: npt.NDArray | None = None
+ buffer: HybridAxisArrayBuffer | None = None
triggers: deque[SampleTriggerMessage] | None = None
- n_samples: int = 0
class SamplerTransformer(
@@ -73,6 +80,16 @@ class SamplerTransformer(
def __call__(
self, message: AxisArray | SampleTriggerMessage
) -> list[SampleMessage]:
+ # TODO: Currently we have a single entry point that accepts both
+ # data and trigger messages and we choose a code path based on
+ # the message type. However, in the future we will likely replace
+ # SampleTriggerMessage with an agumented form of AxisArray,
+ # leveraging its attrs field, which makes this a bit harder.
+ # We should probably force callers of this object to explicitly
+ # call `push_trigger` for trigger messages. This will also
+ # simplify typing somewhat because `push_trigger` should not
+ # return anything yet we currently have it returning an empty
+ # list just to be compatible with __call__.
if isinstance(message, AxisArray):
return super().__call__(message)
else:
@@ -82,102 +99,75 @@ def _hash_message(self, message: AxisArray) -> int:
# Compute hash based on message properties that require state reset
axis = self.settings.axis or message.dims[0]
axis_idx = message.get_axis_idx(axis)
- fs = 1.0 / message.get_axis(axis).gain
sample_shape = (
message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
)
- return hash((fs, sample_shape, axis_idx, message.key))
+ return hash((sample_shape, message.key))
def _reset_state(self, message: AxisArray) -> None:
- axis = self.settings.axis or message.dims[0]
- axis_idx = message.get_axis_idx(axis)
- axis_info = message.get_axis(axis)
- self._state.fs = 1.0 / axis_info.gain
- self._state.buffer = None
+ self._state.buffer = HybridAxisArrayBuffer(
+ duration=self.settings.buffer_dur,
+ axis=self.settings.axis or message.dims[0],
+ update_strategy=self.settings.buffer_update_strategy,
+ overflow_strategy="warn-overwrite", # True circular buffer
+ )
if self._state.triggers is None:
self._state.triggers = deque()
self._state.triggers.clear()
- self._state.n_samples = message.data.shape[axis_idx]
def _process(self, message: AxisArray) -> list[SampleMessage]:
- axis = self.settings.axis or message.dims[0]
- axis_idx = message.get_axis_idx(axis)
- axis_info = message.get_axis(axis)
- self._state.offset = axis_info.offset
-
- # Update buffer
- self._state.buffer = (
- message.data
- if self._state.buffer is None
- else np.concatenate((self._state.buffer, message.data), axis=axis_idx)
- )
+ self._state.buffer.write(message)
- # Calculate timestamps associated with buffer.
- buffer_offset = np.arange(self._state.buffer.shape[axis_idx], dtype=float)
- buffer_offset -= buffer_offset[-message.data.shape[axis_idx]]
- buffer_offset *= axis_info.gain
- buffer_offset += axis_info.offset
+ # How much data in the buffer?
+ buff_t_range = (
+ self._state.buffer.axis_first_value,
+ self._state.buffer.axis_final_value,
+ )
- # ... for each trigger, collect the message (if possible) and append to msg_out
- msg_out: list[SampleMessage] = []
- for trig in list(self._state.triggers):
+ # Process in reverse order so that we can remove triggers safely as we iterate.
+ msgs_out: list[SampleMessage] = []
+ for trig_ix in range(len(self._state.triggers) - 1, -1, -1):
+ trig = self._state.triggers[trig_ix]
if trig.period is None:
- # This trigger was malformed; drop it.
- self._state.triggers.remove(trig)
+ ez.logger.warning("Sampling failed: trigger period not specified")
+ del self._state.triggers[trig_ix]
+ continue
+
+ trig_range = trig.timestamp + np.array(trig.period)
# If the previous iteration had insufficient data for the trigger timestamp + period,
# and buffer-management removed data required for the trigger, then we will never be able
# to accommodate this trigger. Discard it. An increase in buffer_dur is recommended.
- if (trig.timestamp + trig.period[0]) < buffer_offset[0]:
+ if trig_range[0] < buff_t_range[0]:
ez.logger.warning(
- f"Sampling failed: Buffer span {buffer_offset[0]} is beyond the "
- f"requested sample period start: {trig.timestamp + trig.period[0]}"
+ f"Sampling failed: Buffer span {buff_t_range} begins beyond the "
+ f"requested sample period start: {trig_range[0]}"
)
- self._state.triggers.remove(trig)
+ del self._state.triggers[trig_ix]
+ continue
- t_start = trig.timestamp + trig.period[0]
- if t_start >= buffer_offset[0]:
- start = np.searchsorted(buffer_offset, t_start)
- stop = start + int(
- np.round(self._state.fs * (trig.period[1] - trig.period[0]))
- )
- if self._state.buffer.shape[axis_idx] > stop:
- # Trigger period fully enclosed in buffer.
- msg_out.append(
- SampleMessage(
- trigger=trig,
- sample=replace(
- message,
- data=slice_along_axis(
- self._state.buffer, slice(start, stop), axis_idx
- ),
- axes={
- **message.axes,
- axis: replace(
- axis_info, offset=buffer_offset[start]
- ),
- },
- ),
- )
- )
- self._state.triggers.remove(trig)
-
- # Trim buffer
- buf_len = int(self.settings.buffer_dur * self._state.fs)
- self._state.buffer = slice_along_axis(
- self._state.buffer, np.s_[-buf_len:], axis_idx
- )
+ if trig_range[1] > buff_t_range[1]:
+ # We don't *yet* have enough data to satisfy this trigger.
+ continue
+
+ # We know we have enough data in the buffer to satisfy this trigger.
+ buff_idx = self._state.buffer.axis_searchsorted(trig_range, side="right")
+ self._state.buffer.seek(buff_idx[0]) # FFWD to starting position.
+ buff_axarr = self._state.buffer.peek(buff_idx[1] - buff_idx[0])
+ self._state.buffer.seek(-buff_idx[0]) # Rewind it back.
+ # Note: buffer will trim itself as needed based on buffer_dur.
+
+ # Prepare output and drop trigger
+ msgs_out.append(SampleMessage(trigger=copy.copy(trig), sample=buff_axarr))
+ del self._state.triggers[trig_ix]
- return msg_out
+ msgs_out.reverse() # in-place
+ return msgs_out
def push_trigger(self, message: SampleTriggerMessage) -> list[SampleMessage]:
# Input is a trigger message that we will use to sample the buffer.
- if (
- self._state.buffer is None
- or not self._state.fs
- or self._state.offset is None
- ):
+ if self._state.buffer is None:
# We've yet to see any data; drop the trigger.
return []
@@ -194,11 +184,9 @@ def push_trigger(self, message: SampleTriggerMessage) -> list[SampleMessage]:
return []
# Check that period is compatible with buffer duration.
- max_buf_len = int(np.round(self.settings.buffer_dur * self._state.fs))
- req_buf_len = int(np.round((_period[1] - _period[0]) * self._state.fs))
- if req_buf_len >= max_buf_len:
+ if (_period[1] - _period[0]) > self.settings.buffer_dur:
ez.logger.warning(
- f"Sampling failed: {_period=} >= {self.settings.buffer_dur=}"
+ f"Sampling failed: trigger period {_period=} >= buffer capacity {self.settings.buffer_dur=}"
)
return []
@@ -206,7 +194,7 @@ def push_trigger(self, message: SampleTriggerMessage) -> list[SampleMessage]:
if not self.settings.estimate_alignment:
# Override the trigger timestamp with the next sample's likely timestamp.
trigger_ts = (
- self._state.offset + (self.state.n_samples + 1) / self._state.fs
+ self._state.buffer.axis_final_value + self._state.buffer.axis_gain
)
new_trig_msg = replace(
diff --git a/src/ezmsg/sigproc/util/axisarray_buffer.py b/src/ezmsg/sigproc/util/axisarray_buffer.py
new file mode 100644
index 00000000..c918c8df
--- /dev/null
+++ b/src/ezmsg/sigproc/util/axisarray_buffer.py
@@ -0,0 +1,379 @@
+import math
+import typing
+
+from array_api_compat import get_namespace
+import numpy as np
+from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, CoordinateAxis
+from ezmsg.util.messages.util import replace
+
+from .buffer import HybridBuffer
+
+
+Array = typing.TypeVar("Array")
+
+
+class HybridAxisBuffer:
+ """
+ A buffer that intelligently handles ezmsg.util.messages.AxisArray _axes_ objects.
+ LinearAxis is maintained internally by tracking its offset, gain, and the number
+ of samples that have passed through.
+ CoordinateAxis has its data values maintained in a `HybridBuffer`.
+
+ Args:
+ duration: The desired duration of the buffer in seconds. This is non-limiting
+ when managing a LinearAxis.
+ **kwargs: Additional keyword arguments to pass to the underlying HybridBuffer
+ (e.g., `update_strategy`, `threshold`, `overflow_strategy`, `max_size`).
+ """
+
+ _coords_buffer: HybridBuffer | None
+ _coords_template: CoordinateAxis | None
+ _coords_gain_estimate: float | None = None
+ _linear_axis: LinearAxis | None
+ _linear_n_available: int
+
+ def __init__(self, duration: float, **kwargs):
+ self.duration = duration
+ self.buffer_kwargs = kwargs
+ # Delay initialization until the first message arrives
+ self._coords_buffer = None
+ self._coords_template = None
+ self._linear_axis = None
+ self._linear_n_available = 0
+
+ @property
+ def capacity(self) -> int:
+ """The maximum number of samples that can be stored in the buffer."""
+ if self._coords_buffer is not None:
+ return self._coords_buffer.capacity
+ elif self._linear_axis is not None:
+ return int(math.ceil(self.duration / self._linear_axis.gain))
+ else:
+ return 0
+
+ def available(self) -> int:
+ if self._coords_buffer is None:
+ return self._linear_n_available
+ return self._coords_buffer.available()
+
+ def is_empty(self) -> bool:
+ return self.available() == 0
+
+ def is_full(self) -> bool:
+ if self._coords_buffer is not None:
+ return self._coords_buffer.is_full()
+ return 0 < self.capacity == self.available()
+
+ def _initialize(self, first_axis: LinearAxis | CoordinateAxis) -> None:
+ if hasattr(first_axis, "data"):
+ # Initialize a CoordinateAxis buffer
+ if len(first_axis.data) > 1:
+ _axis_gain = (first_axis.data[-1] - first_axis.data[0]) / (
+ len(first_axis.data) - 1
+ )
+ else:
+ _axis_gain = 1.0
+ self._coords_gain_estimate = _axis_gain
+ capacity = int(self.duration / _axis_gain)
+ self._coords_buffer = HybridBuffer(
+ get_namespace(first_axis.data),
+ capacity,
+ other_shape=(),
+ dtype=first_axis.data.dtype,
+ **self.buffer_kwargs,
+ )
+ self._coords_template = replace(first_axis, data=first_axis.data[:0].copy())
+ else:
+ # Initialize a LinearAxis buffer
+ self._linear_axis = replace(first_axis, offset=first_axis.offset)
+ self._linear_n_available = 0
+
+ def write(self, axis: LinearAxis | CoordinateAxis, n_samples: int) -> None:
+ if self._linear_axis is None and self._coords_buffer is None:
+ self._initialize(axis)
+
+ if self._coords_buffer is not None:
+ if axis.__class__ is not self._coords_template.__class__:
+ raise TypeError(
+ f"Buffer initialized with {self._coords_template.__class__.__name__}, "
+ f"but received {axis.__class__.__name__}."
+ )
+ self._coords_buffer.write(axis.data)
+ else:
+ if axis.__class__ is not self._linear_axis.__class__:
+ raise TypeError(
+ f"Buffer initialized with {self._linear_axis.__class__.__name__}, "
+ f"but received {axis.__class__.__name__}."
+ )
+ if axis.gain != self._linear_axis.gain:
+ raise ValueError(
+ f"Buffer initialized with gain={self._linear_axis.gain}, "
+ f"but received gain={axis.gain}."
+ )
+ if self._linear_n_available + n_samples > self.capacity:
+ # Simulate overflow by advancing the offset and decreasing
+ # the number of available samples.
+ n_to_discard = self._linear_n_available + n_samples - self.capacity
+ self.seek(n_to_discard)
+ # Update the offset corresponding to the oldest sample in the buffer
+ # by anchoring on the new offset and accounting for the samples already available.
+ self._linear_axis.offset = (
+ axis.offset - self._linear_n_available * axis.gain
+ )
+ self._linear_n_available += n_samples
+
+ def peek(self, n_samples: int | None = None) -> LinearAxis | CoordinateAxis:
+ if self._coords_buffer is not None:
+ return replace(
+ self._coords_template, data=self._coords_buffer.peek(n_samples)
+ )
+ else:
+ # Return a shallow copy.
+ return replace(self._linear_axis, offset=self._linear_axis.offset)
+
+ def seek(self, n_samples: int) -> int:
+ if self._coords_buffer is not None:
+ return self._coords_buffer.seek(n_samples)
+ else:
+ n_to_seek = min(n_samples, self._linear_n_available)
+ self._linear_n_available -= n_to_seek
+ self._linear_axis.offset += n_to_seek * self._linear_axis.gain
+ return n_to_seek
+
+ def prune(self, n_samples: int) -> int:
+ """Discards all but the last n_samples from the buffer."""
+ n_to_discard = self.available() - n_samples
+ if n_to_discard <= 0:
+ return 0
+ return self.seek(n_to_discard)
+
+ @property
+ def final_value(self) -> float | None:
+ """
+ The axis-value (timestamp, typically) of the last sample in the buffer.
+ This does not advance the read head.
+ """
+ if self._coords_buffer is not None:
+ return self._coords_buffer.peek_last()[0]
+ elif self._linear_axis is not None:
+ return self._linear_axis.value(self._linear_n_available - 1)
+ else:
+ return None
+
+ @property
+ def first_value(self) -> float | None:
+ """
+ The axis-value (timestamp, typically) of the first sample in the buffer.
+ This does not advance the read head.
+ """
+ if self.available() == 0:
+ return None
+ if self._coords_buffer is not None:
+ return self._coords_buffer.peek_at(0)[0]
+ elif self._linear_axis is not None:
+ return self._linear_axis.value(0)
+ else:
+ return None
+
+ @property
+ def gain(self) -> float | None:
+ if self._coords_buffer is not None:
+ return self._coords_gain_estimate
+ elif self._linear_axis is not None:
+ return self._linear_axis.gain
+ else:
+ return None
+
+ def searchsorted(
+ self, values: typing.Union[float, Array], side: str = "left"
+ ) -> typing.Union[int, Array]:
+ if self._coords_buffer is not None:
+ return self._coords_buffer.xp.searchsorted(
+ self._coords_buffer.peek(self.available()), values, side=side
+ )
+ else:
+ if self.available() == 0:
+ if isinstance(values, float):
+ return 0
+ else:
+ _xp = get_namespace(values)
+ return _xp.zeros_like(values, dtype=int)
+
+ f_inds = (values - self._linear_axis.offset) / self._linear_axis.gain
+ res = np.ceil(f_inds)
+ if side == "right":
+ res[np.isclose(f_inds, res)] += 1
+ return res.astype(int)
+
+
+class HybridAxisArrayBuffer:
+ """A buffer that intelligently handles ezmsg.util.messages.AxisArray objects.
+
+ This buffer defers its own initialization until the first message arrives,
+ allowing it to automatically configure its size, shape, dtype, and array backend
+ (e.g., NumPy, CuPy) based on the message content and a desired buffer duration.
+
+ Args:
+ duration: The desired duration of the buffer in seconds.
+ axis: The name of the axis to buffer along.
+ **kwargs: Additional keyword arguments to pass to the underlying HybridBuffer
+ (e.g., `update_strategy`, `threshold`, `overflow_strategy`, `max_size`).
+ """
+
+ _data_buffer: HybridBuffer | None
+ _axis_buffer: HybridAxisBuffer
+ _template_msg: AxisArray | None
+
+ def __init__(self, duration: float, axis: str = "time", **kwargs):
+ self.duration = duration
+ self._axis = axis
+ self.buffer_kwargs = kwargs
+ self._axis_buffer = HybridAxisBuffer(duration=duration, **kwargs)
+ # Delay initialization until the first message arrives
+ self._data_buffer = None
+ self._template_msg = None
+
+ def available(self) -> int:
+ """The total number of unread samples currently available in the buffer."""
+ if self._data_buffer is None:
+ return 0
+ return self._data_buffer.available()
+
+ def is_empty(self) -> bool:
+ return self.available() == 0
+
+ def is_full(self) -> bool:
+ return 0 < self._data_buffer.capacity == self.available()
+
+ @property
+ def axis_first_value(self) -> float | None:
+ """The axis-value (timestamp, typically) of the first sample in the buffer."""
+ return self._axis_buffer.first_value
+
+ @property
+ def axis_final_value(self) -> float | None:
+ """The axis-value (timestamp, typically) of the last sample in the buffer."""
+ return self._axis_buffer.final_value
+
+ def _initialize(self, first_msg: AxisArray) -> None:
+ # Create a template message that has everything except the data are length 0
+ # and the target axis is missing.
+ self._template_msg = replace(
+ first_msg,
+ data=first_msg.data[:0],
+ axes={k: v for k, v in first_msg.axes.items() if k != self._axis},
+ )
+
+ in_axis = first_msg.axes[self._axis]
+ self._axis_buffer._initialize(in_axis)
+
+ capacity = int(self.duration / self._axis_buffer.gain)
+ self._data_buffer = HybridBuffer(
+ get_namespace(first_msg.data),
+ capacity,
+ other_shape=first_msg.data.shape[1:],
+ dtype=first_msg.data.dtype,
+ **self.buffer_kwargs,
+ )
+
+ def write(self, msg: AxisArray) -> None:
+ """Adds an AxisArray message to the buffer, initializing on the first call."""
+ in_axis_idx = msg.get_axis_idx(self._axis)
+ if in_axis_idx > 0:
+ # This class assumes that the target axis is the first axis.
+ # If it is not, we move it to the front.
+ dims = list(msg.dims)
+ dims.insert(0, dims.pop(in_axis_idx))
+ _xp = get_namespace(msg.data)
+ msg = replace(msg, data=_xp.moveaxis(msg.data, in_axis_idx, 0), dims=dims)
+
+ if self._data_buffer is None:
+ self._initialize(msg)
+
+ self._data_buffer.write(msg.data)
+ self._axis_buffer.write(msg.axes[self._axis], msg.shape[0])
+
+ def peek(self, n_samples: int | None = None) -> AxisArray | None:
+ """Retrieves the oldest unread data as a new AxisArray without advancing the read head."""
+
+ if self._data_buffer is None:
+ return None
+
+ data_array = self._data_buffer.peek(n_samples)
+
+ if data_array is None:
+ return None
+
+ out_axis = self._axis_buffer.peek(n_samples)
+
+ return replace(
+ self._template_msg,
+ data=data_array,
+ axes={**self._template_msg.axes, self._axis: out_axis},
+ )
+
+ def peek_axis(
+ self, n_samples: int | None = None
+ ) -> LinearAxis | CoordinateAxis | None:
+ """Retrieves the axis data without advancing the read head."""
+ if self._data_buffer is None:
+ return None
+
+ out_axis = self._axis_buffer.peek(n_samples)
+
+ if out_axis is None:
+ return None
+
+ return out_axis
+
+ def seek(self, n_samples: int) -> int:
+ """Advances the read pointer by n_samples."""
+ if self._data_buffer is None:
+ return 0
+
+ skipped_data_count = self._data_buffer.seek(n_samples)
+ axis_skipped = self._axis_buffer.seek(skipped_data_count)
+ assert (
+ axis_skipped == skipped_data_count
+ ), f"Axis buffer skipped {axis_skipped} samples, but data buffer skipped {skipped_data_count}."
+
+ return skipped_data_count
+
+ def read(self, n_samples: int | None = None) -> AxisArray | None:
+ """Retrieves the oldest unread data as a new AxisArray and advances the read head."""
+ retrieved_axis_array = self.peek(n_samples)
+
+ if retrieved_axis_array is None or retrieved_axis_array.shape[0] == 0:
+ return None
+
+ self.seek(retrieved_axis_array.shape[0])
+
+ return retrieved_axis_array
+
+ def prune(self, n_samples: int) -> int:
+ """Discards all but the last n_samples from the buffer."""
+ if self._data_buffer is None:
+ return 0
+
+ n_to_discard = self.available() - n_samples
+ if n_to_discard <= 0:
+ return 0
+
+ return self.seek(n_to_discard)
+
+ @property
+ def axis_gain(self) -> float | None:
+ """
+ The gain of the target axis, which is the time step between samples.
+ This is typically the sampling rate (e.g., 1 / fs).
+ """
+ return self._axis_buffer.gain
+
+ def axis_searchsorted(
+ self, values: typing.Union[float, Array], side: str = "left"
+ ) -> typing.Union[int, Array]:
+ """
+ Find the indices into which the given values would be inserted
+ into the target axis data to maintain order.
+ """
+ return self._axis_buffer.searchsorted(values, side=side)
diff --git a/src/ezmsg/sigproc/util/buffer.py b/src/ezmsg/sigproc/util/buffer.py
new file mode 100644
index 00000000..254643f4
--- /dev/null
+++ b/src/ezmsg/sigproc/util/buffer.py
@@ -0,0 +1,470 @@
+import collections
+import math
+import typing
+import warnings
+
+Array = typing.TypeVar("Array")
+ArrayNamespace = typing.Any
+DType = typing.Any
+UpdateStrategy = typing.Literal["immediate", "threshold", "on_demand"]
+OverflowStrategy = typing.Literal["grow", "raise", "drop", "warn-overwrite"]
+
+
+class HybridBuffer:
+ """A stateful, FIFO buffer that combines a deque for fast appends with a
+ contiguous circular buffer for efficient, advancing reads.
+
+ This buffer is designed to be agnostic to the array library used (e.g., NumPy,
+ CuPy, PyTorch) via the Python Array API standard.
+
+ Args:
+ array_namespace: The array library (e.g., numpy, cupy) that conforms to the Array API.
+ capacity: The current maximum number of samples to store in the circular buffer.
+ other_shape: A tuple defining the shape of the non-sample dimensions.
+ dtype: The data type of the samples, belonging to the provided array_namespace.
+ update_strategy: The strategy for synchronizing the deque to the circular buffer (flushing).
+ threshold: The number of samples to accumulate in the deque before flushing.
+ Ignored if update_strategy is "immediate" or "on_demand".
+ overflow_strategy: The strategy for handling overflow when the buffer is full.
+ Options are "grow", "raise", "drop", or "warn-overwrite". If "grow" (default), the buffer will
+ increase its capacity to accommodate new samples up to max_size. If "raise", an error will be
+ raised when the buffer is full. If "drop", the overflowing samples will be ignored.
+ If "warn-overwrite", a warning will be logged then the overflowing samples will
+ overwrite previously-unread samples.
+ max_size: The maximum size of the buffer in bytes.
+ If the buffer exceeds this size, it will raise an error.
+ warn_once: If True, will only warn once on overflow when using "warn-overwrite" strategy.
+ """
+
+ def __init__(
+ self,
+ array_namespace: ArrayNamespace,
+ capacity: int,
+ other_shape: tuple[int, ...],
+ dtype: DType,
+ update_strategy: UpdateStrategy = "on_demand",
+ threshold: int = 0,
+ overflow_strategy: OverflowStrategy = "grow",
+ max_size: int = 1024**3, # 1 GB default max size
+ warn_once: bool = True,
+ ):
+ self.xp = array_namespace
+ self._capacity = capacity
+ self._deque = collections.deque()
+ self._update_strategy = update_strategy
+ self._threshold = threshold
+ self._overflow_strategy = overflow_strategy
+ self._max_size = max_size
+ self._warn_once = warn_once
+
+ self._buffer = self.xp.empty((capacity, *other_shape), dtype=dtype)
+ self._head = 0 # Write pointer
+ self._tail = 0 # Read pointer
+ self._buff_unread = 0 # Number of unread samples in the circular buffer
+ self._buff_read = 0 # Tracks samples read and still in buffer
+ self._deque_len = 0 # Number of unread samples in the deque
+ self._last_overflow = (
+ 0 # Tracks the last overflow count, overwritten or skipped
+ )
+ self._warned = False # Tracks if we've warned already (for warn_once)
+
+ @property
+ def capacity(self) -> int:
+ """The maximum number of samples that can be stored in the buffer."""
+ return self._capacity
+
+ def available(self) -> int:
+ """The total number of unread samples available (in buffer and deque)."""
+ return self._buff_unread + self._deque_len
+
+ def is_empty(self) -> bool:
+ """Returns True if there are no unread samples in the buffer or deque."""
+ return self.available() == 0
+
+ def is_full(self) -> bool:
+ """Returns True if the buffer is full and cannot _flush_ more samples without overwriting."""
+ return self._buff_unread == self._capacity
+
+ def tell(self) -> int:
+ """Returns the number of samples that have been read and are still in the buffer."""
+ return self._buff_read
+
+ def write(self, block: Array):
+ """Appends a new block (an array of samples) to the internal deque."""
+ other_shape = self._buffer.shape[1:]
+ if other_shape == (1,) and block.ndim == 1:
+ block = block[:, self.xp.newaxis]
+
+ if block.shape[1:] != other_shape:
+ raise ValueError(
+ f"Block shape {block.shape[1:]} does not match buffer's other_shape {other_shape}"
+ )
+
+ # Most overflow strategies are handled during flush, but there are a couple
+ # scenarios that can be evaluated on write to give immediate feedback.
+ new_len = self._deque_len + block.shape[0]
+ if new_len > self._capacity and self._overflow_strategy == "raise":
+ raise OverflowError(
+ f"Buffer overflow: {new_len} samples awaiting in deque exceeds buffer capacity {self._capacity}."
+ )
+ elif new_len * block.dtype.itemsize > self._max_size:
+ raise OverflowError(
+ f"deque contents would exceed max_size ({self._max_size}) on subsequent flush."
+ "Are you reading samples frequently enough?"
+ )
+
+ self._deque.append(block)
+ self._deque_len += block.shape[0]
+
+ if self._update_strategy == "immediate" or (
+ self._update_strategy == "threshold"
+ and (0 < self._threshold <= self._deque_len)
+ ):
+ self.flush()
+
+ def _estimate_overflow(self, n_samples: int) -> int:
+ """
+ Estimates the number of samples that would overflow we requested n_samples
+ from the buffer.
+ """
+ if n_samples > self.available():
+ raise ValueError(
+ f"Requested {n_samples} samples, but only {self.available()} are available."
+ )
+ n_overflow = 0
+ if self._deque and (n_samples > self._buff_unread):
+ # We would cause a flush, but would that cause an overflow?
+ n_free = self._capacity - self._buff_unread
+ n_overflow = max(0, self._deque_len - n_free)
+ return n_overflow
+
+ def read(
+ self,
+ n_samples: int | None = None,
+ ) -> Array:
+ """
+ Retrieves the oldest unread samples from the buffer with padding
+ and advances the read head.
+
+ Args:
+ n_samples: The number of samples to retrieve. If None, returns all
+ unread samples.
+
+ Returns:
+ An array containing the requested samples. This may be a view or a copy.
+ Note: The result may have more samples than the buffer.capacity as it
+ may include samples from the deque in the output.
+ """
+ n_samples = n_samples if n_samples is not None else self.available()
+ data = None
+ offset = 0
+ n_overflow = self._estimate_overflow(n_samples)
+ if n_overflow > 0:
+ first_read = self._buff_unread
+ if (n_overflow - first_read) < self.capacity or (
+ self._overflow_strategy == "drop"
+ ):
+ # We can prevent the overflow (or at least *some* if using "drop"
+ # strategy) by reading the samples in the buffer first to make room.
+ data = self.xp.empty(
+ (n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype
+ )
+ self.peek(first_read, out=data[:first_read])
+ offset += first_read
+ self.seek(first_read)
+ n_samples -= first_read
+ if data is None:
+ data = self.peek(n_samples)
+ self.seek(data.shape[0])
+ else:
+ d2 = self.peek(n_samples, out=data[offset:])
+ self.seek(d2.shape[0])
+
+ return data
+
+ def peek(self, n_samples: int | None = None, out: Array | None = None) -> Array:
+ """
+ Retrieves the oldest unread samples from the buffer with padding without
+ advancing the read head.
+
+ Args:
+ n_samples: The number of samples to retrieve. If None, returns all
+ unread samples.
+ out: Optionally, a destination array to store the samples.
+ If provided, must have shape (n_samples, *other_shape) where
+ other_shape matches the shape of the samples in the buffer.
+ If `out` is provided then the data will always be copied into it,
+ even if they are contiguous in the buffer.
+
+ Returns:
+ An array containing the requested samples. This may be a view or a copy.
+ Note: The result may have more samples than the buffer.capacity as it
+ may include samples from the deque in the output.
+ """
+ if n_samples is None:
+ n_samples = self.available()
+ elif n_samples > self.available():
+ raise ValueError(
+ f"Requested to peek {n_samples} samples, but only {self.available()} are available."
+ )
+ if out is not None and out.shape[0] < n_samples:
+ raise ValueError(
+ f"Output array shape {out.shape} is smaller than requested {n_samples} samples."
+ )
+
+ if n_samples == 0:
+ return self._buffer[:0]
+
+ self._flush_if_needed(n_samples=n_samples)
+
+ if self._tail + n_samples > self._capacity:
+ # discontiguous read (wraps around)
+ part1_len = self._capacity - self._tail
+ part2_len = n_samples - part1_len
+ out = (
+ out
+ if out is not None
+ else self.xp.empty(
+ (n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype
+ )
+ )
+ out[:part1_len] = self._buffer[self._tail :]
+ out[part1_len:] = self._buffer[:part2_len]
+ else:
+ if out is not None:
+ out[:] = self._buffer[self._tail : self._tail + n_samples]
+ else:
+ # No output array provided, just return a view
+ out = self._buffer[self._tail : self._tail + n_samples]
+
+ return out
+
+ def peek_at(self, idx: int, allow_flush: bool = False) -> Array:
+ """
+ Retrieves a specific sample from the buffer without advancing the read head.
+
+ Args:
+ idx: The index of the sample to retrieve, relative to the read head.
+ allow_flush: If True, allows flushing the deque to the buffer if the
+ requested sample is not in the buffer. If False and the sample is
+ in the deque, the sample will be retrieved from the deque (slow!).
+
+ Returns:
+ An array containing the requested sample. This may be a view or a copy.
+ """
+ if idx < 0 or idx >= self.available():
+ raise IndexError(f"Index {idx} out of bounds for unread samples.")
+
+ if not allow_flush and idx >= self._buff_unread:
+ # The requested sample is in the deque.
+ idx -= self._buff_unread
+ deq_splits = self.xp.cumsum(
+ [0] + [_.shape[0] for _ in self._deque], dtype=int
+ )
+ arr_idx = self.xp.searchsorted(deq_splits, idx, side="right") - 1
+ idx -= deq_splits[arr_idx]
+ return self._deque[arr_idx][idx : idx + 1]
+
+ self._flush_if_needed(n_samples=idx + 1)
+
+ # The requested sample is within the unread samples in the buffer.
+ idx = (self._tail + idx) % self._capacity
+ return self._buffer[idx : idx + 1]
+
+ def peek_last(self) -> Array:
+ """
+ Retrieves the last sample in the buffer without advancing the read head.
+ """
+ if self._deque:
+ return self._deque[-1][-1:]
+ elif self._buff_unread > 0:
+ idx = (self._head - 1 + self._capacity) % self._capacity
+ return self._buffer[idx : idx + 1]
+ else:
+ raise IndexError("Cannot peek last from an empty buffer.")
+
+ def seek(self, n_samples: int) -> int:
+ """
+ Advances the read head by n_samples.
+
+ Args:
+ n_samples: The number of samples to seek.
+ Will seek forward if positive or backward if negative.
+
+ Returns:
+ The number of samples actually skipped.
+ """
+ self._flush_if_needed(n_samples=n_samples)
+
+ n_to_seek = max(min(n_samples, self._buff_unread), -self._buff_read)
+
+ if n_to_seek == 0:
+ return 0
+
+ self._tail = (self._tail + n_to_seek) % self._capacity
+ self._buff_unread -= n_to_seek
+ self._buff_read += n_to_seek
+
+ return n_to_seek
+
+ def _flush_if_needed(self, n_samples: int | None = None):
+ if (
+ self._update_strategy == "on_demand"
+ and self._deque
+ and (n_samples is None or n_samples > self._buff_unread)
+ ):
+ self.flush()
+
+ def flush(self):
+ """
+ Transfers all data from the deque to the circular buffer.
+ Note: This may overwrite data depending on the overflow strategy,
+ which will invalidate previous state variables.
+ """
+ if not self._deque:
+ return
+
+ n_new = self._deque_len
+ n_free = self._capacity - self._buff_unread
+ n_overflow = max(0, n_new - n_free)
+
+ # If new data is larger than buffer and overflow strategy is "warn-overwrite",
+ # then we can take a shortcut and replace the entire buffer.
+ if n_new >= self._capacity and self._overflow_strategy == "warn-overwrite":
+ if n_overflow > 0 and (not self._warn_once or not self._warned):
+ self._warned = True
+ warnings.warn(
+ f"Buffer overflow: {n_new} samples received, but only {self._capacity - self._buff_unread} available. "
+ f"Overwriting {n_overflow} previous samples.",
+ RuntimeWarning,
+ )
+
+ # We need to grab the last `self._capacity` samples from the deque
+ samples_to_copy = self._capacity
+ copied_samples = 0
+ for block in reversed(self._deque):
+ if copied_samples >= samples_to_copy:
+ break
+ n_to_copy = min(block.shape[0], samples_to_copy - copied_samples)
+ start_idx = block.shape[0] - n_to_copy
+ self._buffer[
+ samples_to_copy - copied_samples - n_to_copy : samples_to_copy
+ - copied_samples
+ ] = block[start_idx:]
+ copied_samples += n_to_copy
+
+ self._head = 0
+ self._tail = 0
+ self._buff_unread = self._capacity
+ self._buff_read = 0
+ self._last_overflow = n_overflow
+
+ else:
+ if n_overflow > 0:
+ if self._overflow_strategy == "raise":
+ raise OverflowError(
+ f"Buffer overflow: {n_new} samples received, but only {n_free} available."
+ )
+ elif self._overflow_strategy == "warn-overwrite":
+ if not self._warn_once or not self._warned:
+ self._warned = True
+ warnings.warn(
+ f"Buffer overflow: {n_new} samples received, but only {n_free} available. "
+ f"Overwriting {n_overflow} previous samples.",
+ RuntimeWarning,
+ )
+ # Move the tail forward to make room for the new data.
+ self.seek(n_overflow)
+ # Adjust the read pointer to account for the overflow. Should always be 0.
+ self._buff_read = max(0, self._buff_read - n_overflow)
+ self._last_overflow = n_overflow
+ elif self._overflow_strategy == "drop":
+ # Drop the overflow samples from the deque
+ samples_to_drop = n_overflow
+ while samples_to_drop > 0 and self._deque:
+ block = self._deque[-1]
+ if samples_to_drop >= block.shape[0]:
+ samples_to_drop -= block.shape[0]
+ self._deque.pop()
+ else:
+ block = self._deque.pop()
+ self._deque.append(block[:-samples_to_drop])
+ samples_to_drop = 0
+ n_new -= n_overflow
+ self._last_overflow = n_overflow
+
+ elif self._overflow_strategy == "grow":
+ self._grow_buffer(self._capacity + n_new)
+ self._last_overflow = 0
+
+ # Copy data to buffer by iterating over the deque
+ for block in self._deque:
+ n_block = block.shape[0]
+ space_til_end = self._capacity - self._head
+ if n_block > space_til_end:
+ # Two-part copy (wraps around)
+ part1_len = space_til_end
+ part2_len = n_block - part1_len
+ self._buffer[self._head :] = block[:part1_len]
+ self._buffer[:part2_len] = block[part1_len:]
+ else:
+ # Single-part copy
+ self._buffer[self._head : self._head + n_block] = block
+ self._head = (self._head + n_block) % self._capacity
+
+ self._buff_unread += n_new
+ if (self._buff_read > self._tail) or (self._tail > self._head):
+ # We have wrapped around the buffer; our count of read samples
+ # is simply the buffer capacity minus the count of unread samples.
+ self._buff_read = self._capacity - self._buff_unread
+
+ self._deque.clear()
+ self._deque_len = 0
+
+ def _grow_buffer(self, min_capacity: int):
+ """
+ Grows the buffer to at least min_capacity.
+ This is a helper method for the overflow strategy "grow".
+ """
+ if self._capacity >= min_capacity:
+ return
+
+ other_shape = self._buffer.shape[1:]
+ max_capacity = self._max_size / (
+ self._buffer.dtype.itemsize * math.prod(other_shape)
+ )
+ if min_capacity > max_capacity:
+ raise OverflowError(
+ f"Cannot grow buffer to {min_capacity} samples, "
+ f"maximum capacity is {max_capacity} samples ({self._max_size} bytes)."
+ )
+
+ new_capacity = min(max_capacity, max(self._capacity * 2, min_capacity))
+ new_buffer = self.xp.empty(
+ (new_capacity, *other_shape), dtype=self._buffer.dtype
+ )
+
+ # Copy existing data to new buffer
+ total_samples = self._buff_read + self._buff_unread
+ if total_samples > 0:
+ start_idx = (self._tail - self._buff_read) % self._capacity
+ stop_idx = (self._tail + self._buff_unread) % self._capacity
+ if stop_idx > start_idx:
+ # Data is contiguous
+ new_buffer[:total_samples] = self._buffer[start_idx:stop_idx]
+ else:
+ # Data wraps around. We write it in 2 parts.
+ part1_len = self._capacity - start_idx
+ part2_len = stop_idx
+ new_buffer[:part1_len] = self._buffer[start_idx:]
+ new_buffer[part1_len : part1_len + part2_len] = self._buffer[:stop_idx]
+ # self._buff_read stays the same
+ self._tail = self._buff_read
+ # self._buff_unread stays the same
+ self._head = self._tail + self._buff_unread
+ else:
+ self._tail = 0
+ self._head = 0
+
+ self._buffer = new_buffer
+ self._capacity = new_capacity
diff --git a/src/ezmsg/sigproc/window.py b/src/ezmsg/sigproc/window.py
index 14ece978..8a25c7f4 100644
--- a/src/ezmsg/sigproc/window.py
+++ b/src/ezmsg/sigproc/window.py
@@ -209,13 +209,15 @@ def _process(self, message: AxisArray) -> AxisArray:
)
# Create a vector of buffer timestamps to track axis `offset` in output(s)
- buffer_tvec = xp.asarray(range(self._state.buffer.shape[axis_idx]), dtype=float)
+ buffer_t0 = 0.0
+ buffer_tlen = self._state.buffer.shape[axis_idx]
# Adjust so first _new_ sample at index 0.
- buffer_tvec -= buffer_tvec[-message.data.shape[axis_idx]]
+ buffer_t0 -= self._state.buffer.shape[axis_idx] - message.data.shape[axis_idx]
+
# Convert form indices to 'units' (probably seconds).
- buffer_tvec *= axis_info.gain
- buffer_tvec += axis_info.offset
+ buffer_t0 *= axis_info.gain
+ buffer_t0 += axis_info.offset
if self.settings.window_shift is not None and self._state.shift_deficit > 0:
n_skip = min(self._state.buffer.shape[axis_idx], self._state.shift_deficit)
@@ -223,7 +225,8 @@ def _process(self, message: AxisArray) -> AxisArray:
self._state.buffer = slice_along_axis(
self._state.buffer, slice(n_skip, None), axis_idx
)
- buffer_tvec = buffer_tvec[n_skip:]
+ buffer_t0 += n_skip * axis_info.gain
+ buffer_tlen -= n_skip
self._state.shift_deficit -= n_skip
# Generate outputs.
@@ -250,7 +253,9 @@ def _process(self, message: AxisArray) -> AxisArray:
+ (1,)
+ self._state.buffer.shape[axis_idx:]
)
- win_offset = buffer_tvec[-self._state.window_samples]
+ win_offset = buffer_t0 + axis_info.gain * (
+ buffer_tlen - self._state.window_samples
+ )
elif self._state.buffer.shape[axis_idx] >= self._state.window_samples:
# Deterministic window shifts.
sliding_win_fun = (
@@ -264,10 +269,7 @@ def _process(self, message: AxisArray) -> AxisArray:
axis_idx,
step=self._state.window_shift_samples,
)
- offset_view = sliding_win_fun(buffer_tvec, self._state.window_samples, 0)[
- :: self._state.window_shift_samples
- ]
- win_offset = offset_view[0, 0]
+ win_offset = buffer_t0
# Drop expired beginning of buffer and update shift_deficit
multi_shift = self._state.window_shift_samples * out_dat.shape[axis_idx]
diff --git a/tests/helpers/util.py b/tests/helpers/util.py
index 669c2833..4d9ee9c8 100644
--- a/tests/helpers/util.py
+++ b/tests/helpers/util.py
@@ -153,7 +153,7 @@ def calculate_expected_windows(
# 1:1 mode. Each input (block) yields a new output.
# If the window length is smaller than the block size then we only the tail of each block.
first = max(min(msg_block_size, data_len) - win_len, 0)
- if tvec[::msg_block_size].shape[0] < n_msgs:
+ if tvec[first::msg_block_size].shape[0] < n_msgs:
expected = np.concatenate(
(expected[:, first::msg_block_size], expected[:, -1:]), axis=1
)
diff --git a/tests/unit/buffer/test_axisarray_buffer.py b/tests/unit/buffer/test_axisarray_buffer.py
new file mode 100644
index 00000000..3a213191
--- /dev/null
+++ b/tests/unit/buffer/test_axisarray_buffer.py
@@ -0,0 +1,589 @@
+import pytest
+import numpy as np
+from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, CoordinateAxis
+
+from ezmsg.sigproc.util.axisarray_buffer import HybridAxisBuffer, HybridAxisArrayBuffer
+
+
+class TestHybridAxisBuffer:
+ """Test suite for HybridAxisBuffer"""
+
+ def test_uninitialized_state(self):
+ """Test buffer state before initialization"""
+ buf = HybridAxisBuffer(duration=1.0)
+
+ assert buf.capacity == 0
+ assert buf.available() == 0
+ assert buf.is_empty() is True
+ assert buf.is_full() is False
+ assert buf.gain is None
+ assert buf.final_value is None
+
+ def test_linear_axis_initialization(self):
+ """Test initialization with LinearAxis"""
+ buf = HybridAxisBuffer(duration=1.0)
+
+ # Create a LinearAxis with 1kHz sampling (gain=0.001)
+ axis = LinearAxis(gain=0.001, offset=0.0)
+ buf._initialize(axis)
+
+ assert buf._linear_axis is not None
+ assert buf._coords_buffer is None
+ assert buf.capacity == 1000 # 1.0 sec / 0.001 gain
+ assert buf.gain == 0.001
+ assert buf.available() == 0
+
+ def test_coordinate_axis_initialization(self):
+ """Test initialization with CoordinateAxis"""
+ buf = HybridAxisBuffer(duration=1.0)
+
+ # Create CoordinateAxis with timestamps - note it needs dims parameter
+ timestamps = np.linspace(0, 0.1, 101) # 100 intervals, 0.001 gain
+ axis = CoordinateAxis(data=timestamps, dims=["time"])
+ buf._initialize(axis)
+
+ assert buf._coords_buffer is not None
+ assert buf._linear_axis is None
+ assert buf.capacity == 1000 # 1.0 sec / 0.001 gain
+ assert buf.gain == pytest.approx(0.001)
+ assert buf.available() == 0
+
+ def test_coordinate_axis_single_sample(self):
+ """Test CoordinateAxis initialization with single sample"""
+ buf = HybridAxisBuffer(duration=1.0)
+
+ # Single timestamp should default to gain of 1.0
+ axis = CoordinateAxis(data=np.array([0.0]), dims=["time"])
+ buf._initialize(axis)
+
+ assert buf.capacity == 1 # 1.0 sec / 1.0 gain
+ assert buf.gain == 1.0
+
+ def test_linear_axis_write_and_read(self):
+ """Test writing and reading with LinearAxis"""
+ buf = HybridAxisBuffer(duration=1.0)
+
+ # Initialize with LinearAxis
+ axis1 = LinearAxis(gain=0.001, offset=0.0)
+ buf.write(axis1, n_samples=100)
+
+ assert buf.available() == 100
+ assert buf._linear_n_available == 100
+ assert buf._linear_axis.offset == 0.0
+
+ # Write more samples with different offset
+ # The expected offset for the next write was 0.1,
+ # but we provide 0.15, which causes the original
+ # samples write operation to be adjusted to 0.05
+ axis2 = LinearAxis(gain=0.001, offset=0.15)
+ buf.write(axis2, n_samples=50)
+
+ assert buf.available() == 150
+ # Offset should be adjusted to oldest sample
+ assert buf._linear_axis.offset == pytest.approx(0.05) # 0.15 - 100*0.001
+
+ # Peek at the axis
+ peeked_axis = buf.peek(50)
+ assert isinstance(peeked_axis, LinearAxis)
+ assert peeked_axis.offset == pytest.approx(0.05)
+ assert peeked_axis.gain == 0.001
+
+ # Seek forward
+ sought = buf.seek(50)
+ assert sought == 50
+ assert buf.available() == 100
+ assert buf._linear_axis.offset == pytest.approx(0.1) # 0.05 + 50*0.001
+
+ def test_coordinate_axis_write_and_read(self):
+ """Test writing and reading with CoordinateAxis"""
+ buf = HybridAxisBuffer(duration=1.0, update_strategy="immediate")
+
+ # Initialize with CoordinateAxis
+ timestamps1 = np.linspace(0, 0.099, 100)
+ axis1 = CoordinateAxis(data=timestamps1, dims=["time"])
+ buf.write(axis1, n_samples=100)
+
+ assert buf.available() == 100
+
+ # Write more samples
+ timestamps2 = np.linspace(0.1, 0.149, 50)
+ axis2 = CoordinateAxis(data=timestamps2, dims=["time"])
+ buf.write(axis2, n_samples=50)
+
+ assert buf.available() == 150
+
+ # Peek at the axis
+ peeked_axis = buf.peek(75)
+ assert isinstance(peeked_axis, CoordinateAxis)
+ assert len(peeked_axis.data) == 75
+ np.testing.assert_array_almost_equal(peeked_axis.data, timestamps1[:75])
+
+ # Seek forward
+ sought = buf.seek(50)
+ assert sought == 50
+ assert buf.available() == 100
+
+ # Peek again - should start from sample 50
+ peeked_axis = buf.peek(50)
+ np.testing.assert_array_almost_equal(peeked_axis.data, timestamps1[50:])
+
+ def test_prune(self):
+ """Test pruning samples from buffer"""
+ buf = HybridAxisBuffer(duration=1.0, update_strategy="immediate")
+
+ # Write 200 samples with LinearAxis
+ axis = LinearAxis(gain=0.001, offset=0.0)
+ buf.write(axis, n_samples=200)
+
+ assert buf.available() == 200
+
+ # Prune to keep only last 50 samples
+ pruned = buf.prune(50)
+ assert pruned == 150
+ assert buf.available() == 50
+ assert buf._linear_axis.offset == pytest.approx(0.15) # 0.0 + 150*0.001
+
+ # Try pruning more than available (should do nothing)
+ pruned = buf.prune(100)
+ assert pruned == 0
+ assert buf.available() == 50
+
+ def test_final_value_linear(self):
+ """Test getting final value with LinearAxis"""
+ buf = HybridAxisBuffer(duration=1.0)
+
+ axis = LinearAxis(gain=0.01, offset=1.0)
+ buf.write(axis, n_samples=50)
+
+ # Final value should be offset + (n_samples-1) * gain
+ expected = 1.0 + 49 * 0.01
+ assert buf.final_value == pytest.approx(expected)
+
+ def test_final_value_coordinate(self):
+ """Test getting final value with CoordinateAxis"""
+ buf = HybridAxisBuffer(duration=1.0, update_strategy="immediate")
+
+ timestamps = np.array([1.0, 1.1, 1.2, 1.3, 1.4])
+ axis = CoordinateAxis(data=timestamps, dims=["time"])
+ buf.write(axis, n_samples=5)
+
+ # Final value should be the last timestamp
+ assert buf.final_value == pytest.approx(1.4)
+
+ # Note: final_value accesses the last element directly from peek()
+ # which returns the actual data value
+
+ def test_searchsorted_linear(self):
+ """Test searchsorted with LinearAxis"""
+ buf = HybridAxisBuffer(duration=1.0)
+
+ # Create axis with samples at 0.0, 0.01, 0.02, ..., 0.49
+ axis = LinearAxis(gain=0.01, offset=0.0)
+ buf.write(axis, n_samples=50)
+
+ sim_values = axis.value(np.arange(50))
+
+ # Test single value in between
+ test_val = 0.025
+ expected = np.searchsorted(sim_values, test_val)
+ idx = buf.searchsorted(test_val)
+ assert idx == expected
+
+ # Test array of values, at least one of which should be equal to an axis value
+ values = np.array([0.015, 0.035, 0.055, 0.1])
+ for side in ["left", "right"]:
+ expected_indices = np.searchsorted(sim_values, values, side=side)
+ indices = buf.searchsorted(values, side=side)
+ np.testing.assert_array_equal(indices, expected_indices)
+
+ # Test with empty buffer
+ buf.seek(50) # Clear buffer
+ sim_values = sim_values[50:]
+ assert buf.searchsorted(test_val) == np.searchsorted(sim_values, test_val)
+ np.testing.assert_array_equal(
+ buf.searchsorted(values), np.searchsorted(sim_values, values)
+ )
+
+ def test_searchsorted_coordinate(self):
+ """Test searchsorted with CoordinateAxis"""
+ buf = HybridAxisBuffer(duration=1.0, update_strategy="immediate")
+
+ timestamps = np.array([0.0, 0.01, 0.02, 0.03, 0.04])
+ axis = CoordinateAxis(data=timestamps, dims=["time"])
+ buf.write(axis, n_samples=5)
+
+ # Test single value
+ idx = buf.searchsorted(0.015)
+ assert idx == 1 or idx == 2 # Depends on searchsorted implementation
+
+ # Test array of values
+ values = np.array([0.005, 0.025, 0.045])
+ indices = buf.searchsorted(values)
+ assert all(0 <= idx <= 5 for idx in indices)
+
+ def test_overflow_behavior(self):
+ """Test buffer overflow with LinearAxis"""
+ buf = HybridAxisBuffer(duration=0.1) # Small buffer
+
+ # Initialize with high-frequency axis (10kHz)
+ axis = LinearAxis(gain=0.0001, offset=0.0)
+ buf.write(axis, n_samples=500) # 0.05 seconds of data
+
+ assert buf.available() == 500
+ assert buf.capacity == 1000 # 0.1 sec / 0.0001
+
+ # Write more data to cause overflow
+ axis2 = LinearAxis(gain=0.0001, offset=0.1)
+ buf.write(axis2, n_samples=700) # Total would be 1200
+
+ # Even though LinearAxis doesn't have a true capacity limit,
+ # we simulate one anyway to stay in sync with sister buffers
+ # (e.g., in HybridAxisArrayBuffer)
+ assert buf.available() == 1000
+
+ # But capacity remains the same
+ assert buf.capacity == 1000
+
+ def test_mixed_axis_types_error(self):
+ """Test that mixing axis types raises an error"""
+ buf = HybridAxisBuffer(duration=1.0, update_strategy="immediate")
+
+ # Initialize with LinearAxis
+ linear_axis = LinearAxis(gain=0.001, offset=0.0)
+ buf.write(linear_axis, n_samples=10)
+
+ # Try to write CoordinateAxis - should fail
+ coord_axis = CoordinateAxis(data=np.linspace(0, 0.01, 11), dims=["time"])
+ with pytest.raises(TypeError):
+ buf.write(coord_axis)
+
+ def test_buffer_kwargs_passthrough(self):
+ """Test that kwargs are passed through to underlying buffer"""
+ buf = HybridAxisBuffer(duration=1.0, update_strategy="threshold", threshold=50)
+
+ # Initialize with CoordinateAxis to create internal buffer
+ timestamps = np.linspace(0, 0.1, 101)
+ axis = CoordinateAxis(data=timestamps, dims=["time"])
+ buf.write(axis, n_samples=101)
+
+ # Check that kwargs were passed through
+ assert buf._coords_buffer._update_strategy == "threshold"
+ assert buf._coords_buffer._threshold == 50
+
+ def test_linear_axis_value_and_index(self):
+ """Test LinearAxis value() and index() methods"""
+ buf = HybridAxisBuffer(duration=1.0)
+
+ # Create axis with specific gain and offset
+ axis = LinearAxis(gain=0.01, offset=5.0, unit="ms")
+ buf._initialize(axis)
+
+ # Test that the axis methods work correctly
+ assert axis.value(0) == 5.0
+ assert axis.value(10) == 5.1
+ assert axis.value(np.array([0, 10, 20])) == pytest.approx([5.0, 5.1, 5.2])
+
+ # Test index calculation (inverse of value)
+ assert axis.index(5.0) == 0
+ assert axis.index(5.1) == 10
+ assert axis.index(5.05) == 5 # Should round by default
+
+ # Test with numpy array
+ values = np.array([5.0, 5.1, 5.2])
+ indices = axis.index(values)
+ np.testing.assert_array_equal(indices, [0, 10, 20])
+
+ def test_edge_cases(self):
+ """Test various edge cases"""
+ # Test with very small gain (large capacity)
+ buf = HybridAxisBuffer(duration=1.0)
+ axis_small_gain = LinearAxis(gain=0.00001, offset=0.0)
+ buf._initialize(axis_small_gain)
+ assert buf.capacity == 100000 # 1.0 / 0.00001
+
+ # Test with zero duration (would cause division by zero)
+ buf2 = HybridAxisBuffer(duration=0.0)
+ axis = LinearAxis(gain=0.001, offset=0.0)
+ buf2._initialize(axis)
+ assert buf2.capacity == 0
+
+ # Test peek with None (should return all available)
+ buf3 = HybridAxisBuffer(duration=1.0)
+ axis3 = LinearAxis(gain=0.01, offset=0.0)
+ buf3.write(axis3, n_samples=50)
+ peeked = buf3.peek(None)
+ assert peeked.offset == 0.0
+ assert peeked.gain == 0.01
+
+
+@pytest.fixture
+def linear_axis_message():
+ def _create(samples=10, channels=2, fs=100.0, offset=0.0):
+ shape = (samples, channels)
+ dims = ["time", "ch"]
+ data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
+ gain = 1.0 / fs if fs > 0 else 0
+ axes = {
+ "time": LinearAxis(gain=gain, offset=offset),
+ "ch": CoordinateAxis(data=np.arange(channels).astype(str), dims=["ch"]),
+ }
+ return AxisArray(data, dims, axes=axes)
+
+ return _create
+
+
+@pytest.fixture
+def coordinate_axis_message():
+ def _create(samples=10, channels=2, start_time=0.0, interval=0.01):
+ shape = (samples, channels)
+ dims = ["time", "ch"]
+ data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
+ timestamps = np.arange(samples) * interval + start_time
+ axes = {
+ "time": CoordinateAxis(data=timestamps, dims=["time"]),
+ "ch": CoordinateAxis(data=np.arange(channels).astype(str), dims=["ch"]),
+ }
+ return AxisArray(data, dims, axes=axes)
+
+ return _create
+
+
+def test_deferred_initialization_linear(linear_axis_message):
+ buf = HybridAxisArrayBuffer(duration=1.0) # 1 second buffer
+ assert buf.available() == 0
+ assert buf._data_buffer is None
+ assert buf._axis_buffer is not None
+ assert buf._axis_buffer._linear_axis is None
+ assert buf._axis_buffer._coords_buffer is None
+ assert buf._template_msg is None
+
+ msg = linear_axis_message(fs=100.0)
+ buf.write(msg)
+
+ assert buf.available() == 10
+ assert buf._data_buffer is not None
+ assert buf._data_buffer.capacity == 100 # 1.0s * 100Hz
+ assert buf._axis_buffer._linear_axis.offset == 0.00
+ assert buf._template_msg is not None and buf._template_msg.dims == ["time", "ch"]
+
+
+def test_deferred_initialization_coordinate(coordinate_axis_message):
+ buf = HybridAxisArrayBuffer(duration=1.0)
+ msg = coordinate_axis_message(samples=10, interval=0.01) # Effective fs = 100Hz
+ buf.write(msg)
+
+ assert buf.available() == 10
+ assert buf._data_buffer is not None
+ assert buf._data_buffer.capacity == 100
+ assert buf._axis_buffer is not None
+ assert buf._axis_buffer.capacity == 100
+
+
+def test_add_and_get_linear(linear_axis_message):
+ buf = HybridAxisArrayBuffer(duration=1.0, update_strategy="immediate")
+ msg1 = linear_axis_message(samples=10, fs=100.0, offset=0.0)
+ buf.write(msg1)
+
+ msg2 = linear_axis_message(samples=10, fs=100.0, offset=0.1)
+ buf.write(msg2)
+
+ assert buf.available() == 20
+ retrieved_msg = buf.read(15)
+ assert retrieved_msg.shape == (15, 2)
+ assert retrieved_msg.dims == msg1.dims
+ # Last sample of msg2 is at 0.1 + 9*0.01 = 0.19. Total unread was 20.
+ # Offset of oldest sample = 0.19 - (20-1)*0.01 = 0.0
+ assert retrieved_msg.axes["time"].offset == pytest.approx(0.0)
+ expected_data = np.concatenate([msg1.data, msg2.data[:5]])
+ np.testing.assert_array_equal(retrieved_msg.data, expected_data)
+
+ # Check that the buffer now has 5 samples left
+ assert buf.available() == 5
+ remaining_msg = buf.read()
+ np.testing.assert_array_equal(remaining_msg.data, msg2.data[5:])
+
+
+def test_get_all_data_default(linear_axis_message):
+ buf = HybridAxisArrayBuffer(duration=1.0)
+ msg1 = linear_axis_message(samples=10)
+ msg2 = linear_axis_message(samples=15)
+ buf.write(msg1)
+ buf.write(msg2)
+
+ retrieved = buf.read()
+ assert retrieved.shape[0] == 25
+ assert buf.available() == 0
+
+
+def test_add_and_get_coordinate(coordinate_axis_message):
+ buf = HybridAxisArrayBuffer(duration=1.0, update_strategy="immediate")
+ msg1 = coordinate_axis_message(samples=10, start_time=0.0)
+ buf.write(msg1)
+
+ msg2 = coordinate_axis_message(samples=10, start_time=0.1)
+ buf.write(msg2)
+
+ assert buf.available() == 20
+ retrieved_msg = buf.read(15)
+ assert retrieved_msg.shape == (15, 2)
+ assert retrieved_msg.dims == msg1.dims
+
+ expected_data = np.concatenate([msg1.data, msg2.data[:5]])
+ np.testing.assert_array_equal(retrieved_msg.data, expected_data)
+
+ expected_times = np.concatenate(
+ [msg1.axes["time"].data, msg2.axes["time"].data[:5]]
+ )
+ np.testing.assert_allclose(retrieved_msg.axes["time"].data, expected_times)
+
+ assert buf.available() == 5
+
+
+def test_type_mismatch_error(linear_axis_message, coordinate_axis_message):
+ buf = HybridAxisArrayBuffer(duration=1.0)
+ buf.write(linear_axis_message())
+ with pytest.raises(TypeError):
+ buf.write(coordinate_axis_message())
+
+
+def test_peek_linear(linear_axis_message):
+ buf = HybridAxisArrayBuffer(duration=1.0, update_strategy="immediate")
+ msg1 = linear_axis_message(samples=10, fs=100.0, offset=0.0)
+ buf.write(msg1)
+ msg2 = linear_axis_message(samples=10, fs=100.0, offset=0.1)
+ buf.write(msg2)
+
+ assert buf.available() == 20
+ peeked_msg = buf.peek(15)
+ assert peeked_msg.shape == (15, 2)
+ assert peeked_msg.dims == msg1.dims
+ assert peeked_msg.axes["time"].offset == pytest.approx(0.0)
+ expected_data = np.concatenate([msg1.data, msg2.data[:5]])
+ np.testing.assert_array_equal(peeked_msg.data, expected_data)
+
+ # Assert that state has not changed
+ assert buf.available() == 20
+ # The underlying _data_buffer._tail should still be 0
+ assert buf._data_buffer._tail == 0
+
+ # Get the data to prove it was still there
+ retrieved_msg = buf.read(15)
+ np.testing.assert_array_equal(retrieved_msg.data, expected_data)
+ assert buf.available() == 5
+
+
+def test_peek_coordinate(coordinate_axis_message):
+ buf = HybridAxisArrayBuffer(duration=1.0, update_strategy="immediate")
+ msg1 = coordinate_axis_message(samples=10, start_time=0.0)
+ buf.write(msg1)
+ msg2 = coordinate_axis_message(samples=10, start_time=0.1)
+ buf.write(msg2)
+
+ assert buf.available() == 20
+ peeked_msg = buf.peek(15)
+ assert peeked_msg.shape == (15, 2)
+ assert peeked_msg.dims == msg1.dims
+ expected_data = np.concatenate([msg1.data, msg2.data[:5]])
+ np.testing.assert_array_equal(peeked_msg.data, expected_data)
+ expected_times = np.concatenate(
+ [msg1.axes["time"].data, msg2.axes["time"].data[:5]]
+ )
+ np.testing.assert_allclose(peeked_msg.axes["time"].data, expected_times)
+
+ # Assert that state has not changed
+ assert buf.available() == 20
+ assert buf._data_buffer.tell() == 0
+ assert buf._axis_buffer._coords_buffer.tell() == 0
+
+ # Get the data to prove it was still there
+ retrieved_msg = buf.read(15)
+ np.testing.assert_array_equal(retrieved_msg.data, expected_data)
+ assert buf.available() == 5
+
+
+def test_seek_linear(linear_axis_message):
+ buf = HybridAxisArrayBuffer(duration=1.0, update_strategy="immediate")
+ msg1 = linear_axis_message(samples=10, fs=100.0, offset=0.0)
+ buf.write(msg1)
+ msg2 = linear_axis_message(samples=10, fs=100.0, offset=0.1)
+ buf.write(msg2)
+
+ assert buf.available() == 20
+ skipped_count = buf.seek(10)
+ assert skipped_count == 10
+ assert buf.available() == 10
+ assert buf._data_buffer._tail == 10
+ assert buf._axis_buffer._linear_axis.offset == pytest.approx(0.1)
+
+ # Get the remaining data
+ retrieved_msg = buf.read()
+ assert retrieved_msg.shape == (10, 2)
+ np.testing.assert_array_equal(retrieved_msg.data, msg2.data)
+ # Offset should be 0.1 (start of msg2)
+ assert retrieved_msg.axes["time"].offset == pytest.approx(0.1)
+
+
+def test_seek_coordinate(coordinate_axis_message):
+ buf = HybridAxisArrayBuffer(duration=1.0, update_strategy="immediate")
+ msg1 = coordinate_axis_message(samples=10, start_time=0.0)
+ buf.write(msg1)
+ msg2 = coordinate_axis_message(samples=10, start_time=0.1)
+ buf.write(msg2)
+
+ assert buf.available() == 20
+ skipped_count = buf.seek(10)
+ assert skipped_count == 10
+ assert buf.available() == 10
+ assert buf._data_buffer._tail == 10
+ assert buf._axis_buffer.available() == 10
+
+ # Get the remaining data
+ retrieved_msg = buf.read()
+ assert retrieved_msg.shape == (10, 2)
+ np.testing.assert_array_equal(retrieved_msg.data, msg2.data)
+ np.testing.assert_allclose(retrieved_msg.axes["time"].data, msg2.axes["time"].data)
+
+
+def test_prune(linear_axis_message):
+ buf = HybridAxisArrayBuffer(duration=1.0, update_strategy="immediate")
+ buf.write(linear_axis_message(samples=20))
+ assert buf.available() == 20
+ pruned_count = buf.prune(5)
+ assert pruned_count == 15
+ assert buf.available() == 5
+ retrieved = buf.read()
+ assert retrieved.shape[0] == 5
+
+
+def test_searchsorted_linear(linear_axis_message):
+ buf = HybridAxisArrayBuffer(duration=1.0, update_strategy="immediate")
+ buf.write(linear_axis_message(samples=20, fs=100.0, offset=0.1))
+ # Buffer now has timestamps from 0.1 to 0.29
+ indices = buf.axis_searchsorted(np.array([0.1, 0.15, 0.29]))
+ np.testing.assert_array_equal(indices, np.array([0, 5, 19]))
+
+
+def test_searchsorted_coordinate(coordinate_axis_message):
+ buf = HybridAxisArrayBuffer(duration=1.0, update_strategy="immediate")
+ buf.write(coordinate_axis_message(samples=20, start_time=0.1, interval=0.01))
+ indices = buf.axis_searchsorted(np.array([0.1, 0.15, 0.29]))
+ np.testing.assert_array_equal(indices, np.array([0, 5, 19]))
+
+
+def test_permute_dims(linear_axis_message):
+ buf = HybridAxisArrayBuffer(duration=1.0, axis="time", update_strategy="immediate")
+ msg = linear_axis_message(samples=10, fs=100.0, offset=0.0)
+ # Swap the axes
+ msg.dims = ["ch", "time"]
+ msg.data = np.ascontiguousarray(msg.data.T)
+ # Write the message; it should automatically permute the dimensions back to ["time", "ch"]
+ buf.write(msg)
+ assert buf.available() == 10
+ assert buf._data_buffer is not None
+ assert buf._data_buffer.capacity == 100 # 1.0s * 100Hz
+ assert buf._axis_buffer._linear_axis.offset == 0.00
+ assert buf._template_msg is not None and buf._template_msg.dims == ["time", "ch"]
+ assert msg.dims == ["ch", "time"] # Unchanged
+ retrieved = buf.read()
+ assert retrieved.dims == ["time", "ch"]
+ assert retrieved.shape == (10, 2)
diff --git a/tests/unit/buffer/test_buffer.py b/tests/unit/buffer/test_buffer.py
new file mode 100644
index 00000000..9beeec54
--- /dev/null
+++ b/tests/unit/buffer/test_buffer.py
@@ -0,0 +1,442 @@
+import pytest
+import numpy as np
+from ezmsg.sigproc.util.buffer import HybridBuffer
+
+
+@pytest.fixture
+def buffer_params():
+ return {
+ "array_namespace": np,
+ "capacity": 100,
+ "other_shape": (2,),
+ "dtype": np.float32,
+ "update_strategy": "immediate",
+ "threshold": 0,
+ "overflow_strategy": "warn-overwrite",
+ "max_size": 1024**3, # 1 GB
+ }
+
+
+def test_initialization(buffer_params):
+ buf = HybridBuffer(**buffer_params)
+ assert buf.available() == 0
+ assert not buf.is_full()
+ assert buf.is_empty()
+ assert buf.capacity == buffer_params["capacity"]
+ assert buf._update_strategy == buffer_params["update_strategy"]
+ assert buf._threshold == buffer_params["threshold"]
+ assert buf._overflow_strategy == buffer_params["overflow_strategy"]
+ assert buf._max_size == buffer_params["max_size"]
+ assert buf._buffer.shape[1:] == buffer_params["other_shape"]
+ assert buf._buffer.dtype == buffer_params["dtype"]
+
+
+def test_add_and_get_simple(buffer_params):
+ buf = HybridBuffer(**buffer_params)
+ shape = (10, *buffer_params["other_shape"])
+ data = np.arange(np.prod(shape), dtype=buffer_params["dtype"]).reshape(shape)
+ buf.write(data)
+ assert buf.available() == 10
+ retrieved_data = buf.read(10)
+ np.testing.assert_array_equal(data, retrieved_data)
+
+
+def test_add_1d_message():
+ buf = HybridBuffer(
+ array_namespace=np,
+ capacity=10,
+ other_shape=(1,),
+ dtype=np.float32,
+ update_strategy="immediate",
+ )
+ data = np.arange(5, dtype=np.float32)
+ buf.write(data)
+ assert buf.available() == 5
+ retrieved = buf.read(5)
+ assert retrieved.shape == (5, 1)
+ np.testing.assert_array_equal(data, retrieved.squeeze())
+
+
+def test_get_data_raises_error(buffer_params):
+ buf = HybridBuffer(**buffer_params)
+ data = np.zeros((10, *buffer_params["other_shape"]))
+ buf.write(data)
+ with pytest.raises(ValueError):
+ buf.read(11)
+
+
+def test_add_raises_error_on_shape(buffer_params):
+ buf = HybridBuffer(**buffer_params)
+ wrong_shape = (10, *[d + 1 for d in buffer_params["other_shape"]])
+ data = np.zeros(wrong_shape)
+ with pytest.raises(ValueError):
+ buf.write(data)
+
+
+def test_strategy_on_demand(buffer_params):
+ buf = HybridBuffer(**{**buffer_params, "update_strategy": "on_demand"})
+
+ n_write_1 = 10
+ shape = (n_write_1, *buffer_params["other_shape"])
+ data1 = np.ones(shape)
+ buf.write(data1)
+ assert len(buf._deque) == 1
+ assert buf._buff_unread == 0 # Not synced yet
+ assert buf.available() == n_write_1
+
+ n_write_2 = 5
+ shape2 = (n_write_2, *buffer_params["other_shape"])
+ data2 = np.ones(shape2) * 2
+ buf.write(data2)
+ assert len(buf._deque) == 2
+ assert buf._buff_unread == 0
+ assert buf.available() == n_write_1 + n_write_2
+
+ n_read_1 = 7
+ n_read_2 = (n_write_1 + n_write_2) - n_read_1
+ retrieved = buf.read(n_read_1)
+ assert len(buf._deque) == 0 # Synced now
+ assert buf.available() == n_read_2
+ assert retrieved.shape == (n_read_1, *buffer_params["other_shape"])
+ np.testing.assert_array_equal(retrieved, data1[:n_read_1])
+
+ retrieved = buf.read() # Get all remaining
+ assert buf.available() == 0
+ assert retrieved.shape == (n_read_2, *buffer_params["other_shape"])
+ np.testing.assert_array_equal(retrieved[: (n_write_1 - n_read_1)], data1[n_read_1:])
+ np.testing.assert_array_equal(retrieved[(n_write_1 - n_read_1) :], data2)
+
+
+def test_strategy_immediate(buffer_params):
+ buf = HybridBuffer(**buffer_params)
+
+ n_write_1 = 10
+ shape1 = (n_write_1, *buffer_params["other_shape"])
+ data1 = np.ones(shape1)
+ buf.write(data1)
+ assert len(buf._deque) == 0
+ assert buf._buff_unread == n_write_1
+ assert buf.available() == n_write_1
+
+ n_write_2 = 5
+ shape2 = (n_write_2, *buffer_params["other_shape"])
+ data2 = np.ones(shape2) * 2
+ buf.write(data2)
+ assert len(buf._deque) == 0
+ assert buf._buff_unread == (n_write_1 + n_write_2)
+ assert buf.available() == (n_write_1 + n_write_2)
+
+ retrieved = buf.read()
+ np.testing.assert_array_equal(retrieved[:n_write_1], data1)
+ np.testing.assert_array_equal(retrieved[n_write_1:], data2)
+
+
+def test_strategy_threshold(buffer_params):
+ new_params = {**buffer_params, "update_strategy": "threshold", "threshold": 15}
+ buf = HybridBuffer(**new_params)
+
+ shape1 = (10, *buffer_params["other_shape"])
+ data1 = np.ones(shape1)
+ buf.write(data1)
+ assert len(buf._deque) == 1
+ assert buf.available() == 10
+ assert buf._buff_unread == 0
+
+ shape2 = (4, *buffer_params["other_shape"]) # Total = 14, under threshold
+ data2 = np.ones(shape2)
+ buf.write(data2)
+ assert len(buf._deque) == 2
+ assert buf.available() == 14
+ assert buf._buff_unread == 0
+
+ shape3 = (1, *buffer_params["other_shape"]) # Total = 15, meets threshold
+ data3 = np.ones(shape3)
+ buf.write(data3)
+ assert len(buf._deque) == 0
+ assert buf.available() == 15
+ assert buf._buff_unread == 15
+
+
+def test_buffer_overflow_warn_overwrite(buffer_params):
+ buf = HybridBuffer(**buffer_params)
+ cap = buffer_params["capacity"]
+ # Fill the buffer completely
+ buf.write(np.zeros((cap, *buffer_params["other_shape"])))
+ assert buf._head == 0
+ assert buf._tail == 0
+ assert buf.available() == cap
+
+ # Add more data to cause a wrap + overflow
+ shape = (10, *buffer_params["other_shape"])
+ data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
+ with pytest.warns(RuntimeWarning):
+ buf.write(data)
+ assert buf._head == 10
+ assert buf._tail == 10 # Tail moves forward with head during overflow
+ assert buf.available() == cap
+
+ retrieved = buf.read(10)
+ assert np.all(retrieved == 0)
+
+ # Check that the oldest data was overwritten
+ reamining_buffer_data = buf.read()
+ assert reamining_buffer_data.shape == (cap - 10, *buffer_params["other_shape"])
+ # np.testing.assert_array_equal(reamining_buffer_data[-10:], data)
+ assert np.all(reamining_buffer_data[: cap - 20] == 0)
+
+
+def test_read_wrap_around(buffer_params):
+ buf = HybridBuffer(**buffer_params)
+
+ shape1 = (80, *buffer_params["other_shape"])
+ first_data = np.arange(np.prod(shape1), dtype=np.float32).reshape(shape1)
+ buf.write(first_data)
+ assert buf._head == 80
+ assert buf._tail == 0
+
+ shape2 = (40, *buffer_params["other_shape"])
+ latest_data = np.arange(np.prod(shape2), dtype=np.float32).reshape(shape2) + 1000
+ with pytest.warns(RuntimeWarning):
+ buf.write(latest_data)
+ assert buf._head == 20
+ assert buf._tail == 20 # Tail moves forward with head during overflow
+
+ retrieved = buf.read()
+ assert buf.available() == 0
+ assert retrieved.shape == (100, *buffer_params["other_shape"])
+ np.testing.assert_array_equal(retrieved[:60], first_data[20:])
+ np.testing.assert_array_equal(retrieved[60:], latest_data)
+
+
+def test_overflow_single_message(buffer_params):
+ buf = HybridBuffer(**buffer_params)
+ shape = (200, *buffer_params["other_shape"])
+ data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
+ with pytest.warns(RuntimeWarning):
+ buf.write(data)
+ assert buf.available() == 100
+ retrieved = buf.read()
+ np.testing.assert_array_equal(data[-100:], retrieved)
+
+
+def test_get_zero_samples(buffer_params):
+ buf = HybridBuffer(**buffer_params)
+ data = buf.read(0)
+ assert data.shape == (0, *buffer_params["other_shape"])
+
+ buf.write(np.ones((10, *buffer_params["other_shape"])))
+ data = buf.read(0)
+ assert data.shape == (0, *buffer_params["other_shape"])
+
+
+def test_nd_tensor():
+ params = {
+ "array_namespace": np,
+ "capacity": 50,
+ "other_shape": (3, 4),
+ "dtype": np.int16,
+ }
+ buf = HybridBuffer(**params)
+ shape = (10, *params["other_shape"])
+ data = np.arange(np.prod(shape), dtype=params["dtype"]).reshape(shape)
+ buf.write(data)
+ assert buf.available() == 10
+ retrieved = buf.read(10)
+ assert retrieved.shape == shape
+ np.testing.assert_array_equal(retrieved, data)
+
+
+def test_get_data_default_all(buffer_params):
+ buf = HybridBuffer(**{**buffer_params, "update_strategy": "on_demand"})
+ shape1 = (10, *buffer_params["other_shape"])
+ data1 = np.ones(shape1)
+ buf.write(data1)
+
+ shape2 = (15, *buffer_params["other_shape"])
+ data2 = np.ones(shape2) * 2
+ buf.write(data2)
+
+ # Should trigger sync and get all 25 samples
+ retrieved = buf.read()
+ assert retrieved.shape[0] == 25
+
+ expected = np.concatenate((data1, data2), axis=0)
+ np.testing.assert_array_equal(retrieved, expected)
+
+
+def test_interleaved_read_write(buffer_params):
+ buf = HybridBuffer(**buffer_params)
+ # Add 50
+ data1 = np.arange(50 * 2).reshape(50, 2)
+ buf.write(data1)
+ assert buf.available() == 50
+
+ # Get 20
+ read1 = buf.read(20)
+ np.testing.assert_array_equal(read1, data1[:20])
+ assert buf.available() == 30
+ assert buf._tail == 20
+
+ # Add 30
+ data2 = np.arange(30 * 2).reshape(30, 2) + 1000
+ buf.write(data2)
+ assert buf.available() == 60 # 30 remaining + 30 new
+ assert buf._head == 80 # 50 + 30
+
+ # Get 60 (all remaining)
+ read2 = buf.read(60)
+ assert buf.available() == 0
+ expected_data = np.concatenate([data1[20:], data2])
+ np.testing.assert_array_equal(read2, expected_data)
+
+
+def test_read_to_empty(buffer_params):
+ buf = HybridBuffer(**buffer_params)
+ data = np.arange(30 * 2).reshape(30, 2)
+ buf.write(data)
+ assert buf.available() == 30
+
+ _ = buf.read(30)
+ assert buf.available() == 0
+ assert buf._tail == 30
+
+ # Reading again should return empty array
+ empty_read = buf.read()
+ assert empty_read.shape[0] == 0
+
+
+def test_read_operation_wraps(buffer_params):
+ buf = HybridBuffer(**buffer_params)
+ # Add 80 samples, tail is at 0, head is at 80
+ data1 = np.arange(80 * 2).reshape(80, 2)
+ buf.write(data1)
+
+ # Read 60 samples, tail is at 60, head is at 80
+ buf.read(60)
+ assert buf.available() == 20
+ assert buf._tail == 60
+
+ # Add 40 samples. This will wrap the head around to 20.
+ data2 = np.arange(40 * 2).reshape(40, 2) + 1000
+ buf.write(data2)
+ assert buf.available() == 60 # 20 remaining + 40 new
+ assert buf._head == 20
+
+ # Read 30 samples. This will force the read to wrap.
+ # It will read 20 from data1 (60->80) and 10 from data2 (80->90)
+ read_data = buf.read(30)
+ assert read_data.shape[0] == 30
+ assert buf.available() == 30
+ assert buf._tail == 90 # 60 + 30
+
+ expected = np.concatenate([data1[60:], data2[:10]])
+ np.testing.assert_array_equal(read_data, expected)
+
+
+def test_peek_simple(buffer_params):
+ buf = HybridBuffer(**buffer_params)
+ data = np.arange(20 * 2).reshape(20, 2)
+ buf.write(data)
+
+ peeked_data = buf.peek(10)
+ np.testing.assert_array_equal(peeked_data, data[:10])
+
+ # Assert that state has not changed
+ assert buf.available() == 20
+ assert buf._tail == 0
+
+ # Get the data to prove it was still there
+ retrieved_data = buf.read(10)
+ np.testing.assert_array_equal(retrieved_data, data[:10])
+ assert buf.available() == 10
+
+
+def test_seek_simple(buffer_params):
+ buf = HybridBuffer(**buffer_params)
+ data = np.arange(20 * 2).reshape(20, 2)
+ buf.write(data)
+
+ skipped = buf.seek(10)
+ assert skipped == 10
+ assert buf.available() == 10
+ assert buf._tail == 10
+
+ retrieved_data = buf.read()
+ np.testing.assert_array_equal(retrieved_data, data[10:])
+
+
+def test_peek_and_skip(buffer_params):
+ buf = HybridBuffer(**buffer_params)
+ data = np.arange(20 * 2).reshape(20, 2)
+ buf.write(data)
+
+ peeked = buf.peek(5)
+ np.testing.assert_array_equal(peeked, data[:5])
+
+ peeked_again = buf.peek(5)
+ np.testing.assert_array_equal(peeked_again, data[:5])
+
+ skipped = buf.seek(5)
+ assert skipped == 5
+
+ retrieved = buf.read(5)
+ np.testing.assert_array_equal(retrieved, data[5:10])
+
+
+def test_tell(buffer_params):
+ buf = HybridBuffer(**buffer_params)
+
+ # 1. Initially empty. tell() should return 0.
+ assert buf.tell() == 0
+
+ # 2. Add 50 samples. tell() should return 0.
+ buf.write(np.zeros((50, 2)))
+ assert buf.tell() == 0
+
+ # 3. Read 20 samples. tell() should return 20.
+ buf.read(20)
+ assert buf.tell() == 20
+
+ # Read another 10 samples. tell() should return 30.
+ buf.read(10)
+ assert buf.tell() == 30
+
+ # Read remaining 20 samples. tell() should return 50.
+ buf.read(20)
+ assert buf.tell() == 50
+
+ # Try to read more than available. Should still return 50.
+ with pytest.raises(ValueError):
+ buf.read(1)
+ assert buf.tell() == 50
+
+ # 4. Add 80 samples -> overwrite the first 30.
+ # tell() should return 20: 50 - 30
+ final_msg = np.zeros((80, 2))
+ buf.write(final_msg)
+ assert buf.tell() == 20
+
+
+def test_peek_at(buffer_params):
+ buf = HybridBuffer(**{**buffer_params, "update_strategy": "on_demand"})
+ # Add 50 samples in 5 blocks
+ for i in range(5):
+ buf.write(np.ones((10, 2)) * i)
+
+ # Peek at a value in the buffer before flushing
+ with pytest.raises(IndexError):
+ buf.peek_at(50)
+
+ # Read some data to cause a flush
+ _ = buf.read(1)
+
+ # Test peeking at various locations
+ np.testing.assert_array_equal(buf.peek_at(0), np.ones((1, 2)) * 0)
+ np.testing.assert_array_equal(buf.peek_at(10), np.ones((1, 2)) * 1)
+ np.testing.assert_array_equal(buf.peek_at(20), np.ones((1, 2)) * 2)
+ np.testing.assert_array_equal(buf.peek_at(30), np.ones((1, 2)) * 3)
+ np.testing.assert_array_equal(buf.peek_at(48), np.ones((1, 2)) * 4)
+
+ # Test peeking out of bounds
+ with pytest.raises(IndexError):
+ buf.peek_at(49)
diff --git a/tests/unit/buffer/test_buffer_overflow.py b/tests/unit/buffer/test_buffer_overflow.py
new file mode 100644
index 00000000..8fddb8a8
--- /dev/null
+++ b/tests/unit/buffer/test_buffer_overflow.py
@@ -0,0 +1,203 @@
+import pytest
+import numpy as np
+from ezmsg.sigproc.util.buffer import HybridBuffer
+from ezmsg.sigproc.util.axisarray_buffer import HybridAxisBuffer, HybridAxisArrayBuffer
+from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, CoordinateAxis
+
+
+@pytest.fixture
+def buffer_params():
+ return {
+ "array_namespace": np,
+ "capacity": 100,
+ "other_shape": (2,),
+ "dtype": np.float32,
+ "update_strategy": "immediate",
+ "threshold": 0,
+ "max_size": 1024**3, # 1 GB
+ }
+
+
+@pytest.fixture
+def linear_axis_message():
+ def _create(samples=10, channels=2, fs=100.0, offset=0.0):
+ shape = (samples, channels)
+ dims = ["time", "ch"]
+ data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
+ gain = 1.0 / fs if fs > 0 else 0
+ axes = {
+ "time": LinearAxis(gain=gain, offset=offset),
+ "ch": CoordinateAxis(data=np.arange(channels).astype(str), dims=["ch"]),
+ }
+ return AxisArray(data, dims, axes=axes)
+
+ return _create
+
+
+@pytest.fixture
+def coordinate_axis_message():
+ def _create(samples=10, channels=2, start_time=0.0, interval=0.01):
+ shape = (samples, channels)
+ dims = ["time", "ch"]
+ data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
+ timestamps = np.arange(samples) * interval + start_time
+ axes = {
+ "time": CoordinateAxis(data=timestamps, dims=["time"]),
+ "ch": CoordinateAxis(data=np.arange(channels).astype(str), dims=["ch"]),
+ }
+ return AxisArray(data, dims, axes=axes)
+
+ return _create
+
+
+class TestHybridBufferOverflow:
+ def test_overflow_strategy_raise(self, buffer_params):
+ buf = HybridBuffer(**{**buffer_params, "overflow_strategy": "raise"})
+ buf.write(np.zeros((100, 2)))
+ with pytest.raises(OverflowError):
+ buf.write(np.zeros((1, 2)))
+
+ def test_overflow_strategy_drop(self, buffer_params):
+ buf = HybridBuffer(**{**buffer_params, "overflow_strategy": "drop"})
+ buf.write(np.ones((80, 2)))
+ buf.write(np.ones((30, 2)) * 2) # 10 samples should be dropped
+ assert buf.available() == 100
+ data = buf.read()
+ assert data.shape[0] == 100
+ np.testing.assert_array_equal(data[:80], np.ones((80, 2)))
+ np.testing.assert_array_equal(data[80:], np.ones((20, 2)) * 2)
+
+ def test_overflow_strategy_grow(self, buffer_params):
+ buf = HybridBuffer(**{**buffer_params, "overflow_strategy": "grow"})
+ assert buf.capacity == 100
+ buf.write(np.zeros((80, 2)))
+ assert buf.capacity == 100
+ buf.write(np.zeros((30, 2)))
+ assert buf.capacity > 100
+ assert buf.available() == 110
+
+ # Test that it fails when max_size is reached
+ buf = HybridBuffer(
+ **{
+ **buffer_params,
+ "overflow_strategy": "grow",
+ "capacity": 10,
+ "max_size": 20 * 2 * 4, # 20 samples * 2 channels * 4 bytes/float32
+ }
+ )
+ buf.write(np.zeros((10, 2)))
+ with pytest.raises(OverflowError):
+ buf.write(np.zeros((11, 2)))
+
+ def test_read_prevent_overwrite(self, buffer_params):
+ """
+ This test ensures that the read method can prevent an overwrite by reading
+ the data in two parts if a flush would cause an overflow.
+ """
+ # Scenario 1: Preventable overwrite
+ buf = HybridBuffer(
+ **{
+ **buffer_params,
+ "update_strategy": "on_demand",
+ "overflow_strategy": "raise",
+ }
+ )
+ # 1. Fill buffer with 80 samples
+ buf.write(np.zeros((80, 2)))
+ buf.flush()
+ assert buf.available() == 80
+ assert buf._buff_unread == 80
+
+ # 2. Add 30 samples to deque.
+ # Flushing now would cause an overflow of 10 samples (30 new > 20 free).
+ # This is a preventable overflow since 10 < capacity (100).
+ data_in_deque = np.arange(30 * 2).reshape(30, 2)
+ buf.write(data_in_deque)
+ assert buf.available() == 110
+
+ # 3. Reading 90 samples should trigger the two-part read.
+ # It should first read the 80 from the buffer, then flush and read 10 more.
+ read_data = buf.read(90)
+ assert read_data.shape[0] == 90
+ np.testing.assert_array_equal(read_data[:80], np.zeros((80, 2)))
+ np.testing.assert_array_equal(read_data[80:], data_in_deque[:10])
+ assert buf.available() == 20 # 20 samples remaining in the buffer
+
+ # Scenario 2: Unpreventable overwrite
+ # An overflow is unpreventable if (n_overflow - n_buffered) >= capacity
+ buf = HybridBuffer(
+ **{
+ **buffer_params,
+ "update_strategy": "on_demand",
+ "overflow_strategy": "raise",
+ }
+ )
+ # 1. Fill buffer with 10 samples
+ buf.write(np.zeros((10, 2)))
+ buf.flush()
+
+ # 2. Add 200 samples to deque.
+ # n_overflow = 200 - (100 - 10) = 110.
+ # (n_overflow - n_buffered) = 110 - 10 = 100.
+ # 100 >= 100 is True, so this should be unpreventable.
+ # In fact, the write process recognizes this so it raises an OverflowError
+ # even before we flush.
+ with pytest.raises(OverflowError):
+ buf.write(np.arange(200 * 2).reshape(200, 2))
+
+
+class TestHybridAxisBufferOverflow:
+ def test_hybrid_axis_buffer_overflow_raise(self):
+ buf = HybridAxisBuffer(
+ duration=0.1, overflow_strategy="raise", update_strategy="immediate"
+ )
+ axis = CoordinateAxis(data=np.linspace(0, 0.099, 100), dims=["time"])
+ buf.write(axis, n_samples=100)
+ with pytest.raises(OverflowError):
+ buf.write(axis, n_samples=1)
+
+ def test_hybrid_axis_buffer_overflow_drop(self):
+ buf = HybridAxisBuffer(
+ duration=0.1, overflow_strategy="drop", update_strategy="immediate"
+ )
+ axis = CoordinateAxis(data=np.linspace(0, 0.099, 100), dims=["time"])
+ buf.write(axis, n_samples=100)
+ axis2 = CoordinateAxis(data=np.linspace(0.1, 0.109, 10), dims=["time"])
+ buf.write(axis2, n_samples=10)
+ assert buf.available() == 100
+
+ def test_hybrid_axis_buffer_overflow_grow(self):
+ buf = HybridAxisBuffer(
+ duration=0.1, overflow_strategy="grow", update_strategy="immediate"
+ )
+ axis = CoordinateAxis(data=np.linspace(0, 0.099, 100), dims=["time"])
+ buf.write(axis, n_samples=100)
+ axis2 = CoordinateAxis(data=np.linspace(0.1, 0.109, 10), dims=["time"])
+ buf.write(axis2, n_samples=10)
+ assert buf.available() == 110
+
+
+class TestHybridAxisArrayBufferOverflow:
+ def test_hybrid_axis_array_buffer_overflow_raise(self, linear_axis_message):
+ buf = HybridAxisArrayBuffer(
+ duration=0.1, overflow_strategy="raise", update_strategy="immediate"
+ )
+ buf.write(linear_axis_message(samples=10, fs=100.0))
+ with pytest.raises(OverflowError):
+ buf.write(linear_axis_message(samples=1, fs=100.0))
+
+ def test_hybrid_axis_array_buffer_overflow_drop(self, linear_axis_message):
+ buf = HybridAxisArrayBuffer(
+ duration=0.1, overflow_strategy="drop", update_strategy="immediate"
+ )
+ buf.write(linear_axis_message(samples=8, fs=100.0))
+ buf.write(linear_axis_message(samples=4, fs=100.0))
+ assert buf.available() == 10
+
+ def test_hybrid_axis_array_buffer_overflow_grow(self, linear_axis_message):
+ buf = HybridAxisArrayBuffer(
+ duration=0.1, overflow_strategy="grow", update_strategy="immediate"
+ )
+ buf.write(linear_axis_message(samples=8, fs=100.0))
+ buf.write(linear_axis_message(samples=4, fs=100.0))
+ assert buf.available() == 12
diff --git a/tests/unit/test_fbcca.py b/tests/unit/test_fbcca.py
new file mode 100644
index 00000000..3b94b7e6
--- /dev/null
+++ b/tests/unit/test_fbcca.py
@@ -0,0 +1,766 @@
+import numpy as np
+import pytest
+from ezmsg.util.messages.axisarray import AxisArray
+
+from ezmsg.sigproc.fbcca import (
+ FBCCASettings,
+ FBCCATransformer,
+ StreamingFBCCASettings,
+ StreamingFBCCATransformer,
+ cca_rho_max,
+ calc_softmax,
+)
+from ezmsg.sigproc.sampler import SampleTriggerMessage
+
+
+def test_cca_rho_max_basic():
+ """Test the cca_rho_max function with basic inputs."""
+ # Create two correlated signals
+ n_time = 100
+ t = np.linspace(0, 1, n_time)
+
+ # X: signal with two channels
+ X = np.column_stack([np.sin(2 * np.pi * 10 * t), np.cos(2 * np.pi * 10 * t)])
+
+ # Y: reference signal at same frequency
+ Y = np.column_stack([np.sin(2 * np.pi * 10 * t), np.cos(2 * np.pi * 10 * t)])
+
+ rho = cca_rho_max(X, Y)
+
+ # Should be high correlation (close to 1)
+ assert 0 <= rho <= 1
+ assert rho > 0.95
+
+
+def test_cca_rho_max_uncorrelated():
+ """Test cca_rho_max with uncorrelated signals."""
+ n_time = 100
+ t = np.linspace(0, 1, n_time)
+
+ # X: signal at 10 Hz
+ X = np.column_stack([np.sin(2 * np.pi * 10 * t), np.cos(2 * np.pi * 10 * t)])
+
+ # Y: signal at different frequency (50 Hz)
+ Y = np.column_stack([np.sin(2 * np.pi * 50 * t), np.cos(2 * np.pi * 50 * t)])
+
+ rho = cca_rho_max(X, Y)
+
+ # Should be low correlation
+ assert 0 <= rho <= 1
+ assert rho < 0.5
+
+
+def test_cca_rho_max_zero_variance():
+ """Test cca_rho_max with zero-variance signals."""
+ n_time = 100
+
+ # X: constant signal (zero variance)
+ X = np.ones((n_time, 2))
+
+ # Y: normal signal
+ t = np.linspace(0, 1, n_time)
+ Y = np.column_stack([np.sin(2 * np.pi * 10 * t), np.cos(2 * np.pi * 10 * t)])
+
+ rho = cca_rho_max(X, Y)
+
+ # Should return 0 for zero-variance signal
+ assert rho == 0.0
+
+
+def test_cca_rho_max_empty():
+ """Test cca_rho_max with empty arrays."""
+ X = np.zeros((10, 0))
+ Y = np.zeros((10, 2))
+
+ rho = cca_rho_max(X, Y)
+
+ assert rho == 0.0
+
+
+@pytest.mark.parametrize("beta", [0.5, 1.0, 2.0, 5.0])
+def test_calc_softmax(beta):
+ """Test calc_softmax with different beta values."""
+ # Create test data - 1D array since calc_softmax is used on 1D in the code
+ data = np.array([1.0, 2.0, 3.0, 2.5, 1.5])
+
+ result = calc_softmax(data, axis=-1, beta=beta)
+
+ # Check output shape
+ assert result.shape == data.shape
+
+ # Check sum to 1
+ assert np.allclose(result.sum(), 1.0)
+
+ # Check all values in [0, 1]
+ assert np.all((result >= 0) & (result <= 1))
+
+ # Check higher beta makes distribution more peaked
+ if beta > 1.0:
+ # Higher beta should give more weight to maximum
+ max_idx = data.argmax()
+ assert result[max_idx] > 0.5
+
+
+def test_calc_softmax_multidim():
+ """Test calc_softmax with multi-dimensional data."""
+ # 2D array where softmax is applied along last axis
+ data = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+
+ # Need to apply along correct axis with keepdims
+ result = np.exp(data - data.max(axis=-1, keepdims=True))
+ result = result / result.sum(axis=-1, keepdims=True)
+
+ # Check output shape
+ assert result.shape == data.shape
+
+ # Check sum to 1 along axis
+ assert np.allclose(result.sum(axis=-1), 1.0)
+
+ # Check all values in [0, 1]
+ assert np.all((result >= 0) & (result <= 1))
+
+
+def test_fbcca_basic():
+ """Test basic FBCCA functionality."""
+ fs = 250.0
+ dur = 2.0
+ n_times = int(dur * fs)
+ n_channels = 4
+
+ # Create test signal with 10Hz component
+ t = np.arange(n_times) / fs
+ signal = np.column_stack(
+ [np.sin(2 * np.pi * 10 * t + i * np.pi / 4) for i in range(n_channels)]
+ ).T
+
+ # Create message
+ msg = AxisArray(
+ data=signal,
+ dims=["ch", "time"],
+ axes={
+ "time": AxisArray.TimeAxis(fs=fs, offset=0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_channels).astype(str), dims=["ch"]
+ ),
+ },
+ key="test_fbcca",
+ )
+
+ # Test frequencies
+ test_freqs = [8.0, 10.0, 12.0, 15.0]
+
+ settings = FBCCASettings(
+ time_dim="time",
+ ch_dim="ch",
+ freqs=test_freqs,
+ harmonics=3,
+ )
+
+ transformer = FBCCATransformer(settings=settings)
+ result = transformer(msg)
+
+ # Check output structure
+ assert "target_freq" in result.dims
+ assert result.data.shape == (len(test_freqs),)
+
+ # Check that 10Hz has highest value
+ freq_idx_10hz = test_freqs.index(10.0)
+ assert np.argmax(result.data) == freq_idx_10hz
+
+
+def test_fbcca_with_filterbank_dim():
+ """Test FBCCA with filterbank dimension."""
+ fs = 250.0
+ dur = 2.0
+ n_times = int(dur * fs)
+ n_channels = 4
+ n_subbands = 3
+
+ # Create test signal
+ t = np.arange(n_times) / fs
+ base_signal = np.column_stack(
+ [np.sin(2 * np.pi * 10 * t + i * np.pi / 4) for i in range(n_channels)]
+ )
+
+ # Replicate across subbands
+ signal = np.stack([base_signal.T for _ in range(n_subbands)], axis=0)
+
+ msg = AxisArray(
+ data=signal,
+ dims=["subband", "ch", "time"],
+ axes={
+ "time": AxisArray.TimeAxis(fs=fs, offset=0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_channels).astype(str), dims=["ch"]
+ ),
+ "subband": AxisArray.CoordinateAxis(
+ data=np.arange(n_subbands).astype(str), dims=["subband"]
+ ),
+ },
+ key="test_fbcca_filterbank",
+ )
+
+ test_freqs = [8.0, 10.0, 12.0]
+
+ settings = FBCCASettings(
+ time_dim="time",
+ ch_dim="ch",
+ filterbank_dim="subband",
+ freqs=test_freqs,
+ harmonics=3,
+ )
+
+ transformer = FBCCATransformer(settings=settings)
+ result = transformer(msg)
+
+ # Check output structure
+ assert "target_freq" in result.dims
+ assert "subband" not in result.dims # Should be collapsed
+ assert result.data.shape == (len(test_freqs),)
+
+
+def test_fbcca_with_trigger_freqs():
+ """Test FBCCA with frequencies from SampleTriggerMessage."""
+ from dataclasses import dataclass
+
+ fs = 250.0
+ dur = 2.0
+ n_times = int(dur * fs)
+ n_channels = 4
+
+ # Create test signal
+ t = np.arange(n_times) / fs
+ signal = np.column_stack(
+ [np.sin(2 * np.pi * 12 * t + i * np.pi / 4) for i in range(n_channels)]
+ ).T
+
+ # Create trigger with freqs attribute
+ @dataclass
+ class TestTrigger(SampleTriggerMessage):
+ freqs: list[float] = None
+
+ def __post_init__(self):
+ if self.freqs is None:
+ self.freqs = []
+
+ trigger = TestTrigger(period=(0, dur), freqs=[10.0, 12.0, 15.0])
+
+ msg = AxisArray(
+ data=signal,
+ dims=["ch", "time"],
+ axes={
+ "time": AxisArray.TimeAxis(fs=fs, offset=0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_channels).astype(str), dims=["ch"]
+ ),
+ },
+ attrs={"trigger": trigger},
+ key="test_fbcca_trigger",
+ )
+
+ settings = FBCCASettings(
+ time_dim="time",
+ ch_dim="ch",
+ harmonics=3,
+ )
+
+ transformer = FBCCATransformer(settings=settings)
+ result = transformer(msg)
+
+ # Check that trigger frequencies were used
+ assert result.data.shape == (3,) # 3 frequencies from trigger
+ assert "target_freq" in result.dims
+
+
+def test_fbcca_no_freqs_error():
+ """Test that FBCCA raises error when no frequencies are provided."""
+ fs = 250.0
+ n_times = 500
+ n_channels = 4
+
+ signal = np.random.randn(n_channels, n_times)
+
+ msg = AxisArray(
+ data=signal,
+ dims=["ch", "time"],
+ axes={
+ "time": AxisArray.TimeAxis(fs=fs, offset=0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_channels).astype(str), dims=["ch"]
+ ),
+ },
+ key="test_no_freqs",
+ )
+
+ settings = FBCCASettings(
+ time_dim="time",
+ ch_dim="ch",
+ # No freqs provided
+ )
+
+ transformer = FBCCATransformer(settings=settings)
+
+ with pytest.raises(ValueError, match="no frequencies to test"):
+ transformer(msg)
+
+
+@pytest.mark.parametrize("harmonics", [1, 3, 5, 10])
+def test_fbcca_harmonics(harmonics):
+ """Test FBCCA with different numbers of harmonics."""
+ fs = 250.0
+ dur = 2.0
+ n_times = int(dur * fs)
+ n_channels = 4
+
+ # Create signal with harmonics
+ t = np.arange(n_times) / fs
+ signal = np.column_stack(
+ [
+ np.sin(2 * np.pi * 10 * t) + 0.3 * np.sin(2 * np.pi * 20 * t)
+ for _ in range(n_channels)
+ ]
+ ).T
+
+ msg = AxisArray(
+ data=signal,
+ dims=["ch", "time"],
+ axes={
+ "time": AxisArray.TimeAxis(fs=fs, offset=0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_channels).astype(str), dims=["ch"]
+ ),
+ },
+ key="test_harmonics",
+ )
+
+ settings = FBCCASettings(
+ time_dim="time",
+ ch_dim="ch",
+ freqs=[10.0, 15.0],
+ harmonics=harmonics,
+ )
+
+ transformer = FBCCATransformer(settings=settings)
+ result = transformer(msg)
+
+ assert result.data.shape == (2,)
+ # More harmonics should generally improve detection
+ assert np.argmax(result.data) == 0 # 10Hz should be detected
+
+
+@pytest.mark.parametrize("softmax_beta", [0.0, 0.5, 1.0, 2.0])
+def test_fbcca_softmax_beta(softmax_beta):
+ """Test FBCCA with different softmax beta values."""
+ fs = 250.0
+ dur = 2.0
+ n_times = int(dur * fs)
+ n_channels = 4
+
+ # Create test signal
+ t = np.arange(n_times) / fs
+ signal = np.column_stack([np.sin(2 * np.pi * 10 * t) for _ in range(n_channels)]).T
+
+ msg = AxisArray(
+ data=signal,
+ dims=["ch", "time"],
+ axes={
+ "time": AxisArray.TimeAxis(fs=fs, offset=0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_channels).astype(str), dims=["ch"]
+ ),
+ },
+ key="test_softmax",
+ )
+
+ settings = FBCCASettings(
+ time_dim="time",
+ ch_dim="ch",
+ freqs=[8.0, 10.0, 12.0],
+ harmonics=3,
+ softmax_beta=softmax_beta,
+ )
+
+ transformer = FBCCATransformer(settings=settings)
+ result = transformer(msg)
+
+ assert result.data.shape == (3,)
+
+ if softmax_beta == 0.0:
+ # Beta=0 outputs raw correlations
+ assert not np.allclose(result.data.sum(), 1.0)
+ else:
+ # Beta>0 outputs softmax probabilities
+ assert np.allclose(result.data.sum(), 1.0)
+ assert np.all((result.data >= 0) & (result.data <= 1))
+
+
+def test_fbcca_max_int_time():
+ """Test FBCCA with maximum integration time limit."""
+ fs = 250.0
+ dur = 5.0
+ n_times = int(dur * fs)
+ n_channels = 4
+
+ # Create test signal
+ t = np.arange(n_times) / fs
+ signal = np.column_stack([np.sin(2 * np.pi * 10 * t) for _ in range(n_channels)]).T
+
+ msg = AxisArray(
+ data=signal,
+ dims=["ch", "time"],
+ axes={
+ "time": AxisArray.TimeAxis(fs=fs, offset=0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_channels).astype(str), dims=["ch"]
+ ),
+ },
+ key="test_max_int_time",
+ )
+
+ # Test with max_int_time set
+ settings_limited = FBCCASettings(
+ time_dim="time",
+ ch_dim="ch",
+ freqs=[10.0, 12.0],
+ harmonics=3,
+ max_int_time=2.0, # Only use first 2 seconds
+ )
+
+ transformer_limited = FBCCATransformer(settings=settings_limited)
+ result_limited = transformer_limited(msg)
+
+ # Test without max_int_time
+ settings_full = FBCCASettings(
+ time_dim="time",
+ ch_dim="ch",
+ freqs=[10.0, 12.0],
+ harmonics=3,
+ max_int_time=0.0, # Use all data
+ )
+
+ transformer_full = FBCCATransformer(settings=settings_full)
+ result_full = transformer_full(msg)
+
+ # Both should produce valid output
+ assert result_limited.data.shape == (2,)
+ assert result_full.data.shape == (2,)
+
+ # Results may differ due to different integration times
+ # but both should prefer 10Hz
+ assert np.argmax(result_limited.data) == 0
+ assert np.argmax(result_full.data) == 0
+
+
+@pytest.mark.skip(reason="Empty message handling needs fix in fbcca.py")
+def test_fbcca_empty_message():
+ """Test FBCCA with empty message.
+
+ Note: Currently the implementation has issues reshaping empty arrays.
+ """
+ fs = 250.0
+ n_channels = 4
+
+ msg = AxisArray(
+ data=np.zeros((n_channels, 0)),
+ dims=["ch", "time"],
+ axes={
+ "time": AxisArray.TimeAxis(fs=fs, offset=0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_channels).astype(str), dims=["ch"]
+ ),
+ },
+ key="test_empty",
+ )
+
+ settings = FBCCASettings(
+ time_dim="time",
+ ch_dim="ch",
+ freqs=[10.0, 12.0],
+ harmonics=3,
+ )
+
+ transformer = FBCCATransformer(settings=settings)
+ result = transformer(msg)
+
+ # Should handle empty data gracefully with correct output shape
+ assert result.data.shape == (2,) # 2 frequencies
+ assert "target_freq" in result.dims
+
+
+def test_fbcca_multidim():
+ """Test FBCCA with additional dimensions (e.g., trials)."""
+ fs = 250.0
+ dur = 2.0
+ n_times = int(dur * fs)
+ n_channels = 4
+ n_trials = 3
+
+ # Create test signal with trials dimension
+ t = np.arange(n_times) / fs
+ signal = np.stack(
+ [
+ np.column_stack([np.sin(2 * np.pi * 10 * t) for _ in range(n_channels)]).T
+ for _ in range(n_trials)
+ ],
+ axis=0,
+ )
+
+ msg = AxisArray(
+ data=signal,
+ dims=["trial", "ch", "time"],
+ axes={
+ "time": AxisArray.TimeAxis(fs=fs, offset=0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_channels).astype(str), dims=["ch"]
+ ),
+ "trial": AxisArray.CoordinateAxis(
+ data=np.arange(n_trials).astype(str), dims=["trial"]
+ ),
+ },
+ key="test_multidim",
+ )
+
+ settings = FBCCASettings(
+ time_dim="time",
+ ch_dim="ch",
+ freqs=[10.0, 12.0],
+ harmonics=3,
+ )
+
+ transformer = FBCCATransformer(settings=settings)
+ result = transformer(msg)
+
+ # Output should have trial and target_freq dims
+ assert "trial" in result.dims
+ assert "target_freq" in result.dims
+ assert result.data.shape == (n_trials, 2)
+
+
+def test_fbcca_custom_target_freq_dim():
+ """Test FBCCA with custom target frequency dimension name."""
+ fs = 250.0
+ dur = 2.0
+ n_times = int(dur * fs)
+ n_channels = 4
+
+ # Create test signal
+ t = np.arange(n_times) / fs
+ signal = np.column_stack([np.sin(2 * np.pi * 10 * t) for _ in range(n_channels)]).T
+
+ msg = AxisArray(
+ data=signal,
+ dims=["ch", "time"],
+ axes={
+ "time": AxisArray.TimeAxis(fs=fs, offset=0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_channels).astype(str), dims=["ch"]
+ ),
+ },
+ key="test_custom_dim",
+ )
+
+ settings = FBCCASettings(
+ time_dim="time",
+ ch_dim="ch",
+ freqs=[10.0, 12.0],
+ harmonics=3,
+ target_freq_dim="frequency", # Custom name
+ )
+
+ transformer = FBCCATransformer(settings=settings)
+ result = transformer(msg)
+
+ # Check custom dimension name
+ assert "frequency" in result.dims
+ assert "target_freq" not in result.dims
+
+
+def test_streaming_fbcca_basic():
+ """Test basic StreamingFBCCA functionality."""
+ fs = 250.0
+ dur = 10.0 # Need longer duration for windowing
+ n_times = int(dur * fs)
+ n_channels = 4
+
+ # Create test signal
+ t = np.arange(n_times) / fs
+ signal = np.column_stack(
+ [np.sin(2 * np.pi * 10 * t + i * np.pi / 4) for i in range(n_channels)]
+ ).T
+
+ msg = AxisArray(
+ data=signal,
+ dims=["ch", "time"],
+ axes={
+ "time": AxisArray.TimeAxis(fs=fs, offset=0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_channels).astype(str), dims=["ch"]
+ ),
+ },
+ key="test_streaming_fbcca",
+ )
+
+ settings = StreamingFBCCASettings(
+ time_dim="time",
+ ch_dim="ch",
+ freqs=[8.0, 10.0, 12.0],
+ filterbank_dim="subband",
+ window_dur=4.0,
+ window_shift=2.0,
+ harmonics=3,
+ subbands=3,
+ )
+
+ transformer = StreamingFBCCATransformer(settings=settings)
+ result = transformer(msg)
+
+ # Should have windowed output
+ assert "fbcca_window" in result.dims
+ assert "target_freq" in result.dims
+
+ # Check that multiple windows were created
+ # (exact count depends on windowing implementation with zero_pad_until="shift")
+ assert result.data.shape[0] > 1 # Multiple windows
+ assert result.data.shape[1] == 3 # 3 frequencies
+
+
+def test_streaming_fbcca_no_filterbank():
+ """Test StreamingFBCCA without filterbank (plain CCA)."""
+ fs = 250.0
+ dur = 10.0
+ n_times = int(dur * fs)
+ n_channels = 4
+
+ # Create test signal
+ t = np.arange(n_times) / fs
+ signal = np.column_stack([np.sin(2 * np.pi * 10 * t) for _ in range(n_channels)]).T
+
+ msg = AxisArray(
+ data=signal,
+ dims=["ch", "time"],
+ axes={
+ "time": AxisArray.TimeAxis(fs=fs, offset=0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_channels).astype(str), dims=["ch"]
+ ),
+ },
+ key="test_streaming_no_filterbank",
+ )
+
+ settings = StreamingFBCCASettings(
+ time_dim="time",
+ ch_dim="ch",
+ freqs=[10.0, 12.0],
+ filterbank_dim=None, # No filterbank
+ window_dur=4.0,
+ window_shift=2.0,
+ harmonics=3,
+ )
+
+ transformer = StreamingFBCCATransformer(settings=settings)
+ result = transformer(msg)
+
+ # Should have windowed output
+ assert "fbcca_window" in result.dims
+ assert "target_freq" in result.dims
+ assert "subband" not in result.dims
+
+
+def test_fbcca_axes_preserved():
+ """Test that non-processed axes are preserved in output."""
+ fs = 250.0
+ dur = 2.0
+ n_times = int(dur * fs)
+ n_channels = 4
+ n_epochs = 2
+
+ # Create test signal with epoch dimension
+ t = np.arange(n_times) / fs
+ signal = np.stack(
+ [
+ np.column_stack([np.sin(2 * np.pi * 10 * t) for _ in range(n_channels)]).T
+ for _ in range(n_epochs)
+ ],
+ axis=0,
+ )
+
+ msg = AxisArray(
+ data=signal,
+ dims=["epoch", "ch", "time"],
+ axes={
+ "time": AxisArray.TimeAxis(fs=fs, offset=0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_channels).astype(str), dims=["ch"]
+ ),
+ "epoch": AxisArray.CoordinateAxis(
+ data=np.array(["a", "b"]), dims=["epoch"]
+ ),
+ },
+ key="test_axes",
+ )
+
+ settings = FBCCASettings(
+ time_dim="time",
+ ch_dim="ch",
+ freqs=[10.0, 12.0],
+ harmonics=3,
+ )
+
+ transformer = FBCCATransformer(settings=settings)
+ result = transformer(msg)
+
+ # Epoch axis should be preserved
+ assert "epoch" in result.dims
+ assert "epoch" in result.axes
+ assert np.array_equal(result.axes["epoch"].data, np.array(["a", "b"]))
+
+
+def test_fbcca_frequency_detection():
+ """Test FBCCA correctly identifies different frequencies."""
+ fs = 250.0
+ dur = 3.0
+ n_times = int(dur * fs)
+ n_channels = 4
+
+ test_freqs = [8.0, 10.0, 12.0, 15.0]
+
+ for target_freq in test_freqs:
+ # Create signal at target frequency
+ t = np.arange(n_times) / fs
+ signal = np.column_stack(
+ [
+ np.sin(2 * np.pi * target_freq * t + i * np.pi / 4)
+ for i in range(n_channels)
+ ]
+ ).T
+
+ msg = AxisArray(
+ data=signal,
+ dims=["ch", "time"],
+ axes={
+ "time": AxisArray.TimeAxis(fs=fs, offset=0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_channels).astype(str), dims=["ch"]
+ ),
+ },
+ key=f"test_freq_{target_freq}",
+ )
+
+ settings = FBCCASettings(
+ time_dim="time",
+ ch_dim="ch",
+ freqs=test_freqs,
+ harmonics=5,
+ )
+
+ transformer = FBCCATransformer(settings=settings)
+ result = transformer(msg)
+
+ # Check that correct frequency is detected
+ detected_idx = np.argmax(result.data)
+ detected_freq = test_freqs[detected_idx]
+
+ # Should detect the target frequency
+ assert (
+ detected_freq == target_freq
+ ), f"Expected {target_freq}Hz, detected {detected_freq}Hz"
diff --git a/tests/unit/test_filterbankdesign.py b/tests/unit/test_filterbankdesign.py
new file mode 100644
index 00000000..42abf7b5
--- /dev/null
+++ b/tests/unit/test_filterbankdesign.py
@@ -0,0 +1,692 @@
+import numpy as np
+import pytest
+import scipy.signal
+from ezmsg.util.messages.axisarray import AxisArray
+
+from ezmsg.sigproc.filterbankdesign import (
+ FilterbankDesignSettings,
+ FilterbankDesignTransformer,
+)
+from ezmsg.sigproc.kaiser import KaiserFilterSettings, kaiser_design_fun
+from ezmsg.sigproc.filterbank import FilterbankMode, MinPhaseMode
+
+
+@pytest.mark.parametrize("n_filters", [1, 3, 5])
+def test_calculate_kernels_basic(n_filters):
+ """Test the _calculate_kernels method with varying numbers of filters."""
+ fs = 200.0
+
+ # Create filter settings for different frequency bands
+ filters = []
+ for i in range(n_filters):
+ cutoff = 10.0 + i * 10.0 # 10Hz, 20Hz, 30Hz, etc.
+ filters.append(
+ KaiserFilterSettings(
+ cutoff=cutoff,
+ ripple=60.0,
+ width=5.0,
+ pass_zero=True,
+ wn_hz=True,
+ )
+ )
+
+ settings = FilterbankDesignSettings(filters=filters)
+ transformer = FilterbankDesignTransformer(settings=settings)
+
+ # Calculate kernels
+ kernels = transformer._calculate_kernels(fs)
+
+ assert len(kernels) == n_filters
+ for kernel in kernels:
+ assert isinstance(kernel, np.ndarray)
+ assert len(kernel) > 0
+ # Kaiser filters should have odd number of taps
+ assert len(kernel) % 2 == 1
+
+
+@pytest.mark.parametrize(
+ "mode", [FilterbankMode.CONV, FilterbankMode.FFT, FilterbankMode.AUTO]
+)
+def test_filterbankdesign_transformer_modes(mode):
+ """Test FilterbankDesignTransformer with different processing modes."""
+ fs = 200.0
+ dur = 5.0
+ n_times = int(dur * fs)
+
+ # Create test signal
+ t = np.arange(n_times) / fs
+ # Signal with 10Hz, 40Hz, and 80Hz components
+ signal = (
+ np.sin(2 * np.pi * 10 * t)
+ + np.sin(2 * np.pi * 40 * t)
+ + np.sin(2 * np.pi * 80 * t)
+ )
+
+ # Create filterbank with 3 bandpass filters
+ filters = [
+ # 5-15 Hz (should pass 10Hz)
+ KaiserFilterSettings(
+ cutoff=[5.0, 15.0],
+ ripple=60.0,
+ width=5.0,
+ pass_zero="bandpass",
+ wn_hz=True,
+ ),
+ # 30-50 Hz (should pass 40Hz)
+ KaiserFilterSettings(
+ cutoff=[30.0, 50.0],
+ ripple=60.0,
+ width=5.0,
+ pass_zero="bandpass",
+ wn_hz=True,
+ ),
+ # 70-90 Hz (should pass 80Hz)
+ KaiserFilterSettings(
+ cutoff=[70.0, 90.0],
+ ripple=60.0,
+ width=5.0,
+ pass_zero="bandpass",
+ wn_hz=True,
+ ),
+ ]
+
+ settings = FilterbankDesignSettings(filters=filters, mode=mode, axis="time")
+ transformer = FilterbankDesignTransformer(settings=settings)
+
+ # For FFT mode, use larger chunks to avoid windowing issues
+ # For CONV and AUTO, split into smaller messages to test streaming
+ if mode == FilterbankMode.FFT:
+ # Single large message for FFT mode
+ messages = [
+ AxisArray(
+ data=signal,
+ dims=["time"],
+ axes={"time": AxisArray.TimeAxis(fs=fs, offset=0)},
+ key="test_filterbankdesign",
+ )
+ ]
+ else:
+ # Split into multiple messages for CONV/AUTO modes
+ n_splits = 4
+ messages = []
+ for split_dat in np.array_split(signal, n_splits):
+ offset = len(messages) * len(split_dat) / fs
+ messages.append(
+ AxisArray(
+ data=split_dat,
+ dims=["time"],
+ axes={"time": AxisArray.TimeAxis(fs=fs, offset=offset)},
+ key="test_filterbankdesign",
+ )
+ )
+
+ # Process messages
+ result_msgs = [transformer(msg) for msg in messages]
+ if len(result_msgs) > 1:
+ result = AxisArray.concatenate(*result_msgs, dim="time")
+ else:
+ result = result_msgs[0]
+
+ # Verify output structure
+ assert "kernel" in result.dims
+ assert "time" in result.dims
+ assert result.data.shape[0] == 3 # 3 filters
+
+ # Verify frequency selectivity using FFT
+ # FFT mode has delay/windowing that reduces output length, so check we have enough data
+ if result.data.shape[1] > 200: # Need at least 200 samples for meaningful FFT
+ transient = 100
+ fft_in = np.abs(
+ np.fft.rfft(
+ signal[transient : transient + result.data.shape[1] - transient]
+ )
+ )
+ freqs = np.fft.rfftfreq(result.data.shape[1] - transient, 1 / fs)
+
+ for filter_idx, target_freq in enumerate([10.0, 40.0, 80.0]):
+ fft_out = np.abs(np.fft.rfft(result.data[filter_idx, transient:]))
+
+ idx_target = np.argmin(np.abs(freqs - target_freq))
+
+ # Target frequency should have significant power
+ assert fft_out[idx_target] > 0.3 * fft_in[idx_target]
+
+
+@pytest.mark.parametrize("n_chans", [1, 3])
+@pytest.mark.parametrize("time_ax", [0, 1])
+def test_filterbankdesign_multidim(n_chans, time_ax):
+ """Test FilterbankDesignTransformer with multi-dimensional data."""
+ fs = 200.0
+ dur = 1.0
+ n_times = int(dur * fs)
+
+ # Create test data
+ if n_chans == 1:
+ data_shape = [n_times]
+ dims = ["time"]
+ axes = {"time": AxisArray.TimeAxis(fs=fs, offset=0)}
+ else:
+ if time_ax == 0:
+ data_shape = [n_times, n_chans]
+ dims = ["time", "ch"]
+ else:
+ data_shape = [n_chans, n_times]
+ dims = ["ch", "time"]
+ axes = {
+ "time": AxisArray.TimeAxis(fs=fs, offset=0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_chans).astype(str), dims=["ch"]
+ ),
+ }
+
+ data = np.random.randn(*data_shape)
+
+ # Create simple lowpass filter
+ filters = [
+ KaiserFilterSettings(
+ cutoff=30.0,
+ ripple=60.0,
+ width=10.0,
+ pass_zero=True,
+ wn_hz=True,
+ )
+ ]
+
+ settings = FilterbankDesignSettings(
+ filters=filters, axis="time", mode=FilterbankMode.CONV
+ )
+ transformer = FilterbankDesignTransformer(settings=settings)
+
+ msg = AxisArray(data=data, dims=dims, axes=axes, key="test_multidim")
+ result = transformer(msg)
+
+ # Verify output dimensions
+ assert "kernel" in result.dims
+ assert "time" in result.dims
+ if n_chans > 1:
+ assert "ch" in result.dims
+
+
+@pytest.mark.skip(reason="Empty messages are not supported by filterbank")
+def test_filterbankdesign_empty_message():
+ """Test FilterbankDesignTransformer with empty message.
+
+ Note: Empty messages are not supported by the filterbank transformer.
+ This test is skipped as it's not a valid use case.
+ """
+ fs = 200.0
+
+ filters = [
+ KaiserFilterSettings(
+ cutoff=30.0,
+ ripple=60.0,
+ width=10.0,
+ pass_zero=True,
+ wn_hz=True,
+ )
+ ]
+
+ settings = FilterbankDesignSettings(
+ filters=filters, axis="time", mode=FilterbankMode.CONV
+ )
+ transformer = FilterbankDesignTransformer(settings=settings)
+
+ msg = AxisArray(
+ data=np.zeros((0,)),
+ dims=["time"],
+ axes={"time": AxisArray.TimeAxis(fs=fs, offset=0)},
+ key="test_empty",
+ )
+
+ result = transformer(msg)
+ assert result.data.size == 0
+ assert "kernel" in result.dims
+
+
+def test_filterbankdesign_normalized_frequencies():
+ """Test FilterbankDesignTransformer with normalized frequencies."""
+ fs = 200.0
+ dur = 1.0
+ n_times = int(dur * fs)
+
+ # Create test signal
+ t = np.arange(n_times) / fs
+ signal = np.sin(2 * np.pi * 10 * t) + np.sin(2 * np.pi * 60 * t)
+
+ # Create lowpass filter with normalized frequency
+ # Cutoff at 0.3 (30Hz / 100Hz Nyquist)
+ filters = [
+ KaiserFilterSettings(
+ cutoff=0.3,
+ ripple=60.0,
+ width=0.1, # normalized width
+ pass_zero=True,
+ wn_hz=False, # Use normalized frequencies
+ )
+ ]
+
+ settings = FilterbankDesignSettings(
+ filters=filters, axis="time", mode=FilterbankMode.CONV
+ )
+ transformer = FilterbankDesignTransformer(settings=settings)
+
+ msg = AxisArray(
+ data=signal,
+ dims=["time"],
+ axes={"time": AxisArray.TimeAxis(fs=fs, offset=0)},
+ key="test_normalized",
+ )
+
+ result = transformer(msg)
+
+ # Verify frequency response
+ transient = 50
+ fft_in = np.abs(np.fft.rfft(signal[transient:]))
+ fft_out = np.abs(np.fft.rfft(result.data[0, transient:]))
+ freqs = np.fft.rfftfreq(len(signal[transient:]), 1 / fs)
+
+ idx_10hz = np.argmin(np.abs(freqs - 10))
+ idx_60hz = np.argmin(np.abs(freqs - 60))
+
+ # 10Hz should be mostly preserved
+ assert fft_out[idx_10hz] > 0.7 * fft_in[idx_10hz]
+ # 60Hz should be attenuated
+ assert fft_out[idx_60hz] < 0.1 * fft_in[idx_60hz]
+
+
+@pytest.mark.parametrize(
+ "min_phase", [MinPhaseMode.NONE, MinPhaseMode.HILBERT, MinPhaseMode.HOMOMORPHIC]
+)
+def test_filterbankdesign_min_phase(min_phase):
+ """Test FilterbankDesignTransformer with different minimum phase modes."""
+ fs = 200.0
+ dur = 1.0
+ n_times = int(dur * fs)
+
+ # Create test signal
+ signal = np.random.randn(n_times)
+
+ # Create filter
+ filters = [
+ KaiserFilterSettings(
+ cutoff=30.0,
+ ripple=60.0,
+ width=10.0,
+ pass_zero=True,
+ wn_hz=True,
+ )
+ ]
+
+ settings = FilterbankDesignSettings(
+ filters=filters, axis="time", min_phase=min_phase, mode=FilterbankMode.CONV
+ )
+ transformer = FilterbankDesignTransformer(settings=settings)
+
+ msg = AxisArray(
+ data=signal,
+ dims=["time"],
+ axes={"time": AxisArray.TimeAxis(fs=fs, offset=0)},
+ key="test_min_phase",
+ )
+
+ result = transformer(msg)
+
+ # Should process successfully
+ assert result.data.shape[0] == 1 # 1 filter
+ assert result.data.shape[1] == n_times
+
+
+def test_filterbankdesign_update_settings():
+ """Test update_settings functionality."""
+ fs = 200.0
+ dur = 2.0
+ n_times = int(dur * fs)
+
+ # Create test signal
+ t = np.arange(n_times) / fs
+ signal = np.sin(2 * np.pi * 10 * t) + np.sin(2 * np.pi * 60 * t)
+
+ msg = AxisArray(
+ data=signal,
+ dims=["time"],
+ axes={"time": AxisArray.TimeAxis(fs=fs, offset=0)},
+ key="test_update",
+ )
+
+ # Initial filter - lowpass at 30Hz
+ filters_low = [
+ KaiserFilterSettings(
+ cutoff=30.0,
+ ripple=60.0,
+ width=10.0,
+ pass_zero=True,
+ wn_hz=True,
+ )
+ ]
+
+ settings = FilterbankDesignSettings(
+ filters=filters_low, axis="time", mode=FilterbankMode.CONV
+ )
+ transformer = FilterbankDesignTransformer(settings=settings)
+
+ # Process first message
+ result1 = transformer(msg)
+
+ # Update to highpass at 40Hz
+ filters_high = [
+ KaiserFilterSettings(
+ cutoff=40.0,
+ ripple=60.0,
+ width=10.0,
+ pass_zero=False, # highpass
+ wn_hz=True,
+ )
+ ]
+
+ new_settings = FilterbankDesignSettings(
+ filters=filters_high, axis="time", mode=FilterbankMode.CONV
+ )
+ transformer.update_settings(new_settings=new_settings)
+
+ # Process second message with updated settings
+ result2 = transformer(msg)
+
+ # Results should be different
+ assert not np.allclose(result1.data, result2.data)
+
+ # Verify frequency characteristics changed
+ transient = 50
+ fft_in = np.abs(np.fft.rfft(signal[transient:]))
+ freqs = np.fft.rfftfreq(len(signal[transient:]), 1 / fs)
+
+ fft_out1 = np.abs(np.fft.rfft(result1.data[0, transient:]))
+ fft_out2 = np.abs(np.fft.rfft(result2.data[0, transient:]))
+
+ idx_10hz = np.argmin(np.abs(freqs - 10))
+ idx_60hz = np.argmin(np.abs(freqs - 60))
+
+ # Result1 (lowpass) should pass 10Hz, attenuate 60Hz
+ assert fft_out1[idx_10hz] > 0.7 * fft_in[idx_10hz]
+ assert fft_out1[idx_60hz] < 0.1 * fft_in[idx_60hz]
+
+ # Result2 (highpass) should attenuate 10Hz, pass 60Hz
+ assert fft_out2[idx_10hz] < 0.1 * fft_in[idx_10hz]
+ assert fft_out2[idx_60hz] > 0.7 * fft_in[idx_60hz]
+
+
+def test_filterbankdesign_different_filter_types():
+ """Test FilterbankDesignTransformer with various filter types."""
+ fs = 200.0
+ dur = 2.0
+ n_times = int(dur * fs)
+
+ # Create test signal with multiple frequency components
+ t = np.arange(n_times) / fs
+ signal = (
+ np.sin(2 * np.pi * 10 * t)
+ + np.sin(2 * np.pi * 40 * t)
+ + np.sin(2 * np.pi * 80 * t)
+ )
+
+ # Create different types of filters
+ filters = [
+ # Lowpass
+ KaiserFilterSettings(
+ cutoff=20.0,
+ ripple=60.0,
+ width=5.0,
+ pass_zero=True,
+ wn_hz=True,
+ ),
+ # Highpass
+ KaiserFilterSettings(
+ cutoff=60.0,
+ ripple=60.0,
+ width=5.0,
+ pass_zero=False,
+ wn_hz=True,
+ ),
+ # Bandpass
+ KaiserFilterSettings(
+ cutoff=[30.0, 50.0],
+ ripple=60.0,
+ width=5.0,
+ pass_zero="bandpass",
+ wn_hz=True,
+ ),
+ # Bandstop
+ KaiserFilterSettings(
+ cutoff=[35.0, 45.0],
+ ripple=60.0,
+ width=5.0,
+ pass_zero="bandstop",
+ wn_hz=True,
+ ),
+ ]
+
+ settings = FilterbankDesignSettings(
+ filters=filters, axis="time", mode=FilterbankMode.CONV
+ )
+ transformer = FilterbankDesignTransformer(settings=settings)
+
+ msg = AxisArray(
+ data=signal,
+ dims=["time"],
+ axes={"time": AxisArray.TimeAxis(fs=fs, offset=0)},
+ key="test_filter_types",
+ )
+
+ result = transformer(msg)
+
+ # Verify output structure
+ assert result.data.shape[0] == 4 # 4 filters
+ assert result.data.shape[1] == n_times
+
+ # Basic sanity check - outputs should be different from each other
+ for i in range(4):
+ for j in range(i + 1, 4):
+ assert not np.allclose(result.data[i], result.data[j])
+
+
+def test_filterbankdesign_streaming():
+ """Test FilterbankDesignTransformer with streaming data (multiple messages)."""
+ fs = 200.0
+ dur = 5.0
+ n_times = int(dur * fs)
+
+ # Create long test signal
+ t = np.arange(n_times) / fs
+ signal = np.sin(2 * np.pi * 10 * t) + 0.5 * np.sin(2 * np.pi * 50 * t)
+
+ # Create filter
+ filters = [
+ KaiserFilterSettings(
+ cutoff=30.0,
+ ripple=60.0,
+ width=10.0,
+ pass_zero=True,
+ wn_hz=True,
+ )
+ ]
+
+ settings = FilterbankDesignSettings(
+ filters=filters, axis="time", mode=FilterbankMode.CONV
+ )
+ transformer = FilterbankDesignTransformer(settings=settings)
+
+ # Split into many small messages
+ chunk_size = 50
+ messages = []
+ for i in range(0, n_times, chunk_size):
+ chunk = signal[i : i + chunk_size]
+ messages.append(
+ AxisArray(
+ data=chunk,
+ dims=["time"],
+ axes={"time": AxisArray.TimeAxis(fs=fs, offset=i / fs)},
+ key="test_streaming",
+ )
+ )
+
+ # Process all messages
+ results = [transformer(msg) for msg in messages]
+ result = AxisArray.concatenate(*results, dim="time")
+
+ # Compare to reference implementation
+ # Design the same filter manually
+ coefs = kaiser_design_fun(
+ fs=fs,
+ cutoff=30.0,
+ ripple=60.0,
+ width=10.0,
+ pass_zero=True,
+ wn_hz=True,
+ )
+ b, a = coefs
+ zi = scipy.signal.lfiltic(b, a, [])
+ expected, _ = scipy.signal.lfilter(b, a, signal, zi=zi)
+
+ # Results should match (after initial transient)
+ transient = len(b)
+ assert np.allclose(
+ result.data[0, transient:], expected[transient:], rtol=1e-5, atol=1e-8
+ )
+
+
+def test_filterbankdesign_comparison_with_filterbank():
+ """Verify FilterbankDesignTransformer produces same results as FilterbankTransformer with manually designed kernels."""
+ fs = 200.0
+ dur = 2.0
+ n_times = int(dur * fs)
+
+ # Create test signal
+ signal = np.random.randn(n_times)
+
+ # Design filters manually
+ filter_settings = [
+ KaiserFilterSettings(
+ cutoff=20.0,
+ ripple=60.0,
+ width=10.0,
+ pass_zero=True,
+ wn_hz=True,
+ ),
+ KaiserFilterSettings(
+ cutoff=50.0,
+ ripple=60.0,
+ width=10.0,
+ pass_zero=False,
+ wn_hz=True,
+ ),
+ ]
+
+ # Calculate kernels manually
+ kernels = []
+ for filt_set in filter_settings:
+ coefs = kaiser_design_fun(
+ fs=fs,
+ cutoff=filt_set.cutoff,
+ ripple=filt_set.ripple,
+ width=filt_set.width,
+ pass_zero=filt_set.pass_zero,
+ wn_hz=filt_set.wn_hz,
+ )
+ kernels.append(coefs[0])
+
+ # Use FilterbankTransformer directly
+ from ezmsg.sigproc.filterbank import FilterbankTransformer, FilterbankSettings
+
+ filterbank_settings = FilterbankSettings(
+ kernels=kernels,
+ mode=FilterbankMode.CONV,
+ axis="time",
+ )
+ filterbank_transformer = FilterbankTransformer(settings=filterbank_settings)
+
+ # Use FilterbankDesignTransformer
+ design_settings = FilterbankDesignSettings(
+ filters=filter_settings,
+ mode=FilterbankMode.CONV,
+ axis="time",
+ )
+ design_transformer = FilterbankDesignTransformer(settings=design_settings)
+
+ # Create message
+ msg = AxisArray(
+ data=signal,
+ dims=["time"],
+ axes={"time": AxisArray.TimeAxis(fs=fs, offset=0)},
+ key="test_comparison",
+ )
+
+ # Process with both transformers
+ result_filterbank = filterbank_transformer(msg)
+ result_design = design_transformer(msg)
+
+ # Results should be identical
+ assert np.allclose(
+ result_filterbank.data, result_design.data, rtol=1e-10, atol=1e-12
+ )
+
+
+def test_filterbankdesign_ripple_width_variations():
+ """Test FilterbankDesignTransformer with varying ripple and width parameters."""
+ fs = 200.0
+ dur = 1.0
+ n_times = int(dur * fs)
+
+ signal = np.random.randn(n_times)
+
+ # Test different ripple values (affects filter steepness)
+ for ripple in [40.0, 60.0, 80.0]:
+ filters = [
+ KaiserFilterSettings(
+ cutoff=30.0,
+ ripple=ripple,
+ width=10.0,
+ pass_zero=True,
+ wn_hz=True,
+ )
+ ]
+
+ settings = FilterbankDesignSettings(
+ filters=filters, axis="time", mode=FilterbankMode.CONV
+ )
+ transformer = FilterbankDesignTransformer(settings=settings)
+
+ msg = AxisArray(
+ data=signal,
+ dims=["time"],
+ axes={"time": AxisArray.TimeAxis(fs=fs, offset=0)},
+ key=f"test_ripple_{ripple}",
+ )
+
+ result = transformer(msg)
+ assert result.data.shape == (1, n_times)
+
+ # Test different width values (affects transition width)
+ for width in [5.0, 10.0, 20.0]:
+ filters = [
+ KaiserFilterSettings(
+ cutoff=30.0,
+ ripple=60.0,
+ width=width,
+ pass_zero=True,
+ wn_hz=True,
+ )
+ ]
+
+ settings = FilterbankDesignSettings(
+ filters=filters, axis="time", mode=FilterbankMode.CONV
+ )
+ transformer = FilterbankDesignTransformer(settings=settings)
+
+ msg = AxisArray(
+ data=signal,
+ dims=["time"],
+ axes={"time": AxisArray.TimeAxis(fs=fs, offset=0)},
+ key=f"test_width_{width}",
+ )
+
+ result = transformer(msg)
+ assert result.data.shape == (1, n_times)
diff --git a/tests/unit/test_firfilter.py b/tests/unit/test_firfilter.py
new file mode 100644
index 00000000..2038bd9c
--- /dev/null
+++ b/tests/unit/test_firfilter.py
@@ -0,0 +1,287 @@
+import numpy as np
+import pytest
+import scipy.signal
+from frozendict import frozendict
+from ezmsg.util.messages.axisarray import AxisArray
+
+from ezmsg.sigproc.firfilter import (
+ FIRFilterSettings,
+ FIRFilterTransformer,
+ firwin_design_fun,
+)
+
+
+@pytest.mark.parametrize(
+ "cutoff, pass_zero",
+ [
+ (30.0, True), # lowpass
+ (30.0, False), # highpass
+ ([30.0, 45.0], True), # bandpass
+ ([30.0, 45.0], False), # bandstop
+ ([30.0, 45.0], "bandpass"), # explicit bandpass
+ ([30.0, 45.0], "bandstop"), # explicit bandstop
+ ],
+)
+@pytest.mark.parametrize("order", [11, 21, 51]) # Odd numbers for FIR
+@pytest.mark.parametrize("window", ["hamming", "hann", "blackman"])
+def test_firwin_design_fun(cutoff, pass_zero, order, window):
+ """Test the FIR filter design function with various parameters."""
+ fs = 200.0
+ result = firwin_design_fun(
+ fs=fs,
+ order=order,
+ cutoff=cutoff,
+ window=window,
+ pass_zero=pass_zero,
+ wn_hz=True,
+ )
+
+ assert result is not None
+ b, a = result
+ assert len(b) == order
+ assert len(a) == 1
+ assert a[0] == 1.0
+ # FIR filters are always stable (no feedback, only zeros)
+
+
+def test_firwin_design_fun_zero_order():
+ """Test that zero order returns None (no filter)."""
+ result = firwin_design_fun(
+ fs=200.0,
+ order=0,
+ cutoff=30.0,
+ window="hamming",
+ pass_zero=True,
+ wn_hz=True,
+ )
+ assert result is None
+
+
+@pytest.mark.parametrize(
+ "cutoff, pass_zero",
+ [
+ (30.0, True), # lowpass
+ (30.0, False), # highpass
+ ([30.0, 45.0], "bandpass"), # bandpass
+ ([30.0, 45.0], "bandstop"), # bandstop
+ ],
+)
+@pytest.mark.parametrize("order", [0, 11, 21]) # Include 0 for passthrough
+@pytest.mark.parametrize("fs", [200.0])
+@pytest.mark.parametrize("n_chans", [3])
+@pytest.mark.parametrize("n_dims, time_ax", [(1, 0), (3, 0), (3, 1), (3, 2)])
+@pytest.mark.parametrize("coef_type", ["ba", "sos"])
+def test_firfilter_transformer(
+ cutoff,
+ pass_zero,
+ order,
+ fs,
+ n_chans,
+ n_dims,
+ time_ax,
+ coef_type,
+):
+ """Test FIR filter transformer with various configurations."""
+ dur = 2.0
+ n_freqs = 5
+ n_splits = 4
+
+ n_times = int(dur * fs)
+ if n_dims == 1:
+ dat_shape = [n_times]
+ dat_dims = ["time"]
+ other_axes = {}
+ else:
+ dat_shape = [n_freqs, n_chans]
+ dat_shape.insert(time_ax, n_times)
+ dat_dims = ["freq", "ch"]
+ dat_dims.insert(time_ax, "time")
+ other_axes = {
+ "freq": AxisArray.LinearAxis(unit="Hz", offset=0.0, gain=1.0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_chans).astype(str), dims=["ch"]
+ ),
+ }
+ in_dat = np.arange(np.prod(dat_shape), dtype=float).reshape(*dat_shape)
+
+ # Calculate Expected Result
+ if order > 0:
+ coefs = firwin_design_fun(
+ fs=fs,
+ order=order,
+ cutoff=cutoff,
+ window="hamming",
+ pass_zero=pass_zero,
+ scale=True,
+ wn_hz=True,
+ )
+ if coef_type == "sos":
+ # Convert ba to sos for comparison
+ b, a = coefs
+ coefs = scipy.signal.tf2sos(b, a)
+ tmp_dat = np.moveaxis(in_dat, time_ax, -1)
+
+ if coef_type == "ba":
+ b, a = coefs
+ # FIR filters use zero initial conditions (lfiltic with empty arrays)
+ zi = scipy.signal.lfiltic(b, a, [])
+ if n_dims == 3:
+ zi = np.tile(zi[None, None, :], (n_freqs, n_chans, 1))
+ out_dat, _ = scipy.signal.lfilter(b, a, tmp_dat, zi=zi)
+ elif coef_type == "sos":
+ # SOS representation uses sosfilt_zi for initial conditions
+ zi = scipy.signal.sosfilt_zi(coefs)
+ if n_dims == 3:
+ zi = np.tile(zi[:, None, None, :], (1, n_freqs, n_chans, 1))
+ out_dat, _ = scipy.signal.sosfilt(coefs, tmp_dat, zi=zi)
+ expected = np.moveaxis(out_dat, -1, time_ax)
+ else:
+ # Zero order = passthrough
+ expected = in_dat
+
+ # Split the data into multiple messages
+ n_seen = 0
+ messages = []
+ for split_dat in np.array_split(in_dat, n_splits, axis=time_ax):
+ _time_axis = AxisArray.TimeAxis(fs=fs, offset=n_seen / fs)
+ messages.append(
+ AxisArray(
+ split_dat,
+ dims=dat_dims,
+ axes=frozendict({**other_axes, "time": _time_axis}),
+ key="test_firfilter",
+ )
+ )
+ n_seen += split_dat.shape[time_ax]
+
+ # Create transformer
+ axis_name = "time" if time_ax != 0 else None
+ settings = FIRFilterSettings(
+ axis=axis_name,
+ order=order,
+ cutoff=cutoff,
+ window="hamming",
+ pass_zero=pass_zero,
+ scale=True,
+ wn_hz=True,
+ coef_type=coef_type,
+ )
+ transformer = FIRFilterTransformer(settings=settings)
+
+ # Process messages
+ result = []
+ for msg in messages:
+ out_msg = transformer(msg)
+ result.append(out_msg.data)
+ result = np.concatenate(result, axis=time_ax)
+
+ assert np.allclose(result, expected, rtol=1e-5, atol=1e-8)
+
+
+def test_firfilter_empty_msg():
+ """Test FIR filter with empty message."""
+ settings = FIRFilterSettings(
+ axis="time",
+ order=21,
+ cutoff=30.0,
+ window="hamming",
+ pass_zero=True,
+ coef_type="ba",
+ )
+ transformer = FIRFilterTransformer(settings=settings)
+
+ msg_in = AxisArray(
+ data=np.zeros((0, 2)),
+ dims=["time", "ch"],
+ axes={
+ "time": AxisArray.TimeAxis(fs=100.0, offset=0),
+ "ch": AxisArray.CoordinateAxis(data=np.arange(2).astype(str), dims=["ch"]),
+ },
+ key="test_firfilter_empty",
+ )
+
+ result = transformer(msg_in)
+ assert result.data.size == 0
+
+
+def test_firfilter_normalized_frequencies():
+ """Test FIR filter with normalized frequencies (wn_hz=False)."""
+ fs = 200.0
+ dur = 1.0
+ n_times = int(dur * fs)
+
+ # Create input signal
+ t = np.arange(n_times) / fs
+ # Mix of 10Hz and 60Hz sine waves
+ in_dat = np.sin(2 * np.pi * 10 * t) + np.sin(2 * np.pi * 60 * t)
+
+ msg = AxisArray(
+ data=in_dat,
+ dims=["time"],
+ axes={"time": AxisArray.TimeAxis(fs=fs, offset=0)},
+ key="test_normalized",
+ )
+
+ # Design lowpass filter at 0.3 (normalized, = 30Hz)
+ settings = FIRFilterSettings(
+ axis="time",
+ order=51,
+ cutoff=0.3, # Normalized cutoff (30Hz / Nyquist(100Hz))
+ window="hamming",
+ pass_zero=True,
+ wn_hz=False, # Use normalized frequencies
+ coef_type="ba",
+ )
+ transformer = FIRFilterTransformer(settings=settings)
+
+ result = transformer(msg)
+
+ # Verify output shape matches input
+ assert result.data.shape == in_dat.shape
+
+ # Check that 10Hz component is preserved and 60Hz is attenuated
+ fft_in = np.abs(np.fft.rfft(in_dat))
+ fft_out = np.abs(np.fft.rfft(result.data))
+ freqs = np.fft.rfftfreq(n_times, 1 / fs)
+
+ idx_10hz = np.argmin(np.abs(freqs - 10))
+ idx_60hz = np.argmin(np.abs(freqs - 60))
+
+ # 10Hz should be mostly preserved (> 80% of input)
+ assert fft_out[idx_10hz] > 0.8 * fft_in[idx_10hz]
+ # 60Hz should be significantly attenuated (< 20% of input)
+ assert fft_out[idx_60hz] < 0.2 * fft_in[idx_60hz]
+
+
+def test_firfilter_kaiser_width():
+ """Test FIR filter using Kaiser window specified via width parameter."""
+ fs = 200.0
+ dur = 1.0
+ n_times = int(dur * fs)
+
+ # Create test signal
+ t = np.arange(n_times) / fs
+ in_dat = np.sin(2 * np.pi * 10 * t) + np.sin(2 * np.pi * 50 * t)
+
+ msg = AxisArray(
+ data=in_dat,
+ dims=["time"],
+ axes={"time": AxisArray.TimeAxis(fs=fs, offset=0)},
+ key="test_kaiser_width",
+ )
+
+ # When width is specified, window parameter is ignored and Kaiser is used
+ settings = FIRFilterSettings(
+ axis="time",
+ order=51,
+ cutoff=30.0,
+ width=10.0, # Transition width
+ window="hamming", # Will be ignored
+ pass_zero=True,
+ wn_hz=True,
+ coef_type="ba",
+ )
+ transformer = FIRFilterTransformer(settings=settings)
+
+ result = transformer(msg)
+ assert result.data.shape == in_dat.shape
diff --git a/tests/unit/test_kaiser.py b/tests/unit/test_kaiser.py
new file mode 100644
index 00000000..1d25dbf6
--- /dev/null
+++ b/tests/unit/test_kaiser.py
@@ -0,0 +1,426 @@
+import numpy as np
+import pytest
+import scipy.signal
+from frozendict import frozendict
+from ezmsg.util.messages.axisarray import AxisArray
+
+from ezmsg.sigproc.kaiser import (
+ KaiserFilterSettings,
+ KaiserFilterTransformer,
+ kaiser_design_fun,
+)
+
+
+@pytest.mark.parametrize(
+ "cutoff, pass_zero",
+ [
+ (30.0, True), # lowpass
+ (30.0, False), # highpass
+ ([30.0, 45.0], "bandpass"), # bandpass
+ ([30.0, 45.0], "bandstop"), # bandstop
+ ],
+)
+@pytest.mark.parametrize("ripple", [40.0, 60.0, 80.0]) # dB attenuation
+@pytest.mark.parametrize("width", [5.0, 10.0]) # Transition width in Hz
+def test_kaiser_design_fun(cutoff, pass_zero, ripple, width):
+ """Test the Kaiser filter design function with various parameters."""
+ fs = 200.0
+ result = kaiser_design_fun(
+ fs=fs,
+ cutoff=cutoff,
+ ripple=ripple,
+ width=width,
+ pass_zero=pass_zero,
+ wn_hz=True,
+ )
+
+ assert result is not None
+ b, a = result
+ # Kaiser filter should have odd number of taps
+ assert len(b) % 2 == 1
+ assert len(a) == 1
+ assert a[0] == 1.0
+ # Higher ripple (more attenuation) should require more taps
+ # (This is a general trend but not strict)
+ assert len(b) > 5
+
+
+def test_kaiser_design_fun_missing_params():
+ """Test that missing parameters return None."""
+ fs = 200.0
+
+ # Missing ripple
+ result = kaiser_design_fun(
+ fs=fs,
+ cutoff=30.0,
+ ripple=None,
+ width=10.0,
+ pass_zero=True,
+ wn_hz=True,
+ )
+ assert result is None
+
+ # Missing width
+ result = kaiser_design_fun(
+ fs=fs,
+ cutoff=30.0,
+ ripple=60.0,
+ width=None,
+ pass_zero=True,
+ wn_hz=True,
+ )
+ assert result is None
+
+ # Missing cutoff
+ result = kaiser_design_fun(
+ fs=fs,
+ cutoff=None,
+ ripple=60.0,
+ width=10.0,
+ pass_zero=True,
+ wn_hz=True,
+ )
+ assert result is None
+
+
+@pytest.mark.parametrize(
+ "cutoff, pass_zero",
+ [
+ (30.0, True), # lowpass
+ (30.0, False), # highpass
+ ([30.0, 45.0], "bandpass"), # bandpass
+ ([30.0, 45.0], "bandstop"), # bandstop
+ ],
+)
+@pytest.mark.parametrize("ripple", [60.0])
+@pytest.mark.parametrize("width", [10.0])
+@pytest.mark.parametrize("fs", [200.0])
+@pytest.mark.parametrize("n_chans", [3])
+@pytest.mark.parametrize("n_dims, time_ax", [(1, 0), (3, 0), (3, 1), (3, 2)])
+@pytest.mark.parametrize("coef_type", ["ba", "sos"])
+def test_kaiser_filter_transformer(
+ cutoff,
+ pass_zero,
+ ripple,
+ width,
+ fs,
+ n_chans,
+ n_dims,
+ time_ax,
+ coef_type,
+):
+ """Test Kaiser filter transformer with various configurations."""
+ dur = 2.0
+ n_freqs = 5
+ n_splits = 4
+
+ n_times = int(dur * fs)
+ if n_dims == 1:
+ dat_shape = [n_times]
+ dat_dims = ["time"]
+ other_axes = {}
+ else:
+ dat_shape = [n_freqs, n_chans]
+ dat_shape.insert(time_ax, n_times)
+ dat_dims = ["freq", "ch"]
+ dat_dims.insert(time_ax, "time")
+ other_axes = {
+ "freq": AxisArray.LinearAxis(unit="Hz", offset=0.0, gain=1.0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(n_chans).astype(str), dims=["ch"]
+ ),
+ }
+ in_dat = np.arange(np.prod(dat_shape), dtype=float).reshape(*dat_shape)
+
+ # Calculate Expected Result using scipy directly
+ coefs = kaiser_design_fun(
+ fs=fs,
+ cutoff=cutoff,
+ ripple=ripple,
+ width=width,
+ pass_zero=pass_zero,
+ wn_hz=True,
+ )
+
+ if coef_type == "sos":
+ # Convert ba to sos for comparison
+ b, a = coefs
+ coefs = scipy.signal.tf2sos(b, a)
+
+ tmp_dat = np.moveaxis(in_dat, time_ax, -1)
+
+ if coef_type == "ba":
+ b, a = coefs
+ # FIR filters use zero initial conditions (lfiltic with empty arrays)
+ zi = scipy.signal.lfiltic(b, a, [])
+ if n_dims == 3:
+ zi = np.tile(zi[None, None, :], (n_freqs, n_chans, 1))
+ out_dat, _ = scipy.signal.lfilter(b, a, tmp_dat, zi=zi)
+ elif coef_type == "sos":
+ # SOS representation uses sosfilt_zi for initial conditions
+ zi = scipy.signal.sosfilt_zi(coefs)
+ if n_dims == 3:
+ zi = np.tile(zi[:, None, None, :], (1, n_freqs, n_chans, 1))
+ out_dat, _ = scipy.signal.sosfilt(coefs, tmp_dat, zi=zi)
+
+ expected = np.moveaxis(out_dat, -1, time_ax)
+
+ # Split the data into multiple messages
+ n_seen = 0
+ messages = []
+ for split_dat in np.array_split(in_dat, n_splits, axis=time_ax):
+ _time_axis = AxisArray.TimeAxis(fs=fs, offset=n_seen / fs)
+ messages.append(
+ AxisArray(
+ split_dat,
+ dims=dat_dims,
+ axes=frozendict({**other_axes, "time": _time_axis}),
+ key="test_kaiser",
+ )
+ )
+ n_seen += split_dat.shape[time_ax]
+
+ # Create transformer
+ axis_name = "time" if time_ax != 0 else None
+ settings = KaiserFilterSettings(
+ axis=axis_name,
+ cutoff=cutoff,
+ ripple=ripple,
+ width=width,
+ pass_zero=pass_zero,
+ wn_hz=True,
+ coef_type=coef_type,
+ )
+ transformer = KaiserFilterTransformer(settings=settings)
+
+ # Process messages
+ result = []
+ for msg in messages:
+ out_msg = transformer(msg)
+ result.append(out_msg.data)
+ result = np.concatenate(result, axis=time_ax)
+
+ assert np.allclose(result, expected, rtol=1e-5, atol=1e-8)
+
+
+def test_kaiser_filter_empty_msg():
+ """Test Kaiser filter with empty message."""
+ settings = KaiserFilterSettings(
+ axis="time",
+ cutoff=30.0,
+ ripple=60.0,
+ width=10.0,
+ pass_zero=True,
+ coef_type="ba",
+ )
+ transformer = KaiserFilterTransformer(settings=settings)
+
+ msg_in = AxisArray(
+ data=np.zeros((0, 2)),
+ dims=["time", "ch"],
+ axes={
+ "time": AxisArray.TimeAxis(fs=100.0, offset=0),
+ "ch": AxisArray.CoordinateAxis(data=np.arange(2).astype(str), dims=["ch"]),
+ },
+ key="test_kaiser_empty",
+ )
+
+ result = transformer(msg_in)
+ assert result.data.size == 0
+
+
+def test_kaiser_filter_frequency_response():
+ """Test Kaiser filter frequency response characteristics."""
+ fs = 200.0
+ dur = 2.0
+ n_times = int(dur * fs)
+
+ # Create test signal with multiple frequency components
+ t = np.arange(n_times) / fs
+ # 10Hz (should pass), 40Hz (transition), 60Hz (should stop)
+ in_dat = (
+ np.sin(2 * np.pi * 10 * t)
+ + np.sin(2 * np.pi * 40 * t)
+ + np.sin(2 * np.pi * 60 * t)
+ )
+
+ msg = AxisArray(
+ data=in_dat,
+ dims=["time"],
+ axes={"time": AxisArray.TimeAxis(fs=fs, offset=0)},
+ key="test_kaiser_freq_response",
+ )
+
+ # Design lowpass filter at 30Hz with 10Hz transition width
+ # Should pass 10Hz, attenuate 60Hz
+ settings = KaiserFilterSettings(
+ axis="time",
+ cutoff=30.0,
+ ripple=60.0, # 60dB attenuation in stopband
+ width=10.0, # 10Hz transition width (30Hz to 40Hz)
+ pass_zero=True,
+ wn_hz=True,
+ coef_type="ba",
+ )
+ transformer = KaiserFilterTransformer(settings=settings)
+
+ result = transformer(msg)
+
+ # Analyze frequency content
+ fft_in = np.abs(np.fft.rfft(in_dat))
+ fft_out = np.abs(np.fft.rfft(result.data))
+ freqs = np.fft.rfftfreq(n_times, 1 / fs)
+
+ idx_10hz = np.argmin(np.abs(freqs - 10))
+ idx_60hz = np.argmin(np.abs(freqs - 60))
+
+ # 10Hz should be mostly preserved (passband)
+ assert fft_out[idx_10hz] > 0.85 * fft_in[idx_10hz]
+ # 60Hz should be highly attenuated (stopband)
+ # With 60dB ripple specification, attenuation should be significant
+ assert fft_out[idx_60hz] < 0.01 * fft_in[idx_60hz]
+
+
+def test_kaiser_filter_normalized_frequencies():
+ """Test Kaiser filter with normalized frequencies (wn_hz=False)."""
+ fs = 200.0
+ dur = 1.0
+ n_times = int(dur * fs)
+
+ # Create input signal
+ t = np.arange(n_times) / fs
+ in_dat = np.sin(2 * np.pi * 10 * t) + np.sin(2 * np.pi * 60 * t)
+
+ msg = AxisArray(
+ data=in_dat,
+ dims=["time"],
+ axes={"time": AxisArray.TimeAxis(fs=fs, offset=0)},
+ key="test_normalized",
+ )
+
+ # Design filter with normalized frequencies
+ # Cutoff at 0.3 (30Hz / 100Hz Nyquist)
+ # Width of 0.1 (10Hz / 100Hz Nyquist)
+ settings = KaiserFilterSettings(
+ axis="time",
+ cutoff=0.3,
+ ripple=60.0,
+ width=0.1,
+ pass_zero=True,
+ wn_hz=False, # Use normalized frequencies
+ coef_type="ba",
+ )
+ transformer = KaiserFilterTransformer(settings=settings)
+
+ result = transformer(msg)
+
+ # Verify frequency response
+ fft_in = np.abs(np.fft.rfft(in_dat))
+ fft_out = np.abs(np.fft.rfft(result.data))
+ freqs = np.fft.rfftfreq(n_times, 1 / fs)
+
+ idx_10hz = np.argmin(np.abs(freqs - 10))
+ idx_60hz = np.argmin(np.abs(freqs - 60))
+
+ # 10Hz should be mostly preserved (passband)
+ assert fft_out[idx_10hz] > 0.80 * fft_in[idx_10hz]
+ # 60Hz should be highly attenuated (stopband)
+ assert fft_out[idx_60hz] < 0.01 * fft_in[idx_60hz]
+
+
+def test_kaiser_filter_vs_fir_with_kaiser_window():
+ """Verify Kaiser filter produces similar results to FIR with Kaiser window."""
+ from ezmsg.sigproc.firfilter import FIRFilterTransformer, FIRFilterSettings
+
+ fs = 200.0
+ dur = 1.0
+ n_times = int(dur * fs)
+
+ # Create test signal
+ t = np.arange(n_times) / fs
+ in_dat = np.sin(2 * np.pi * 10 * t) + np.sin(2 * np.pi * 60 * t)
+
+ msg = AxisArray(
+ data=in_dat,
+ dims=["time"],
+ axes={"time": AxisArray.TimeAxis(fs=fs, offset=0)},
+ key="test_comparison",
+ )
+
+ # Kaiser filter
+ kaiser_settings = KaiserFilterSettings(
+ axis="time",
+ cutoff=30.0,
+ ripple=60.0,
+ width=10.0,
+ pass_zero=True,
+ wn_hz=True,
+ coef_type="ba",
+ )
+ kaiser_transformer = KaiserFilterTransformer(settings=kaiser_settings)
+
+ # Get the designed filter parameters
+ coefs = kaiser_design_fun(
+ fs=fs,
+ cutoff=30.0,
+ ripple=60.0,
+ width=10.0,
+ pass_zero=True,
+ wn_hz=True,
+ )
+ b_kaiser, a_kaiser = coefs
+ n_taps = len(b_kaiser)
+
+ # Calculate beta from ripple
+ ripple_db = 60.0
+ beta = scipy.signal.kaiser_beta(ripple_db)
+
+ # FIR filter with equivalent Kaiser window
+ fir_settings = FIRFilterSettings(
+ axis="time",
+ order=n_taps,
+ cutoff=30.0,
+ window=("kaiser", beta),
+ pass_zero=True,
+ scale=False, # Kaiser filter uses scale=False
+ wn_hz=True,
+ coef_type="ba",
+ )
+ fir_transformer = FIRFilterTransformer(settings=fir_settings)
+
+ result_kaiser = kaiser_transformer(msg)
+ result_fir = fir_transformer(msg)
+
+ # Results should be very similar
+ assert np.allclose(result_kaiser.data, result_fir.data, rtol=1e-3, atol=1e-5)
+
+
+def test_kaiser_higher_ripple_more_taps():
+ """Verify that higher ripple requirements result in more filter taps."""
+ fs = 200.0
+
+ # Lower ripple (less stringent) should use fewer taps
+ coefs_low = kaiser_design_fun(
+ fs=fs,
+ cutoff=30.0,
+ ripple=40.0,
+ width=10.0,
+ pass_zero=True,
+ wn_hz=True,
+ )
+
+ # Higher ripple (more stringent) should use more taps
+ coefs_high = kaiser_design_fun(
+ fs=fs,
+ cutoff=30.0,
+ ripple=80.0,
+ width=10.0,
+ pass_zero=True,
+ wn_hz=True,
+ )
+
+ b_low, _ = coefs_low
+ b_high, _ = coefs_high
+
+ # Higher ripple should require more taps
+ assert len(b_high) > len(b_low)
diff --git a/tests/unit/test_resample.py b/tests/unit/test_resample.py
index a7897da5..cb6780da 100644
--- a/tests/unit/test_resample.py
+++ b/tests/unit/test_resample.py
@@ -11,18 +11,23 @@
@pytest.fixture
def irregular_messages() -> list[AxisArray]:
+ """
+ 10.2 seconds of 128 Hz (jittery intervals) data split unevenly
+ into 10 messages + a duplicate after the first message.
+ """
nch = 3
avg_fs = 128.0
dur = 10.2
ntimes = int(avg_fs * dur)
tvec = np.arange(ntimes) / avg_fs
+ np.random.seed(42) # For reproducibility
tvec += np.random.normal(0, 0.2 / avg_fs, ntimes)
tvec = np.sort(tvec)
n_msgs = 10
splits = np.sort(np.random.choice(np.arange(ntimes), n_msgs - 1))
- # prepend splits with a `0` and `splits[0]. The latter is to test what happens with an empty message.
- # Append with `ntimes` to ensure all samples are sent.
- splits = np.hstack(([0, splits[0]], splits, [ntimes]))
+ splits = np.hstack(([0], splits, [ntimes])) # Ensure we have the borders.
+ # Ensure we have a duplicate after the first message
+ splits = np.insert(splits, 1, [splits[1]])
msgs = []
ch_ax = AxisArray.CoordinateAxis(
data=np.arange(nch).astype(str), dims=["ch"], unit="label"
@@ -48,6 +53,7 @@ def irregular_messages() -> list[AxisArray]:
@pytest.fixture
def reference_messages() -> list[AxisArray]:
+ """10 seconds of data at 500 Hz, split evenly into 10 messages."""
nch = 1
fs = 500.0
dur = 10.0
@@ -81,7 +87,7 @@ async def test_resample(
)
expected_data = f(newx)
- resample = ResampleProcessor(resample_rate=resample_rate)
+ resample = ResampleProcessor(resample_rate=resample_rate, buffer_duration=4.0)
results = []
n_returned = 0
for msg_ix, msg in enumerate(irregular_messages):
@@ -90,9 +96,10 @@ async def test_resample(
resample(msg)
result = next(resample)
msg_len = result.data.shape[0]
- assert np.allclose(
+ b_match = np.allclose(
result.data, expected_data[n_returned : n_returned + msg_len]
)
+ assert b_match, f"Message {msg_ix} data mismatch."
results.append(result)
n_returned += result.data.shape[0]
diff --git a/tests/unit/test_sampler.py b/tests/unit/test_sampler.py
index 680a076b..c3f7ed54 100644
--- a/tests/unit/test_sampler.py
+++ b/tests/unit/test_sampler.py
@@ -5,14 +5,12 @@
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.sigproc.util.message import SampleTriggerMessage
-from ezmsg.sigproc.sampler import (
- sampler,
-)
+from ezmsg.sigproc.sampler import SamplerTransformer, SamplerSettings
from tests.helpers.util import assert_messages_equal
-def test_sampler_gen():
+def test_sampler():
data_dur = 10.0
chunk_period = 0.1
fs = 500.0
@@ -65,22 +63,24 @@ def test_sampler_gen():
# Create the sample-generator
period_dur = period[1] - period[0]
buffer_dur = 2 * max(period_dur, period[1])
- gen = sampler(
- buffer_dur, axis="time", period=None, value=None, estimate_alignment=True
+ proc = SamplerTransformer(
+ settings=SamplerSettings(
+ buffer_dur, axis="time", period=None, value=None, estimate_alignment=True
+ )
)
# Run the messages through the generator and collect samples.
samples = []
- for msg in mix_msgs:
- samples.extend(gen.send(msg))
+ for msg_ix, msg in enumerate(mix_msgs):
+ samples.extend(proc(msg))
assert_messages_equal(signal_msgs, backup_signal)
assert_messages_equal(trigger_msgs, backup_trigger)
assert len(samples) == n_trigs
- # Check sample data size
+ # Check sample data size. Note: sampler puts the time axis first.
assert all(
- [_.sample.data.shape == (n_chans, int(fs * period_dur)) for _ in samples]
+ [_.sample.data.shape == (int(fs * period_dur), n_chans) for _ in samples]
)
# Compare the sample window slice against the trigger timestamps
latencies = [
diff --git a/tests/unit/test_window.py b/tests/unit/test_window.py
index 9d943e8c..bdaa231e 100644
--- a/tests/unit/test_window.py
+++ b/tests/unit/test_window.py
@@ -6,7 +6,7 @@
from frozendict import frozendict
import sparse
from ezmsg.util.messages.axisarray import AxisArray
-from ezmsg.sigproc.window import windowing
+from ezmsg.sigproc.window import WindowTransformer
from tests.helpers.util import assert_messages_equal, calculate_expected_windows
@@ -32,19 +32,19 @@ def test_window_gen_nodur():
key="test_window_gen_nodur",
)
backup = [copy.deepcopy(test_msg)]
- proc = windowing(window_dur=None)
+ proc = WindowTransformer(window_dur=None)
result = proc(test_msg)
assert_messages_equal([test_msg], backup)
assert result is test_msg
assert np.shares_memory(result.data, test_msg.data)
-@pytest.mark.parametrize("msg_block_size", [1, 5, 10, 20, 60])
-@pytest.mark.parametrize("newaxis", [None, "win"])
+@pytest.mark.parametrize("msg_block_size", [60, 1, 5, 10, 100])
+@pytest.mark.parametrize("newaxis", ["win", None])
@pytest.mark.parametrize("win_dur", [0.3, 1.0])
-@pytest.mark.parametrize("win_shift", [None, 0.2, 1.0])
+@pytest.mark.parametrize("win_shift", [0.2, 1.0, None])
@pytest.mark.parametrize("zero_pad", ["input", "shift", "none"])
-@pytest.mark.parametrize("fs", [10.0, 500.0])
+@pytest.mark.parametrize("fs", [100.0, 500.0])
@pytest.mark.parametrize("anchor", ["beginning", "middle", "end"])
@pytest.mark.parametrize("time_ax", [0, 1])
def test_window_generator(
@@ -57,21 +57,19 @@ def test_window_generator(
anchor: str,
time_ax: int,
):
- nchans = 3
+ nchans = 5
shift_len = int(win_shift * fs) if win_shift is not None else None
win_len = int(win_dur * fs)
- data_len = 2 * win_len
+ data_len = 2 * max(win_len, msg_block_size)
if win_shift is not None:
data_len += shift_len - 1
+ tvec = np.arange(data_len) / fs
data = np.arange(nchans * data_len, dtype=float).reshape((nchans, data_len))
# Below, we transpose the individual messages if time_ax == 0.
- tvec = np.arange(data_len) / fs
-
- n_msgs = int(np.ceil(data_len / msg_block_size))
# Instantiate the processor
- proc = windowing(
+ proc = WindowTransformer(
axis="time",
newaxis=newaxis,
window_dur=win_dur,
@@ -80,8 +78,8 @@ def test_window_generator(
anchor=anchor,
)
- # Create inputs and send them to the process, collecting the results along the way.
- test_msg = AxisArray(
+ # Create inputs
+ template_msg = AxisArray(
data[..., ()],
dims=["ch", "time"] if time_ax == 1 else ["time", "ch"],
axes=frozendict(
@@ -94,36 +92,36 @@ def test_window_generator(
),
key="test_window_generator",
)
- messages = []
- backup = []
- results = []
+ n_msgs = int(np.ceil(data_len / msg_block_size))
+ in_msgs = []
for msg_ix in range(n_msgs):
msg_data = data[..., msg_ix * msg_block_size : (msg_ix + 1) * msg_block_size]
if time_ax == 0:
msg_data = np.ascontiguousarray(msg_data.T)
- test_msg = replace(
- test_msg,
- data=msg_data,
- axes={
- **test_msg.axes,
- "time": replace(
- test_msg.axes["time"], offset=tvec[msg_ix * msg_block_size]
- ),
- },
- key=test_msg.key,
+ in_msgs.append(
+ replace(
+ template_msg,
+ data=msg_data,
+ axes={
+ **template_msg.axes,
+ "time": replace(
+ template_msg.axes["time"], offset=tvec[msg_ix * msg_block_size]
+ ),
+ },
+ )
)
- messages.append(test_msg)
- backup.append(copy.deepcopy(test_msg))
- win_msg = proc(test_msg)
- results.append(win_msg)
+ backup = copy.deepcopy(in_msgs)
- assert_messages_equal(messages, backup)
+ # Do the actual processing.
+ out_msgs = [proc(_) for _ in in_msgs]
+
+ assert_messages_equal(in_msgs, backup)
# Check each return value's metadata (offsets checked at end)
expected_dims = (
- test_msg.dims[:time_ax] + [newaxis or "win"] + test_msg.dims[time_ax:]
+ template_msg.dims[:time_ax] + [newaxis or "win"] + template_msg.dims[time_ax:]
)
- for msg in results:
+ for msg in out_msgs:
assert msg.axes["time"].gain == 1 / fs
assert msg.dims == expected_dims
assert (newaxis or "win") in msg.axes
@@ -134,11 +132,11 @@ def test_window_generator(
# Post-process the results to yield a single data array and a single vector of offsets.
win_ax = time_ax
# time_ax = win_ax + 1
- result = np.concatenate([_.data for _ in results], win_ax)
+ result = np.concatenate([_.data for _ in out_msgs], win_ax)
offsets = np.hstack(
[
_.axes[newaxis or "win"].value(np.arange(_.data.shape[win_ax]))
- for _ in results
+ for _ in out_msgs
]
)
@@ -159,7 +157,9 @@ def test_window_generator(
)
# Compare results to expected
- assert np.array_equal(result, expected)
+ if win_shift is None:
+ assert len(out_msgs) == len(in_msgs)
+ assert np.allclose(result, expected)
assert np.allclose(offsets, tvec)
@@ -171,49 +171,87 @@ def test_sparse_window(
win_shift: float | None,
zero_pad: str,
):
+ msg_block_size = 60
fs = 100.0
- n_ch = 5
- n_samps = 1_000
- msg_len = 100
+ nchans = 5
+
+ # Create sparse data
+ shift_len = int(win_shift * fs) if win_shift is not None else None
win_len = int(win_dur * fs)
+ data_len = 2 * max(win_len, msg_block_size)
+ if win_shift is not None:
+ data_len += shift_len - 1
+ tvec = np.arange(data_len) / fs
rng = np.random.default_rng()
- s = sparse.random((n_samps, n_ch), density=0.1, random_state=rng) > 0
- in_msgs = [
- AxisArray(
- data=s[msg_ix * msg_len : (msg_ix + 1) * msg_len],
- dims=["time", "ch"],
- axes={
- "time": AxisArray.Axis.TimeAxis(fs=fs, offset=msg_ix / fs),
- },
- key="test_sparse_window",
- )
- for msg_ix in range(10)
- ]
+ s = sparse.random((data_len, nchans), density=0.1, random_state=rng) > 0
- proc = windowing(
+ # Create WindowTransformer
+ proc = WindowTransformer(
axis="time",
newaxis="win",
window_dur=win_dur,
window_shift=win_shift,
zero_pad_until=zero_pad,
+ anchor="beginning",
+ )
+
+ template_msg = AxisArray(
+ data=s[:0],
+ dims=["time", "ch"],
+ axes=frozendict(
+ {
+ "time": AxisArray.TimeAxis(fs=fs, offset=0.0),
+ "ch": AxisArray.CoordinateAxis(
+ data=np.arange(nchans).astype(str), unit="label", dims=["ch"]
+ ),
+ }
+ ),
+ key="test_sparse_window",
)
- out_msgs = [proc.send(_) for _ in in_msgs]
+ n_msgs = int(np.ceil(data_len / msg_block_size))
+ in_msgs = [
+ replace(
+ template_msg,
+ data=s[msg_ix * msg_block_size : (msg_ix + 1) * msg_block_size],
+ axes={
+ **template_msg.axes,
+ "time": replace(
+ template_msg.axes["time"], offset=tvec[msg_ix * msg_block_size]
+ ),
+ },
+ )
+ for msg_ix in range(n_msgs)
+ ]
+
+ # Process messages
+ out_msgs = [proc(_) for _ in in_msgs]
+
+ # Assert per-message shape and collect total number of windows and window time vector
nwins = 0
+ win_tvec = []
for om in out_msgs:
assert om.dims == ["win", "time", "ch"]
assert om.data.shape[1] == win_len
- assert om.data.shape[2] == n_ch
+ assert om.data.shape[2] == nchans
nwins += om.data.shape[0]
- if win_shift is None:
- # 1:1 mode
- assert nwins == len(out_msgs)
- else:
- shift_len = int(win_shift * fs)
- prepended = 0
- if zero_pad == "input":
- prepended = max(0, win_len - msg_len)
- elif zero_pad == "shift":
- prepended = max(0, win_len - shift_len)
- win_offsets = np.arange(n_samps + prepended)[::shift_len]
- expected_nwins = np.sum(win_offsets <= (n_samps + prepended - win_len))
- assert nwins == expected_nwins
+ win_tvec.append(om.axes["win"].value(np.arange(om.data.shape[0])))
+ win_tvec = np.hstack(win_tvec)
+
+ # Calculate the expected time vector; note this method expects data time axis to be last.
+ _, expected_tvec = calculate_expected_windows(
+ np.arange(nchans * data_len).reshape((nchans, data_len)),
+ fs,
+ win_shift,
+ zero_pad,
+ "beginning",
+ msg_block_size,
+ shift_len,
+ win_len,
+ nchans,
+ data_len,
+ n_msgs,
+ 0,
+ )
+
+ assert nwins == len(expected_tvec)
+ assert np.allclose(win_tvec, expected_tvec)