Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ requires-python = ">=3.10.15"
dynamic = ["version"]
dependencies = [
"array-api-compat>=1.11.1",
"ezmsg-baseproc>=1.0",
"ezmsg-baseproc>=1.0.3",
"ezmsg>=3.6.0",
"numba>=0.61.0",
"numpy>=1.26.0",
Expand Down
10 changes: 9 additions & 1 deletion src/ezmsg/sigproc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,16 @@
New code should import directly from ezmsg.baseproc instead.
"""

import warnings

warnings.warn(
"Importing from 'ezmsg.sigproc.base' is deprecated. Please import from 'ezmsg.baseproc' instead.",
DeprecationWarning,
stacklevel=2,
)

# Re-export everything from ezmsg.baseproc for backwards compatibility
from ezmsg.baseproc import (
from ezmsg.baseproc import ( # noqa: E402
# Protocols
AdaptiveTransformer,
# Type variables
Expand Down
121 changes: 121 additions & 0 deletions src/ezmsg/sigproc/math/add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Signal addition utilities."""

import asyncio
import typing
from dataclasses import dataclass, field

import ezmsg.core as ez
from ezmsg.baseproc.util.asio import run_coroutine_sync
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from ..base import BaseTransformer, BaseTransformerUnit

# --- Constant Addition (single input) ---


class ConstAddSettings(ez.Settings):
value: float = 0.0
"""Number to add to the input data."""


class ConstAddTransformer(BaseTransformer[ConstAddSettings, AxisArray, AxisArray]):
"""Add a constant value to input data."""

def _process(self, message: AxisArray) -> AxisArray:
return replace(message, data=message.data + self.settings.value)


class ConstAdd(BaseTransformerUnit[ConstAddSettings, AxisArray, AxisArray, ConstAddTransformer]):
"""Unit wrapper for ConstAddTransformer."""

SETTINGS = ConstAddSettings


# --- Two-input Addition ---


@dataclass
class AddState:
"""State for Add processor with two input queues."""

queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)


class AddProcessor:
"""Processor that adds two AxisArray signals together.

This processor maintains separate queues for two input streams and
adds corresponding messages element-wise. It assumes both inputs
have compatible shapes and aligned time spans.
"""

def __init__(self):
self._state = AddState()

@property
def state(self) -> AddState:
return self._state

@state.setter
def state(self, state: AddState | bytes | None) -> None:
if state is not None:
# TODO: Support hydrating state from bytes
# if isinstance(state, bytes):
# self._state = pickle.loads(state)
# else:
self._state = state

def push_a(self, msg: AxisArray) -> None:
"""Push a message to queue A."""
self._state.queue_a.put_nowait(msg)

def push_b(self, msg: AxisArray) -> None:
"""Push a message to queue B."""
self._state.queue_b.put_nowait(msg)

async def __acall__(self) -> AxisArray:
"""Await and add the next messages from both queues."""
a = await self._state.queue_a.get()
b = await self._state.queue_b.get()
return replace(a, data=a.data + b.data)

def __call__(self) -> AxisArray:
"""Synchronously get and add the next messages from both queues."""
return run_coroutine_sync(self.__acall__())

# Aliases for legacy interface
async def __anext__(self) -> AxisArray:
return await self.__acall__()

def __next__(self) -> AxisArray:
return self.__call__()


class Add(ez.Unit):
"""Add two signals together.

Assumes compatible/similar axes/dimensions and aligned time spans.
Messages are paired by arrival order (oldest from each queue).
"""

INPUT_SIGNAL_A = ez.InputStream(AxisArray)
INPUT_SIGNAL_B = ez.InputStream(AxisArray)
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)

async def initialize(self) -> None:
self.processor = AddProcessor()

@ez.subscriber(INPUT_SIGNAL_A)
async def on_a(self, msg: AxisArray) -> None:
self.processor.push_a(msg)

@ez.subscriber(INPUT_SIGNAL_B)
async def on_b(self, msg: AxisArray) -> None:
self.processor.push_b(msg)

@ez.publisher(OUTPUT_SIGNAL)
async def output(self) -> typing.AsyncGenerator:
while True:
yield self.OUTPUT_SIGNAL, await self.processor.__acall__()
109 changes: 90 additions & 19 deletions src/ezmsg/sigproc/math/difference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import asyncio
import typing
from dataclasses import dataclass, field

import ezmsg.core as ez
from ezmsg.baseproc.util.asio import run_coroutine_sync
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

Expand Down Expand Up @@ -43,22 +48,88 @@ def const_difference(value: float = 0.0, subtrahend: bool = True) -> ConstDiffer
return ConstDifferenceTransformer(ConstDifferenceSettings(value=value, subtrahend=subtrahend))


# class DifferenceSettings(ez.Settings):
# pass
#
#
# class Difference(ez.Unit):
# SETTINGS = DifferenceSettings
#
# INPUT_SIGNAL_1 = ez.InputStream(AxisArray)
# INPUT_SIGNAL_2 = ez.InputStream(AxisArray)
# OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
#
# @ez.subscriber(INPUT_SIGNAL_2, zero_copy=True)
# @ez.publisher(OUTPUT_SIGNAL)
# async def on_input_2(self, message: AxisArray) -> typing.AsyncGenerator:
# # TODO: buffer_2
# # TODO: take buffer_1 - buffer_2 for ranges that align
# # TODO: Drop samples from buffer_1 and buffer_2
# if ret is not None:
# yield self.OUTPUT_SIGNAL, ret
# --- Two-input Difference ---


@dataclass
class DifferenceState:
"""State for Difference processor with two input queues."""

queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)


class DifferenceProcessor:
"""Processor that subtracts two AxisArray signals (A - B).

This processor maintains separate queues for two input streams and
subtracts corresponding messages element-wise. It assumes both inputs
have compatible shapes and aligned time spans.
"""

def __init__(self):
self._state = DifferenceState()

@property
def state(self) -> DifferenceState:
return self._state

@state.setter
def state(self, state: DifferenceState | bytes | None) -> None:
if state is not None:
self._state = state

def push_a(self, msg: AxisArray) -> None:
"""Push a message to queue A (minuend)."""
self._state.queue_a.put_nowait(msg)

def push_b(self, msg: AxisArray) -> None:
"""Push a message to queue B (subtrahend)."""
self._state.queue_b.put_nowait(msg)

async def __acall__(self) -> AxisArray:
"""Await and subtract the next messages (A - B)."""
a = await self._state.queue_a.get()
b = await self._state.queue_b.get()
return replace(a, data=a.data - b.data)

def __call__(self) -> AxisArray:
"""Synchronously get and subtract the next messages."""
return run_coroutine_sync(self.__acall__())

# Aliases for legacy interface
async def __anext__(self) -> AxisArray:
return await self.__acall__()

def __next__(self) -> AxisArray:
return self.__call__()


class Difference(ez.Unit):
"""Subtract two signals (A - B).

Assumes compatible/similar axes/dimensions and aligned time spans.
Messages are paired by arrival order (oldest from each queue).

OUTPUT = INPUT_SIGNAL_A - INPUT_SIGNAL_B
"""

INPUT_SIGNAL_A = ez.InputStream(AxisArray)
INPUT_SIGNAL_B = ez.InputStream(AxisArray)
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)

async def initialize(self) -> None:
self.processor = DifferenceProcessor()

@ez.subscriber(INPUT_SIGNAL_A)
async def on_a(self, msg: AxisArray) -> None:
self.processor.push_a(msg)

@ez.subscriber(INPUT_SIGNAL_B)
async def on_b(self, msg: AxisArray) -> None:
self.processor.push_b(msg)

@ez.publisher(OUTPUT_SIGNAL)
async def output(self) -> typing.AsyncGenerator:
while True:
yield self.OUTPUT_SIGNAL, await self.processor.__acall__()
Loading