Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/ezmsg/sigproc/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from array_api_compat import get_namespace
import typing

import numpy as np
Expand All @@ -12,6 +13,7 @@

from .spectral import OptionsEnum
from .base import (
BaseTransformer,
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
Expand Down Expand Up @@ -213,3 +215,70 @@ def ranged_aggregate(
return RangedAggregateTransformer(
RangedAggregateSettings(axis=axis, bands=bands, operation=operation)
)


class AggregateSettings(ez.Settings):
"""Settings for :obj:`Aggregate`."""

axis: str
"""The name of the axis to aggregate over. This axis will be removed from the output."""

operation: AggregationFunction = AggregationFunction.MEAN
""":obj:`AggregationFunction` to apply."""


class AggregateTransformer(BaseTransformer[AggregateSettings, AxisArray, AxisArray]):
"""
Transformer that aggregates an entire axis using a specified operation.

Unlike :obj:`RangedAggregateTransformer` which aggregates over specific ranges/bands
and preserves the axis (with one value per band), this transformer aggregates the
entire axis and removes it from the output, reducing dimensionality by one.
"""

def _process(self, message: AxisArray) -> AxisArray:
xp = get_namespace(message.data)
axis_idx = message.get_axis_idx(self.settings.axis)
op = self.settings.operation

if op == AggregationFunction.NONE:
raise ValueError(
"AggregationFunction.NONE is not supported for full-axis aggregation"
)

if op == AggregationFunction.TRAPEZOID:
# Trapezoid integration requires x-coordinates
target_axis = message.get_axis(self.settings.axis)
if hasattr(target_axis, "data"):
x = target_axis.data
else:
x = target_axis.value(np.arange(message.data.shape[axis_idx]))
agg_data = np.trapezoid(np.asarray(message.data), x=x, axis=axis_idx)
else:
# Try array-API compatible function first, fall back to numpy
func_name = op.value
if hasattr(xp, func_name):
agg_data = getattr(xp, func_name)(message.data, axis=axis_idx)
else:
agg_data = AGGREGATORS[op](message.data, axis=axis_idx)

new_dims = list(message.dims)
new_dims.pop(axis_idx)

new_axes = dict(message.axes)
new_axes.pop(self.settings.axis, None)

return replace(
message,
data=agg_data,
dims=new_dims,
axes=new_axes,
)


class AggregateUnit(
BaseTransformerUnit[AggregateSettings, AxisArray, AxisArray, AggregateTransformer]
):
"""Unit that aggregates an entire axis using a specified operation."""

SETTINGS = AggregateSettings
252 changes: 251 additions & 1 deletion tests/unit/test_aggregate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import copy
from functools import partial

import numpy as np
import pytest
from frozendict import frozendict
from ezmsg.util.messages.axisarray import AxisArray

from ezmsg.sigproc.aggregate import ranged_aggregate, AggregationFunction
from ezmsg.sigproc.aggregate import (
ranged_aggregate,
AggregationFunction,
AggregateTransformer,
AggregateSettings,
)

from tests.helpers.util import assert_messages_equal

Expand Down Expand Up @@ -159,3 +165,247 @@ def test_aggregate_handle_change(change_ax: str):
print(len(out_msgs1))
out_msgs2 = [gen.send(_) for _ in in_msgs2]
print(len(out_msgs2))


# ============== Tests for AggregateTransformer ==============


def get_simple_msg(n_times=10, n_chans=5, n_freqs=8, fs=100.0):
"""Create a simple AxisArray message for testing AggregateTransformer."""
data = np.arange(n_times * n_chans * n_freqs, dtype=float).reshape(
n_times, n_chans, n_freqs
)
return AxisArray(
data=data,
dims=["time", "ch", "freq"],
axes=frozendict(
{
"time": AxisArray.TimeAxis(fs=fs, offset=0.0),
"ch": AxisArray.CoordinateAxis(
data=np.array([f"ch{i}" for i in range(n_chans)]),
dims=["ch"],
),
"freq": AxisArray.LinearAxis(gain=2.0, offset=1.0, unit="Hz"),
}
),
)


@pytest.mark.parametrize(
"operation",
[
AggregationFunction.MEAN,
AggregationFunction.SUM,
AggregationFunction.MAX,
AggregationFunction.MIN,
AggregationFunction.STD,
AggregationFunction.MEDIAN,
],
)
def test_aggregate_transformer_basic(operation: AggregationFunction):
"""Test AggregateTransformer with basic aggregation operations."""
msg_in = get_simple_msg()
backup = copy.deepcopy(msg_in)

transformer = AggregateTransformer(
AggregateSettings(axis="freq", operation=operation)
)
msg_out = transformer(msg_in)

# Verify input wasn't modified
assert_messages_equal([msg_in], [backup])

# Verify output type
assert isinstance(msg_out, AxisArray)

# Verify axis was removed
assert "freq" not in msg_out.dims
assert "freq" not in msg_out.axes
assert msg_out.dims == ["time", "ch"]

# Verify output shape
assert msg_out.data.shape == (10, 5)

# Verify data correctness
np_func = getattr(np, operation.value)
expected = np_func(msg_in.data, axis=2)
assert np.allclose(msg_out.data, expected)


@pytest.mark.parametrize("axis", ["time", "ch", "freq"])
def test_aggregate_transformer_different_axes(axis: str):
"""Test AggregateTransformer can aggregate along different axes."""
msg_in = get_simple_msg(n_times=10, n_chans=5, n_freqs=8)

transformer = AggregateTransformer(
AggregateSettings(axis=axis, operation=AggregationFunction.MEAN)
)
msg_out = transformer(msg_in)

# Verify the specified axis was removed
assert axis not in msg_out.dims
assert axis not in msg_out.axes

# Verify remaining dims
expected_dims = [d for d in ["time", "ch", "freq"] if d != axis]
assert msg_out.dims == expected_dims

# Verify shape
axis_idx = msg_in.get_axis_idx(axis)
expected_shape = list(msg_in.data.shape)
expected_shape.pop(axis_idx)
assert msg_out.data.shape == tuple(expected_shape)

# Verify data
expected = np.mean(msg_in.data, axis=axis_idx)
assert np.allclose(msg_out.data, expected)


def test_aggregate_transformer_none_raises():
"""Test that AggregationFunction.NONE raises an error."""
msg_in = get_simple_msg()

transformer = AggregateTransformer(
AggregateSettings(axis="freq", operation=AggregationFunction.NONE)
)

with pytest.raises(ValueError, match="NONE is not supported"):
transformer(msg_in)


@pytest.mark.parametrize(
"operation",
[
AggregationFunction.NANMEAN,
AggregationFunction.NANSUM,
AggregationFunction.NANMAX,
AggregationFunction.NANMIN,
AggregationFunction.NANSTD,
AggregationFunction.NANMEDIAN,
],
)
def test_aggregate_transformer_nan_operations(operation: AggregationFunction):
"""Test AggregateTransformer with NaN-aware operations."""
msg_in = get_simple_msg()
# Introduce some NaN values
msg_in.data[0, 0, 0] = np.nan
msg_in.data[5, 2, 3] = np.nan

transformer = AggregateTransformer(
AggregateSettings(axis="freq", operation=operation)
)
msg_out = transformer(msg_in)

# Verify output doesn't have NaN where nan-operations should have handled it
np_func = getattr(np, operation.value)
expected = np_func(msg_in.data, axis=2)
assert np.allclose(msg_out.data, expected, equal_nan=True)


@pytest.mark.parametrize(
"operation", [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]
)
def test_aggregate_transformer_argminmax(operation: AggregationFunction):
"""Test AggregateTransformer with argmin/argmax operations."""
msg_in = get_simple_msg()

transformer = AggregateTransformer(
AggregateSettings(axis="freq", operation=operation)
)
msg_out = transformer(msg_in)

# Verify output shape (axis removed)
assert msg_out.data.shape == (10, 5)
assert "freq" not in msg_out.dims

# Verify data correctness (returns indices)
np_func = getattr(np, operation.value)
expected = np_func(msg_in.data, axis=2)
assert np.array_equal(msg_out.data, expected)


def test_aggregate_transformer_trapezoid():
"""Test AggregateTransformer with trapezoid integration."""
msg_in = get_simple_msg(n_times=5, n_chans=3, n_freqs=10)

transformer = AggregateTransformer(
AggregateSettings(axis="freq", operation=AggregationFunction.TRAPEZOID)
)
msg_out = transformer(msg_in)

# Verify output shape
assert msg_out.data.shape == (5, 3)
assert "freq" not in msg_out.dims

# Calculate expected result using axis coordinates
freq_axis = msg_in.axes["freq"]
x = freq_axis.value(np.arange(msg_in.data.shape[2]))
expected = np.trapezoid(msg_in.data, x=x, axis=2)

assert np.allclose(msg_out.data, expected)


def test_aggregate_transformer_trapezoid_coordinate_axis():
"""Test trapezoid integration with CoordinateAxis."""
n_times, n_chans, n_freqs = 5, 3, 10
data = np.arange(n_times * n_chans * n_freqs, dtype=float).reshape(
n_times, n_chans, n_freqs
)
freq_values = np.array([1.0, 2.0, 4.0, 7.0, 11.0, 16.0, 22.0, 29.0, 37.0, 46.0])
msg_in = AxisArray(
data=data,
dims=["time", "ch", "freq"],
axes=frozendict(
{
"time": AxisArray.TimeAxis(fs=100.0, offset=0.0),
"freq": AxisArray.CoordinateAxis(
data=freq_values, dims=["freq"], unit="Hz"
),
}
),
)

transformer = AggregateTransformer(
AggregateSettings(axis="freq", operation=AggregationFunction.TRAPEZOID)
)
msg_out = transformer(msg_in)

# Calculate expected using the coordinate values
expected = np.trapezoid(msg_in.data, x=freq_values, axis=2)
assert np.allclose(msg_out.data, expected)


def test_aggregate_transformer_preserves_other_axes():
"""Test that non-aggregated axes are preserved correctly."""
msg_in = get_simple_msg()

transformer = AggregateTransformer(
AggregateSettings(axis="freq", operation=AggregationFunction.MEAN)
)
msg_out = transformer(msg_in)

# Verify time axis preserved
assert "time" in msg_out.axes
assert msg_out.axes["time"] == msg_in.axes["time"]

# Verify ch axis preserved
assert "ch" in msg_out.axes
ch_ax_in = msg_in.axes["ch"]
ch_ax_out = msg_out.axes["ch"]
assert np.array_equal(ch_ax_out.data, ch_ax_in.data)


def test_aggregate_transformer_multiple_calls():
"""Test that transformer works correctly with multiple calls."""
transformer = AggregateTransformer(
AggregateSettings(axis="freq", operation=AggregationFunction.SUM)
)

for i in range(3):
msg_in = get_simple_msg()
msg_in.data = msg_in.data + i * 1000 # Different data each time

msg_out = transformer(msg_in)

expected = np.sum(msg_in.data, axis=2)
assert np.allclose(msg_out.data, expected)