Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/guides/ProcessorsBase.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ do not inherit from `BaseStatefulProcessor` and `BaseStatefulProducer`. They acc
| 3 | `BaseConsumerUnit` | 1 | `ConsumerType` |
| 4 | `BaseTransformerUnit` | 1 | `TransformerType` |
| 5 | `BaseAdaptiveTransformerUnit` | 1 | `AdaptiveTransformerType` |
| 6 | `BaseClockDrivenProducerUnit` | 1 | `ClockDrivenProducerType` |
| 6 | `BaseClockDrivenUnit` | 1 | `ClockDrivenProducerType` |

Note, it is strongly recommended to use `BaseConsumerUnit`, `BaseTransformerUnit`, `BaseAdaptiveTransformerUnit`, or `BaseClockDrivenProducerUnit` for implementing concrete subclasses rather than `BaseProcessorUnit`.
Note, it is strongly recommended to use `BaseConsumerUnit`, `BaseTransformerUnit`, `BaseAdaptiveTransformerUnit`, or `BaseClockDrivenUnit` for implementing concrete subclasses rather than `BaseProcessorUnit`.


## Implementing a custom standalone processor
Expand Down
4 changes: 2 additions & 2 deletions docs/source/guides/how-tos/processors/clockdriven.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Here's a complete example of a sine wave generator:

from ezmsg.baseproc import (
BaseClockDrivenProducer,
BaseClockDrivenProducerUnit,
BaseClockDrivenUnit,
ClockDrivenSettings,
ClockDrivenState,
processor_state,
Expand Down Expand Up @@ -124,7 +124,7 @@ Here's a complete example of a sine wave generator:


class SinGeneratorUnit(
BaseClockDrivenProducerUnit[SinGeneratorSettings, SinGenerator]
BaseClockDrivenUnit[SinGeneratorSettings, SinGenerator]
):
"""
ezmsg Unit wrapper for SinGenerator.
Expand Down
4 changes: 2 additions & 2 deletions src/ezmsg/baseproc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
from .units import (
AdaptiveTransformerType,
BaseAdaptiveTransformerUnit,
BaseClockDrivenProducerUnit,
BaseClockDrivenUnit,
BaseConsumerUnit,
BaseProcessorUnit,
BaseProducerUnit,
Expand Down Expand Up @@ -158,7 +158,7 @@
"BaseConsumerUnit",
"BaseTransformerUnit",
"BaseAdaptiveTransformerUnit",
"BaseClockDrivenProducerUnit",
"BaseClockDrivenUnit",
"GenAxisArray",
# Type resolution helpers
"get_base_producer_type",
Expand Down
4 changes: 2 additions & 2 deletions src/ezmsg/baseproc/counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ClockDrivenState,
)
from .protocols import processor_state
from .units import BaseClockDrivenProducerUnit
from .units import BaseClockDrivenUnit


class CounterSettings(ClockDrivenSettings):
Expand Down Expand Up @@ -57,7 +57,7 @@ def _produce(self, n_samples: int, time_axis: LinearAxis) -> AxisArray:
)


class Counter(BaseClockDrivenProducerUnit[CounterSettings, CounterTransformer]):
class Counter(BaseClockDrivenUnit[CounterSettings, CounterTransformer]):
"""
Transforms clock ticks into monotonically increasing counter values as AxisArray.

Expand Down
4 changes: 2 additions & 2 deletions src/ezmsg/baseproc/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async def on_sample(self, msg: SampleMessage) -> None:
await self.processor.apartial_fit(msg)


class BaseClockDrivenProducerUnit(
class BaseClockDrivenUnit(
BaseProcessorUnit[SettingsType],
ABC,
typing.Generic[SettingsType, ClockDrivenProducerType],
Expand All @@ -260,7 +260,7 @@ class BaseClockDrivenProducerUnit(

Implement a new Unit as follows::

class SinGeneratorUnit(BaseClockDrivenProducerUnit[
class SinGeneratorUnit(BaseClockDrivenUnit[
SinGeneratorSettings, # SettingsType (must extend ClockDrivenSettings)
SinProducer, # ClockDrivenProducerType
]):
Expand Down
18 changes: 18 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""pytest configuration for ezmsg-baseproc tests."""

import os
import sys

import pytest

# Add tests directory to path so 'tests.helpers' can be imported
_tests_dir = os.path.dirname(__file__)
_parent_dir = os.path.dirname(_tests_dir)
if _parent_dir not in sys.path:
sys.path.insert(0, _parent_dir)


@pytest.fixture
def test_name(request):
"""Provide the test name to test functions."""
return request.node.name
Empty file added tests/helpers/__init__.py
Empty file.
24 changes: 24 additions & 0 deletions tests/helpers/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os
import tempfile
from pathlib import Path


def get_test_fn(test_name: str | None = None, extension: str = "txt") -> Path:
"""PYTEST compatible temporary test file creator"""

# Get current test name if we can..
if test_name is None:
test_name = os.environ.get("PYTEST_CURRENT_TEST")
if test_name is not None:
test_name = test_name.split(":")[-1].split(" ")[0]
else:
test_name = __name__

file_path = Path(tempfile.gettempdir())
file_path = file_path / Path(f"{test_name}.{extension}")

# Create the file
with open(file_path, "w"):
pass

return file_path
229 changes: 229 additions & 0 deletions tests/test_clock_counter_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
"""Integration tests for Clock and Counter ezmsg systems."""

import math
import os
from dataclasses import field

import ezmsg.core as ez
import numpy as np
import pytest
from ezmsg.util.messagecodec import message_log
from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings

from ezmsg.baseproc import (
Clock,
ClockSettings,
Counter,
CounterSettings,
)
from tests.helpers.util import get_test_fn


class ClockTestSystemSettings(ez.Settings):
clock_settings: ClockSettings
log_settings: MessageLoggerSettings
term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings)


class ClockTestSystem(ez.Collection):
SETTINGS = ClockTestSystemSettings

CLOCK = Clock()
LOG = MessageLogger()
TERM = TerminateOnTotal()

def configure(self) -> None:
self.CLOCK.apply_settings(self.SETTINGS.clock_settings)
self.LOG.apply_settings(self.SETTINGS.log_settings)
self.TERM.apply_settings(self.SETTINGS.term_settings)

def network(self) -> ez.NetworkDefinition:
return (
(self.CLOCK.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE),
(self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE),
)


@pytest.mark.parametrize("dispatch_rate", [math.inf, 2.0, 20.0])
def test_clock_system(
dispatch_rate: float,
test_name: str | None = None,
):
run_time = 1.0
n_target = 100 if math.isinf(dispatch_rate) else int(np.ceil(dispatch_rate * run_time))
test_filename = get_test_fn(test_name)
ez.logger.info(test_filename)
settings = ClockTestSystemSettings(
clock_settings=ClockSettings(dispatch_rate=dispatch_rate),
log_settings=MessageLoggerSettings(output=test_filename),
term_settings=TerminateOnTotalSettings(total=n_target),
)
system = ClockTestSystem(settings)
ez.run(SYSTEM=system)

# Collect result
messages = list(message_log(test_filename))
os.remove(test_filename)

# Clock produces LinearAxis with gain and offset
assert all(isinstance(m, AxisArray.LinearAxis) for m in messages)
assert len(messages) >= n_target


class CounterTestSystemSettings(ez.Settings):
clock_settings: ClockSettings
counter_settings: CounterSettings
log_settings: MessageLoggerSettings
term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings)


class CounterTestSystem(ez.Collection):
"""Counter must be driven by Clock in the new architecture."""

SETTINGS = CounterTestSystemSettings

CLOCK = Clock()
COUNTER = Counter()
LOG = MessageLogger()
TERM = TerminateOnTotal()

def configure(self) -> None:
self.CLOCK.apply_settings(self.SETTINGS.clock_settings)
self.COUNTER.apply_settings(self.SETTINGS.counter_settings)
self.LOG.apply_settings(self.SETTINGS.log_settings)
self.TERM.apply_settings(self.SETTINGS.term_settings)

def network(self) -> ez.NetworkDefinition:
return (
(self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_CLOCK),
(self.COUNTER.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE),
(self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE),
)


@pytest.mark.parametrize(
"n_time, fs, dispatch_rate, mod",
[
(1, 10.0, math.inf, None), # AFAP mode
(20, 1000.0, 50.0, None), # Realtime mode (50 Hz dispatch = 20 samples/tick @ 1000 Hz)
(1, 1000.0, 100.0, 2**3), # 100 Hz dispatch with mod
(10, 10.0, 10.0, 2**3), # 10 Hz dispatch with mod
],
)
def test_counter_system(
n_time: int,
fs: float,
dispatch_rate: float,
mod: int | None,
test_name: str | None = None,
):
target_dur = 2.6 # 2.6 seconds per test
if math.isinf(dispatch_rate):
# AFAP mode - runs as fast as possible
target_messages = 100 # Fixed target for AFAP
else:
target_messages = int(target_dur * dispatch_rate)

test_filename = get_test_fn(test_name)
ez.logger.info(test_filename)
settings = CounterTestSystemSettings(
clock_settings=ClockSettings(dispatch_rate=dispatch_rate),
counter_settings=CounterSettings(
n_time=n_time,
fs=fs,
mod=mod,
),
log_settings=MessageLoggerSettings(
output=test_filename,
),
term_settings=TerminateOnTotalSettings(
total=target_messages,
),
)
system = CounterTestSystem(settings)
ez.run(SYSTEM=system)

# Collect result
messages: list[AxisArray] = [_ for _ in message_log(test_filename)]
os.remove(test_filename)

if math.isinf(dispatch_rate):
# The number of messages depends on how fast the computer is
target_messages = len(messages)
# This should be an equivalence assertion (==) but the use of TerminateOnTotal does
# not guarantee that MessageLogger will exit before an additional message is received.
# Let's just clip the last message if we exceed the target messages.
if len(messages) > target_messages:
messages = messages[:target_messages]
assert len(messages) >= target_messages

# Just do one quick data check (Counter now outputs 1D array)
agg = AxisArray.concatenate(*messages, dim="time")
target_samples = n_time * target_messages
expected_data = np.arange(target_samples)
if mod is not None:
expected_data = expected_data % mod
assert np.array_equal(agg.data, expected_data)


@pytest.mark.parametrize(
"clock_rate, fs, n_time",
[
(10.0, 1000.0, 100), # 10 Hz clock, fs=1000, n_time=100 (fixed)
(20.0, 500.0, None), # 20 Hz clock, fs=500, n_time derived (25 samples per tick)
(5.0, 1000.0, None), # 5 Hz clock, fs=1000, n_time derived (200 samples per tick)
],
)
def test_counter_with_external_clock(
clock_rate: float,
fs: float,
n_time: int | None,
test_name: str | None = None,
):
"""Test Counter driven by external Clock (now the standard pattern)."""
target_messages = 20
test_filename = get_test_fn(test_name)
ez.logger.info(test_filename)

# This now uses the same CounterTestSystem since all counters need clocks
settings = CounterTestSystemSettings(
clock_settings=ClockSettings(dispatch_rate=clock_rate),
counter_settings=CounterSettings(
fs=fs,
n_time=n_time,
),
log_settings=MessageLoggerSettings(output=test_filename),
term_settings=TerminateOnTotalSettings(total=target_messages),
)
system = CounterTestSystem(settings)
ez.run(SYSTEM=system)

# Collect result
messages: list[AxisArray] = [_ for _ in message_log(test_filename)]
os.remove(test_filename)

assert len(messages) >= target_messages

# Verify each message has correct sample rate (gain = 1/fs)
for msg in messages:
assert msg.axes["time"].gain == 1.0 / fs

# Verify data continuity
messages = messages[:target_messages] # Trim to target
agg = AxisArray.concatenate(*messages, dim="time")

# Expected samples per tick
if n_time is not None:
expected_samples_per_tick = n_time
else:
expected_samples_per_tick = int(fs / clock_rate)

expected_total = expected_samples_per_tick * target_messages
# Allow for fractional sample accumulation variance
assert abs(len(agg.data) - expected_total) <= target_messages

# Counter values should be sequential (0, 1, 2, ...)
expected_data = np.arange(len(agg.data))
assert np.array_equal(agg.data, expected_data)