Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions docs/source/guides/explanations/array_api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
Array API Support
=================

ezmsg-sigproc provides support for the `Python Array API standard
<https://data-apis.org/array-api/>`_, enabling many transformers to work with
arrays from different backends such as NumPy, CuPy, PyTorch, and JAX.

What is the Array API?
----------------------

The Array API is a standardized interface for array operations across different
Python array libraries. By coding to this standard, ezmsg-sigproc transformers
can process data regardless of which array library created it, enabling:

- **GPU acceleration** via CuPy or PyTorch tensors
- **Framework interoperability** for integration with ML pipelines
- **Hardware flexibility** without code changes

How It Works
------------

Compatible transformers use `array-api-compat <https://github.com/data-apis/array-api-compat>`_
to detect the input array's namespace and use the appropriate operations:

.. code-block:: python

from array_api_compat import get_namespace

def _process(self, message: AxisArray) -> AxisArray:
xp = get_namespace(message.data) # numpy, cupy, torch, etc.
result = xp.abs(message.data) # Uses the correct backend
return replace(message, data=result)

Usage Example
-------------

Using Array API compatible transformers with CuPy for GPU acceleration:

.. code-block:: python

import cupy as cp
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.sigproc.math.abs import AbsTransformer
from ezmsg.sigproc.math.clip import ClipTransformer, ClipSettings

# Create data on GPU
gpu_data = cp.random.randn(1000, 64).astype(cp.float32)
message = AxisArray(gpu_data, dims=["time", "ch"])

# Process entirely on GPU - no data transfer!
abs_transformer = AbsTransformer()
clip_transformer = ClipTransformer(ClipSettings(min=0.0, max=1.0))

result = clip_transformer(abs_transformer(message))
# result.data is still a CuPy array on GPU

Compatible Modules
------------------

The following transformers fully support the Array API standard:

Math Operations
^^^^^^^^^^^^^^^

.. list-table::
:header-rows: 1
:widths: 30 70

* - Module
- Description
* - :mod:`ezmsg.sigproc.math.abs`
- Absolute value
* - :mod:`ezmsg.sigproc.math.clip`
- Clip values to a range
* - :mod:`ezmsg.sigproc.math.log`
- Logarithm with configurable base
* - :mod:`ezmsg.sigproc.math.scale`
- Multiply by a constant
* - :mod:`ezmsg.sigproc.math.invert`
- Compute 1/x
* - :mod:`ezmsg.sigproc.math.difference`
- Subtract a constant (ConstDifferenceTransformer)

Signal Processing
^^^^^^^^^^^^^^^^^

.. list-table::
:header-rows: 1
:widths: 30 70

* - Module
- Description
* - :mod:`ezmsg.sigproc.diff`
- Compute differences along an axis
* - :mod:`ezmsg.sigproc.transpose`
- Transpose/permute array dimensions
* - :mod:`ezmsg.sigproc.linear`
- Per-channel linear transform (scale + offset)
* - :mod:`ezmsg.sigproc.aggregate`
- Aggregate operations (AggregateTransformer only)

Coordinate Transforms
^^^^^^^^^^^^^^^^^^^^^

.. list-table::
:header-rows: 1
:widths: 30 70

* - Module
- Description
* - :mod:`ezmsg.sigproc.coordinatespaces`
- Cartesian/polar coordinate conversions

Limitations
-----------

Some operations remain NumPy-only due to lack of Array API equivalents:

- **Random number generation**: Modules using ``np.random`` (e.g., ``denormalize``)
- **SciPy operations**: Filtering (``scipy.signal.lfilter``), FFT, wavelets
- **Advanced indexing**: Some slicing operations for metadata handling
- **Memory layout**: ``np.require`` for contiguous array optimization (NumPy only)

Metadata arrays (axis labels, coordinates) typically remain as NumPy arrays
since they are not performance-critical.

Adding Array API Support
------------------------

When contributing new transformers, follow this pattern:

.. code-block:: python

from array_api_compat import get_namespace
from ezmsg.baseproc import BaseTransformer
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

class MyTransformer(BaseTransformer[MySettings, AxisArray, AxisArray]):
def _process(self, message: AxisArray) -> AxisArray:
xp = get_namespace(message.data)

# Use xp instead of np for array operations
result = xp.sqrt(xp.abs(message.data))

return replace(message, data=result)

Key guidelines:

1. Call ``get_namespace(message.data)`` at the start of ``_process``
2. Use ``xp.function_name`` instead of ``np.function_name``
3. Note that some functions have different names:
- ``np.concatenate`` → ``xp.concat``
- ``np.transpose`` → ``xp.permute_dims``
4. Keep metadata operations (axis labels, etc.) as NumPy
5. Use in-place operations (``/=``, ``*=``) where possible for efficiency
4 changes: 3 additions & 1 deletion docs/source/guides/sigproc/content-sigproc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ ezmsg-sigproc
Timeseries signal processing implementations in ezmsg, leveraging numpy and scipy.
Most of the methods and classes in this extension are intended to be used in building signal processing pipelines.
They use :class:`ezmsg.util.messages.axisarray.AxisArray` as the primary data structure for passing signals between components.
The message's data are expected to be a numpy array.
The message's data are typically NumPy arrays, though many transformers support the
:doc:`Array API standard <../explanations/array_api>` for use with CuPy, PyTorch, and other backends.

.. note:: Some generators might yield valid :class:`AxisArray` messages with ``.data`` size of 0.
This may occur when the generator receives inadequate data to produce a valid output, such as when windowing or buffering.
Expand All @@ -21,3 +22,4 @@ This may occur when the generator receives inadequate data to produce a valid ou
base
units
processors
../explanations/array_api
9 changes: 9 additions & 0 deletions src/ezmsg/sigproc/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
"""
Aggregation operations over arrays.

.. note::
:obj:`AggregateTransformer` supports the :doc:`Array API standard </guides/explanations/array_api>`,
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
:obj:`RangedAggregateTransformer` currently requires NumPy arrays.
"""

import typing

import ezmsg.core as ez
Expand Down
27 changes: 22 additions & 5 deletions src/ezmsg/sigproc/coordinatespaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

This module provides utilities and ezmsg nodes for transforming between
Cartesian (x, y) and polar (r, theta) coordinate systems.

.. note::
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
"""

from enum import Enum
Expand All @@ -11,6 +15,7 @@
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from array_api_compat import get_namespace, is_array_api_obj
from ezmsg.baseproc import (
BaseTransformer,
BaseTransformerUnit,
Expand All @@ -20,14 +25,24 @@
# -- Utility functions for coordinate transformations --


def _get_namespace_or_numpy(*args: npt.ArrayLike):
"""Get array namespace if any arg is an array, otherwise return numpy."""
for arg in args:
if is_array_api_obj(arg):
return get_namespace(arg)
return np


def polar2z(r: npt.ArrayLike, theta: npt.ArrayLike) -> npt.ArrayLike:
"""Convert polar coordinates to complex number representation."""
return r * np.exp(1j * theta)
xp = _get_namespace_or_numpy(r, theta)
return r * xp.exp(1j * theta)


def z2polar(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
"""Convert complex number to polar coordinates (r, theta)."""
return np.abs(z), np.angle(z)
xp = _get_namespace_or_numpy(z)
return xp.abs(z), xp.atan2(xp.imag(z), xp.real(z))


def cart2z(x: npt.ArrayLike, y: npt.ArrayLike) -> npt.ArrayLike:
Expand All @@ -37,7 +52,8 @@ def cart2z(x: npt.ArrayLike, y: npt.ArrayLike) -> npt.ArrayLike:

def z2cart(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
"""Convert complex number to Cartesian coordinates (x, y)."""
return np.real(z), np.imag(z)
xp = _get_namespace_or_numpy(z)
return xp.real(z), xp.imag(z)


def cart2pol(x: npt.ArrayLike, y: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
Expand Down Expand Up @@ -90,6 +106,7 @@ class CoordinateSpacesTransformer(BaseTransformer[CoordinateSpacesSettings, Axis
"""

def _process(self, message: AxisArray) -> AxisArray:
xp = get_namespace(message.data)
axis = self.settings.axis or message.dims[-1]
axis_idx = message.get_axis_idx(axis)

Expand All @@ -116,9 +133,9 @@ def _process(self, message: AxisArray) -> AxisArray:
out_a, out_b = pol2cart(component_a, component_b)

# Stack results back along the same axis
result = np.stack([out_a, out_b], axis=axis_idx)
result = xp.stack([out_a, out_b], axis=axis_idx)

# Update axis labels if present
# Update axis labels if present (use numpy for string labels)
axes = message.axes
if axis in axes and hasattr(axes[axis], "data"):
if self.settings.mode == CoordinateMode.CART2POL:
Expand Down
19 changes: 15 additions & 4 deletions src/ezmsg/sigproc/diff.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
"""
Compute differences along an axis.

.. note::
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
"""

import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from array_api_compat import get_namespace
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
Expand Down Expand Up @@ -39,23 +48,25 @@ def _reset_state(self, message) -> None:
self.state.last_time = ax_info.data[0] - 0.001

def _process(self, message: AxisArray) -> AxisArray:
xp = get_namespace(message.data)
axis = self.settings.axis or message.dims[0]
ax_idx = message.get_axis_idx(axis)

diffs = np.diff(
np.concatenate((self.state.last_dat, message.data), axis=ax_idx),
diffs = xp.diff(
xp.concat((self.state.last_dat, message.data), axis=ax_idx),
axis=ax_idx,
)
# Prepare last_dat for next iteration
self.state.last_dat = slice_along_axis(message.data, slice(-1, None), axis=ax_idx)
# Scale by fs if requested. This convers the diff to a derivative. e.g., diff of position becomes velocity.
# Scale by fs if requested. This converts the diff to a derivative. e.g., diff of position becomes velocity.
if self.settings.scale_by_fs:
ax_info = message.get_axis(axis)
if hasattr(ax_info, "data"):
# ax_info.data is typically numpy for metadata, so use np.diff here
dt = np.diff(np.concatenate(([self.state.last_time], ax_info.data)))
# Expand dt dims to match diffs
exp_sl = (None,) * ax_idx + (Ellipsis,) + (None,) * (message.data.ndim - ax_idx - 1)
diffs /= dt[exp_sl]
diffs /= xp.asarray(dt[exp_sl])
self.state.last_time = ax_info.data[-1] # For next iteration
else:
diffs /= ax_info.gain
Expand Down
9 changes: 6 additions & 3 deletions src/ezmsg/sigproc/linear.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Apply a linear transformation: output = scale * input + offset.
"""
Apply a linear transformation: output = scale * input + offset.

Supports per-element scale and offset along a specified axis.
Uses Array API for compatibility with numpy, cupy, torch, etc.

For full matrix transformations, use :obj:`AffineTransformTransformer` instead.

.. note::
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
"""

import ezmsg.core as ez
Expand Down
14 changes: 10 additions & 4 deletions src/ezmsg/sigproc/math/abs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
"""Take the absolute value of the data."""
# TODO: Array API
"""
Take the absolute value of the data.

import numpy as np
.. note::
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
"""

from array_api_compat import get_namespace
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
Expand All @@ -13,7 +18,8 @@ class AbsSettings:

class AbsTransformer(BaseTransformer[None, AxisArray, AxisArray]):
def _process(self, message: AxisArray) -> AxisArray:
return replace(message, data=np.abs(message.data))
xp = get_namespace(message.data)
return replace(message, data=xp.abs(message.data))


class Abs(BaseTransformerUnit[None, AxisArray, AxisArray, AbsTransformer]): ... # SETTINGS = None
Expand Down
Loading