diff --git a/docs/img/HybridBufferBasic.svg b/docs/source/guides/img/HybridBufferBasic.svg similarity index 100% rename from docs/img/HybridBufferBasic.svg rename to docs/source/guides/img/HybridBufferBasic.svg diff --git a/docs/img/HybridBufferOverflow.svg b/docs/source/guides/img/HybridBufferOverflow.svg similarity index 100% rename from docs/img/HybridBufferOverflow.svg rename to docs/source/guides/img/HybridBufferOverflow.svg diff --git a/src/ezmsg/sigproc/aggregate.py b/src/ezmsg/sigproc/aggregate.py index 2c084cc1..ddf2d922 100644 --- a/src/ezmsg/sigproc/aggregate.py +++ b/src/ezmsg/sigproc/aggregate.py @@ -1,3 +1,4 @@ +from array_api_compat import get_namespace import typing import numpy as np @@ -12,6 +13,7 @@ from .spectral import OptionsEnum from .base import ( + BaseTransformer, BaseStatefulTransformer, BaseTransformerUnit, processor_state, @@ -213,3 +215,70 @@ def ranged_aggregate( return RangedAggregateTransformer( RangedAggregateSettings(axis=axis, bands=bands, operation=operation) ) + + +class AggregateSettings(ez.Settings): + """Settings for :obj:`Aggregate`.""" + + axis: str + """The name of the axis to aggregate over. This axis will be removed from the output.""" + + operation: AggregationFunction = AggregationFunction.MEAN + """:obj:`AggregationFunction` to apply.""" + + +class AggregateTransformer(BaseTransformer[AggregateSettings, AxisArray, AxisArray]): + """ + Transformer that aggregates an entire axis using a specified operation. + + Unlike :obj:`RangedAggregateTransformer` which aggregates over specific ranges/bands + and preserves the axis (with one value per band), this transformer aggregates the + entire axis and removes it from the output, reducing dimensionality by one. + """ + + def _process(self, message: AxisArray) -> AxisArray: + xp = get_namespace(message.data) + axis_idx = message.get_axis_idx(self.settings.axis) + op = self.settings.operation + + if op == AggregationFunction.NONE: + raise ValueError( + "AggregationFunction.NONE is not supported for full-axis aggregation" + ) + + if op == AggregationFunction.TRAPEZOID: + # Trapezoid integration requires x-coordinates + target_axis = message.get_axis(self.settings.axis) + if hasattr(target_axis, "data"): + x = target_axis.data + else: + x = target_axis.value(np.arange(message.data.shape[axis_idx])) + agg_data = np.trapezoid(np.asarray(message.data), x=x, axis=axis_idx) + else: + # Try array-API compatible function first, fall back to numpy + func_name = op.value + if hasattr(xp, func_name): + agg_data = getattr(xp, func_name)(message.data, axis=axis_idx) + else: + agg_data = AGGREGATORS[op](message.data, axis=axis_idx) + + new_dims = list(message.dims) + new_dims.pop(axis_idx) + + new_axes = dict(message.axes) + new_axes.pop(self.settings.axis, None) + + return replace( + message, + data=agg_data, + dims=new_dims, + axes=new_axes, + ) + + +class AggregateUnit( + BaseTransformerUnit[AggregateSettings, AxisArray, AxisArray, AggregateTransformer] +): + """Unit that aggregates an entire axis using a specified operation.""" + + SETTINGS = AggregateSettings diff --git a/tests/unit/test_aggregate.py b/tests/unit/test_aggregate.py index fae2d571..4685445a 100644 --- a/tests/unit/test_aggregate.py +++ b/tests/unit/test_aggregate.py @@ -1,3 +1,4 @@ +import copy from functools import partial import numpy as np @@ -5,7 +6,12 @@ from frozendict import frozendict from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.sigproc.aggregate import ranged_aggregate, AggregationFunction +from ezmsg.sigproc.aggregate import ( + ranged_aggregate, + AggregationFunction, + AggregateTransformer, + AggregateSettings, +) from tests.helpers.util import assert_messages_equal @@ -159,3 +165,247 @@ def test_aggregate_handle_change(change_ax: str): print(len(out_msgs1)) out_msgs2 = [gen.send(_) for _ in in_msgs2] print(len(out_msgs2)) + + +# ============== Tests for AggregateTransformer ============== + + +def get_simple_msg(n_times=10, n_chans=5, n_freqs=8, fs=100.0): + """Create a simple AxisArray message for testing AggregateTransformer.""" + data = np.arange(n_times * n_chans * n_freqs, dtype=float).reshape( + n_times, n_chans, n_freqs + ) + return AxisArray( + data=data, + dims=["time", "ch", "freq"], + axes=frozendict( + { + "time": AxisArray.TimeAxis(fs=fs, offset=0.0), + "ch": AxisArray.CoordinateAxis( + data=np.array([f"ch{i}" for i in range(n_chans)]), + dims=["ch"], + ), + "freq": AxisArray.LinearAxis(gain=2.0, offset=1.0, unit="Hz"), + } + ), + ) + + +@pytest.mark.parametrize( + "operation", + [ + AggregationFunction.MEAN, + AggregationFunction.SUM, + AggregationFunction.MAX, + AggregationFunction.MIN, + AggregationFunction.STD, + AggregationFunction.MEDIAN, + ], +) +def test_aggregate_transformer_basic(operation: AggregationFunction): + """Test AggregateTransformer with basic aggregation operations.""" + msg_in = get_simple_msg() + backup = copy.deepcopy(msg_in) + + transformer = AggregateTransformer( + AggregateSettings(axis="freq", operation=operation) + ) + msg_out = transformer(msg_in) + + # Verify input wasn't modified + assert_messages_equal([msg_in], [backup]) + + # Verify output type + assert isinstance(msg_out, AxisArray) + + # Verify axis was removed + assert "freq" not in msg_out.dims + assert "freq" not in msg_out.axes + assert msg_out.dims == ["time", "ch"] + + # Verify output shape + assert msg_out.data.shape == (10, 5) + + # Verify data correctness + np_func = getattr(np, operation.value) + expected = np_func(msg_in.data, axis=2) + assert np.allclose(msg_out.data, expected) + + +@pytest.mark.parametrize("axis", ["time", "ch", "freq"]) +def test_aggregate_transformer_different_axes(axis: str): + """Test AggregateTransformer can aggregate along different axes.""" + msg_in = get_simple_msg(n_times=10, n_chans=5, n_freqs=8) + + transformer = AggregateTransformer( + AggregateSettings(axis=axis, operation=AggregationFunction.MEAN) + ) + msg_out = transformer(msg_in) + + # Verify the specified axis was removed + assert axis not in msg_out.dims + assert axis not in msg_out.axes + + # Verify remaining dims + expected_dims = [d for d in ["time", "ch", "freq"] if d != axis] + assert msg_out.dims == expected_dims + + # Verify shape + axis_idx = msg_in.get_axis_idx(axis) + expected_shape = list(msg_in.data.shape) + expected_shape.pop(axis_idx) + assert msg_out.data.shape == tuple(expected_shape) + + # Verify data + expected = np.mean(msg_in.data, axis=axis_idx) + assert np.allclose(msg_out.data, expected) + + +def test_aggregate_transformer_none_raises(): + """Test that AggregationFunction.NONE raises an error.""" + msg_in = get_simple_msg() + + transformer = AggregateTransformer( + AggregateSettings(axis="freq", operation=AggregationFunction.NONE) + ) + + with pytest.raises(ValueError, match="NONE is not supported"): + transformer(msg_in) + + +@pytest.mark.parametrize( + "operation", + [ + AggregationFunction.NANMEAN, + AggregationFunction.NANSUM, + AggregationFunction.NANMAX, + AggregationFunction.NANMIN, + AggregationFunction.NANSTD, + AggregationFunction.NANMEDIAN, + ], +) +def test_aggregate_transformer_nan_operations(operation: AggregationFunction): + """Test AggregateTransformer with NaN-aware operations.""" + msg_in = get_simple_msg() + # Introduce some NaN values + msg_in.data[0, 0, 0] = np.nan + msg_in.data[5, 2, 3] = np.nan + + transformer = AggregateTransformer( + AggregateSettings(axis="freq", operation=operation) + ) + msg_out = transformer(msg_in) + + # Verify output doesn't have NaN where nan-operations should have handled it + np_func = getattr(np, operation.value) + expected = np_func(msg_in.data, axis=2) + assert np.allclose(msg_out.data, expected, equal_nan=True) + + +@pytest.mark.parametrize( + "operation", [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX] +) +def test_aggregate_transformer_argminmax(operation: AggregationFunction): + """Test AggregateTransformer with argmin/argmax operations.""" + msg_in = get_simple_msg() + + transformer = AggregateTransformer( + AggregateSettings(axis="freq", operation=operation) + ) + msg_out = transformer(msg_in) + + # Verify output shape (axis removed) + assert msg_out.data.shape == (10, 5) + assert "freq" not in msg_out.dims + + # Verify data correctness (returns indices) + np_func = getattr(np, operation.value) + expected = np_func(msg_in.data, axis=2) + assert np.array_equal(msg_out.data, expected) + + +def test_aggregate_transformer_trapezoid(): + """Test AggregateTransformer with trapezoid integration.""" + msg_in = get_simple_msg(n_times=5, n_chans=3, n_freqs=10) + + transformer = AggregateTransformer( + AggregateSettings(axis="freq", operation=AggregationFunction.TRAPEZOID) + ) + msg_out = transformer(msg_in) + + # Verify output shape + assert msg_out.data.shape == (5, 3) + assert "freq" not in msg_out.dims + + # Calculate expected result using axis coordinates + freq_axis = msg_in.axes["freq"] + x = freq_axis.value(np.arange(msg_in.data.shape[2])) + expected = np.trapezoid(msg_in.data, x=x, axis=2) + + assert np.allclose(msg_out.data, expected) + + +def test_aggregate_transformer_trapezoid_coordinate_axis(): + """Test trapezoid integration with CoordinateAxis.""" + n_times, n_chans, n_freqs = 5, 3, 10 + data = np.arange(n_times * n_chans * n_freqs, dtype=float).reshape( + n_times, n_chans, n_freqs + ) + freq_values = np.array([1.0, 2.0, 4.0, 7.0, 11.0, 16.0, 22.0, 29.0, 37.0, 46.0]) + msg_in = AxisArray( + data=data, + dims=["time", "ch", "freq"], + axes=frozendict( + { + "time": AxisArray.TimeAxis(fs=100.0, offset=0.0), + "freq": AxisArray.CoordinateAxis( + data=freq_values, dims=["freq"], unit="Hz" + ), + } + ), + ) + + transformer = AggregateTransformer( + AggregateSettings(axis="freq", operation=AggregationFunction.TRAPEZOID) + ) + msg_out = transformer(msg_in) + + # Calculate expected using the coordinate values + expected = np.trapezoid(msg_in.data, x=freq_values, axis=2) + assert np.allclose(msg_out.data, expected) + + +def test_aggregate_transformer_preserves_other_axes(): + """Test that non-aggregated axes are preserved correctly.""" + msg_in = get_simple_msg() + + transformer = AggregateTransformer( + AggregateSettings(axis="freq", operation=AggregationFunction.MEAN) + ) + msg_out = transformer(msg_in) + + # Verify time axis preserved + assert "time" in msg_out.axes + assert msg_out.axes["time"] == msg_in.axes["time"] + + # Verify ch axis preserved + assert "ch" in msg_out.axes + ch_ax_in = msg_in.axes["ch"] + ch_ax_out = msg_out.axes["ch"] + assert np.array_equal(ch_ax_out.data, ch_ax_in.data) + + +def test_aggregate_transformer_multiple_calls(): + """Test that transformer works correctly with multiple calls.""" + transformer = AggregateTransformer( + AggregateSettings(axis="freq", operation=AggregationFunction.SUM) + ) + + for i in range(3): + msg_in = get_simple_msg() + msg_in.data = msg_in.data + i * 1000 # Different data each time + + msg_out = transformer(msg_in) + + expected = np.sum(msg_in.data, axis=2) + assert np.allclose(msg_out.data, expected)