From 1d2e0cde07747f1c66cfcadfefb4d3bedc097f18 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Tue, 27 Jan 2026 00:41:42 -0500 Subject: [PATCH] Update butter zero phase with composition of standard butterworth for forward and a backward filter that operates on --- src/ezmsg/sigproc/butterworthzerophase.py | 304 +++++++-- .../ezmsg/test_butterworthzerophase_system.py | 40 +- tests/unit/test_butterworthzerophase.py | 577 +++++++++++++++--- 3 files changed, 756 insertions(+), 165 deletions(-) diff --git a/src/ezmsg/sigproc/butterworthzerophase.py b/src/ezmsg/sigproc/butterworthzerophase.py index 1bd2fbff..f054344c 100644 --- a/src/ezmsg/sigproc/butterworthzerophase.py +++ b/src/ezmsg/sigproc/butterworthzerophase.py @@ -1,42 +1,80 @@ +""" +Streaming zero-phase Butterworth filter implemented as a two-stage composite processor. + +Stage 1: Forward causal Butterworth filter (from ezmsg.sigproc.butterworthfilter) +Stage 2: Backward acausal filter with buffering (ButterworthBackwardFilterTransformer) + +The output is delayed by `pad_length` samples to ensure the backward pass has sufficient +future context. The pad_length is computed analytically using scipy's heuristic. +""" + import functools import typing -import ezmsg.core as ez import numpy as np import scipy.signal -from ezmsg.baseproc import SettingsType +from ezmsg.baseproc import BaseTransformerUnit +from ezmsg.baseproc.composite import CompositeProcessor from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.messages.util import replace -from ezmsg.sigproc.butterworthfilter import ButterworthFilterSettings, butter_design_fun -from ezmsg.sigproc.filter import ( - BACoeffs, - BaseFilterByDesignTransformerUnit, - FilterByDesignTransformer, - SOSCoeffs, +from .butterworthfilter import ( + ButterworthFilterSettings, + ButterworthFilterTransformer, + butter_design_fun, ) +from .filter import BACoeffs, FilterByDesignTransformer, SOSCoeffs +from .util.axisarray_buffer import HybridAxisArrayBuffer class ButterworthZeroPhaseSettings(ButterworthFilterSettings): - """Settings for :obj:`ButterworthZeroPhase`.""" + """ + Settings for :obj:`ButterworthZeroPhase`. + + This implements a streaming zero-phase Butterworth filter using forward-backward + filtering. The output is delayed by `pad_length` samples to ensure the backward + pass has sufficient future context. + + The pad_length is computed by finding where the filter's impulse response decays + to `settle_cutoff` fraction of its peak value. This accounts for the filter's + actual time constant rather than just its order. + """ + + # Inherits from ButterworthFilterSettings: + # axis, coef_type, order, cuton, cutoff, wn_hz - # axis, coef_type, order, cuton, cutoff, wn_hz are inherited from ButterworthFilterSettings - padtype: str | None = None + settle_cutoff: float = 0.01 """ - Padding type to use in `scipy.signal.filtfilt`. - Must be one of {'odd', 'even', 'constant', None}. - Default is None for no padding. + Fraction of peak impulse response used to determine settling time. + The pad_length is set to the number of samples until the impulse response + decays to this fraction of its peak. Default is 0.01 (1% of peak). """ - padlen: int | None = 0 + max_pad_duration: float | None = None """ - Length of the padding to use in `scipy.signal.filtfilt`. - If None, SciPy's default padding is used. + Maximum pad duration in seconds. If set, the pad_length will be capped + at this value times the sampling rate. Use this to limit latency for + filters with very long impulse responses. Default is None (no limit). """ -class ButterworthZeroPhaseTransformer(FilterByDesignTransformer[ButterworthZeroPhaseSettings, BACoeffs | SOSCoeffs]): - """Zero-phase (filtfilt) Butterworth using your design function.""" +class ButterworthBackwardFilterTransformer(FilterByDesignTransformer[ButterworthFilterSettings, BACoeffs | SOSCoeffs]): + """ + Backward (acausal) Butterworth filter with buffering. + + This transformer buffers its input and applies the filter in reverse, + outputting only the "settled" portion where transients have decayed. + This introduces a lag of ``pad_length`` samples. + + Intended to be used as stage 2 in a zero-phase filter pipeline, receiving + forward-filtered data from a ButterworthFilterTransformer. + """ + + # Instance attributes (initialized in _reset_state) + _buffer: HybridAxisArrayBuffer | None + _coefs_cache: BACoeffs | SOSCoeffs | None + _zi_tiled: np.ndarray | None + _pad_length: int def get_design_function( self, @@ -50,74 +88,218 @@ def get_design_function( wn_hz=self.settings.wn_hz, ) - def update_settings(self, new_settings: typing.Optional[SettingsType] = None, **kwargs) -> None: + def _compute_pad_length(self, fs: float) -> int: """ - Update settings and mark that filter coefficients need to be recalculated. + Compute pad length based on the filter's impulse response settling time. + + The pad_length is determined by finding where the impulse response decays + to `settle_cutoff` fraction of its peak value. This is then optionally + capped by `max_pad_duration`. Args: - new_settings: Complete new settings object to replace current settings - **kwargs: Individual settings to update + fs: Sampling frequency in Hz. + + Returns: + Number of samples for the pad length. """ - # Update settings - if new_settings is not None: - self.settings = new_settings + # Design the filter to compute impulse response + coefs = self.get_design_function()(fs) + if coefs is None: + # Filter design failed or is disabled + return 0 + + # Generate impulse response - use a generous length initially + # Start with scipy's heuristic as minimum, then extend if needed + if self.settings.coef_type == "ba": + min_length = 3 * (self.settings.order + 1) else: - self.settings = replace(self.settings, **kwargs) + n_sections = (self.settings.order + 1) // 2 + min_length = 3 * n_sections * 2 - # Set flag to trigger recalculation on next message - self._coefs_cache = None - self._fs_cache = None - self.state.needs_redesign = True + # Use 10x the minimum as initial impulse length, or at least 10000 samples + # (10000 samples allows for ~333ms at 30kHz, covering most practical cases) + impulse_length = max(min_length * 10, 10000) + + # Cap impulse length computation if max_pad_duration is set + if self.settings.max_pad_duration is not None: + max_samples = int(self.settings.max_pad_duration * fs) + impulse_length = min(impulse_length, max_samples + 1) + + impulse = np.zeros(impulse_length) + impulse[0] = 1.0 + + if self.settings.coef_type == "ba": + b, a = coefs + h = scipy.signal.lfilter(b, a, impulse) + else: + h = scipy.signal.sosfilt(coefs, impulse) + + # Find where impulse response settles to settle_cutoff of peak + abs_h = np.abs(h) + peak = abs_h.max() + if peak == 0: + return min_length + + threshold = self.settings.settle_cutoff * peak + above_threshold = np.where(abs_h > threshold)[0] + + if len(above_threshold) == 0: + pad_length = min_length + else: + pad_length = above_threshold[-1] + 1 + + # Ensure at least the scipy heuristic minimum + pad_length = max(pad_length, min_length) + + # Apply max_pad_duration cap if set + if self.settings.max_pad_duration is not None: + max_samples = int(self.settings.max_pad_duration * fs) + pad_length = min(pad_length, max_samples) + + return pad_length def _reset_state(self, message: AxisArray) -> None: + """Reset filter state when stream changes.""" self._coefs_cache = None - self._fs_cache = None + self._zi_tiled = None + self._buffer = None + # Compute pad_length based on the message's sampling rate + axis = message.dims[0] if self.settings.axis is None else self.settings.axis + fs = 1 / message.axes[axis].gain + self._pad_length = self._compute_pad_length(fs) self.state.needs_redesign = True + def _compute_zi_tiled(self, data: np.ndarray, ax_idx: int) -> None: + """Compute and cache the tiled zi for the given data shape. + + Called once per stream (or after filter redesign). The result is + broadcast-ready for multiplication by the edge sample on each chunk. + """ + if self.settings.coef_type == "ba": + b, a = self._coefs_cache + zi_base = scipy.signal.lfilter_zi(b, a) + else: # sos + zi_base = scipy.signal.sosfilt_zi(self._coefs_cache) + + n_tail = data.ndim - ax_idx - 1 + + if self.settings.coef_type == "ba": + zi_expand = (None,) * ax_idx + (slice(None),) + (None,) * n_tail + n_tile = data.shape[:ax_idx] + (1,) + data.shape[ax_idx + 1 :] + else: # sos + zi_expand = (slice(None),) + (None,) * ax_idx + (slice(None),) + (None,) * n_tail + n_tile = (1,) + data.shape[:ax_idx] + (1,) + data.shape[ax_idx + 1 :] + + self._zi_tiled = np.tile(zi_base[zi_expand], n_tile) + + def _initialize_zi(self, data: np.ndarray, ax_idx: int) -> np.ndarray: + """Initialize filter state (zi) scaled by edge value.""" + if self._zi_tiled is None: + self._compute_zi_tiled(data, ax_idx) + first_sample = np.take(data, [0], axis=ax_idx) + return self._zi_tiled * first_sample + def _process(self, message: AxisArray) -> AxisArray: axis = message.dims[0] if self.settings.axis is None else self.settings.axis ax_idx = message.get_axis_idx(axis) fs = 1 / message.axes[axis].gain - if ( - self._coefs_cache is None - or self.state.needs_redesign - or (self._fs_cache is None or not np.isclose(self._fs_cache, fs)) - ): + # Check if we need to redesign filter + if self._coefs_cache is None or self.state.needs_redesign: self._coefs_cache = self.get_design_function()(fs) - self._fs_cache = fs + self._pad_length = self._compute_pad_length(fs) + self._zi_tiled = None # Invalidate; recomputed on next use. self.state.needs_redesign = False + # Initialize buffer with duration based on pad_length + # Add some margin to handle variable chunk sizes + buffer_duration = (self._pad_length + 1) / fs + self._buffer = HybridAxisArrayBuffer(duration=buffer_duration, axis=axis) + + # Early exit if filter is effectively disabled if self._coefs_cache is None or self.settings.order <= 0 or message.data.size <= 0: return message - x = message.data - if self.settings.coef_type == "sos": - y = scipy.signal.sosfiltfilt( - self._coefs_cache, - x, - axis=ax_idx, - padtype=self.settings.padtype, - padlen=self.settings.padlen, - ) - elif self.settings.coef_type == "ba": + # Write new data to buffer + self._buffer.write(message) + n_available = self._buffer.available() + n_output = n_available - self._pad_length + + # If we don't have enough data yet, return empty + if n_output <= 0: + new_shape = list(message.data.shape) + new_shape[ax_idx] = 0 + empty_data = np.empty(new_shape, dtype=message.data.dtype) + return replace(message, data=empty_data) + + # Peek all available data from buffer + # Note: HybridAxisArrayBuffer moves the target axis to position 0 + buffered = self._buffer.peek(n_available) + combined = buffered.data + buffer_ax_idx = 0 # Buffer always puts time axis at position 0 + + # Backward filter on reversed data + combined_rev = np.flip(combined, axis=buffer_ax_idx) + backward_zi = self._initialize_zi(combined_rev, buffer_ax_idx) + + if self.settings.coef_type == "ba": b, a = self._coefs_cache - y = scipy.signal.filtfilt( - b, - a, - x, - axis=ax_idx, - padtype=self.settings.padtype, - padlen=self.settings.padlen, - ) - else: - ez.logger.error("coef_type must be 'sos' or 'ba'.") - raise ValueError("coef_type must be 'sos' or 'ba'.") + y_bwd_rev, _ = scipy.signal.lfilter(b, a, combined_rev, axis=buffer_ax_idx, zi=backward_zi) + else: # sos + y_bwd_rev, _ = scipy.signal.sosfilt(self._coefs_cache, combined_rev, axis=buffer_ax_idx, zi=backward_zi) + + # Reverse back to get output in correct time order + y_bwd = np.flip(y_bwd_rev, axis=buffer_ax_idx) + + # Output the settled portion (first n_output samples) + y = y_bwd[:n_output] + + # Advance buffer read head to discard output samples, keep pad_length + self._buffer.seek(n_output) + + # Build output with adjusted time axis + # LinearAxis offset is already correct from the buffer + out_axis = buffered.axes[axis] + + # Move axis back to original position if needed + if ax_idx != 0: + y = np.moveaxis(y, 0, ax_idx) + + return replace( + message, + data=y, + axes={**message.axes, axis: out_axis}, + ) + + +class ButterworthZeroPhaseTransformer(CompositeProcessor[ButterworthZeroPhaseSettings, AxisArray, AxisArray]): + """ + Streaming zero-phase Butterworth filter as a composite of two stages. + + Stage 1 (forward): Standard causal Butterworth filter with state + Stage 2 (backward): Acausal Butterworth filter with buffering + + The output is delayed by ``pad_length`` samples. + """ + + @staticmethod + def _initialize_processors( + settings: ButterworthZeroPhaseSettings, + ) -> dict[str, typing.Any]: + # Both stages use the same filter design settings + return { + "forward": ButterworthFilterTransformer(settings), + "backward": ButterworthBackwardFilterTransformer(settings), + } - return replace(message, data=y) + @classmethod + def get_message_type(cls, dir: str) -> type[AxisArray]: + if dir in ("in", "out"): + return AxisArray + raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.") class ButterworthZeroPhase( - BaseFilterByDesignTransformerUnit[ButterworthZeroPhaseSettings, ButterworthZeroPhaseTransformer] + BaseTransformerUnit[ButterworthZeroPhaseSettings, AxisArray, AxisArray, ButterworthZeroPhaseTransformer] ): SETTINGS = ButterworthZeroPhaseSettings diff --git a/tests/integration/ezmsg/test_butterworthzerophase_system.py b/tests/integration/ezmsg/test_butterworthzerophase_system.py index 2dbbbefb..43667495 100644 --- a/tests/integration/ezmsg/test_butterworthzerophase_system.py +++ b/tests/integration/ezmsg/test_butterworthzerophase_system.py @@ -2,12 +2,17 @@ from pathlib import Path import ezmsg.core as ez +import numpy as np from ezmsg.util.messagecodec import message_log from ezmsg.util.messagelogger import MessageLogger from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.terminate import TerminateOnTotal -from ezmsg.sigproc.butterworthzerophase import ButterworthZeroPhase +from ezmsg.sigproc.butterworthzerophase import ( + ButterworthBackwardFilterTransformer, + ButterworthZeroPhase, + ButterworthZeroPhaseSettings, +) from tests.helpers.synth import EEGSynth @@ -16,6 +21,16 @@ def test_butterworth_zero_phase_system(): n_time = 50 n_total = 10 n_channels = 96 + order = 4 + cuton = 30.0 + cutoff = 45.0 + coef_type = "sos" + + # Compute expected pad_length for this filter configuration + settings = ButterworthZeroPhaseSettings(order=order, cuton=cuton, cutoff=cutoff, coef_type=coef_type) + backward = ButterworthBackwardFilterTransformer(settings) + pad_length = backward._compute_pad_length(fs) + test_filename = Path(tempfile.gettempdir()) test_filename = test_filename / Path("test_butterworth_zero_phase_system.txt") with open(test_filename, "w"): @@ -25,10 +40,10 @@ def test_butterworth_zero_phase_system(): comps = { "SRC": EEGSynth(n_time=n_time, fs=fs, n_ch=n_channels, alpha_freq=10.0), "BUTTER": ButterworthZeroPhase( - order=4, - cuton=30.0, - cutoff=45.0, - coef_type="sos", + order=order, + cuton=cuton, + cutoff=cutoff, + coef_type=coef_type, ), "LOG": MessageLogger(output=test_filename), "TERM": TerminateOnTotal(total=n_total), @@ -45,7 +60,20 @@ def test_butterworth_zero_phase_system(): messages = list(message_log(test_filename)) assert len(messages) >= n_total + total_output_samples = 0 for msg in messages: assert isinstance(msg, AxisArray) + # Non-time dimensions must always be preserved, even during warmup assert msg.data.shape[1] == n_channels - assert msg.data.shape[0] == n_time + # Time dimension may be 0 during warmup while buffering for backward pass + assert msg.data.shape[0] >= 0 + # Data should be finite (for non-empty messages) + if msg.data.size > 0: + assert np.isfinite(msg.data).all() + total_output_samples += msg.data.shape[0] + + # Total output should be input samples minus pad_length delay + total_input_samples = n_total * n_time + expected_output = total_input_samples - pad_length + # Allow some tolerance since we may have received slightly more than n_total messages + assert total_output_samples >= expected_output diff --git a/tests/unit/test_butterworthzerophase.py b/tests/unit/test_butterworthzerophase.py index 5d0d4714..4f17df01 100644 --- a/tests/unit/test_butterworthzerophase.py +++ b/tests/unit/test_butterworthzerophase.py @@ -5,11 +5,34 @@ from frozendict import frozendict from ezmsg.sigproc.butterworthzerophase import ( + ButterworthBackwardFilterTransformer, ButterworthZeroPhaseSettings, ButterworthZeroPhaseTransformer, ) +def _compute_pad_length( + order: int, + coef_type: str, + fs: float, + cutoff: float | None = None, + cuton: float | None = None, + settle_cutoff: float = 0.01, + max_pad_duration: float | None = None, +) -> int: + """Helper to compute expected pad_length using impulse response settling.""" + settings = ButterworthZeroPhaseSettings( + order=order, + coef_type=coef_type, + cutoff=cutoff, + cuton=cuton, + settle_cutoff=settle_cutoff, + max_pad_duration=max_pad_duration, + ) + backward = ButterworthBackwardFilterTransformer(settings) + return backward._compute_pad_length(fs) + + @pytest.mark.parametrize( "cutoff, cuton", [ @@ -21,7 +44,7 @@ ) @pytest.mark.parametrize("order", [2, 4, 8]) def test_butterworth_zp_filter_specs(cutoff, cuton, order): - """Zero-phase settings inherit filter_specs logic from legacy Butterworth settings.""" + """Zero-phase settings inherit filter_specs logic from ButterworthFilterSettings.""" btype, Wn = ButterworthZeroPhaseSettings(order=order, cuton=cuton, cutoff=cutoff).filter_specs() if cuton is None: assert btype == "lowpass" and Wn == cutoff @@ -33,82 +56,284 @@ def test_butterworth_zp_filter_specs(cutoff, cuton, order): assert btype == "bandstop" and Wn == (cutoff, cuton) +@pytest.mark.parametrize("order", [2, 4, 8]) +@pytest.mark.parametrize("coef_type", ["ba", "sos"]) +def test_pad_length_computation(order, coef_type): + """Verify pad_length is computed based on impulse response settling.""" + fs = 1000.0 # Use a moderate fs for this test + settings = ButterworthZeroPhaseSettings(order=order, cutoff=30.0, coef_type=coef_type) + backward = ButterworthBackwardFilterTransformer(settings) + pad_length = backward._compute_pad_length(fs) + + # Verify pad_length is at least the scipy heuristic minimum + if coef_type == "ba": + min_length = 3 * (order + 1) + else: + n_sections = (order + 1) // 2 + min_length = 3 * n_sections * 2 + assert pad_length >= min_length + + # Verify pad_length is reasonable (not excessively large for this filter) + # At fs=1000, cutoff=30 Hz gives normalized freq = 0.06, should settle quickly + assert pad_length < 500 # Sanity check + + +def test_settle_cutoff_affects_pad_length(): + """Larger settle_cutoff should result in shorter pad_length.""" + fs = 1000.0 + order = 4 + cutoff = 30.0 + + # Default settle_cutoff = 0.01 + settings_default = ButterworthZeroPhaseSettings(order=order, cutoff=cutoff, settle_cutoff=0.01) + backward_default = ButterworthBackwardFilterTransformer(settings_default) + pad_default = backward_default._compute_pad_length(fs) + + # Larger settle_cutoff = 0.1 (10% of peak instead of 1%) + settings_larger = ButterworthZeroPhaseSettings(order=order, cutoff=cutoff, settle_cutoff=0.1) + backward_larger = ButterworthBackwardFilterTransformer(settings_larger) + pad_larger = backward_larger._compute_pad_length(fs) + + # Smaller settle_cutoff = 0.001 (0.1% of peak) + settings_smaller = ButterworthZeroPhaseSettings(order=order, cutoff=cutoff, settle_cutoff=0.001) + backward_smaller = ButterworthBackwardFilterTransformer(settings_smaller) + pad_smaller = backward_smaller._compute_pad_length(fs) + + # Larger cutoff threshold should give shorter pad length + assert pad_larger < pad_default + # Smaller cutoff threshold should give longer pad length + assert pad_smaller > pad_default + + +def test_max_pad_duration_caps_pad_length(): + """max_pad_duration should cap the pad_length.""" + fs = 1000.0 + order = 4 + cutoff = 1.0 # 1 Hz lowpass - very long impulse response (~2292 samples) + + # Without cap + settings_uncapped = ButterworthZeroPhaseSettings(order=order, cutoff=cutoff) + backward_uncapped = ButterworthBackwardFilterTransformer(settings_uncapped) + pad_uncapped = backward_uncapped._compute_pad_length(fs) + + # With cap of 0.5 seconds = 500 samples at 1 kHz + max_duration = 0.5 + settings_capped = ButterworthZeroPhaseSettings(order=order, cutoff=cutoff, max_pad_duration=max_duration) + backward_capped = ButterworthBackwardFilterTransformer(settings_capped) + pad_capped = backward_capped._compute_pad_length(fs) + + expected_max = int(max_duration * fs) + + # Uncapped should be longer than the cap + assert pad_uncapped > expected_max, f"Expected uncapped {pad_uncapped} > {expected_max}" + # Capped should be at most the expected max + assert pad_capped <= expected_max + # Capped should be exactly the max (since uncapped exceeds it) + assert pad_capped == expected_max + + +def test_max_pad_duration_no_effect_when_not_limiting(): + """max_pad_duration should have no effect when pad_length is already shorter.""" + fs = 1000.0 + order = 4 + cutoff = 100.0 # Higher cutoff = faster settling + + # Get natural pad length (should be short) + settings_natural = ButterworthZeroPhaseSettings(order=order, cutoff=cutoff) + backward_natural = ButterworthBackwardFilterTransformer(settings_natural) + pad_natural = backward_natural._compute_pad_length(fs) + + # With generous cap of 1 second = 1000 samples + settings_with_cap = ButterworthZeroPhaseSettings(order=order, cutoff=cutoff, max_pad_duration=1.0) + backward_with_cap = ButterworthBackwardFilterTransformer(settings_with_cap) + pad_with_cap = backward_with_cap._compute_pad_length(fs) + + # Both should be equal since natural pad is well under 1 second + assert pad_natural == pad_with_cap + + +def _make_message(data, dims, fs, time_axis_name="time"): + """Helper to create AxisArray messages with frozendict axes to detect mutation.""" + axes = {} + for i, dim in enumerate(dims): + if dim == time_axis_name: + axes[dim] = AxisArray.TimeAxis(fs=fs, offset=0.0) + elif dim == "ch": + axes[dim] = AxisArray.CoordinateAxis(data=np.arange(data.shape[i]).astype(str), dims=[dim]) + else: + axes[dim] = AxisArray.LinearAxis(unit="", offset=0.0, gain=1.0) + return AxisArray(data=data, dims=dims, axes=frozendict(axes), key="test") + + @pytest.mark.parametrize( "cutoff, cuton", [ - (30.0, None), # lowpass - (None, 30.0), # highpass - (45.0, 30.0), # bandpass - (30.0, 45.0), # bandstop + (500.0, None), # lowpass + (None, 250.0), # highpass + (7500.0, 300.0), # bandpass + (3000.0, 6000.0), # bandstop ], ) -@pytest.mark.parametrize("order", [0, 2, 4]) -@pytest.mark.parametrize("fs", [200.0]) -@pytest.mark.parametrize("n_chans", [3]) -@pytest.mark.parametrize("n_dims, time_ax", [(1, 0), (3, 0), (3, 1), (3, 2)]) +@pytest.mark.parametrize("order", [4, 8]) @pytest.mark.parametrize("coef_type", ["ba", "sos"]) -@pytest.mark.parametrize("padtype,padlen", [(None, 0), ("odd", None)]) -def test_butterworth_zero_phase_matches_scipy( - cutoff, cuton, order, fs, n_chans, n_dims, time_ax, coef_type, padtype, padlen -): - dur = 2.0 - n_times = int(dur * fs) +def test_single_chunk_matches_reference(cutoff, cuton, order, coef_type): + """ + Single large chunk output should be highly correlated with scipy filtfilt. - if n_dims == 1: - dat_shape = [n_times] - dims = ["time"] - other_axes = {} - else: - dat_shape = [5, n_chans] - dat_shape.insert(time_ax, n_times) - dims = ["freq", "ch"] - dims.insert(time_ax, "time") - other_axes = { - "freq": AxisArray.LinearAxis(unit="Hz", offset=0.0, gain=1.0), - "ch": AxisArray.CoordinateAxis(data=np.arange(n_chans).astype(str), dims=["ch"]), - } + Note: The streaming implementation initializes zi differently than scipy's + filtfilt, so exact numerical match is not expected. Instead, we verify: + 1. Output shape is correct + 2. Output is highly correlated with reference + 3. Values are all finite + """ + fs = 30000.0 + n_times = int(2.0 * fs) + rng = np.random.default_rng(42) + x = rng.standard_normal((n_times, 3)) - x = np.linspace(0, 1, np.prod(dat_shape), dtype=float).reshape(*dat_shape) + msg = _make_message(x, ["time", "ch"], fs) - msg = AxisArray( - data=x, - dims=dims, - axes=frozendict({**other_axes, "time": AxisArray.TimeAxis(fs=fs, offset=0.0)}), - key="test_butterworth_zero_phase", + transformer = ButterworthZeroPhaseTransformer( + ButterworthZeroPhaseSettings( + axis="time", + order=order, + cuton=cuton, + cutoff=cutoff, + coef_type=coef_type, + ) ) - # expected via SciPy + result = transformer(msg) + pad_length = _compute_pad_length(order, coef_type, fs, cutoff=cutoff, cuton=cuton) + + # Output should be n_times - pad_length samples + assert result.data.shape[0] == n_times - pad_length + assert result.data.shape[1] == 3 + + # All values should be finite + assert np.isfinite(result.data).all() + + # Compute reference using scipy filtfilt on the same data (no padding) btype, Wn = ButterworthZeroPhaseSettings(order=order, cuton=cuton, cutoff=cutoff).filter_specs() - if order == 0: - expected = x + if coef_type == "ba": + b, a = scipy.signal.butter(order, Wn, btype=btype, fs=fs, output="ba") + ref = scipy.signal.filtfilt(b, a, x, axis=0, padtype=None, padlen=0) else: - tmp = np.moveaxis(x, time_ax, -1) - if coef_type == "ba": - b, a = scipy.signal.butter(order, Wn, btype=btype, fs=fs, output="ba") - y = scipy.signal.filtfilt(b, a, tmp, axis=-1, padtype=padtype, padlen=padlen) - else: - sos = scipy.signal.butter(order, Wn, btype=btype, fs=fs, output="sos") - y = scipy.signal.sosfiltfilt(sos, tmp, axis=-1, padtype=padtype, padlen=padlen) - expected = np.moveaxis(y, -1, time_ax) + sos = scipy.signal.butter(order, Wn, btype=btype, fs=fs, output="sos") + ref = scipy.signal.sosfiltfilt(sos, x, axis=0, padtype=None, padlen=0) - axis_name = "time" if time_ax != 0 else None - zp = ButterworthZeroPhaseTransformer( - axis=axis_name, - order=order, - cuton=cuton, - cutoff=cutoff, - coef_type=coef_type, - wn_hz=True, - padtype=padtype, - padlen=padlen, + # Trim reference to match output length + ref_trimmed = ref[: n_times - pad_length] + + # Outputs should be highly correlated (r > 0.99) + for ch in range(result.data.shape[1]): + # Note: Skip first pad_length samples to account for differences in initialization + r = np.corrcoef(result.data[pad_length:, ch], ref_trimmed[pad_length:, ch])[0, 1] + assert r > 0.999, f"Correlation {r} too low for channel {ch}" + + +@pytest.mark.parametrize("order", [2, 4]) +@pytest.mark.parametrize("coef_type", ["ba", "sos"]) +def test_streaming_chunked_processing(order, coef_type): + """ + Verify streaming chunked processing produces valid output. + + Note: Chunked processing won't produce exactly identical results to + single-chunk processing because the backward filter sees different amounts + of future context at chunk boundaries. This test verifies: + 1. Output shape is correct + 2. All values are finite + 3. Output is reasonably correlated with single-chunk output + """ + fs = 30000.0 + cuton = 300.0 + cutoff = None + n_times = int(2.0 * fs) + chunk_size = 48 + rng = np.random.default_rng(42) + x = rng.standard_normal((n_times, 3)) + + # Single chunk processing (reference) + single_transformer = ButterworthZeroPhaseTransformer( + ButterworthZeroPhaseSettings(axis="time", order=order, cuton=cuton, cutoff=cutoff, coef_type=coef_type) ) + single_msg = _make_message(x, ["time", "ch"], fs) + single_result = single_transformer(single_msg) - out = zp(msg).data - assert np.allclose(out, expected, atol=1e-10, rtol=1e-7) + # Chunked processing + chunked_transformer = ButterworthZeroPhaseTransformer( + ButterworthZeroPhaseSettings(axis="time", order=order, cuton=cuton, cutoff=cutoff, coef_type=coef_type) + ) + + outputs = [] + for i in range(0, n_times, chunk_size): + chunk = x[i : i + chunk_size] + msg = _make_message(chunk, ["time", "ch"], fs) + result = chunked_transformer(msg) + if result.data.size > 0: + outputs.append(result.data) + + chunked_output = np.concatenate(outputs, axis=0) + + # Both should have the same output length + assert chunked_output.shape == single_result.data.shape + # All values should be finite + assert np.isfinite(chunked_output).all() -def test_butterworth_zero_phase_empty_msg(): - zp = ButterworthZeroPhaseTransformer(axis="time", order=4, cuton=0.1, cutoff=10.0, coef_type="sos") + # Outputs should be highly correlated (r > 0.98) + # Note: Correlation may be slightly lower for low-order filters with small chunks + for ch in range(chunked_output.shape[1]): + r = np.corrcoef(chunked_output[:, ch], single_result.data[:, ch])[0, 1] + assert r > 0.98, f"Correlation {r} too low for channel {ch}" + + +@pytest.mark.parametrize("order", [2, 4]) +@pytest.mark.parametrize("coef_type", ["ba", "sos"]) +def test_warmup_returns_empty(order, coef_type): + """During warmup (< pad_length samples), output should be empty.""" + fs = 200.0 + cutoff = 30.0 + pad_length = _compute_pad_length(order, coef_type, fs, cutoff=cutoff) + + transformer = ButterworthZeroPhaseTransformer( + ButterworthZeroPhaseSettings(axis="time", order=order, cutoff=cutoff, coef_type=coef_type) + ) + + # Send chunks smaller than pad_length + chunk_size = max(1, pad_length // 3) + rng = np.random.default_rng(42) + x = rng.standard_normal((chunk_size, 2)) + msg = _make_message(x, ["time", "ch"], fs) + + # First chunk - should return empty but preserve non-time dimensions + result1 = transformer(msg) + assert result1.data.shape[0] == 0 + assert result1.data.shape[1] == 2 # Channel dimension preserved + + # Second chunk - still in warmup, non-time dimensions still preserved + result2 = transformer(msg) + assert result2.data.shape[0] == 0 + assert result2.data.shape[1] == 2 # Channel dimension preserved + + # After enough chunks, should start outputting + total_sent = 2 * chunk_size + while total_sent <= pad_length: + result = transformer(msg) + total_sent += chunk_size + if total_sent <= pad_length: + assert result.data.shape[0] == 0 + + # Next chunk should produce output + result = transformer(msg) + assert result.data.shape[0] > 0 + + +def test_empty_message(): + """Empty input should return empty output.""" + transformer = ButterworthZeroPhaseTransformer( + ButterworthZeroPhaseSettings(axis="time", order=4, cutoff=30.0, coef_type="sos") + ) msg = AxisArray( data=np.zeros((0, 2)), dims=["time", "ch"], @@ -116,49 +341,205 @@ def test_butterworth_zero_phase_empty_msg(): "time": AxisArray.TimeAxis(fs=100.0, offset=0.0), "ch": AxisArray.CoordinateAxis(data=np.array(["0", "1"]), dims=["ch"]), }, - key="test_butterworth_zero_phase_empty", + key="empty", ) - res = zp(msg) - assert res.data.size == 0 + result = transformer(msg) + assert result.data.size == 0 -def test_butterworth_zero_phase_update_settings_changes_output(): +def test_order_zero_passthrough(): + """Order 0 should pass data through unchanged and return same object.""" fs = 200.0 - t = np.arange(int(2.0 * fs)) / fs - x = np.vstack([np.sin(2 * np.pi * 10 * t), np.sin(2 * np.pi * 40 * t)]).T + rng = np.random.default_rng(42) + x = rng.standard_normal((100, 3)) - msg = AxisArray( - data=x, - dims=["time", "ch"], - axes={ - "time": AxisArray.TimeAxis(fs=fs, offset=0.0), - "ch": AxisArray.CoordinateAxis(data=np.array(["0", "1"]), dims=["ch"]), - }, - key="test_butterworth_zero_phase_update", + transformer = ButterworthZeroPhaseTransformer(ButterworthZeroPhaseSettings(axis="time", order=0, cutoff=30.0)) + msg = _make_message(x, ["time", "ch"], fs) + result = transformer(msg) + + assert result is msg + + +@pytest.mark.parametrize("coef_type", ["ba", "sos"]) +def test_zero_phase_property(coef_type): + """ + Zero-phase filter should not introduce phase delay for sinusoids + in the passband. + """ + fs = 1000.0 + duration = 2.0 + n_times = int(duration * fs) + t = np.arange(n_times) / fs + f0 = 20.0 # Test frequency in passband + + # Pure sinusoid + x = np.sin(2 * np.pi * f0 * t) + + order = 4 + cuton = 5.0 + cutoff = 50.0 + transformer = ButterworthZeroPhaseTransformer( + ButterworthZeroPhaseSettings( + axis="time", + order=order, + cuton=cuton, + cutoff=cutoff, # Bandpass 5-50 Hz + coef_type=coef_type, + ) ) + pad_length = _compute_pad_length(order, coef_type, fs, cutoff=cutoff, cuton=cuton) + + msg = _make_message(x.reshape(-1, 1), ["time", "ch"], fs) + result = transformer(msg) + y = result.data.flatten() + + # Compare in the interior region where both signals are valid + edge = 100 + n_output = n_times - pad_length + xi = x[edge : n_output - edge] + yi = y[edge : n_output - edge] + + # Cross-correlation to find lag + corr = np.correlate(yi, xi, mode="full") + lag = np.argmax(corr) - (len(xi) - 1) + + # Zero-phase means zero lag (within 1 sample tolerance) + assert abs(lag) <= 1 + + +@pytest.mark.parametrize("n_dims, time_ax", [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]) +@pytest.mark.parametrize("coef_type", ["ba", "sos"]) +def test_different_axis_positions(n_dims, time_ax, coef_type): + """Filter should work correctly with time axis in different positions.""" + fs = 200.0 + n_times = 200 + order = 4 + rng = np.random.default_rng(42) + + if n_dims == 1: + shape = [n_times] + dims = ["time"] + axis_name = None + elif n_dims == 2: + shape = [3, 5] + shape[time_ax] = n_times + dims = ["ch", "freq"] + dims[time_ax] = "time" + axis_name = "time" + else: + shape = [3, 5, 7] + shape[time_ax] = n_times + dims = ["ch", "freq", "other"] + dims[time_ax] = "time" + axis_name = "time" + + x = rng.standard_normal(shape) + msg = _make_message(x, dims, fs) + + cutoff = 30.0 + transformer = ButterworthZeroPhaseTransformer( + ButterworthZeroPhaseSettings(axis=axis_name, order=order, cutoff=cutoff, coef_type=coef_type) + ) + pad_length = _compute_pad_length(order, coef_type, fs, cutoff=cutoff) + + result = transformer(msg) + + # Check output shape + expected_shape = list(shape) + expected_shape[time_ax] = n_times - pad_length + assert list(result.data.shape) == expected_shape + + # Check output is finite + assert np.isfinite(result.data).all() + + +def test_offset_accumulates_correctly(): + """Time axis offset should be handled correctly across chunks.""" + fs = 100.0 + chunk_size = 50 + n_chunks = 5 + order = 2 + coef_type = "ba" + cutoff = 20.0 + pad_length = _compute_pad_length(order, coef_type, fs, cutoff=cutoff) + + transformer = ButterworthZeroPhaseTransformer( + ButterworthZeroPhaseSettings(axis="time", order=order, cutoff=cutoff, coef_type=coef_type) + ) + + rng = np.random.default_rng(42) + total_output_samples = 0 + + for i in range(n_chunks): + x = rng.standard_normal((chunk_size, 2)) + msg = AxisArray( + data=x, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=fs, offset=i * chunk_size / fs), + "ch": AxisArray.CoordinateAxis(data=np.array(["0", "1"]), dims=["ch"]), + }, + key="test", + ) + result = transformer(msg) + + if result.data.shape[0] > 0: + total_output_samples += result.data.shape[0] + + # After all chunks, total output should be total_input - pad_length + expected_output = n_chunks * chunk_size - pad_length + assert total_output_samples == expected_output + + +@pytest.mark.parametrize("coef_type", ["ba", "sos"]) +def test_filter_actually_filters(coef_type): + """Verify the filter actually attenuates out-of-band frequencies.""" + fs = 1000.0 + duration = 2.0 + n_times = int(duration * fs) + t = np.arange(n_times) / fs + + # Signal: 10 Hz (in passband) + 200 Hz (out of passband for LP at 50 Hz) + x = np.sin(2 * np.pi * 10 * t) + np.sin(2 * np.pi * 200 * t) + + transformer = ButterworthZeroPhaseTransformer( + ButterworthZeroPhaseSettings(axis="time", order=4, cutoff=50.0, coef_type=coef_type) + ) + + msg = _make_message(x.reshape(-1, 1), ["time", "ch"], fs) + result = transformer(msg) + y = result.data.flatten() + + # Check power spectrum + fft_in = np.abs(np.fft.rfft(x)) + fft_out = np.abs(np.fft.rfft(y)) + freqs = np.fft.rfftfreq(len(x), 1 / fs) + freqs_out = np.fft.rfftfreq(len(y), 1 / fs) + + idx_10_in = np.argmin(np.abs(freqs - 10)) + idx_200_in = np.argmin(np.abs(freqs - 200)) + idx_10_out = np.argmin(np.abs(freqs_out - 10)) + idx_200_out = np.argmin(np.abs(freqs_out - 200)) + + # 10 Hz should be mostly preserved + assert fft_out[idx_10_out] > 0.5 * fft_in[idx_10_in] + + # 200 Hz should be heavily attenuated (order 4 = 80 dB/decade) + assert fft_out[idx_200_out] < 0.01 * fft_in[idx_200_in] + + +def test_composite_structure(): + """Verify the composite processor has the expected structure.""" + transformer = ButterworthZeroPhaseTransformer(ButterworthZeroPhaseSettings(axis="time", order=4, cutoff=30.0)) + + # Should have forward and backward processors + assert "forward" in transformer._procs + assert "backward" in transformer._procs + + # Forward should be ButterworthFilterTransformer + from ezmsg.sigproc.butterworthfilter import ButterworthFilterTransformer + + assert isinstance(transformer._procs["forward"], ButterworthFilterTransformer) - zp = ButterworthZeroPhaseTransformer(axis="time", order=4, cutoff=30.0, coef_type="sos", padtype="odd", padlen=None) - y1 = zp(msg).data - # LP at 30 should pass 10 Hz and attenuate 40 Hz - p_in = np.abs(np.fft.rfft(x, axis=0)) ** 2 - p1 = np.abs(np.fft.rfft(y1, axis=0)) ** 2 - f = np.fft.rfftfreq(x.shape[0], 1 / fs) - - def peak_power(power, f0): - return power[np.argmin(np.abs(f - f0))] - - assert peak_power(p1[:, 0], 10.0) > 0.7 * peak_power(p_in[:, 0], 10.0) - assert peak_power(p1[:, 1], 40.0) < 0.3 * peak_power(p_in[:, 1], 40.0) - - # Switch to HP at 25 Hz - zp.update_settings(cutoff=None, cuton=25.0) - y2 = zp(msg).data - p2 = np.abs(np.fft.rfft(y2, axis=0)) ** 2 - assert peak_power(p2[:, 0], 10.0) < 0.3 * peak_power(p_in[:, 0], 10.0) - assert peak_power(p2[:, 1], 40.0) > 0.7 * peak_power(p_in[:, 1], 40.0) - - zp.update_settings(coef_type="ba", order=2, cutoff=15.0, cuton=None) - y3 = zp(msg).data - # attenuate 40 more than 10 - p3 = np.abs(np.fft.rfft(y3, axis=0)) ** 2 - assert peak_power(p3[:, 1], 40.0) < peak_power(p3[:, 0], 10.0) + # Backward should be ButterworthBackwardFilterTransformer + assert isinstance(transformer._procs["backward"], ButterworthBackwardFilterTransformer)