diff --git a/src/ezmsg/baseproc/protocols.py b/src/ezmsg/baseproc/protocols.py index 48cefcc..916ba66 100644 --- a/src/ezmsg/baseproc/protocols.py +++ b/src/ezmsg/baseproc/protocols.py @@ -146,3 +146,5 @@ def partial_fit(self, message: AxisArray) -> None: ... async def apartial_fit(self, message: AxisArray) -> None: ... + def partial_fit_transform(self, message: AxisArray) -> MessageOutType: ... + async def apartial_fit_transform(self, message: AxisArray) -> MessageOutType: ... diff --git a/src/ezmsg/baseproc/stateful.py b/src/ezmsg/baseproc/stateful.py index b859345..b3e42a6 100644 --- a/src/ezmsg/baseproc/stateful.py +++ b/src/ezmsg/baseproc/stateful.py @@ -2,10 +2,10 @@ import pickle import typing +import warnings from abc import ABC, abstractmethod from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.util.messages.util import replace from .processor import ( BaseProcessor, @@ -14,7 +14,7 @@ ) from .protocols import MessageInType, MessageOutType, SettingsType, StateType from .util.asio import run_coroutine_sync -from .util.message import SampleMessage, is_sample_message +from .util.message import is_sample_message from .util.typeresolution import resolve_typevar @@ -274,37 +274,47 @@ async def apartial_fit(self, message: AxisArray) -> None: return self.partial_fit(message) def __call__(self, message: MessageInType) -> MessageOutType | None: - """ - Adapt transformer with training data (and optionally labels) - in AxisArray with attrs["trigger"]. - - Args: - message: An AxisArray with optional trigger in attrs["trigger"], - containing labels (y) in attrs["trigger"].value and - data (X) in message.data - - Returns: None - """ if is_sample_message(message): - if isinstance(message, SampleMessage): - # Auto-convert old format → new format - message = replace( - message.sample, - attrs={**message.sample.attrs, "trigger": message.trigger}, - ) - return self.partial_fit(message) + warnings.warn( + f"{self.__class__.__name__}.__call__() received a sample message " + "(AxisArray with 'trigger' in attrs). Auto-routing to partial_fit " + "has been removed. Use partial_fit() for training only, or " + "partial_fit_transform() for training + inference.", + UserWarning, + stacklevel=2, + ) return super().__call__(message) async def __acall__(self, message: MessageInType) -> MessageOutType | None: if is_sample_message(message): - if isinstance(message, SampleMessage): - message = replace( - message.sample, - attrs={**message.sample.attrs, "trigger": message.trigger}, - ) - return await self.apartial_fit(message) + warnings.warn( + f"{self.__class__.__name__}.__acall__() received a sample message " + "(AxisArray with 'trigger' in attrs). Auto-routing to partial_fit " + "has been removed. Use apartial_fit() for training only, or " + "apartial_fit_transform() for training + inference.", + UserWarning, + stacklevel=2, + ) return await super().__acall__(message) + def partial_fit_transform(self, message: AxisArray) -> MessageOutType: + """Train on the message, then run inference and return the result.""" + msg_hash = self._hash_message(message) + if msg_hash != self._hash: + self._reset_state(message) + self._hash = msg_hash + self.partial_fit(message) + return self._process(message) + + async def apartial_fit_transform(self, message: AxisArray) -> MessageOutType: + """Async variant of partial_fit_transform.""" + msg_hash = self._hash_message(message) + if msg_hash != self._hash: + self._reset_state(message) + self._hash = msg_hash + await self.apartial_fit(message) + return await self._aprocess(message) + class BaseAsyncTransformer( BaseStatefulTransformer[SettingsType, MessageInType, MessageOutType, StateType], diff --git a/src/ezmsg/baseproc/units.py b/src/ezmsg/baseproc/units.py index c41429b..3ea7343 100644 --- a/src/ezmsg/baseproc/units.py +++ b/src/ezmsg/baseproc/units.py @@ -225,6 +225,7 @@ class BaseAdaptiveTransformerUnit( INPUT_SAMPLE = ez.InputStream(AxisArray) INPUT_SIGNAL = ez.InputStream(MessageInType) OUTPUT_SIGNAL = ez.OutputStream(MessageOutType) + OUTPUT_SAMPLE = ez.OutputStream(MessageOutType) def create_processor(self) -> None: # self.processor: AdaptiveTransformerType[SettingsType, MessageInType, MessageOutType, StateType] @@ -241,8 +242,12 @@ async def on_signal(self, message: MessageInType) -> typing.AsyncGenerator: yield self.OUTPUT_SIGNAL, result @ez.subscriber(INPUT_SAMPLE) - async def on_sample(self, msg: AxisArray) -> None: - await self.processor.apartial_fit(msg) + @ez.publisher(OUTPUT_SAMPLE) + @profile_subpub(trace_oldest=False) + async def on_sample(self, msg: AxisArray) -> typing.AsyncGenerator: + result = await self.processor.apartial_fit_transform(msg) + if result is not None: + yield self.OUTPUT_SAMPLE, result class BaseClockDrivenUnit( diff --git a/tests/test_baseproc.py b/tests/test_baseproc.py index ac481a8..90b1b03 100644 --- a/tests/test_baseproc.py +++ b/tests/test_baseproc.py @@ -43,6 +43,7 @@ class MockSettings: @processor_state class MockState: iterations: int = 0 + fit_count: int = 0 hash: int = -1 @@ -141,7 +142,7 @@ def _process(self, message: MockMessageA) -> MockMessageB: return MockMessageB() def partial_fit(self, message: AxisArray) -> None: - self._state.iterations += 1 + self._state.fit_count += 1 class MockAsyncTransformer(BaseAsyncTransformer[MockSettings, MockMessageA, MockMessageB, MockState]): @@ -770,20 +771,29 @@ class TestBaseAdaptiveTransformer: def test_partial_fit(self): transformer = MockAdaptiveTransformer() transformer.partial_fit(mock_sample_message()) - assert transformer.state.iterations == 1 + assert transformer.state.fit_count == 1 @pytest.mark.asyncio async def test_apartial_fit(self): transformer = MockAdaptiveTransformer() await transformer.apartial_fit(mock_sample_message()) - assert transformer.state.iterations == 1 + assert transformer.state.fit_count == 1 - def test_call_with_sample_message(self): + def test_call_with_sample_message_warns(self): transformer = MockAdaptiveTransformer() sample_msg = mock_sample_message() - result = transformer(sample_msg) - assert result is None # partial_fit returns None - assert transformer.state.iterations == 1 + with pytest.warns(UserWarning, match="Auto-routing to partial_fit"): + result = transformer(sample_msg) + assert isinstance(result, MockMessageB) # inference, not partial_fit + assert transformer.state.fit_count == 0 # partial_fit NOT called + + def test_partial_fit_transform(self): + transformer = MockAdaptiveTransformer() + sample_msg = mock_sample_message() + result = transformer.partial_fit_transform(sample_msg) + assert isinstance(result, MockMessageB) + assert transformer.state.fit_count == 1 + assert transformer.state.iterations == 1 # _process was called class TestBaseAsyncTransformer: