diff --git a/pyproject.toml b/pyproject.toml index ab4bea9a..93ef193d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ authors = [ { name = "Griffin Milsap", email = "griffin.milsap@gmail.com" }, { name = "Preston Peranich", email = "pperanich@gmail.com" }, { name = "Chadwick Boulay", email = "chadwick.boulay@gmail.com" }, + { name = "Kyle McGraw", email = "kmcgraw@blackrockneuro.com" }, ] license = "MIT" readme = "README.md" diff --git a/src/ezmsg/sigproc/butterworthzerophase.py b/src/ezmsg/sigproc/butterworthzerophase.py new file mode 100644 index 00000000..711efbbf --- /dev/null +++ b/src/ezmsg/sigproc/butterworthzerophase.py @@ -0,0 +1,164 @@ +import functools +import typing + +import ezmsg.core as ez +import numpy as np +import scipy.signal +from ezmsg.sigproc.base import SettingsType +from ezmsg.sigproc.butterworthfilter import ButterworthFilterSettings, butter_design_fun +from ezmsg.sigproc.filter import ( + BACoeffs, + BaseFilterByDesignTransformerUnit, + FilterByDesignTransformer, + SOSCoeffs, +) +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace + + +class ButterworthZeroPhaseSettings(ButterworthFilterSettings): + """Settings for :obj:`ButterworthZeroPhase`.""" + + # axis, coef_type, order, cuton, cutoff, wn_hz are inherited from ButterworthFilterSettings + padtype: str | None = None + """ + Padding type to use in `scipy.signal.filtfilt`. + Must be one of {'odd', 'even', 'constant', None}. + Default is None for no padding. + """ + padlen: int | None = 0 + """ + Length of the padding to use in `scipy.signal.filtfilt`. + If None, SciPy's default padding is used. + """ + + +class ButterworthZeroPhaseTransformer( + FilterByDesignTransformer[ButterworthZeroPhaseSettings, BACoeffs | SOSCoeffs] +): + """Zero-phase (filtfilt) Butterworth using your design function.""" + + def get_design_function( + self, + ) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]: + return functools.partial( + butter_design_fun, + order=self.settings.order, + cuton=self.settings.cuton, + cutoff=self.settings.cutoff, + coef_type=self.settings.coef_type, + wn_hz=self.settings.wn_hz, + ) + + def update_settings( + self, new_settings: typing.Optional[SettingsType] = None, **kwargs + ) -> None: + """ + Update settings and mark that filter coefficients need to be recalculated. + + Args: + new_settings: Complete new settings object to replace current settings + **kwargs: Individual settings to update + """ + # Update settings + if new_settings is not None: + self.settings = new_settings + else: + self.settings = replace(self.settings, **kwargs) + + # Set flag to trigger recalculation on next message + self._coefs_cache = None + self._fs_cache = None + self.state.needs_redesign = True + + def _reset_state(self, message: AxisArray) -> None: + self._coefs_cache = None + self._fs_cache = None + self.state.needs_redesign = True + + def _process(self, message: AxisArray) -> AxisArray: + axis = message.dims[0] if self.settings.axis is None else self.settings.axis + ax_idx = message.get_axis_idx(axis) + fs = 1 / message.axes[axis].gain + + if ( + self._coefs_cache is None + or self.state.needs_redesign + or (self._fs_cache is None or not np.isclose(self._fs_cache, fs)) + ): + self._coefs_cache = self.get_design_function()(fs) + self._fs_cache = fs + self.state.needs_redesign = False + + if ( + self._coefs_cache is None + or self.settings.order <= 0 + or message.data.size <= 0 + ): + return message + + x = message.data + if self.settings.coef_type == "sos": + y = scipy.signal.sosfiltfilt( + self._coefs_cache, + x, + axis=ax_idx, + padtype=self.settings.padtype, + padlen=self.settings.padlen, + ) + elif self.settings.coef_type == "ba": + b, a = self._coefs_cache + y = scipy.signal.filtfilt( + b, + a, + x, + axis=ax_idx, + padtype=self.settings.padtype, + padlen=self.settings.padlen, + ) + else: + ez.logger.error("coef_type must be 'sos' or 'ba'.") + raise ValueError("coef_type must be 'sos' or 'ba'.") + + return replace(message, data=y) + + +class ButterworthZeroPhase( + BaseFilterByDesignTransformerUnit[ + ButterworthZeroPhaseSettings, ButterworthZeroPhaseTransformer + ] +): + SETTINGS = ButterworthZeroPhaseSettings + + +def butter_zero_phase( + axis: str | None, + order: int = 0, + cuton: float | None = None, + cutoff: float | None = None, + coef_type: str = "ba", + wn_hz: bool = True, + padtype: str | None = None, + padlen: int | None = 0, +) -> ButterworthZeroPhaseTransformer: + """ + Convenience generator wrapping filter_gen_by_design for Butterworth Zero Phase filters. + Apply Butterworth Zero Phase filter to streaming data. Uses :obj:`scipy.signal.butter` to design the filter. + See :obj:`ButterworthZeroPhaseSettings.filter_specs` for an explanation of specifying different + filter types (lowpass, highpass, bandpass, bandstop) from the parameters. + + Returns: + :obj:`ButterworthZeroPhaseTransformer` + """ + return ButterworthZeroPhaseTransformer( + ButterworthZeroPhaseSettings( + axis=axis, + order=order, + cuton=cuton, + cutoff=cutoff, + coef_type=coef_type, + wn_hz=wn_hz, + padtype=padtype, + padlen=padlen, + ) + ) diff --git a/tests/integration/ezmsg/test_butterworthzerophase_system.py b/tests/integration/ezmsg/test_butterworthzerophase_system.py new file mode 100644 index 00000000..16b6690e --- /dev/null +++ b/tests/integration/ezmsg/test_butterworthzerophase_system.py @@ -0,0 +1,51 @@ +import tempfile +from pathlib import Path + +import ezmsg.core as ez +from ezmsg.sigproc.synth import EEGSynth +from ezmsg.util.messagecodec import message_log +from ezmsg.util.messagelogger import MessageLogger +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.terminate import TerminateOnTotal + +from ezmsg.sigproc.butterworthzerophase import ButterworthZeroPhase + + +def test_butterworth_zero_phase_system(): + fs = 1000.0 + n_time = 50 + n_total = 10 + n_channels = 96 + test_filename = Path(tempfile.gettempdir()) + test_filename = test_filename / Path("test_butterworth_zero_phase_system.txt") + with open(test_filename, "w"): + pass + ez.logger.info(f"Logging to {test_filename}") + + comps = { + "SRC": EEGSynth(n_time=n_time, fs=fs, n_ch=n_channels, alpha_freq=10.0), + "BUTTER": ButterworthZeroPhase( + order=4, + cuton=30.0, + cutoff=45.0, + coef_type="sos", + ), + "LOG": MessageLogger(output=test_filename), + "TERM": TerminateOnTotal(total=n_total), + } + + conns = ( + (comps["SRC"].OUTPUT_SIGNAL, comps["BUTTER"].INPUT_SIGNAL), + (comps["BUTTER"].OUTPUT_SIGNAL, comps["LOG"].INPUT_MESSAGE), + (comps["LOG"].OUTPUT_MESSAGE, comps["TERM"].INPUT_MESSAGE), + ) + + ez.run(components=comps, connections=conns) + + messages = list(message_log(test_filename)) + assert len(messages) >= n_total + + for msg in messages: + assert isinstance(msg, AxisArray) + assert msg.data.shape[1] == n_channels + assert msg.data.shape[0] == n_time diff --git a/tests/unit/test_butterworthzerophase.py b/tests/unit/test_butterworthzerophase.py new file mode 100644 index 00000000..afb2ef64 --- /dev/null +++ b/tests/unit/test_butterworthzerophase.py @@ -0,0 +1,178 @@ +import numpy as np +import pytest +import scipy.signal +from ezmsg.util.messages.axisarray import AxisArray +from frozendict import frozendict + +from ezmsg.sigproc.butterworthzerophase import ( + ButterworthZeroPhaseSettings, + butter_zero_phase, +) + + +@pytest.mark.parametrize( + "cutoff, cuton", + [ + (30.0, None), # lowpass + (None, 30.0), # highpass + (45.0, 30.0), # bandpass + (30.0, 45.0), # bandstop + ], +) +@pytest.mark.parametrize("order", [2, 4, 8]) +def test_butterworth_zp_filter_specs(cutoff, cuton, order): + """Zero-phase settings inherit filter_specs logic from legacy Butterworth settings.""" + btype, Wn = ButterworthZeroPhaseSettings( + order=order, cuton=cuton, cutoff=cutoff + ).filter_specs() + if cuton is None: + assert btype == "lowpass" and Wn == cutoff + elif cutoff is None: + assert btype == "highpass" and Wn == cuton + elif cuton <= cutoff: + assert btype == "bandpass" and Wn == (cuton, cutoff) + else: + assert btype == "bandstop" and Wn == (cutoff, cuton) + + +@pytest.mark.parametrize( + "cutoff, cuton", + [ + (30.0, None), # lowpass + (None, 30.0), # highpass + (45.0, 30.0), # bandpass + (30.0, 45.0), # bandstop + ], +) +@pytest.mark.parametrize("order", [0, 2, 4]) +@pytest.mark.parametrize("fs", [200.0]) +@pytest.mark.parametrize("n_chans", [3]) +@pytest.mark.parametrize("n_dims, time_ax", [(1, 0), (3, 0), (3, 1), (3, 2)]) +@pytest.mark.parametrize("coef_type", ["ba", "sos"]) +@pytest.mark.parametrize("padtype,padlen", [(None, 0), ("odd", None)]) +def test_butterworth_zero_phase_matches_scipy( + cutoff, cuton, order, fs, n_chans, n_dims, time_ax, coef_type, padtype, padlen +): + dur = 2.0 + n_times = int(dur * fs) + + if n_dims == 1: + dat_shape = [n_times] + dims = ["time"] + other_axes = {} + else: + dat_shape = [5, n_chans] + dat_shape.insert(time_ax, n_times) + dims = ["freq", "ch"] + dims.insert(time_ax, "time") + other_axes = { + "freq": AxisArray.LinearAxis(unit="Hz", offset=0.0, gain=1.0), + "ch": AxisArray.CoordinateAxis( + data=np.arange(n_chans).astype(str), dims=["ch"] + ), + } + + x = np.linspace(0, 1, np.prod(dat_shape), dtype=float).reshape(*dat_shape) + + msg = AxisArray( + data=x, + dims=dims, + axes=frozendict({**other_axes, "time": AxisArray.TimeAxis(fs=fs, offset=0.0)}), + key="test_butterworth_zero_phase", + ) + + # expected via SciPy + btype, Wn = ButterworthZeroPhaseSettings( + order=order, cuton=cuton, cutoff=cutoff + ).filter_specs() + if order == 0: + expected = x + else: + tmp = np.moveaxis(x, time_ax, -1) + if coef_type == "ba": + b, a = scipy.signal.butter(order, Wn, btype=btype, fs=fs, output="ba") + y = scipy.signal.filtfilt( + b, a, tmp, axis=-1, padtype=padtype, padlen=padlen + ) + else: + sos = scipy.signal.butter(order, Wn, btype=btype, fs=fs, output="sos") + y = scipy.signal.sosfiltfilt( + sos, tmp, axis=-1, padtype=padtype, padlen=padlen + ) + expected = np.moveaxis(y, -1, time_ax) + + axis_name = "time" if time_ax != 0 else None + zp = butter_zero_phase( + axis=axis_name, + order=order, + cuton=cuton, + cutoff=cutoff, + coef_type=coef_type, + wn_hz=True, + padtype=padtype, + padlen=padlen, + ) + + out = zp.send(msg).data + assert np.allclose(out, expected, atol=1e-10, rtol=1e-7) + + +def test_butterworth_zero_phase_empty_msg(): + zp = butter_zero_phase( + axis="time", order=4, cuton=0.1, cutoff=10.0, coef_type="sos" + ) + msg = AxisArray( + data=np.zeros((0, 2)), + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=100.0, offset=0.0), + "ch": AxisArray.CoordinateAxis(data=np.array(["0", "1"]), dims=["ch"]), + }, + key="test_butterworth_zero_phase_empty", + ) + res = zp.send(msg) + assert res.data.size == 0 + + +def test_butterworth_zero_phase_update_settings_changes_output(): + fs = 200.0 + t = np.arange(int(2.0 * fs)) / fs + x = np.vstack([np.sin(2 * np.pi * 10 * t), np.sin(2 * np.pi * 40 * t)]).T + + msg = AxisArray( + data=x, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=fs, offset=0.0), + "ch": AxisArray.CoordinateAxis(data=np.array(["0", "1"]), dims=["ch"]), + }, + key="test_butterworth_zero_phase_update", + ) + + zp = butter_zero_phase( + axis="time", order=4, cutoff=30.0, coef_type="sos", padtype="odd", padlen=None + ) + y1 = zp(msg).data + # LP at 30 should pass 10 Hz and attenuate 40 Hz + p_in = np.abs(np.fft.rfft(x, axis=0)) ** 2 + p1 = np.abs(np.fft.rfft(y1, axis=0)) ** 2 + f = np.fft.rfftfreq(x.shape[0], 1 / fs) + + def peak_power(power, f0): + return power[np.argmin(np.abs(f - f0))] + + assert peak_power(p1[:, 0], 10.0) > 0.7 * peak_power(p_in[:, 0], 10.0) + assert peak_power(p1[:, 1], 40.0) < 0.3 * peak_power(p_in[:, 1], 40.0) + + # Switch to HP at 25 Hz + zp.update_settings(cutoff=None, cuton=25.0) + y2 = zp(msg).data + p2 = np.abs(np.fft.rfft(y2, axis=0)) ** 2 + assert peak_power(p2[:, 0], 10.0) < 0.3 * peak_power(p_in[:, 0], 10.0) + assert peak_power(p2[:, 1], 40.0) > 0.7 * peak_power(p_in[:, 1], 40.0) + + zp.update_settings(coef_type="ba", order=2, cutoff=15.0, cuton=None) + y3 = zp(msg).data + # attenuate 40 more than 10 + p3 = np.abs(np.fft.rfft(y3, axis=0)) ** 2 + assert peak_power(p3[:, 1], 40.0) < peak_power(p3[:, 0], 10.0)