From c2777e50ee7efaf1aa4fe7f2eb699e70480a132d Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Tue, 16 Dec 2025 21:44:50 -0500 Subject: [PATCH 1/2] Add deprecation warning when importing modules that have been migrated and are being re-exported --- src/ezmsg/sigproc/base.py | 10 +++++++++- src/ezmsg/sigproc/util/asio.py | 10 +++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/ezmsg/sigproc/base.py b/src/ezmsg/sigproc/base.py index 8334c118..71f48a3d 100644 --- a/src/ezmsg/sigproc/base.py +++ b/src/ezmsg/sigproc/base.py @@ -7,8 +7,16 @@ New code should import directly from ezmsg.baseproc instead. """ +import warnings + +warnings.warn( + "Importing from 'ezmsg.sigproc.base' is deprecated. Please import from 'ezmsg.baseproc' instead.", + DeprecationWarning, + stacklevel=2, +) + # Re-export everything from ezmsg.baseproc for backwards compatibility -from ezmsg.baseproc import ( +from ezmsg.baseproc import ( # noqa: E402 # Protocols AdaptiveTransformer, # Type variables diff --git a/src/ezmsg/sigproc/util/asio.py b/src/ezmsg/sigproc/util/asio.py index 6703baa2..98d03c20 100644 --- a/src/ezmsg/sigproc/util/asio.py +++ b/src/ezmsg/sigproc/util/asio.py @@ -4,7 +4,15 @@ New code should import directly from ezmsg.baseproc instead. """ -from ezmsg.baseproc.util.asio import ( +import warnings + +warnings.warn( + "Importing from 'ezmsg.sigproc.util.asio' is deprecated. Please import from 'ezmsg.baseproc.util.asio' instead.", + DeprecationWarning, + stacklevel=2, +) + +from ezmsg.baseproc.util.asio import ( # noqa: E402 CoroutineExecutionError, SyncToAsyncGeneratorWrapper, run_coroutine_sync, From b36ea847a977e37a64b6ad57abeedff7db7276ec Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Tue, 16 Dec 2025 21:46:28 -0500 Subject: [PATCH 2/2] Move synth to ezmsg.simbiophys --- pyproject.toml | 2 +- src/ezmsg/sigproc/math/add.py | 121 +++ src/ezmsg/sigproc/math/difference.py | 109 ++- src/ezmsg/sigproc/synth.py | 740 ------------------ tests/helpers/synth.py | 281 +++++++ tests/integration/ezmsg/test_add_system.py | 148 ++++ .../ezmsg/test_butterworth_system.py | 2 +- .../ezmsg/test_butterworthzerophase_system.py | 2 +- .../integration/ezmsg/test_decimate_system.py | 2 +- .../ezmsg/test_difference_system.py | 203 +++++ .../ezmsg/test_downsample_system.py | 2 +- tests/integration/ezmsg/test_filter_system.py | 2 +- .../ezmsg/test_fir_hilbert_system.py | 2 +- .../integration/ezmsg/test_fir_pmc_system.py | 2 +- .../ezmsg/test_rollingscaler_system.py | 2 +- .../integration/ezmsg/test_sampler_system.py | 2 +- tests/integration/ezmsg/test_scaler_system.py | 3 +- .../integration/ezmsg/test_spectrum_system.py | 2 +- tests/integration/ezmsg/test_synth_system.py | 237 ------ tests/integration/ezmsg/test_window_system.py | 2 +- tests/unit/test_math_add.py | 247 ++++++ tests/unit/test_math_difference.py | 278 +++++++ tests/unit/test_synth.py | 142 ---- 23 files changed, 1381 insertions(+), 1152 deletions(-) create mode 100644 src/ezmsg/sigproc/math/add.py delete mode 100644 src/ezmsg/sigproc/synth.py create mode 100644 tests/helpers/synth.py create mode 100644 tests/integration/ezmsg/test_add_system.py create mode 100644 tests/integration/ezmsg/test_difference_system.py delete mode 100644 tests/integration/ezmsg/test_synth_system.py create mode 100644 tests/unit/test_math_add.py create mode 100644 tests/unit/test_math_difference.py delete mode 100644 tests/unit/test_synth.py diff --git a/pyproject.toml b/pyproject.toml index d6ba50f4..e556a2bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ requires-python = ">=3.10.15" dynamic = ["version"] dependencies = [ "array-api-compat>=1.11.1", - "ezmsg-baseproc>=1.0", + "ezmsg-baseproc>=1.0.3", "ezmsg>=3.6.0", "numba>=0.61.0", "numpy>=1.26.0", diff --git a/src/ezmsg/sigproc/math/add.py b/src/ezmsg/sigproc/math/add.py new file mode 100644 index 00000000..dbf7b64e --- /dev/null +++ b/src/ezmsg/sigproc/math/add.py @@ -0,0 +1,121 @@ +"""Signal addition utilities.""" + +import asyncio +import typing +from dataclasses import dataclass, field + +import ezmsg.core as ez +from ezmsg.baseproc.util.asio import run_coroutine_sync +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace + +from ..base import BaseTransformer, BaseTransformerUnit + +# --- Constant Addition (single input) --- + + +class ConstAddSettings(ez.Settings): + value: float = 0.0 + """Number to add to the input data.""" + + +class ConstAddTransformer(BaseTransformer[ConstAddSettings, AxisArray, AxisArray]): + """Add a constant value to input data.""" + + def _process(self, message: AxisArray) -> AxisArray: + return replace(message, data=message.data + self.settings.value) + + +class ConstAdd(BaseTransformerUnit[ConstAddSettings, AxisArray, AxisArray, ConstAddTransformer]): + """Unit wrapper for ConstAddTransformer.""" + + SETTINGS = ConstAddSettings + + +# --- Two-input Addition --- + + +@dataclass +class AddState: + """State for Add processor with two input queues.""" + + queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue) + queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue) + + +class AddProcessor: + """Processor that adds two AxisArray signals together. + + This processor maintains separate queues for two input streams and + adds corresponding messages element-wise. It assumes both inputs + have compatible shapes and aligned time spans. + """ + + def __init__(self): + self._state = AddState() + + @property + def state(self) -> AddState: + return self._state + + @state.setter + def state(self, state: AddState | bytes | None) -> None: + if state is not None: + # TODO: Support hydrating state from bytes + # if isinstance(state, bytes): + # self._state = pickle.loads(state) + # else: + self._state = state + + def push_a(self, msg: AxisArray) -> None: + """Push a message to queue A.""" + self._state.queue_a.put_nowait(msg) + + def push_b(self, msg: AxisArray) -> None: + """Push a message to queue B.""" + self._state.queue_b.put_nowait(msg) + + async def __acall__(self) -> AxisArray: + """Await and add the next messages from both queues.""" + a = await self._state.queue_a.get() + b = await self._state.queue_b.get() + return replace(a, data=a.data + b.data) + + def __call__(self) -> AxisArray: + """Synchronously get and add the next messages from both queues.""" + return run_coroutine_sync(self.__acall__()) + + # Aliases for legacy interface + async def __anext__(self) -> AxisArray: + return await self.__acall__() + + def __next__(self) -> AxisArray: + return self.__call__() + + +class Add(ez.Unit): + """Add two signals together. + + Assumes compatible/similar axes/dimensions and aligned time spans. + Messages are paired by arrival order (oldest from each queue). + """ + + INPUT_SIGNAL_A = ez.InputStream(AxisArray) + INPUT_SIGNAL_B = ez.InputStream(AxisArray) + OUTPUT_SIGNAL = ez.OutputStream(AxisArray) + + async def initialize(self) -> None: + self.processor = AddProcessor() + + @ez.subscriber(INPUT_SIGNAL_A) + async def on_a(self, msg: AxisArray) -> None: + self.processor.push_a(msg) + + @ez.subscriber(INPUT_SIGNAL_B) + async def on_b(self, msg: AxisArray) -> None: + self.processor.push_b(msg) + + @ez.publisher(OUTPUT_SIGNAL) + async def output(self) -> typing.AsyncGenerator: + while True: + yield self.OUTPUT_SIGNAL, await self.processor.__acall__() diff --git a/src/ezmsg/sigproc/math/difference.py b/src/ezmsg/sigproc/math/difference.py index 0a61e6be..c95b0270 100644 --- a/src/ezmsg/sigproc/math/difference.py +++ b/src/ezmsg/sigproc/math/difference.py @@ -1,4 +1,9 @@ +import asyncio +import typing +from dataclasses import dataclass, field + import ezmsg.core as ez +from ezmsg.baseproc.util.asio import run_coroutine_sync from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.messages.util import replace @@ -43,22 +48,88 @@ def const_difference(value: float = 0.0, subtrahend: bool = True) -> ConstDiffer return ConstDifferenceTransformer(ConstDifferenceSettings(value=value, subtrahend=subtrahend)) -# class DifferenceSettings(ez.Settings): -# pass -# -# -# class Difference(ez.Unit): -# SETTINGS = DifferenceSettings -# -# INPUT_SIGNAL_1 = ez.InputStream(AxisArray) -# INPUT_SIGNAL_2 = ez.InputStream(AxisArray) -# OUTPUT_SIGNAL = ez.OutputStream(AxisArray) -# -# @ez.subscriber(INPUT_SIGNAL_2, zero_copy=True) -# @ez.publisher(OUTPUT_SIGNAL) -# async def on_input_2(self, message: AxisArray) -> typing.AsyncGenerator: -# # TODO: buffer_2 -# # TODO: take buffer_1 - buffer_2 for ranges that align -# # TODO: Drop samples from buffer_1 and buffer_2 -# if ret is not None: -# yield self.OUTPUT_SIGNAL, ret +# --- Two-input Difference --- + + +@dataclass +class DifferenceState: + """State for Difference processor with two input queues.""" + + queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue) + queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue) + + +class DifferenceProcessor: + """Processor that subtracts two AxisArray signals (A - B). + + This processor maintains separate queues for two input streams and + subtracts corresponding messages element-wise. It assumes both inputs + have compatible shapes and aligned time spans. + """ + + def __init__(self): + self._state = DifferenceState() + + @property + def state(self) -> DifferenceState: + return self._state + + @state.setter + def state(self, state: DifferenceState | bytes | None) -> None: + if state is not None: + self._state = state + + def push_a(self, msg: AxisArray) -> None: + """Push a message to queue A (minuend).""" + self._state.queue_a.put_nowait(msg) + + def push_b(self, msg: AxisArray) -> None: + """Push a message to queue B (subtrahend).""" + self._state.queue_b.put_nowait(msg) + + async def __acall__(self) -> AxisArray: + """Await and subtract the next messages (A - B).""" + a = await self._state.queue_a.get() + b = await self._state.queue_b.get() + return replace(a, data=a.data - b.data) + + def __call__(self) -> AxisArray: + """Synchronously get and subtract the next messages.""" + return run_coroutine_sync(self.__acall__()) + + # Aliases for legacy interface + async def __anext__(self) -> AxisArray: + return await self.__acall__() + + def __next__(self) -> AxisArray: + return self.__call__() + + +class Difference(ez.Unit): + """Subtract two signals (A - B). + + Assumes compatible/similar axes/dimensions and aligned time spans. + Messages are paired by arrival order (oldest from each queue). + + OUTPUT = INPUT_SIGNAL_A - INPUT_SIGNAL_B + """ + + INPUT_SIGNAL_A = ez.InputStream(AxisArray) + INPUT_SIGNAL_B = ez.InputStream(AxisArray) + OUTPUT_SIGNAL = ez.OutputStream(AxisArray) + + async def initialize(self) -> None: + self.processor = DifferenceProcessor() + + @ez.subscriber(INPUT_SIGNAL_A) + async def on_a(self, msg: AxisArray) -> None: + self.processor.push_a(msg) + + @ez.subscriber(INPUT_SIGNAL_B) + async def on_b(self, msg: AxisArray) -> None: + self.processor.push_b(msg) + + @ez.publisher(OUTPUT_SIGNAL) + async def output(self) -> typing.AsyncGenerator: + while True: + yield self.OUTPUT_SIGNAL, await self.processor.__acall__() diff --git a/src/ezmsg/sigproc/synth.py b/src/ezmsg/sigproc/synth.py deleted file mode 100644 index c6cec838..00000000 --- a/src/ezmsg/sigproc/synth.py +++ /dev/null @@ -1,740 +0,0 @@ -import asyncio -import time -import traceback -import typing -from dataclasses import dataclass, field - -import ezmsg.core as ez -import numpy as np -from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.util.messages.util import replace - -from .base import ( - BaseProducerUnit, - BaseStatefulProducer, - BaseTransformer, - BaseTransformerUnit, - CompositeProducer, - MessageInType, - MessageOutType, - ProducerType, - SettingsType, - processor_state, -) -from .butterworthfilter import ButterworthFilterSettings, ButterworthFilterTransformer -from .util.asio import run_coroutine_sync -from .util.profile import profile_subpub - - -@dataclass -class AddState: - queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue) - queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue) - - -class AddProcessor: - def __init__(self): - self._state = AddState() - - @property - def state(self) -> AddState: - return self._state - - @state.setter - def state(self, state: AddState | bytes | None) -> None: - if state is not None: - # TODO: Support hydrating state from bytes - # if isinstance(state, bytes): - # self._state = pickle.loads(state) - # else: - self._state = state - - def push_a(self, msg: AxisArray) -> None: - self._state.queue_a.put_nowait(msg) - - def push_b(self, msg: AxisArray) -> None: - self._state.queue_b.put_nowait(msg) - - async def __acall__(self) -> AxisArray: - a = await self._state.queue_a.get() - b = await self._state.queue_b.get() - return replace(a, data=a.data + b.data) - - def __call__(self) -> AxisArray: - return run_coroutine_sync(self.__acall__()) - - # Aliases for legacy interface - async def __anext__(self) -> AxisArray: - return await self.__acall__() - - def __next__(self) -> AxisArray: - return self.__call__() - - -class Add(ez.Unit): - """Add two signals together. Assumes compatible/similar axes/dimensions.""" - - INPUT_SIGNAL_A = ez.InputStream(AxisArray) - INPUT_SIGNAL_B = ez.InputStream(AxisArray) - OUTPUT_SIGNAL = ez.OutputStream(AxisArray) - - async def initialize(self) -> None: - self.processor = AddProcessor() - - @ez.subscriber(INPUT_SIGNAL_A) - async def on_a(self, msg: AxisArray) -> None: - self.processor.push_a(msg) - - @ez.subscriber(INPUT_SIGNAL_B) - async def on_b(self, msg: AxisArray) -> None: - self.processor.push_b(msg) - - @ez.publisher(OUTPUT_SIGNAL) - async def output(self) -> typing.AsyncGenerator: - while True: - yield self.OUTPUT_SIGNAL, await self.processor.__acall__() - - -class ClockSettings(ez.Settings): - """Settings for clock generator.""" - - dispatch_rate: float | str | None = None - """Dispatch rate in Hz, 'realtime', or None for external clock""" - - -@processor_state -class ClockState: - """State for clock generator.""" - - t_0: float = field(default_factory=time.time) # Start time - n_dispatch: int = 0 # Number of dispatches - - -class ClockProducer(BaseStatefulProducer[ClockSettings, ez.Flag, ClockState]): - """ - Produces clock ticks at specified rate. - Can be used to drive periodic operations. - """ - - def _reset_state(self) -> None: - """Reset internal state.""" - self._state.t_0 = time.time() - self._state.n_dispatch = 0 - - def __call__(self) -> ez.Flag: - """Synchronous clock production. We override __call__ (which uses run_coroutine_sync) - to avoid async overhead.""" - if self._hash == -1: - self._reset_state() - self._hash = 0 - - if isinstance(self.settings.dispatch_rate, (int, float)): - # Manual dispatch_rate. (else it is 'as fast as possible') - target_time = self.state.t_0 + (self.state.n_dispatch + 1) / self.settings.dispatch_rate - now = time.time() - if target_time > now: - time.sleep(target_time - now) - - self.state.n_dispatch += 1 - return ez.Flag() - - async def _produce(self) -> ez.Flag: - """Generate next clock tick.""" - if isinstance(self.settings.dispatch_rate, (int, float)): - # Manual dispatch_rate. (else it is 'as fast as possible') - target_time = self.state.t_0 + (self.state.n_dispatch + 1) / self.settings.dispatch_rate - now = time.time() - if target_time > now: - await asyncio.sleep(target_time - now) - - self.state.n_dispatch += 1 - return ez.Flag() - - -def aclock(dispatch_rate: float | None) -> ClockProducer: - """ - Construct an async generator that yields events at a specified rate. - - Returns: - A :obj:`ClockProducer` object. - """ - return ClockProducer(ClockSettings(dispatch_rate=dispatch_rate)) - - -clock = aclock -""" -Alias for :obj:`aclock` expected by synchronous methods. `ClockProducer` can be used in sync or async. -""" - - -class Clock( - BaseProducerUnit[ - ClockSettings, # SettingsType - ez.Flag, # MessageType - ClockProducer, # ProducerType - ] -): - SETTINGS = ClockSettings - - @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL) - async def produce(self) -> typing.AsyncGenerator: - # Override so we can not to yield if out is False-like - while True: - out = await self.producer.__acall__() - if out: - yield self.OUTPUT_SIGNAL, out - - -# COUNTER - Generate incrementing integer. fs and dispatch_rate parameters combine to give many options. # -class CounterSettings(ez.Settings): - # TODO: Adapt this to use ezmsg.util.rate? - """ - Settings for :obj:`Counter`. - See :obj:`acounter` for a description of the parameters. - """ - - n_time: int - """Number of samples to output per block.""" - - fs: float - """Sampling rate of signal output in Hz""" - - n_ch: int = 1 - """Number of channels to synthesize""" - - dispatch_rate: float | str | None = None - """ - Message dispatch rate (Hz), 'realtime', 'ext_clock', or None (fast as possible) - Note: if dispatch_rate is a float then time offsets will be synthetic and the - system will run faster or slower than wall clock time. - """ - - mod: int | None = None - """If set to an integer, counter will rollover""" - - -@processor_state -class CounterState: - """ - State for counter generator. - """ - - counter_start: int = 0 - """next sample's first value""" - - n_sent: int = 0 - """number of samples sent""" - - clock_zero: float | None = None - """time of first sample""" - - timer_type: str = "unspecified" - """ - "realtime" | "ext_clock" | "manual" | "unspecified" - """ - - new_generator: asyncio.Event | None = None - """ - Event to signal the counter has been reset. - """ - - -class CounterProducer(BaseStatefulProducer[CounterSettings, AxisArray, CounterState]): - """Produces incrementing integer blocks as AxisArray.""" - - # TODO: Adapt this to use ezmsg.util.rate? - - @classmethod - def get_message_type(cls, dir: str) -> typing.Optional[type[AxisArray]]: - if dir == "in": - return None - elif dir == "out": - return AxisArray - else: - raise ValueError(f"Invalid direction: {dir}. Use 'in' or 'out'.") - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if isinstance(self.settings.dispatch_rate, str) and self.settings.dispatch_rate not in [ - "realtime", - "ext_clock", - ]: - raise ValueError(f"Unknown dispatch_rate: {self.settings.dispatch_rate}") - self._reset_state() - self._hash = 0 - - def _reset_state(self) -> None: - """Reset internal state.""" - self._state.counter_start = 0 - self._state.n_sent = 0 - self._state.clock_zero = time.time() - if self.settings.dispatch_rate is not None: - if isinstance(self.settings.dispatch_rate, str): - self._state.timer_type = self.settings.dispatch_rate.lower() - else: - self._state.timer_type = "manual" - if self._state.new_generator is None: - self._state.new_generator = asyncio.Event() - # Set the event to indicate that the state has been reset. - self._state.new_generator.set() - - async def _produce(self) -> AxisArray: - """Generate next counter block.""" - # 1. Prepare counter data - block_samp = np.arange(self.state.counter_start, self.state.counter_start + self.settings.n_time)[:, np.newaxis] - if self.settings.mod is not None: - block_samp %= self.settings.mod - block_samp = np.tile(block_samp, (1, self.settings.n_ch)) - - # 2. Sleep if necessary. 3. Calculate time offset. - if self._state.timer_type == "realtime": - n_next = self.state.n_sent + self.settings.n_time - t_next = self.state.clock_zero + n_next / self.settings.fs - await asyncio.sleep(t_next - time.time()) - offset = t_next - self.settings.n_time / self.settings.fs - elif self._state.timer_type == "manual": - # manual dispatch rate - n_disp_next = 1 + self.state.n_sent / self.settings.n_time - t_disp_next = self.state.clock_zero + n_disp_next / self.settings.dispatch_rate - await asyncio.sleep(t_disp_next - time.time()) - offset = self.state.n_sent / self.settings.fs - elif self._state.timer_type == "ext_clock": - # ext_clock -- no sleep. Assume this is called at appropriate intervals. - offset = time.time() - else: - # Was "unspecified" - offset = self.state.n_sent / self.settings.fs - - # 4. Create output AxisArray - # Note: We can make this a bit faster by preparing a template for self._state - result = AxisArray( - data=block_samp, - dims=["time", "ch"], - axes={ - "time": AxisArray.TimeAxis(fs=self.settings.fs, offset=offset), - "ch": AxisArray.CoordinateAxis( - data=np.array([f"Ch{_}" for _ in range(self.settings.n_ch)]), - dims=["ch"], - ), - }, - key="acounter", - ) - - # 5. Update state - self.state.counter_start = block_samp[-1, 0] + 1 - self.state.n_sent += self.settings.n_time - - return result - - -def acounter( - n_time: int, - fs: float | None, - n_ch: int = 1, - dispatch_rate: float | str | None = None, - mod: int | None = None, -) -> CounterProducer: - """ - Construct an asynchronous generator to generate AxisArray objects at a specified rate - and with the specified sampling rate. - - NOTE: This module uses asyncio.sleep to delay appropriately in realtime mode. - This method of sleeping/yielding execution priority has quirky behavior with - sub-millisecond sleep periods which may result in unexpected behavior (e.g. - fs = 2000, n_time = 1, realtime = True -- may result in ~1400 msgs/sec) - - Returns: - An asynchronous generator. - """ - return CounterProducer(CounterSettings(n_time=n_time, fs=fs, n_ch=n_ch, dispatch_rate=dispatch_rate, mod=mod)) - - -class Counter( - BaseProducerUnit[ - CounterSettings, # SettingsType - AxisArray, # MessageOutType - CounterProducer, # ProducerType - ] -): - """Generates monotonically increasing counter. Unit for :obj:`CounterProducer`.""" - - SETTINGS = CounterSettings - INPUT_CLOCK = ez.InputStream(ez.Flag) - - @ez.subscriber(INPUT_CLOCK) - @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL) - async def on_clock(self, _: ez.Flag): - if self.producer.settings.dispatch_rate == "ext_clock": - out = await self.producer.__acall__() - yield self.OUTPUT_SIGNAL, out - - @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL) - async def produce(self) -> typing.AsyncGenerator: - """ - Generate counter output. - This is an infinite loop, but we will likely only enter the loop once if we are self-timed, - and twice if we are using an external clock. - - When using an internal clock, we enter the loop, and wait for the event which should have - been reset upon initialization then we immediately clear, then go to the internal loop - that will async call __acall__ to let the internal timer determine when to produce an output. - - When using an external clock, we enter the loop, and wait for the event which should have been - reset upon initialization then we immediately clear, then we hit `continue` to loop back around - and wait for the event to be set again -- potentially forever. In this case, it is expected that - `on_clock` will be called to produce the output. - """ - try: - while True: - # Once-only, enter the generator loop - await self.producer.state.new_generator.wait() - self.producer.state.new_generator.clear() - - if self.producer.settings.dispatch_rate == "ext_clock": - # We shouldn't even be here. Cycle around and wait on the event again. - continue - - # We are not using an external clock. Run the generator. - while not self.producer.state.new_generator.is_set(): - out = await self.producer.__acall__() - yield self.OUTPUT_SIGNAL, out - except Exception: - ez.logger.info(traceback.format_exc()) - - -class SinGeneratorSettings(ez.Settings): - """ - Settings for :obj:`SinGenerator`. - See :obj:`sin` for parameter descriptions. - """ - - axis: str | None = "time" - """ - The name of the axis over which the sinusoid passes. - Note: The axis must exist in the msg.axes and be of type AxisArray.LinearAxis. - """ - - freq: float = 1.0 - """The frequency of the sinusoid, in Hz.""" - - amp: float = 1.0 # Amplitude - """The amplitude of the sinusoid.""" - - phase: float = 0.0 # Phase offset (in radians) - """The initial phase of the sinusoid, in radians.""" - - -class SinTransformer(BaseTransformer[SinGeneratorSettings, AxisArray, AxisArray]): - """Transforms counter values into sinusoidal waveforms.""" - - def _process(self, message: AxisArray) -> AxisArray: - """Transform input counter values into sinusoidal waveform.""" - axis = self.settings.axis or message.dims[0] - - ang_freq = 2.0 * np.pi * self.settings.freq - w = (ang_freq * message.get_axis(axis).gain) * message.data - out_data = self.settings.amp * np.sin(w + self.settings.phase) - - return replace(message, data=out_data) - - -class SinGenerator(BaseTransformerUnit[SinGeneratorSettings, AxisArray, AxisArray, SinTransformer]): - """Unit for generating sinusoidal waveforms.""" - - SETTINGS = SinGeneratorSettings - - -def sin( - axis: str | None = "time", - freq: float = 1.0, - amp: float = 1.0, - phase: float = 0.0, -) -> SinTransformer: - """ - Construct a generator of sinusoidal waveforms in AxisArray objects. - - Returns: - A primed generator that expects .send(axis_array) of sample counts - and yields an AxisArray of sinusoids. - """ - return SinTransformer(SinGeneratorSettings(axis=axis, freq=freq, amp=amp, phase=phase)) - - -class RandomGeneratorSettings(ez.Settings): - loc: float = 0.0 - """loc argument for :obj:`numpy.random.normal`""" - - scale: float = 1.0 - """scale argument for :obj:`numpy.random.normal`""" - - -class RandomTransformer(BaseTransformer[RandomGeneratorSettings, AxisArray, AxisArray]): - """ - Replaces input data with random data and returns the result. - """ - - def __init__(self, *args, settings: RandomGeneratorSettings | None = None, **kwargs): - super().__init__(*args, settings=settings, **kwargs) - - def _process(self, message: AxisArray) -> AxisArray: - random_data = np.random.normal(size=message.shape, loc=self.settings.loc, scale=self.settings.scale) - return replace(message, data=random_data) - - -class RandomGenerator( - BaseTransformerUnit[ - RandomGeneratorSettings, - AxisArray, - AxisArray, - RandomTransformer, - ] -): - SETTINGS = RandomGeneratorSettings - - -class OscillatorSettings(ez.Settings): - """Settings for :obj:`Oscillator`""" - - n_time: int - """Number of samples to output per block.""" - - fs: float - """Sampling rate of signal output in Hz""" - - n_ch: int = 1 - """Number of channels to output per block""" - - dispatch_rate: float | str | None = None - """(Hz) | 'realtime' | 'ext_clock'""" - - freq: float = 1.0 - """Oscillation frequency in Hz""" - - amp: float = 1.0 - """Amplitude""" - - phase: float = 0.0 - """Phase offset (in radians)""" - - sync: bool = False - """Adjust `freq` to sync with sampling rate""" - - -class OscillatorProducer(CompositeProducer[OscillatorSettings, AxisArray]): - @staticmethod - def _initialize_processors( - settings: OscillatorSettings, - ) -> dict[str, CounterProducer | SinTransformer]: - # Calculate synchronous settings if necessary - freq = settings.freq - mod = None - if settings.sync: - period = 1.0 / settings.freq - mod = round(period * settings.fs) - freq = 1.0 / (mod / settings.fs) - - return { - "counter": CounterProducer( - CounterSettings( - n_time=settings.n_time, - fs=settings.fs, - n_ch=settings.n_ch, - dispatch_rate=settings.dispatch_rate, - mod=mod, - ) - ), - "sin": SinTransformer(SinGeneratorSettings(freq=freq, amp=settings.amp, phase=settings.phase)), - } - - -class BaseCounterFirstProducerUnit( - BaseProducerUnit[SettingsType, MessageOutType, ProducerType], - typing.Generic[SettingsType, MessageInType, MessageOutType, ProducerType], -): - """ - Base class for units whose primary processor is a composite producer with a CounterProducer as the first - processor (producer) in the chain. - """ - - INPUT_SIGNAL = ez.InputStream(MessageInType) - - def create_producer(self): - super().create_producer() - - def recurse_get_counter(proc) -> CounterProducer: - if hasattr(proc, "_procs"): - return recurse_get_counter(list(proc._procs.values())[0]) - return proc - - self._counter = recurse_get_counter(self.producer) - - @ez.subscriber(INPUT_SIGNAL, zero_copy=True) - @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL) - @profile_subpub(trace_oldest=False) - async def on_signal(self, _: ez.Flag): - if self.producer.settings.dispatch_rate == "ext_clock": - out = await self.producer.__acall__() - yield self.OUTPUT_SIGNAL, out - - @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL) - async def produce(self) -> typing.AsyncGenerator: - try: - counter_state = self._counter.state - while True: - # Once-only, enter the generator loop - await counter_state.new_generator.wait() - counter_state.new_generator.clear() - - if self.producer.settings.dispatch_rate == "ext_clock": - # We shouldn't even be here. Cycle around and wait on the event again. - continue - - # We are not using an external clock. Run the generator. - while not counter_state.new_generator.is_set(): - out = await self.producer.__acall__() - yield self.OUTPUT_SIGNAL, out - except Exception: - ez.logger.info(traceback.format_exc()) - - -class Oscillator(BaseCounterFirstProducerUnit[OscillatorSettings, AxisArray, AxisArray, OscillatorProducer]): - """Generates sinusoidal waveforms using a counter and sine transformer.""" - - SETTINGS = OscillatorSettings - - -class NoiseSettings(ez.Settings): - """ - See :obj:`CounterSettings` and :obj:`RandomGeneratorSettings`. - """ - - n_time: int # Number of samples to output per block - fs: float # Sampling rate of signal output in Hz - n_ch: int = 1 # Number of channels to output - dispatch_rate: float | str | None = None - """(Hz), 'realtime', or 'ext_clock'""" - loc: float = 0.0 # DC offset - scale: float = 1.0 # Scale (in standard deviations) - - -WhiteNoiseSettings = NoiseSettings - - -class WhiteNoiseProducer(CompositeProducer[NoiseSettings, AxisArray]): - @staticmethod - def _initialize_processors( - settings: NoiseSettings, - ) -> dict[str, CounterProducer | RandomTransformer]: - return { - "counter": CounterProducer( - CounterSettings( - n_time=settings.n_time, - fs=settings.fs, - n_ch=settings.n_ch, - dispatch_rate=settings.dispatch_rate, - mod=None, - ) - ), - "random": RandomTransformer( - RandomGeneratorSettings( - loc=settings.loc, - scale=settings.scale, - ) - ), - } - - -class WhiteNoise(BaseCounterFirstProducerUnit[NoiseSettings, AxisArray, AxisArray, WhiteNoiseProducer]): - """chains a :obj:`Counter` and :obj:`RandomGenerator`.""" - - SETTINGS = NoiseSettings - - -PinkNoiseSettings = NoiseSettings - - -class PinkNoiseProducer(CompositeProducer[PinkNoiseSettings, AxisArray]): - @staticmethod - def _initialize_processors( - settings: PinkNoiseSettings, - ) -> dict[str, WhiteNoiseProducer | ButterworthFilterTransformer]: - return { - "white_noise": WhiteNoiseProducer(settings=settings), - "filter": ButterworthFilterTransformer( - settings=ButterworthFilterSettings( - axis="time", - order=1, - cutoff=settings.fs * 0.01, # Hz - ) - ), - } - - -class PinkNoise(BaseCounterFirstProducerUnit[NoiseSettings, AxisArray, AxisArray, PinkNoiseProducer]): - """chains :obj:`WhiteNoise` and :obj:`ButterworthFilter`.""" - - SETTINGS = NoiseSettings - - -class EEGSynthSettings(ez.Settings): - """See :obj:`OscillatorSettings`.""" - - fs: float = 500.0 # Hz - n_time: int = 100 - alpha_freq: float = 10.5 # Hz - n_ch: int = 8 - - -class EEGSynth(ez.Collection): - """ - A :obj:`Collection` that chains a :obj:`Clock` to both :obj:`PinkNoise` - and :obj:`Oscillator`, then :obj:`Add` s the result. - - Unlike the Oscillator, WhiteNoise, and PinkNoise composite processors which have linear - flows, this class has a diamond flow, with clock branching to both PinkNoise and Oscillator, - which then are combined in Add. - - Optional: Refactor as a ProducerUnit, similar to Clock, but we manually add all the other - transformers. - """ - - SETTINGS = EEGSynthSettings - - OUTPUT_SIGNAL = ez.OutputStream(AxisArray) - - CLOCK = Clock() - NOISE = PinkNoise() - OSC = Oscillator() - ADD = Add() - - def configure(self) -> None: - self.CLOCK.apply_settings(ClockSettings(dispatch_rate=self.SETTINGS.fs / self.SETTINGS.n_time)) - - self.OSC.apply_settings( - OscillatorSettings( - n_time=self.SETTINGS.n_time, - fs=self.SETTINGS.fs, - n_ch=self.SETTINGS.n_ch, - dispatch_rate="ext_clock", - freq=self.SETTINGS.alpha_freq, - ) - ) - - self.NOISE.apply_settings( - PinkNoiseSettings( - n_time=self.SETTINGS.n_time, - fs=self.SETTINGS.fs, - n_ch=self.SETTINGS.n_ch, - dispatch_rate="ext_clock", - scale=5.0, - ) - ) - - def network(self) -> ez.NetworkDefinition: - return ( - (self.CLOCK.OUTPUT_SIGNAL, self.OSC.INPUT_SIGNAL), - (self.CLOCK.OUTPUT_SIGNAL, self.NOISE.INPUT_SIGNAL), - (self.OSC.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_A), - (self.NOISE.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_B), - (self.ADD.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL), - ) diff --git a/tests/helpers/synth.py b/tests/helpers/synth.py new file mode 100644 index 00000000..dbdb9d1d --- /dev/null +++ b/tests/helpers/synth.py @@ -0,0 +1,281 @@ +""" +Test signal generators for ezmsg-sigproc integration tests. + +These are simplified signal generators intended for testing purposes only. +For production use, see ezmsg-simbiophys package. +""" + +import asyncio +import time +import typing + +import ezmsg.core as ez +import numpy as np +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.sigproc.math.add import Add # noqa: F401 - re-exported for test use + + +# Counter - Produces incrementing integer samples +class CounterSettings(ez.Settings): + n_time: int = 100 + fs: float = 1000.0 + n_ch: int = 1 + dispatch_rate: float | None = None # Hz or None for fast as possible + + +class Counter(ez.Unit): + """Simple counter generator for testing.""" + + SETTINGS = CounterSettings + OUTPUT_SIGNAL = ez.OutputStream(AxisArray) + + async def initialize(self) -> None: + self._counter = 0 + self._n_sent = 0 + self._t0 = time.time() + + @ez.publisher(OUTPUT_SIGNAL) + async def produce(self) -> typing.AsyncGenerator: + while True: + # Sleep if needed + if self.SETTINGS.dispatch_rate is not None: + n_disp = 1 + self._n_sent / self.SETTINGS.n_time + t_next = self._t0 + n_disp / self.SETTINGS.dispatch_rate + sleep_time = t_next - time.time() + if sleep_time > 0: + await asyncio.sleep(sleep_time) + + # Generate counter data + block = np.arange(self._counter, self._counter + self.SETTINGS.n_time)[:, np.newaxis] + block = np.tile(block, (1, self.SETTINGS.n_ch)) + + offset = self._n_sent / self.SETTINGS.fs + result = AxisArray( + data=block, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=self.SETTINGS.fs, offset=offset), + "ch": AxisArray.CoordinateAxis( + data=np.array([f"Ch{_}" for _ in range(self.SETTINGS.n_ch)]), + dims=["ch"], + ), + }, + ) + + self._counter = block[-1, 0] + 1 + self._n_sent += self.SETTINGS.n_time + + yield self.OUTPUT_SIGNAL, result + + +# WhiteNoise - Produces random Gaussian noise +class WhiteNoiseSettings(ez.Settings): + n_time: int = 100 + fs: float = 1000.0 + n_ch: int = 1 + dispatch_rate: float | None = None + loc: float = 0.0 + scale: float = 1.0 + + +class WhiteNoise(ez.Unit): + """Simple white noise generator for testing.""" + + SETTINGS = WhiteNoiseSettings + OUTPUT_SIGNAL = ez.OutputStream(AxisArray) + + async def initialize(self) -> None: + self._n_sent = 0 + self._t0 = time.time() + + @ez.publisher(OUTPUT_SIGNAL) + async def produce(self) -> typing.AsyncGenerator: + while True: + # Sleep if needed + if self.SETTINGS.dispatch_rate is not None: + n_disp = 1 + self._n_sent / self.SETTINGS.n_time + t_next = self._t0 + n_disp / self.SETTINGS.dispatch_rate + sleep_time = t_next - time.time() + if sleep_time > 0: + await asyncio.sleep(sleep_time) + + # Generate noise data + data = np.random.normal( + loc=self.SETTINGS.loc, + scale=self.SETTINGS.scale, + size=(self.SETTINGS.n_time, self.SETTINGS.n_ch), + ) + + offset = self._n_sent / self.SETTINGS.fs + result = AxisArray( + data=data, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=self.SETTINGS.fs, offset=offset), + "ch": AxisArray.CoordinateAxis( + data=np.array([f"Ch{_}" for _ in range(self.SETTINGS.n_ch)]), + dims=["ch"], + ), + }, + ) + + self._n_sent += self.SETTINGS.n_time + + yield self.OUTPUT_SIGNAL, result + + +# Oscillator - Produces sinusoidal signals +class OscillatorSettings(ez.Settings): + n_time: int = 100 + fs: float = 1000.0 + n_ch: int = 1 + dispatch_rate: float | str | None = None # Hz, "realtime", or None for fast as possible + freq: float = 10.0 # Hz + amp: float = 1.0 + phase: float = 0.0 + sync: bool = False # Adjust freq to sync with sampling rate + + +class Oscillator(ez.Unit): + """Simple oscillator generator for testing.""" + + SETTINGS = OscillatorSettings + OUTPUT_SIGNAL = ez.OutputStream(AxisArray) + + async def initialize(self) -> None: + self._n_sent = 0 + self._t0 = time.time() + + # Calculate synchronized frequency if requested + self._freq = self.SETTINGS.freq + if self.SETTINGS.sync: + period = 1.0 / self.SETTINGS.freq + mod = round(period * self.SETTINGS.fs) + self._freq = 1.0 / (mod / self.SETTINGS.fs) + + @ez.publisher(OUTPUT_SIGNAL) + async def produce(self) -> typing.AsyncGenerator: + while True: + # Calculate offset based on timing mode + if self.SETTINGS.dispatch_rate == "realtime": + # Realtime mode: sleep until wall-clock time matches sample time + n_next = self._n_sent + self.SETTINGS.n_time + t_next = self._t0 + n_next / self.SETTINGS.fs + sleep_time = t_next - time.time() + if sleep_time > 0: + await asyncio.sleep(sleep_time) + offset = t_next - self.SETTINGS.n_time / self.SETTINGS.fs + elif self.SETTINGS.dispatch_rate is not None: + # Manual dispatch rate mode + n_disp = 1 + self._n_sent / self.SETTINGS.n_time + t_next = self._t0 + n_disp / self.SETTINGS.dispatch_rate + sleep_time = t_next - time.time() + if sleep_time > 0: + await asyncio.sleep(sleep_time) + offset = self._n_sent / self.SETTINGS.fs + else: + # Fast as possible mode + offset = self._n_sent / self.SETTINGS.fs + + # Generate sinusoidal data + sample_indices = np.arange(self._n_sent, self._n_sent + self.SETTINGS.n_time) + t = sample_indices / self.SETTINGS.fs + data = self.SETTINGS.amp * np.sin(2 * np.pi * self._freq * t + self.SETTINGS.phase) + data = data[:, np.newaxis] + data = np.tile(data, (1, self.SETTINGS.n_ch)) + + result = AxisArray( + data=data, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=self.SETTINGS.fs, offset=offset), + "ch": AxisArray.CoordinateAxis( + data=np.array([f"Ch{_}" for _ in range(self.SETTINGS.n_ch)]), + dims=["ch"], + ), + }, + ) + + self._n_sent += self.SETTINGS.n_time + + yield self.OUTPUT_SIGNAL, result + + +# EEGSynth - Combines oscillator and pink noise (simplified version without actual pink noise filter) +class EEGSynthSettings(ez.Settings): + fs: float = 500.0 + n_time: int = 100 + alpha_freq: float = 10.5 + n_ch: int = 8 + + +class Clock(ez.Unit): + """Simple clock generator.""" + + OUTPUT_SIGNAL = ez.OutputStream(ez.Flag) + + SETTINGS: ez.Settings + + async def initialize(self) -> None: + self._t0 = time.time() + self._n_dispatch = 0 + + @ez.publisher(OUTPUT_SIGNAL) + async def produce(self) -> typing.AsyncGenerator: + while True: + if hasattr(self.SETTINGS, "dispatch_rate") and self.SETTINGS.dispatch_rate is not None: + target_time = self._t0 + (self._n_dispatch + 1) / self.SETTINGS.dispatch_rate + sleep_time = target_time - time.time() + if sleep_time > 0: + await asyncio.sleep(sleep_time) + + self._n_dispatch += 1 + yield self.OUTPUT_SIGNAL, ez.Flag() + + +class ClockSettings(ez.Settings): + dispatch_rate: float | None = None + + +class EEGSynth(ez.Collection): + """ + Simple EEG-like signal generator for testing. + Combines oscillator (alpha rhythm) with white noise. + """ + + SETTINGS = EEGSynthSettings + + OUTPUT_SIGNAL = ez.OutputStream(AxisArray) + + OSC = Oscillator() + NOISE = WhiteNoise() + ADD = Add() + + def configure(self) -> None: + self.OSC.apply_settings( + OscillatorSettings( + n_time=self.SETTINGS.n_time, + fs=self.SETTINGS.fs, + n_ch=self.SETTINGS.n_ch, + dispatch_rate=self.SETTINGS.fs / self.SETTINGS.n_time, + freq=self.SETTINGS.alpha_freq, + ) + ) + + self.NOISE.apply_settings( + WhiteNoiseSettings( + n_time=self.SETTINGS.n_time, + fs=self.SETTINGS.fs, + n_ch=self.SETTINGS.n_ch, + dispatch_rate=self.SETTINGS.fs / self.SETTINGS.n_time, + scale=5.0, + ) + ) + + def network(self) -> ez.NetworkDefinition: + return ( + (self.OSC.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_A), + (self.NOISE.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_B), + (self.ADD.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL), + ) diff --git a/tests/integration/ezmsg/test_add_system.py b/tests/integration/ezmsg/test_add_system.py new file mode 100644 index 00000000..17cbc34e --- /dev/null +++ b/tests/integration/ezmsg/test_add_system.py @@ -0,0 +1,148 @@ +"""Integration tests for ezmsg.sigproc.math.add module.""" + +import os + +import ezmsg.core as ez +import numpy as np +from ezmsg.util.messagecodec import message_log +from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings + +from ezmsg.sigproc.math.add import Add, ConstAdd, ConstAddSettings +from tests.helpers.synth import Counter, CounterSettings, Oscillator, OscillatorSettings +from tests.helpers.util import get_test_fn + + +def test_add_two_signals_system( + fs: float = 100.0, + n_time: int = 10, + n_messages: int = 5, + test_name: str | None = None, +): + """ + Test that Add unit correctly adds two synchronized signals. + + Uses two Counter units with different starting values to verify + element-wise addition. + """ + test_filename = get_test_fn(test_name) + ez.logger.info(test_filename) + + # We'll use two oscillators with different frequencies + # and verify the sum matches expected + comps = { + "OSC1": Oscillator( + OscillatorSettings( + n_time=n_time, + fs=fs, + n_ch=1, + dispatch_rate=fs / n_time, + freq=5.0, + amp=1.0, + ) + ), + "OSC2": Oscillator( + OscillatorSettings( + n_time=n_time, + fs=fs, + n_ch=1, + dispatch_rate=fs / n_time, + freq=10.0, + amp=2.0, + ) + ), + "ADD": Add(), + "LOG": MessageLogger( + MessageLoggerSettings( + output=test_filename, + ) + ), + "TERM": TerminateOnTotal( + TerminateOnTotalSettings( + total=n_messages, + ) + ), + } + conns = ( + (comps["OSC1"].OUTPUT_SIGNAL, comps["ADD"].INPUT_SIGNAL_A), + (comps["OSC2"].OUTPUT_SIGNAL, comps["ADD"].INPUT_SIGNAL_B), + (comps["ADD"].OUTPUT_SIGNAL, comps["LOG"].INPUT_MESSAGE), + (comps["LOG"].OUTPUT_MESSAGE, comps["TERM"].INPUT_MESSAGE), + ) + ez.run(components=comps, connections=conns) + + # Collect result + messages: list[AxisArray] = [_ for _ in message_log(test_filename)] + os.remove(test_filename) + + assert len(messages) == n_messages + + # Verify each message has correct shape + for msg in messages: + assert msg.data.shape == (n_time, 1) + + # Reconstruct the full signal and verify it's the sum of two sinusoids + data = np.concatenate([_.data for _ in messages]).squeeze() + n_samples = len(data) + t = np.arange(n_samples) / fs + + # Expected: 1.0 * sin(2*pi*5*t) + 2.0 * sin(2*pi*10*t) + expected = 1.0 * np.sin(2 * np.pi * 5.0 * t) + 2.0 * np.sin(2 * np.pi * 10.0 * t) + assert np.allclose(data, expected, atol=1e-10) + + +def test_const_add_system( + fs: float = 100.0, + n_time: int = 10, + n_messages: int = 5, + add_value: float = 100.0, + test_name: str | None = None, +): + """ + Test that ConstAdd unit correctly adds a constant to a signal. + """ + test_filename = get_test_fn(test_name) + ez.logger.info(test_filename) + + comps = { + "COUNTER": Counter( + CounterSettings( + n_time=n_time, + fs=fs, + n_ch=1, + dispatch_rate=fs / n_time, + ) + ), + "ADD": ConstAdd(ConstAddSettings(value=add_value)), + "LOG": MessageLogger( + MessageLoggerSettings( + output=test_filename, + ) + ), + "TERM": TerminateOnTotal( + TerminateOnTotalSettings( + total=n_messages, + ) + ), + } + conns = ( + (comps["COUNTER"].OUTPUT_SIGNAL, comps["ADD"].INPUT_SIGNAL), + (comps["ADD"].OUTPUT_SIGNAL, comps["LOG"].INPUT_MESSAGE), + (comps["LOG"].OUTPUT_MESSAGE, comps["TERM"].INPUT_MESSAGE), + ) + ez.run(components=comps, connections=conns) + + # Collect result + messages: list[AxisArray] = [_ for _ in message_log(test_filename)] + os.remove(test_filename) + + assert len(messages) == n_messages + + # Verify the constant was added + data = np.concatenate([_.data for _ in messages]).squeeze() + n_samples = len(data) + + # Counter produces 0, 1, 2, 3, ... so with add_value=100, we expect 100, 101, 102, ... + expected = np.arange(n_samples) + add_value + assert np.allclose(data, expected) diff --git a/tests/integration/ezmsg/test_butterworth_system.py b/tests/integration/ezmsg/test_butterworth_system.py index 229d63cf..a4d33291 100644 --- a/tests/integration/ezmsg/test_butterworth_system.py +++ b/tests/integration/ezmsg/test_butterworth_system.py @@ -11,7 +11,7 @@ from ezmsg.util.terminate import TerminateOnTimeoutSettings as TerminateTestSettings from ezmsg.sigproc.butterworthfilter import ButterworthFilter, ButterworthFilterSettings -from ezmsg.sigproc.synth import WhiteNoise, WhiteNoiseSettings +from tests.helpers.synth import WhiteNoise, WhiteNoiseSettings from tests.helpers.util import get_test_fn diff --git a/tests/integration/ezmsg/test_butterworthzerophase_system.py b/tests/integration/ezmsg/test_butterworthzerophase_system.py index a353e48e..2dbbbefb 100644 --- a/tests/integration/ezmsg/test_butterworthzerophase_system.py +++ b/tests/integration/ezmsg/test_butterworthzerophase_system.py @@ -8,7 +8,7 @@ from ezmsg.util.terminate import TerminateOnTotal from ezmsg.sigproc.butterworthzerophase import ButterworthZeroPhase -from ezmsg.sigproc.synth import EEGSynth +from tests.helpers.synth import EEGSynth def test_butterworth_zero_phase_system(): diff --git a/tests/integration/ezmsg/test_decimate_system.py b/tests/integration/ezmsg/test_decimate_system.py index 36463e4e..443fb66d 100644 --- a/tests/integration/ezmsg/test_decimate_system.py +++ b/tests/integration/ezmsg/test_decimate_system.py @@ -8,7 +8,7 @@ from ezmsg.util.terminate import TerminateOnTotal from ezmsg.sigproc.decimate import Decimate -from ezmsg.sigproc.synth import EEGSynth +from tests.helpers.synth import EEGSynth from tests.helpers.util import get_test_fn diff --git a/tests/integration/ezmsg/test_difference_system.py b/tests/integration/ezmsg/test_difference_system.py new file mode 100644 index 00000000..cdefe0d2 --- /dev/null +++ b/tests/integration/ezmsg/test_difference_system.py @@ -0,0 +1,203 @@ +"""Integration tests for ezmsg.sigproc.math.difference module.""" + +import os + +import ezmsg.core as ez +import numpy as np +from ezmsg.util.messagecodec import message_log +from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings + +from ezmsg.sigproc.math.difference import ConstDifference, ConstDifferenceSettings, Difference +from tests.helpers.synth import Counter, CounterSettings, Oscillator, OscillatorSettings +from tests.helpers.util import get_test_fn + + +def test_difference_two_signals_system( + fs: float = 100.0, + n_time: int = 10, + n_messages: int = 5, + test_name: str | None = None, +): + """ + Test that Difference unit correctly subtracts two synchronized signals. + + Uses two Oscillators and verifies the difference matches expected. + """ + test_filename = get_test_fn(test_name) + ez.logger.info(test_filename) + + # Two oscillators: OSC1 - OSC2 + comps = { + "OSC1": Oscillator( + OscillatorSettings( + n_time=n_time, + fs=fs, + n_ch=1, + dispatch_rate=fs / n_time, + freq=5.0, + amp=3.0, # Larger amplitude + ) + ), + "OSC2": Oscillator( + OscillatorSettings( + n_time=n_time, + fs=fs, + n_ch=1, + dispatch_rate=fs / n_time, + freq=5.0, + amp=1.0, # Smaller amplitude, same frequency + ) + ), + "DIFF": Difference(), + "LOG": MessageLogger( + MessageLoggerSettings( + output=test_filename, + ) + ), + "TERM": TerminateOnTotal( + TerminateOnTotalSettings( + total=n_messages, + ) + ), + } + conns = ( + (comps["OSC1"].OUTPUT_SIGNAL, comps["DIFF"].INPUT_SIGNAL_A), + (comps["OSC2"].OUTPUT_SIGNAL, comps["DIFF"].INPUT_SIGNAL_B), + (comps["DIFF"].OUTPUT_SIGNAL, comps["LOG"].INPUT_MESSAGE), + (comps["LOG"].OUTPUT_MESSAGE, comps["TERM"].INPUT_MESSAGE), + ) + ez.run(components=comps, connections=conns) + + # Collect result + messages: list[AxisArray] = [_ for _ in message_log(test_filename)] + os.remove(test_filename) + + assert len(messages) == n_messages + + # Verify each message has correct shape + for msg in messages: + assert msg.data.shape == (n_time, 1) + + # Reconstruct the full signal and verify it's the difference of two sinusoids + data = np.concatenate([_.data for _ in messages]).squeeze() + n_samples = len(data) + t = np.arange(n_samples) / fs + + # Expected: 3.0 * sin(2*pi*5*t) - 1.0 * sin(2*pi*5*t) = 2.0 * sin(2*pi*5*t) + expected = 2.0 * np.sin(2 * np.pi * 5.0 * t) + assert np.allclose(data, expected, atol=1e-10) + + +def test_const_difference_system( + fs: float = 100.0, + n_time: int = 10, + n_messages: int = 5, + subtract_value: float = 50.0, + test_name: str | None = None, +): + """ + Test that ConstDifference unit correctly subtracts a constant from a signal. + """ + test_filename = get_test_fn(test_name) + ez.logger.info(test_filename) + + comps = { + "COUNTER": Counter( + CounterSettings( + n_time=n_time, + fs=fs, + n_ch=1, + dispatch_rate=fs / n_time, + ) + ), + "DIFF": ConstDifference(ConstDifferenceSettings(value=subtract_value, subtrahend=True)), + "LOG": MessageLogger( + MessageLoggerSettings( + output=test_filename, + ) + ), + "TERM": TerminateOnTotal( + TerminateOnTotalSettings( + total=n_messages, + ) + ), + } + conns = ( + (comps["COUNTER"].OUTPUT_SIGNAL, comps["DIFF"].INPUT_SIGNAL), + (comps["DIFF"].OUTPUT_SIGNAL, comps["LOG"].INPUT_MESSAGE), + (comps["LOG"].OUTPUT_MESSAGE, comps["TERM"].INPUT_MESSAGE), + ) + ez.run(components=comps, connections=conns) + + # Collect result + messages: list[AxisArray] = [_ for _ in message_log(test_filename)] + os.remove(test_filename) + + assert len(messages) == n_messages + + # Verify the constant was subtracted + data = np.concatenate([_.data for _ in messages]).squeeze() + n_samples = len(data) + + # Counter produces 0, 1, 2, 3, ... so with subtract_value=50, we expect -50, -49, -48, ... + expected = np.arange(n_samples) - subtract_value + assert np.allclose(data, expected) + + +def test_const_difference_subtrahend_false_system( + fs: float = 100.0, + n_time: int = 10, + n_messages: int = 5, + value: float = 100.0, + test_name: str | None = None, +): + """ + Test ConstDifference with subtrahend=False (value - input). + """ + test_filename = get_test_fn(test_name) + ez.logger.info(test_filename) + + comps = { + "COUNTER": Counter( + CounterSettings( + n_time=n_time, + fs=fs, + n_ch=1, + dispatch_rate=fs / n_time, + ) + ), + "DIFF": ConstDifference(ConstDifferenceSettings(value=value, subtrahend=False)), + "LOG": MessageLogger( + MessageLoggerSettings( + output=test_filename, + ) + ), + "TERM": TerminateOnTotal( + TerminateOnTotalSettings( + total=n_messages, + ) + ), + } + conns = ( + (comps["COUNTER"].OUTPUT_SIGNAL, comps["DIFF"].INPUT_SIGNAL), + (comps["DIFF"].OUTPUT_SIGNAL, comps["LOG"].INPUT_MESSAGE), + (comps["LOG"].OUTPUT_MESSAGE, comps["TERM"].INPUT_MESSAGE), + ) + ez.run(components=comps, connections=conns) + + # Collect result + messages: list[AxisArray] = [_ for _ in message_log(test_filename)] + os.remove(test_filename) + + assert len(messages) == n_messages + + # Verify: value - input + data = np.concatenate([_.data for _ in messages]).squeeze() + n_samples = len(data) + + # Counter produces 0, 1, 2, 3, ... so with value=100 and subtrahend=False: + # result = 100 - counter = 100, 99, 98, ... + expected = value - np.arange(n_samples) + assert np.allclose(data, expected) diff --git a/tests/integration/ezmsg/test_downsample_system.py b/tests/integration/ezmsg/test_downsample_system.py index ea16ab69..0127da95 100644 --- a/tests/integration/ezmsg/test_downsample_system.py +++ b/tests/integration/ezmsg/test_downsample_system.py @@ -12,7 +12,7 @@ from ezmsg.util.terminate import TerminateOnTimeoutSettings as TerminateTestSettings from ezmsg.sigproc.downsample import Downsample, DownsampleSettings -from ezmsg.sigproc.synth import Counter, CounterSettings +from tests.helpers.synth import Counter, CounterSettings from tests.helpers.util import get_test_fn diff --git a/tests/integration/ezmsg/test_filter_system.py b/tests/integration/ezmsg/test_filter_system.py index d1b85fe2..2c5098b0 100644 --- a/tests/integration/ezmsg/test_filter_system.py +++ b/tests/integration/ezmsg/test_filter_system.py @@ -9,7 +9,7 @@ from ezmsg.sigproc.butterworthfilter import ButterworthFilter from ezmsg.sigproc.cheby import ChebyshevFilter -from ezmsg.sigproc.synth import EEGSynth +from tests.helpers.synth import EEGSynth from tests.helpers.util import get_test_fn diff --git a/tests/integration/ezmsg/test_fir_hilbert_system.py b/tests/integration/ezmsg/test_fir_hilbert_system.py index 855dd076..e9b332df 100644 --- a/tests/integration/ezmsg/test_fir_hilbert_system.py +++ b/tests/integration/ezmsg/test_fir_hilbert_system.py @@ -8,7 +8,7 @@ from ezmsg.util.terminate import TerminateOnTotal from ezmsg.sigproc.fir_hilbert import FIRHilbertEnvelopeUnit -from ezmsg.sigproc.synth import EEGSynth +from tests.helpers.synth import EEGSynth def test_hilbert_system(): diff --git a/tests/integration/ezmsg/test_fir_pmc_system.py b/tests/integration/ezmsg/test_fir_pmc_system.py index 10be08b9..e7f2c90e 100644 --- a/tests/integration/ezmsg/test_fir_pmc_system.py +++ b/tests/integration/ezmsg/test_fir_pmc_system.py @@ -8,7 +8,7 @@ from ezmsg.util.terminate import TerminateOnTotal from ezmsg.sigproc.fir_pmc import ParksMcClellanFIR -from ezmsg.sigproc.synth import EEGSynth +from tests.helpers.synth import EEGSynth def test_pmc_fir_system(): diff --git a/tests/integration/ezmsg/test_rollingscaler_system.py b/tests/integration/ezmsg/test_rollingscaler_system.py index cab2f802..52e5f948 100644 --- a/tests/integration/ezmsg/test_rollingscaler_system.py +++ b/tests/integration/ezmsg/test_rollingscaler_system.py @@ -8,7 +8,7 @@ from ezmsg.util.terminate import TerminateOnTotal from ezmsg.sigproc.rollingscaler import RollingScalerUnit -from ezmsg.sigproc.synth import EEGSynth +from tests.helpers.synth import EEGSynth def test_rolling_scaler_system(): diff --git a/tests/integration/ezmsg/test_sampler_system.py b/tests/integration/ezmsg/test_sampler_system.py index 79d855c6..5841eaf1 100644 --- a/tests/integration/ezmsg/test_sampler_system.py +++ b/tests/integration/ezmsg/test_sampler_system.py @@ -13,7 +13,7 @@ TriggerGenerator, TriggerGeneratorSettings, ) -from ezmsg.sigproc.synth import Oscillator, OscillatorSettings +from tests.helpers.synth import Oscillator, OscillatorSettings from tests.helpers.util import get_test_fn diff --git a/tests/integration/ezmsg/test_scaler_system.py b/tests/integration/ezmsg/test_scaler_system.py index ee1c2f49..cbc5e48d 100644 --- a/tests/integration/ezmsg/test_scaler_system.py +++ b/tests/integration/ezmsg/test_scaler_system.py @@ -9,7 +9,7 @@ from frozendict import frozendict from ezmsg.sigproc.scaler import AdaptiveStandardScaler, AdaptiveStandardScalerSettings, scaler_np -from ezmsg.sigproc.synth import Counter, CounterSettings +from tests.helpers.synth import Counter, CounterSettings from tests.helpers.util import get_test_fn @@ -36,7 +36,6 @@ def test_scaler_system( fs=fs, n_ch=1, dispatch_rate=duration, # Simulation duration in 1.0 seconds - mod=None, ) ), "SCALER": AdaptiveStandardScaler(AdaptiveStandardScalerSettings(time_constant=tau, axis="time")), diff --git a/tests/integration/ezmsg/test_spectrum_system.py b/tests/integration/ezmsg/test_spectrum_system.py index c562bb9d..defd517d 100644 --- a/tests/integration/ezmsg/test_spectrum_system.py +++ b/tests/integration/ezmsg/test_spectrum_system.py @@ -14,8 +14,8 @@ SpectrumSettings, WindowFunction, ) -from ezmsg.sigproc.synth import EEGSynth, EEGSynthSettings from ezmsg.sigproc.window import Window, WindowSettings +from tests.helpers.synth import EEGSynth, EEGSynthSettings from tests.helpers.util import ( get_test_fn, ) diff --git a/tests/integration/ezmsg/test_synth_system.py b/tests/integration/ezmsg/test_synth_system.py deleted file mode 100644 index 90de1928..00000000 --- a/tests/integration/ezmsg/test_synth_system.py +++ /dev/null @@ -1,237 +0,0 @@ -import asyncio # noqa: F401 -import os -from dataclasses import field - -import ezmsg.core as ez -import numpy as np -import pytest -from ezmsg.util.messagecodec import message_log -from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings -from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings - -from ezmsg.sigproc.synth import ( - Clock, - ClockSettings, - Counter, - CounterSettings, - EEGSynth, - EEGSynthSettings, -) -from tests.helpers.util import get_test_fn - - -class ClockTestSystemSettings(ez.Settings): - clock_settings: ClockSettings - log_settings: MessageLoggerSettings - term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings) - - -class ClockTestSystem(ez.Collection): - SETTINGS = ClockTestSystemSettings - - CLOCK = Clock() - LOG = MessageLogger() - TERM = TerminateOnTotal() - - def configure(self) -> None: - self.CLOCK.apply_settings(self.SETTINGS.clock_settings) - self.LOG.apply_settings(self.SETTINGS.log_settings) - self.TERM.apply_settings(self.SETTINGS.term_settings) - - def network(self) -> ez.NetworkDefinition: - return ( - (self.CLOCK.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE), - (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), - ) - - -@pytest.mark.parametrize("dispatch_rate", [None, 2.0, 20.0]) -def test_clock_system( - dispatch_rate: float | None, - test_name: str | None = None, -): - run_time = 1.0 - n_target = int(np.ceil(dispatch_rate * run_time)) if dispatch_rate else 100 - test_filename = get_test_fn(test_name) - ez.logger.info(test_filename) - settings = ClockTestSystemSettings( - clock_settings=ClockSettings(dispatch_rate=dispatch_rate), - log_settings=MessageLoggerSettings(output=test_filename), - term_settings=TerminateOnTotalSettings(total=n_target), - ) - system = ClockTestSystem(settings) - ez.run(SYSTEM=system) - - # Collect result - messages: list[AxisArray] = [_ for _ in message_log(test_filename)] - os.remove(test_filename) - - assert all([_ == ez.Flag() for _ in messages]) - assert len(messages) >= n_target - - -class CounterTestSystemSettings(ez.Settings): - counter_settings: CounterSettings - log_settings: MessageLoggerSettings - term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings) - - -class CounterTestSystem(ez.Collection): - SETTINGS = CounterTestSystemSettings - - COUNTER = Counter() - LOG = MessageLogger() - TERM = TerminateOnTotal() - - def configure(self) -> None: - self.COUNTER.apply_settings(self.SETTINGS.counter_settings) - self.LOG.apply_settings(self.SETTINGS.log_settings) - self.TERM.apply_settings(self.SETTINGS.term_settings) - - def network(self) -> ez.NetworkDefinition: - return ( - (self.COUNTER.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE), - (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), - ) - - -# Integration Test. -# General functionality of acounter verified above. Here we only need to test a couple configs. -@pytest.mark.parametrize( - "block_size, fs, dispatch_rate, mod", - [ - (1, 10.0, None, None), - (20, 1000.0, "realtime", None), - (1, 1000.0, 2.0, 2**3), - (10, 10.0, 20.0, 2**3), - # No test for ext_clock because that requires a different system - # (20, 10.0, "ext_clock", None), - ], -) -def test_counter_system( - block_size: int, - fs: float, - dispatch_rate: float | str | None, - mod: int | None, - test_name: str | None = None, -): - n_ch = 3 - target_dur = 2.6 # 2.6 seconds per test - if dispatch_rate is None: - # No sleep / wait - chunk_dur = 0.1 - elif isinstance(dispatch_rate, str): - if dispatch_rate == "realtime": - chunk_dur = block_size / fs - else: - # Note: float dispatch_rate will yield different number of samples than expected by target_dur and fs - chunk_dur = 1.0 / dispatch_rate - target_messages = int(target_dur / chunk_dur) - - test_filename = get_test_fn(test_name) - ez.logger.info(test_filename) - settings = CounterTestSystemSettings( - counter_settings=CounterSettings( - n_time=block_size, - fs=fs, - n_ch=n_ch, - dispatch_rate=dispatch_rate, - mod=mod, - ), - log_settings=MessageLoggerSettings( - output=test_filename, - ), - term_settings=TerminateOnTotalSettings( - total=target_messages, - ), - ) - system = CounterTestSystem(settings) - ez.run(SYSTEM=system) - - # Collect result - messages: list[AxisArray] = [_ for _ in message_log(test_filename)] - os.remove(test_filename) - - if dispatch_rate is None: - # The number of messages depends on how fast the computer is - target_messages = len(messages) - # This should be an equivalence assertion (==) but the use of TerminateOnTotal does - # not guarantee that MessageLogger will exit before an additional message is received. - # Let's just clip the last message if we exceed the target messages. - if len(messages) > target_messages: - messages = messages[:target_messages] - assert len(messages) == target_messages - - # Just do one quick data check - agg = AxisArray.concatenate(*messages, dim="time") - target_samples = block_size * target_messages - expected_data = np.arange(target_samples) - if mod is not None: - expected_data = expected_data % mod - assert np.array_equal(agg.data[:, 0], expected_data) - - -# TODO: test SinGenerator in a system. - - -class EEGSynthSettingsTest(ez.Settings): - synth_settings: EEGSynthSettings - log_settings: MessageLoggerSettings - term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings) - - -class EEGSynthIntegrationTest(ez.Collection): - SETTINGS = EEGSynthSettingsTest - - SOURCE = EEGSynth() - SINK = MessageLogger() - TERM = TerminateOnTotal() - - def configure(self) -> None: - self.SOURCE.apply_settings(self.SETTINGS.synth_settings) - self.SINK.apply_settings(self.SETTINGS.log_settings) - self.TERM.apply_settings(self.SETTINGS.term_settings) - - def network(self) -> ez.NetworkDefinition: - return ( - (self.SOURCE.OUTPUT_SIGNAL, self.SINK.INPUT_MESSAGE), - (self.SINK.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), - ) - - -def test_eegsynth_system( - test_name: str | None = None, -): - # Just a quick test to make sure the system runs. We aren't checking validity of values or anything. - fs = 500.0 - n_time = 100 # samples per block. dispatch_rate = fs / n_time - target_dur = 2.0 - target_messages = int(target_dur * fs / n_time) - - test_filename = get_test_fn(test_name) - ez.logger.info(test_filename) - - settings = EEGSynthSettingsTest( - synth_settings=EEGSynthSettings( - fs=fs, - n_time=n_time, - alpha_freq=10.5, - n_ch=8, - ), - log_settings=MessageLoggerSettings( - output=test_filename, - ), - term_settings=TerminateOnTotalSettings( - total=target_messages, - ), - ) - - system = EEGSynthIntegrationTest(settings) - ez.run(SYSTEM=system) - - messages: list[AxisArray] = [_ for _ in message_log(test_filename)] - os.remove(test_filename) - agg = AxisArray.concatenate(*messages, dim="time") - assert agg.axes["time"].gain == 1 / fs - assert agg.data.ndim == 2 diff --git a/tests/integration/ezmsg/test_window_system.py b/tests/integration/ezmsg/test_window_system.py index c3e5b829..b5be2c94 100644 --- a/tests/integration/ezmsg/test_window_system.py +++ b/tests/integration/ezmsg/test_window_system.py @@ -13,8 +13,8 @@ from ezmsg.util.terminate import TerminateOnTimeout as TerminateTest from ezmsg.util.terminate import TerminateOnTimeoutSettings as TerminateTestSettings -from ezmsg.sigproc.synth import Counter, CounterSettings from ezmsg.sigproc.window import Window, WindowSettings +from tests.helpers.synth import Counter, CounterSettings from tests.helpers.util import calculate_expected_windows, get_test_fn diff --git a/tests/unit/test_math_add.py b/tests/unit/test_math_add.py new file mode 100644 index 00000000..4aa6c143 --- /dev/null +++ b/tests/unit/test_math_add.py @@ -0,0 +1,247 @@ +"""Unit tests for ezmsg.sigproc.math.add module.""" + +import asyncio +import copy + +import numpy as np +from ezmsg.util.messages.axisarray import AxisArray +from frozendict import frozendict + +from ezmsg.sigproc.math.add import ( + AddProcessor, + AddState, + ConstAddSettings, + ConstAddTransformer, +) +from tests.helpers.util import assert_messages_equal + + +class TestConstAddTransformer: + """Tests for ConstAddTransformer.""" + + def test_basic_add_positive(self): + """Test adding a positive constant.""" + transformer = ConstAddTransformer(ConstAddSettings(value=5.0)) + + data = np.array([[1.0, 2.0], [3.0, 4.0]]) + msg_in = AxisArray( + data, + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + backup = copy.deepcopy(msg_in) + + msg_out = transformer(msg_in) + + expected = np.array([[6.0, 7.0], [8.0, 9.0]]) + assert np.allclose(msg_out.data, expected) + assert_messages_equal([msg_in], [backup]) + + def test_basic_add_negative(self): + """Test adding a negative constant (effectively subtraction).""" + transformer = ConstAddTransformer(ConstAddSettings(value=-3.0)) + + data = np.array([[10.0, 20.0], [30.0, 40.0]]) + msg_in = AxisArray( + data, + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + + msg_out = transformer(msg_in) + + expected = np.array([[7.0, 17.0], [27.0, 37.0]]) + assert np.allclose(msg_out.data, expected) + + def test_add_zero(self): + """Test adding zero (identity operation).""" + transformer = ConstAddTransformer(ConstAddSettings(value=0.0)) + + data = np.array([[1.0, 2.0], [3.0, 4.0]]) + msg_in = AxisArray( + data, + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + + msg_out = transformer(msg_in) + + assert np.allclose(msg_out.data, data) + + def test_preserves_axes(self): + """Test that axes are preserved in output.""" + transformer = ConstAddTransformer(ConstAddSettings(value=1.0)) + + data = np.array([[1.0, 2.0], [3.0, 4.0]]) + ch_axis = AxisArray.CoordinateAxis(data=np.array(["A", "B"]), dims=["ch"]) + msg_in = AxisArray( + data, + dims=["time", "ch"], + axes=frozendict( + { + "time": AxisArray.TimeAxis(fs=100.0, offset=1.5), + "ch": ch_axis, + } + ), + ) + + msg_out = transformer(msg_in) + + assert msg_out.dims == msg_in.dims + assert msg_out.axes["time"].gain == msg_in.axes["time"].gain + assert msg_out.axes["time"].offset == msg_in.axes["time"].offset + + def test_stateless_across_chunks(self): + """Test that transformer is stateless across multiple chunks.""" + transformer = ConstAddTransformer(ConstAddSettings(value=10.0)) + + chunks = [ + AxisArray( + np.array([[i * 1.0]]), + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0, offset=i * 0.01)}), + ) + for i in range(5) + ] + + outputs = [transformer(chunk) for chunk in chunks] + + for i, out in enumerate(outputs): + assert np.allclose(out.data, np.array([[i * 1.0 + 10.0]])) + + +class TestAddProcessor: + """Tests for AddProcessor.""" + + def test_basic_add(self): + """Test basic addition of two messages.""" + processor = AddProcessor() + + data_a = np.array([[1.0, 2.0], [3.0, 4.0]]) + data_b = np.array([[10.0, 20.0], [30.0, 40.0]]) + + msg_a = AxisArray( + data_a, + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + msg_b = AxisArray( + data_b, + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + + processor.push_a(msg_a) + processor.push_b(msg_b) + + # Use sync call + result = processor() + + expected = np.array([[11.0, 22.0], [33.0, 44.0]]) + assert np.allclose(result.data, expected) + + def test_queue_ordering(self): + """Test that messages are paired in order.""" + processor = AddProcessor() + + # Push multiple messages to each queue + for i in range(3): + msg_a = AxisArray( + np.array([[float(i)]]), + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + msg_b = AxisArray( + np.array([[float(i * 10)]]), + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + processor.push_a(msg_a) + processor.push_b(msg_b) + + # Results should be paired in order + for i in range(3): + result = processor() # Use sync call + expected = float(i) + float(i * 10) + assert np.allclose(result.data, np.array([[expected]])) + + def test_state_property(self): + """Test state getter and setter.""" + processor = AddProcessor() + + assert isinstance(processor.state, AddState) + + new_state = AddState() + processor.state = new_state + assert processor.state is new_state + + # Setting None should not change state + old_state = processor.state + processor.state = None + assert processor.state is old_state + + def test_sync_call(self): + """Test synchronous __call__ method.""" + processor = AddProcessor() + + msg_a = AxisArray( + np.array([[1.0]]), + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + msg_b = AxisArray( + np.array([[2.0]]), + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + + processor.push_a(msg_a) + processor.push_b(msg_b) + + result = processor() + assert np.allclose(result.data, np.array([[3.0]])) + + def test_legacy_interface(self): + """Test legacy __next__ and __anext__ interfaces.""" + processor = AddProcessor() + + msg_a = AxisArray( + np.array([[5.0]]), + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + msg_b = AxisArray( + np.array([[7.0]]), + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + + processor.push_a(msg_a) + processor.push_b(msg_b) + + # Test __next__ + result = next(processor) + assert np.allclose(result.data, np.array([[12.0]])) + + +class TestAddState: + """Tests for AddState dataclass.""" + + def test_default_queues(self): + """Test that default queues are created.""" + state = AddState() + + assert isinstance(state.queue_a, asyncio.Queue) + assert isinstance(state.queue_b, asyncio.Queue) + assert state.queue_a.empty() + assert state.queue_b.empty() + + def test_independent_queues(self): + """Test that queues are independent between instances.""" + state1 = AddState() + state2 = AddState() + + state1.queue_a.put_nowait("test") + + assert not state1.queue_a.empty() + assert state2.queue_a.empty() diff --git a/tests/unit/test_math_difference.py b/tests/unit/test_math_difference.py new file mode 100644 index 00000000..b17d5d52 --- /dev/null +++ b/tests/unit/test_math_difference.py @@ -0,0 +1,278 @@ +"""Unit tests for ezmsg.sigproc.math.difference module.""" + +import asyncio +import copy + +import numpy as np +from ezmsg.util.messages.axisarray import AxisArray +from frozendict import frozendict + +from ezmsg.sigproc.math.difference import ( + ConstDifferenceSettings, + ConstDifferenceTransformer, + DifferenceProcessor, + DifferenceState, + const_difference, +) +from tests.helpers.util import assert_messages_equal + + +class TestConstDifferenceTransformer: + """Tests for ConstDifferenceTransformer.""" + + def test_subtract_positive(self): + """Test subtracting a positive constant from input.""" + transformer = ConstDifferenceTransformer(ConstDifferenceSettings(value=5.0, subtrahend=True)) + + data = np.array([[10.0, 20.0], [30.0, 40.0]]) + msg_in = AxisArray( + data, + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + backup = copy.deepcopy(msg_in) + + msg_out = transformer(msg_in) + + expected = np.array([[5.0, 15.0], [25.0, 35.0]]) + assert np.allclose(msg_out.data, expected) + assert_messages_equal([msg_in], [backup]) + + def test_subtract_from_value(self): + """Test subtracting input from a constant value.""" + transformer = ConstDifferenceTransformer(ConstDifferenceSettings(value=100.0, subtrahend=False)) + + data = np.array([[10.0, 20.0], [30.0, 40.0]]) + msg_in = AxisArray( + data, + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + + msg_out = transformer(msg_in) + + # value - data = 100 - data + expected = np.array([[90.0, 80.0], [70.0, 60.0]]) + assert np.allclose(msg_out.data, expected) + + def test_subtract_zero(self): + """Test subtracting zero (identity operation).""" + transformer = ConstDifferenceTransformer(ConstDifferenceSettings(value=0.0, subtrahend=True)) + + data = np.array([[1.0, 2.0], [3.0, 4.0]]) + msg_in = AxisArray( + data, + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + + msg_out = transformer(msg_in) + + assert np.allclose(msg_out.data, data) + + def test_preserves_axes(self): + """Test that axes are preserved in output.""" + transformer = ConstDifferenceTransformer(ConstDifferenceSettings(value=1.0)) + + data = np.array([[1.0, 2.0], [3.0, 4.0]]) + ch_axis = AxisArray.CoordinateAxis(data=np.array(["A", "B"]), dims=["ch"]) + msg_in = AxisArray( + data, + dims=["time", "ch"], + axes=frozendict( + { + "time": AxisArray.TimeAxis(fs=100.0, offset=1.5), + "ch": ch_axis, + } + ), + ) + + msg_out = transformer(msg_in) + + assert msg_out.dims == msg_in.dims + assert msg_out.axes["time"].gain == msg_in.axes["time"].gain + assert msg_out.axes["time"].offset == msg_in.axes["time"].offset + + +class TestConstDifferenceFactory: + """Tests for const_difference factory function.""" + + def test_factory_creates_transformer(self): + """Test that factory creates a properly configured transformer.""" + transformer = const_difference(value=7.5, subtrahend=True) + + assert isinstance(transformer, ConstDifferenceTransformer) + assert transformer.settings.value == 7.5 + assert transformer.settings.subtrahend is True + + def test_factory_subtrahend_false(self): + """Test factory with subtrahend=False.""" + transformer = const_difference(value=50.0, subtrahend=False) + + assert transformer.settings.value == 50.0 + assert transformer.settings.subtrahend is False + + def test_factory_default_values(self): + """Test factory with default values.""" + transformer = const_difference() + + assert transformer.settings.value == 0.0 + assert transformer.settings.subtrahend is True + + +class TestDifferenceProcessor: + """Tests for DifferenceProcessor.""" + + def test_basic_difference(self): + """Test basic subtraction of two messages (A - B).""" + processor = DifferenceProcessor() + + data_a = np.array([[10.0, 20.0], [30.0, 40.0]]) + data_b = np.array([[1.0, 2.0], [3.0, 4.0]]) + + msg_a = AxisArray( + data_a, + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + msg_b = AxisArray( + data_b, + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + + processor.push_a(msg_a) + processor.push_b(msg_b) + + # Use sync call + result = processor() + + expected = np.array([[9.0, 18.0], [27.0, 36.0]]) + assert np.allclose(result.data, expected) + + def test_queue_ordering(self): + """Test that messages are paired in order.""" + processor = DifferenceProcessor() + + # Push multiple messages to each queue + for i in range(3): + msg_a = AxisArray( + np.array([[float(i * 10)]]), + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + msg_b = AxisArray( + np.array([[float(i)]]), + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + processor.push_a(msg_a) + processor.push_b(msg_b) + + # Results should be paired in order: (0 - 0), (10 - 1), (20 - 2) + for i in range(3): + result = processor() # Use sync call + expected = float(i * 10) - float(i) + assert np.allclose(result.data, np.array([[expected]])) + + def test_state_property(self): + """Test state getter and setter.""" + processor = DifferenceProcessor() + + assert isinstance(processor.state, DifferenceState) + + new_state = DifferenceState() + processor.state = new_state + assert processor.state is new_state + + # Setting None should not change state + old_state = processor.state + processor.state = None + assert processor.state is old_state + + def test_sync_call(self): + """Test synchronous __call__ method.""" + processor = DifferenceProcessor() + + msg_a = AxisArray( + np.array([[10.0]]), + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + msg_b = AxisArray( + np.array([[3.0]]), + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + + processor.push_a(msg_a) + processor.push_b(msg_b) + + result = processor() + assert np.allclose(result.data, np.array([[7.0]])) + + def test_legacy_interface(self): + """Test legacy __next__ and __anext__ interfaces.""" + processor = DifferenceProcessor() + + msg_a = AxisArray( + np.array([[20.0]]), + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + msg_b = AxisArray( + np.array([[8.0]]), + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + + processor.push_a(msg_a) + processor.push_b(msg_b) + + # Test __next__ + result = next(processor) + assert np.allclose(result.data, np.array([[12.0]])) + + def test_negative_result(self): + """Test that negative results are handled correctly (B > A).""" + processor = DifferenceProcessor() + + msg_a = AxisArray( + np.array([[5.0]]), + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + msg_b = AxisArray( + np.array([[10.0]]), + dims=["time", "ch"], + axes=frozendict({"time": AxisArray.TimeAxis(fs=100.0)}), + ) + + processor.push_a(msg_a) + processor.push_b(msg_b) + + result = processor() + assert np.allclose(result.data, np.array([[-5.0]])) + + +class TestDifferenceState: + """Tests for DifferenceState dataclass.""" + + def test_default_queues(self): + """Test that default queues are created.""" + state = DifferenceState() + + assert isinstance(state.queue_a, asyncio.Queue) + assert isinstance(state.queue_b, asyncio.Queue) + assert state.queue_a.empty() + assert state.queue_b.empty() + + def test_independent_queues(self): + """Test that queues are independent between instances.""" + state1 = DifferenceState() + state2 = DifferenceState() + + state1.queue_a.put_nowait("test") + + assert not state1.queue_a.empty() + assert state2.queue_a.empty() diff --git a/tests/unit/test_synth.py b/tests/unit/test_synth.py deleted file mode 100644 index e64ad664..00000000 --- a/tests/unit/test_synth.py +++ /dev/null @@ -1,142 +0,0 @@ -import asyncio # noqa: F401 -import time - -import ezmsg.core as ez -import numpy as np -import pytest -from ezmsg.util.messages.axisarray import AxisArray - -from ezmsg.sigproc.synth import ( - aclock, - acounter, - clock, - sin, -) - - -# TEST CLOCK -@pytest.mark.parametrize("dispatch_rate", [None, 1.0, 2.0, 5.0, 10.0, 20.0]) -def test_clock_gen(dispatch_rate: float | None): - run_time = 1.0 - n_target = int(np.ceil(dispatch_rate * run_time)) if dispatch_rate else 100 - gen = clock(dispatch_rate=dispatch_rate) - result = [] - t_start = time.time() - while len(result) < n_target: - result.append(next(gen)) - t_elapsed = time.time() - t_start - assert all([_ == ez.Flag() for _ in result]) - if dispatch_rate is not None: - assert (run_time - 1 / dispatch_rate) < t_elapsed < (run_time + 0.2) - else: - # 100 usec per iteration is pretty generous - assert t_elapsed < (n_target * 1e-4) - - -@pytest.mark.parametrize("dispatch_rate", [None, 2.0, 20.0]) -@pytest.mark.asyncio -async def test_aclock_agen(dispatch_rate: float | None): - run_time = 1.0 - n_target = int(np.ceil(dispatch_rate * run_time)) if dispatch_rate else 100 - agen = aclock(dispatch_rate=dispatch_rate) - result = [] - t_start = time.time() - while len(result) < n_target: - new_result = await agen.__anext__() - result.append(new_result) - t_elapsed = time.time() - t_start - assert all([_ == ez.Flag() for _ in result]) - if dispatch_rate: - assert (run_time - 1.1 / dispatch_rate) < t_elapsed < (run_time + 0.1) - else: - # 100 usec per iteration is pretty generous - assert t_elapsed < (n_target * 1e-4) - - -@pytest.mark.parametrize("block_size", [1, 20]) -@pytest.mark.parametrize("fs", [10.0, 1000.0]) -@pytest.mark.parametrize("n_ch", [3]) -@pytest.mark.parametrize( - "dispatch_rate", [None, "realtime", "ext_clock", 2.0, 20.0] -) # "ext_clock" needs a separate test -@pytest.mark.parametrize("mod", [2**3, None]) -@pytest.mark.asyncio -async def test_acounter( - block_size: int, - fs: float, - n_ch: int, - dispatch_rate: float | str | None, - mod: int | None, -): - target_dur = 2.6 # 2.6 seconds per test - if dispatch_rate is None: - # No sleep / wait - chunk_dur = 0.1 - elif isinstance(dispatch_rate, str): - if dispatch_rate == "realtime": - chunk_dur = block_size / fs - elif dispatch_rate == "ext_clock": - # No sleep / wait - chunk_dur = 0.1 - else: - # Note: float dispatch_rate will yield different number of samples than expected by target_dur and fs - chunk_dur = 1.0 / dispatch_rate - target_messages = int(target_dur / chunk_dur) - - # Run generator - agen = acounter(block_size, fs, n_ch=n_ch, dispatch_rate=dispatch_rate, mod=mod) - messages = [await agen.__anext__() for _ in range(target_messages)] - - # Test contents of individual messages - for msg in messages: - assert type(msg) is AxisArray - assert msg.data.shape == (block_size, n_ch) - assert "time" in msg.axes - assert msg.axes["time"].gain == 1 / fs - assert "ch" in msg.axes - assert np.array_equal(msg.axes["ch"].data, np.array([f"Ch{_}" for _ in range(n_ch)])) - - agg = AxisArray.concatenate(*messages, dim="time") - - target_samples = block_size * target_messages - expected_data = np.arange(target_samples) - if mod is not None: - expected_data = expected_data % mod - assert np.array_equal(agg.data[:, 0], expected_data) - - offsets = np.array([m.axes["time"].offset for m in messages]) - expected_offsets = np.arange(target_messages) * block_size / fs - if dispatch_rate == "realtime" or dispatch_rate == "ext_clock": - expected_offsets += offsets[0] # offsets are in real-time - atol = 0.002 - else: - # Offsets are synthetic. - atol = 1.0e-8 - assert np.allclose(offsets[2:], expected_offsets[2:], atol=atol) - - -# TEST SIN # -def test_sin_gen(freq: float = 1.0, amp: float = 1.0, phase: float = 0.0): - axis: str | None = "time" - srate = max(4.0 * freq, 1000.0) - sim_dur = 30.0 - n_samples = int(srate * sim_dur) - n_msgs = min(n_samples, 10) - axis_idx = 0 - - messages = [] - for split_dat in np.array_split(np.arange(n_samples)[:, None], n_msgs, axis=axis_idx): - _time_axis = AxisArray.TimeAxis(fs=srate, offset=float(split_dat[0, 0])) - messages.append(AxisArray(split_dat, dims=["time", "ch"], axes={"time": _time_axis})) - - def f_test(t): - return amp * np.sin(2 * np.pi * freq * t + phase) - - gen = sin(axis=axis, freq=freq, amp=amp, phase=phase) - results = [] - for msg in messages: - res = gen.send(msg) - assert np.allclose(res.data, f_test(msg.data / srate)) - results.append(res) - concat_ax_arr = AxisArray.concatenate(*results, dim="time") - assert np.allclose(concat_ax_arr.data, f_test(np.arange(n_samples) / srate)[:, None])