Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ authors = [
{ name = "Griffin Milsap", email = "griffin.milsap@gmail.com" },
{ name = "Preston Peranich", email = "pperanich@gmail.com" },
{ name = "Chadwick Boulay", email = "chadwick.boulay@gmail.com" },
{ name = "Kyle McGraw", email = "kmcgraw@blackrockneuro.com" },
]
license = "MIT"
readme = "README.md"
Expand Down
164 changes: 164 additions & 0 deletions src/ezmsg/sigproc/butterworthzerophase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import functools
import typing

import ezmsg.core as ez
import numpy as np
import scipy.signal
from ezmsg.sigproc.base import SettingsType
from ezmsg.sigproc.butterworthfilter import ButterworthFilterSettings, butter_design_fun
from ezmsg.sigproc.filter import (
BACoeffs,
BaseFilterByDesignTransformerUnit,
FilterByDesignTransformer,
SOSCoeffs,
)
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace


class ButterworthZeroPhaseSettings(ButterworthFilterSettings):
"""Settings for :obj:`ButterworthZeroPhase`."""

# axis, coef_type, order, cuton, cutoff, wn_hz are inherited from ButterworthFilterSettings
padtype: str | None = None
"""
Padding type to use in `scipy.signal.filtfilt`.
Must be one of {'odd', 'even', 'constant', None}.
Default is None for no padding.
"""
padlen: int | None = 0
"""
Length of the padding to use in `scipy.signal.filtfilt`.
If None, SciPy's default padding is used.
"""


class ButterworthZeroPhaseTransformer(
FilterByDesignTransformer[ButterworthZeroPhaseSettings, BACoeffs | SOSCoeffs]
):
"""Zero-phase (filtfilt) Butterworth using your design function."""

def get_design_function(
self,
) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
return functools.partial(
butter_design_fun,
order=self.settings.order,
cuton=self.settings.cuton,
cutoff=self.settings.cutoff,
coef_type=self.settings.coef_type,
wn_hz=self.settings.wn_hz,
)

def update_settings(
self, new_settings: typing.Optional[SettingsType] = None, **kwargs
) -> None:
"""
Update settings and mark that filter coefficients need to be recalculated.

Args:
new_settings: Complete new settings object to replace current settings
**kwargs: Individual settings to update
"""
# Update settings
if new_settings is not None:
self.settings = new_settings
else:
self.settings = replace(self.settings, **kwargs)

# Set flag to trigger recalculation on next message
self._coefs_cache = None
self._fs_cache = None
self.state.needs_redesign = True

def _reset_state(self, message: AxisArray) -> None:
self._coefs_cache = None
self._fs_cache = None
self.state.needs_redesign = True

def _process(self, message: AxisArray) -> AxisArray:
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
ax_idx = message.get_axis_idx(axis)
fs = 1 / message.axes[axis].gain

if (
self._coefs_cache is None
or self.state.needs_redesign
or (self._fs_cache is None or not np.isclose(self._fs_cache, fs))
):
self._coefs_cache = self.get_design_function()(fs)
self._fs_cache = fs
self.state.needs_redesign = False

if (
self._coefs_cache is None
or self.settings.order <= 0
or message.data.size <= 0
):
return message

x = message.data
if self.settings.coef_type == "sos":
y = scipy.signal.sosfiltfilt(
self._coefs_cache,
x,
axis=ax_idx,
padtype=self.settings.padtype,
padlen=self.settings.padlen,
)
elif self.settings.coef_type == "ba":
b, a = self._coefs_cache
y = scipy.signal.filtfilt(
b,
a,
x,
axis=ax_idx,
padtype=self.settings.padtype,
padlen=self.settings.padlen,
)
else:
ez.logger.error("coef_type must be 'sos' or 'ba'.")
raise ValueError("coef_type must be 'sos' or 'ba'.")

return replace(message, data=y)


class ButterworthZeroPhase(
BaseFilterByDesignTransformerUnit[
ButterworthZeroPhaseSettings, ButterworthZeroPhaseTransformer
]
):
SETTINGS = ButterworthZeroPhaseSettings


def butter_zero_phase(
axis: str | None,
order: int = 0,
cuton: float | None = None,
cutoff: float | None = None,
coef_type: str = "ba",
wn_hz: bool = True,
padtype: str | None = None,
padlen: int | None = 0,
) -> ButterworthZeroPhaseTransformer:
"""
Convenience generator wrapping filter_gen_by_design for Butterworth Zero Phase filters.
Apply Butterworth Zero Phase filter to streaming data. Uses :obj:`scipy.signal.butter` to design the filter.
See :obj:`ButterworthZeroPhaseSettings.filter_specs` for an explanation of specifying different
filter types (lowpass, highpass, bandpass, bandstop) from the parameters.

Returns:
:obj:`ButterworthZeroPhaseTransformer`
"""
return ButterworthZeroPhaseTransformer(
ButterworthZeroPhaseSettings(
axis=axis,
order=order,
cuton=cuton,
cutoff=cutoff,
coef_type=coef_type,
wn_hz=wn_hz,
padtype=padtype,
padlen=padlen,
)
)
51 changes: 51 additions & 0 deletions tests/integration/ezmsg/test_butterworthzerophase_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import tempfile
from pathlib import Path

import ezmsg.core as ez
from ezmsg.sigproc.synth import EEGSynth
from ezmsg.util.messagecodec import message_log
from ezmsg.util.messagelogger import MessageLogger
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.terminate import TerminateOnTotal

from ezmsg.sigproc.butterworthzerophase import ButterworthZeroPhase


def test_butterworth_zero_phase_system():
fs = 1000.0
n_time = 50
n_total = 10
n_channels = 96
test_filename = Path(tempfile.gettempdir())
test_filename = test_filename / Path("test_butterworth_zero_phase_system.txt")
with open(test_filename, "w"):
pass
ez.logger.info(f"Logging to {test_filename}")

comps = {
"SRC": EEGSynth(n_time=n_time, fs=fs, n_ch=n_channels, alpha_freq=10.0),
"BUTTER": ButterworthZeroPhase(
order=4,
cuton=30.0,
cutoff=45.0,
coef_type="sos",
),
"LOG": MessageLogger(output=test_filename),
"TERM": TerminateOnTotal(total=n_total),
}

conns = (
(comps["SRC"].OUTPUT_SIGNAL, comps["BUTTER"].INPUT_SIGNAL),
(comps["BUTTER"].OUTPUT_SIGNAL, comps["LOG"].INPUT_MESSAGE),
(comps["LOG"].OUTPUT_MESSAGE, comps["TERM"].INPUT_MESSAGE),
)

ez.run(components=comps, connections=conns)

messages = list(message_log(test_filename))
assert len(messages) >= n_total

for msg in messages:
assert isinstance(msg, AxisArray)
assert msg.data.shape[1] == n_channels
assert msg.data.shape[0] == n_time
Loading