diff --git a/src/ezmsg/baseproc/protocols.py b/src/ezmsg/baseproc/protocols.py index d5707dd..48cefcc 100644 --- a/src/ezmsg/baseproc/protocols.py +++ b/src/ezmsg/baseproc/protocols.py @@ -5,8 +5,7 @@ from dataclasses import dataclass import ezmsg.core as ez - -from .util.message import SampleMessage +from ezmsg.util.messages.axisarray import AxisArray # --- Processor state decorator --- processor_state = functools.partial(dataclass, unsafe_hash=True, frozen=False, init=False) @@ -138,7 +137,7 @@ def stateful_op( class AdaptiveTransformer(StatefulTransformer, typing.Protocol): - def partial_fit(self, message: SampleMessage) -> None: + def partial_fit(self, message: AxisArray) -> None: """Update transformer state using labeled training data. This method should update the internal state/parameters of the transformer @@ -146,4 +145,4 @@ def partial_fit(self, message: SampleMessage) -> None: """ ... - async def apartial_fit(self, message: SampleMessage) -> None: ... + async def apartial_fit(self, message: AxisArray) -> None: ... diff --git a/src/ezmsg/baseproc/stateful.py b/src/ezmsg/baseproc/stateful.py index 990f134..b859345 100644 --- a/src/ezmsg/baseproc/stateful.py +++ b/src/ezmsg/baseproc/stateful.py @@ -4,6 +4,9 @@ import typing from abc import ABC, abstractmethod +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace + from .processor import ( BaseProcessor, BaseProducer, @@ -256,7 +259,7 @@ def stateful_op( class BaseAdaptiveTransformer( BaseStatefulTransformer[ SettingsType, - MessageInType | SampleMessage, + MessageInType, MessageOutType | None, StateType, ], @@ -264,30 +267,41 @@ class BaseAdaptiveTransformer( typing.Generic[SettingsType, MessageInType, MessageOutType, StateType], ): @abstractmethod - def partial_fit(self, message: SampleMessage) -> None: ... + def partial_fit(self, message: AxisArray) -> None: ... - async def apartial_fit(self, message: SampleMessage) -> None: + async def apartial_fit(self, message: AxisArray) -> None: """Override me if you need async partial fitting.""" return self.partial_fit(message) - def __call__(self, message: MessageInType | SampleMessage) -> MessageOutType | None: + def __call__(self, message: MessageInType) -> MessageOutType | None: """ Adapt transformer with training data (and optionally labels) - in SampleMessage + in AxisArray with attrs["trigger"]. Args: - message: An instance of SampleMessage with optional - labels (y) in message.trigger.value.data and - data (X) in message.sample.data + message: An AxisArray with optional trigger in attrs["trigger"], + containing labels (y) in attrs["trigger"].value and + data (X) in message.data Returns: None """ if is_sample_message(message): + if isinstance(message, SampleMessage): + # Auto-convert old format → new format + message = replace( + message.sample, + attrs={**message.sample.attrs, "trigger": message.trigger}, + ) return self.partial_fit(message) return super().__call__(message) - async def __acall__(self, message: MessageInType | SampleMessage) -> MessageOutType | None: + async def __acall__(self, message: MessageInType) -> MessageOutType | None: if is_sample_message(message): + if isinstance(message, SampleMessage): + message = replace( + message.sample, + attrs={**message.sample.attrs, "trigger": message.trigger}, + ) return await self.apartial_fit(message) return await super().__acall__(message) diff --git a/src/ezmsg/baseproc/units.py b/src/ezmsg/baseproc/units.py index bb0554b..c41429b 100644 --- a/src/ezmsg/baseproc/units.py +++ b/src/ezmsg/baseproc/units.py @@ -14,7 +14,6 @@ from .processor import BaseConsumer, BaseProducer, BaseTransformer from .protocols import MessageInType, MessageOutType, SettingsType from .stateful import BaseAdaptiveTransformer, BaseStatefulConsumer, BaseStatefulTransformer -from .util.message import SampleMessage from .util.profile import profile_subpub from .util.typeresolution import resolve_typevar @@ -223,7 +222,7 @@ class BaseAdaptiveTransformerUnit( ABC, typing.Generic[SettingsType, MessageInType, MessageOutType, AdaptiveTransformerType], ): - INPUT_SAMPLE = ez.InputStream(SampleMessage) + INPUT_SAMPLE = ez.InputStream(AxisArray) INPUT_SIGNAL = ez.InputStream(MessageInType) OUTPUT_SIGNAL = ez.OutputStream(MessageOutType) @@ -242,7 +241,7 @@ async def on_signal(self, message: MessageInType) -> typing.AsyncGenerator: yield self.OUTPUT_SIGNAL, result @ez.subscriber(INPUT_SAMPLE) - async def on_sample(self, msg: SampleMessage) -> None: + async def on_sample(self, msg: AxisArray) -> None: await self.processor.apartial_fit(msg) diff --git a/src/ezmsg/baseproc/util/message.py b/src/ezmsg/baseproc/util/message.py index 960d4e9..e78b085 100644 --- a/src/ezmsg/baseproc/util/message.py +++ b/src/ezmsg/baseproc/util/message.py @@ -1,5 +1,6 @@ import time import typing +import warnings from dataclasses import dataclass, field from ezmsg.util.messages.axisarray import AxisArray @@ -19,13 +20,28 @@ class SampleTriggerMessage: @dataclass class SampleMessage: + """ + .. deprecated:: + ``SampleMessage`` is deprecated. Use ``AxisArray`` with + ``attrs={"trigger": SampleTriggerMessage(...)}`` instead. + """ + trigger: SampleTriggerMessage """The time, window, and value (if any) associated with the trigger.""" sample: AxisArray """The data sampled around the trigger.""" + def __post_init__(self): + warnings.warn( + "SampleMessage is deprecated. Use AxisArray with " "attrs={'trigger': SampleTriggerMessage(...)} instead.", + DeprecationWarning, + stacklevel=2, + ) + -def is_sample_message(message: typing.Any) -> typing.TypeGuard[SampleMessage]: - """Check if the message is a SampleMessage.""" - return hasattr(message, "trigger") +def is_sample_message(message: typing.Any) -> bool: + """Detect old SampleMessage OR new AxisArray-with-trigger.""" + if isinstance(message, SampleMessage): + return True + return isinstance(message, AxisArray) and "trigger" in getattr(message, "attrs", {}) diff --git a/tests/test_baseproc.py b/tests/test_baseproc.py index 4797347..ac481a8 100644 --- a/tests/test_baseproc.py +++ b/tests/test_baseproc.py @@ -4,9 +4,10 @@ import pickle from types import NoneType from typing import Any -from unittest.mock import MagicMock +import numpy as np import pytest +from ezmsg.util.messages.axisarray import AxisArray from ezmsg.baseproc import ( BaseAdaptiveTransformer, @@ -21,7 +22,7 @@ BaseTransformer, CompositeProcessor, CompositeProducer, - SampleMessage, + SampleTriggerMessage, _get_base_processor_message_in_type, _get_base_processor_message_out_type, _get_base_processor_settings_type, @@ -135,11 +136,11 @@ class MockAdaptiveTransformer(BaseAdaptiveTransformer[MockSettings, MockMessageA def _reset_state(self, message: MockMessageA) -> None: self._state.iterations = 0 - def _process(self, message: MockMessageA | SampleMessage) -> MockMessageB: + def _process(self, message: MockMessageA) -> MockMessageB: self._state.iterations += 1 return MockMessageB() - def partial_fit(self, message: SampleMessage) -> None: + def partial_fit(self, message: AxisArray) -> None: self._state.iterations += 1 @@ -756,10 +757,13 @@ def test_stateful_op(self): assert new_state[0].iterations == 1 -# Mock SampleMessage for testing BaseAdaptiveTransformer +# Helper to create an AxisArray with trigger in attrs for testing BaseAdaptiveTransformer def mock_sample_message(): - sample_message = MagicMock(spec=SampleMessage) - return sample_message + return AxisArray( + data=np.zeros((1, 1)), + dims=["time", "ch"], + attrs={"trigger": SampleTriggerMessage()}, + ) class TestBaseAdaptiveTransformer: @@ -776,9 +780,7 @@ async def test_apartial_fit(self): def test_call_with_sample_message(self): transformer = MockAdaptiveTransformer() - # Create a sample message with a trigger attribute sample_msg = mock_sample_message() - setattr(sample_msg, "trigger", None) result = transformer(sample_msg) assert result is None # partial_fit returns None assert transformer.state.iterations == 1