diff --git a/examples/mouse_poll.py b/examples/mouse_poll.py index 3b0473c..54801c7 100644 --- a/examples/mouse_poll.py +++ b/examples/mouse_poll.py @@ -43,7 +43,7 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( - (self.CLOCK.OUTPUT_SIGNAL, self.MOUSE.INPUT_SIGNAL), + (self.CLOCK.OUTPUT_SIGNAL, self.MOUSE.INPUT_CLOCK), (self.MOUSE.OUTPUT_SIGNAL, self.LOG.INPUT), ) diff --git a/pyproject.toml b/pyproject.toml index 537f2b1..3d5a270 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ "ezmsg>=3.6.0", - "ezmsg-baseproc>=1.1.0", + "ezmsg-baseproc>=1.2.1", "numpy>=1.26.0", "pynput>=1.8.1", ] diff --git a/src/ezmsg/peripheraldevice/__init__.py b/src/ezmsg/peripheraldevice/__init__.py index 1645162..33ab6ce 100644 --- a/src/ezmsg/peripheraldevice/__init__.py +++ b/src/ezmsg/peripheraldevice/__init__.py @@ -5,9 +5,9 @@ MouseListenerSettings, MouseListenerState, MousePoller, + MousePollerProducer, MousePollerSettings, MousePollerState, - MousePollerTransformer, ) __all__ = [ @@ -16,8 +16,8 @@ "MouseListenerSettings", "MouseListenerState", "MouseListener", + "MousePollerProducer", "MousePollerSettings", "MousePollerState", - "MousePollerTransformer", "MousePoller", ] diff --git a/src/ezmsg/peripheraldevice/mouse.py b/src/ezmsg/peripheraldevice/mouse.py index 86cd830..4915d2e 100644 --- a/src/ezmsg/peripheraldevice/mouse.py +++ b/src/ezmsg/peripheraldevice/mouse.py @@ -1,71 +1,144 @@ """Mouse input via pynput.""" import queue +import threading import time +import time as time_module +from collections import deque import ezmsg.core as ez import numpy as np from ezmsg.baseproc import ( + BaseClockDrivenProducer, + BaseClockDrivenUnit, BaseProducerUnit, BaseStatefulProducer, - BaseStatefulTransformer, - BaseTransformerUnit, + ClockDrivenSettings, + ClockDrivenState, processor_state, ) from ezmsg.util.messages.axisarray import AxisArray, replace from pynput.mouse import Controller, Listener # ============================================================================= -# Polled Mouse Transformer (takes LinearAxis from Clock, like Counter) +# Polled Mouse Producer (takes LinearAxis from Clock) # ============================================================================= -class MousePollerSettings(ez.Settings): - """Settings for MousePollerTransformer.""" +class MousePollerSettings(ClockDrivenSettings): + """Settings for MousePollerProducer.""" - pass + fs: float = 60.0 + """ + Output sample rate in Hz (only used when n_time is None). + + When n_time is specified, the effective poll rate is derived from the clock: + poll_rate = n_time / clock.gain (i.e., n_time samples per tick). + + When n_time is None, fs determines both the poll rate and output sample rate, + with n_samples = fs * clock.gain per tick. + """ + + n_time: int | None = None + """ + Samples per block. + - If specified: fixed chunk size, poll rate derived from clock timing + - If None: derived from fs * clock.gain, poll rate is fs + """ @processor_state -class MousePollerState: - """State for MousePollerTransformer.""" +class MousePollerState(ClockDrivenState): + """State for MousePollerProducer.""" controller: Controller | None = None template: AxisArray | None = None + # Background polling state + poll_thread: threading.Thread | None = None + poll_buffer: deque | None = None + stop_event: threading.Event | None = None + last_position: tuple[float, float] = (0.0, 0.0) + use_thread: bool = False + poll_rate: float = 0.0 + -class MousePollerTransformer( - BaseStatefulTransformer[ - MousePollerSettings, - AxisArray.LinearAxis, - AxisArray, - MousePollerState, - ] -): +class MousePollerProducer(BaseClockDrivenProducer[MousePollerSettings, MousePollerState]): """ - Reads current mouse position when triggered by clock tick. + Reads mouse position, optionally with high-rate background polling. + + Takes LinearAxis input (from Clock) and outputs mouse positions. - Takes LinearAxis input (from Clock) and outputs the current mouse position - as a single sample with the clock's timestamp. + Behavior depends on settings: + + **Fixed chunk mode (n_time is set):** + - Each tick produces exactly n_time samples + - Poll rate = n_time * clock_rate (derived from clock timing) + - settings.fs is ignored + - Thread used when n_time > 1 + + **Variable chunk mode (n_time is None):** + - Each tick produces fs * clock.gain samples (with fractional tracking) + - Poll rate = fs (from settings) + - Thread used when fs != clock_rate Input: LinearAxis (from Clock - provides timing info) - Output: AxisArray with shape (1, 2) - single sample with x, y channels + Output: AxisArray with shape (n_samples, 2) - x, y channels """ - def _reset_state(self, message: AxisArray.LinearAxis) -> None: - """Initialize mouse controller.""" - self._state.controller = Controller() + def _hash_message(self, message: AxisArray.LinearAxis) -> int: + """ + Hash based on clock gain to detect rate changes. - # Pre-construct template AxisArray + Returns different hash when clock rate changes significantly, + triggering state reset and potential thread restart. + """ + # Quantize gain to avoid floating point noise triggering resets + quantized_gain = round(message.gain * 1e6) + return hash(quantized_gain) + + def _reset_state(self, time_axis: AxisArray.LinearAxis) -> None: + """Initialize mouse controller and optionally start polling thread.""" + # Stop any existing polling thread + self._stop_poll_thread() + + self._state.controller = Controller() + self._state.last_position = self._state.controller.position + + clock_rate = 1.0 / time_axis.gain if time_axis.gain > 0 else float("inf") + + if self.settings.n_time is not None: + # Fixed chunk mode: poll rate derived from clock timing + # Need n_time samples per tick, so poll at n_time * clock_rate + self._state.poll_rate = self.settings.n_time * clock_rate + need_thread = self.settings.n_time > 1 + else: + # Variable chunk mode: poll rate is fs from settings + # n_samples = fs * clock.gain (handled by base class) + self._state.poll_rate = self.settings.fs + # Need thread if fs != clock_rate (meaning n_samples != 1) + need_thread = not np.isclose(self.settings.fs, clock_rate, rtol=0.01) + + self._state.use_thread = need_thread + + if self._state.use_thread: + # Start background polling thread + buffer_size = max(int(self._state.poll_rate * 10), 1000) + self._state.poll_buffer = deque(maxlen=buffer_size) + self._state.stop_event = threading.Event() + self._state.poll_thread = threading.Thread( + target=self._poll_loop, + daemon=True, + ) + self._state.poll_thread.start() + + # Pre-construct template AxisArray (shape will be updated in _produce) + n_time_for_template = self.settings.n_time if self.settings.n_time is not None else 1 self._state.template = AxisArray( - data=np.zeros((1, 2), dtype=np.float64), + data=np.zeros((n_time_for_template, 2), dtype=np.float64), dims=["time", "ch"], axes={ - "time": AxisArray.LinearAxis( - unit="s", - gain=message.gain, - offset=message.offset, - ), + "time": time_axis, "ch": AxisArray.CoordinateAxis( data=np.array(["x", "y"]), dims=["ch"], @@ -74,36 +147,68 @@ def _reset_state(self, message: AxisArray.LinearAxis) -> None: key="mouse", ) - def _process(self, message: AxisArray.LinearAxis) -> AxisArray: - """Read current mouse position and return as AxisArray.""" - pos = self._state.controller.position - - # Create output with single sample - data = np.array([[pos[0], pos[1]]], dtype=np.float64) - time_axis = replace( - self._state.template.axes["time"], - offset=message.offset, - ) + def _poll_loop(self) -> None: + """Background thread that polls mouse at poll_rate.""" + interval = 1.0 / self._state.poll_rate + last_poll = time_module.perf_counter() - interval + + while not self._state.stop_event.is_set(): + sleep_time = (last_poll + interval) - time_module.perf_counter() + if sleep_time > 0: + time_module.sleep(sleep_time) + pos = self._state.controller.position + self._state.poll_buffer.append(pos) + self._state.last_position = pos + last_poll = time_module.perf_counter() + + def _stop_poll_thread(self) -> None: + """Stop the background polling thread if running.""" + if self._state.stop_event is not None: + self._state.stop_event.set() + if self._state.poll_thread is not None: + self._state.poll_thread.join(timeout=1.0) + self._state.poll_thread = None + self._state.stop_event = None + self._state.poll_buffer = None + + def _produce(self, n_samples: int, time_axis: AxisArray.LinearAxis) -> AxisArray: + """Generate mouse position data.""" + if self._state.use_thread and self._state.poll_buffer is not None: + # Get samples from buffer + positions = [] + for _ in range(n_samples): + if self._state.poll_buffer: + pos = self._state.poll_buffer.popleft() + self._state.last_position = pos + else: + # Buffer empty - hold last known position + pos = self._state.last_position + positions.append([pos[0], pos[1]]) + data = np.array(positions, dtype=np.float64) + else: + # Simple single-poll mode + pos = self._state.controller.position + data = np.array([[pos[0], pos[1]]], dtype=np.float64) return replace( self._state.template, data=data, - axes={"time": time_axis, "ch": self._state.template.axes["ch"]}, + axes={**self._state.template.axes, "time": time_axis}, ) + def __del__(self) -> None: + """Stop polling thread on destruction.""" + if hasattr(self, "_state"): + self._stop_poll_thread() + -class MousePoller( - BaseTransformerUnit[ - MousePollerSettings, - AxisArray.LinearAxis, - AxisArray, - MousePollerTransformer, - ] -): +class MousePoller(BaseClockDrivenUnit[MousePollerSettings, MousePollerProducer]): """ Unit for reading mouse position from Clock input. - Receives LinearAxis from Clock and outputs current mouse position. + Receives LinearAxis from Clock and outputs mouse positions. + Supports both simple polling (one sample per tick) and high-rate + background polling with buffering. """ SETTINGS = MousePollerSettings diff --git a/tests/test_mouse.py b/tests/test_mouse.py index 8bcf964..48f50d9 100644 --- a/tests/test_mouse.py +++ b/tests/test_mouse.py @@ -9,28 +9,29 @@ from ezmsg.peripheraldevice.mouse import ( MouseListenerProducer, MouseListenerSettings, + MousePollerProducer, MousePollerSettings, - MousePollerTransformer, ) -class TestMousePollerTransformer: - """Tests for MousePollerTransformer.""" +class TestMousePollerProducer: + """Tests for MousePollerProducer.""" @patch("ezmsg.peripheraldevice.mouse.Controller") - def test_basic_output(self, mock_controller_class): - """Test that transformer produces valid output.""" + def test_basic_output_simple_mode(self, mock_controller_class): + """Test that producer produces valid output in simple polling mode.""" # Setup mock mock_controller = MagicMock() mock_controller.position = (100, 200) mock_controller_class.return_value = mock_controller - transformer = MousePollerTransformer(MousePollerSettings()) + # Use fs=10 to match clock rate (1/0.1 = 10 Hz) for simple mode + producer = MousePollerProducer(MousePollerSettings(fs=10.0, n_time=1)) # Create a LinearAxis input (like what Clock produces) clock_tick = AxisArray.LinearAxis(gain=0.1, offset=0.0) - result = transformer(clock_tick) + result = producer(clock_tick) # Check output shape assert result.data.shape == (1, 2), "Output should be (1, 2) for single sample with x, y" @@ -59,11 +60,12 @@ def test_output_values_are_numeric(self, mock_controller_class): mock_controller.position = (500, 300) mock_controller_class.return_value = mock_controller - transformer = MousePollerTransformer(MousePollerSettings()) + # Use matching fs for simple mode + producer = MousePollerProducer(MousePollerSettings(fs=10.0, n_time=1)) clock_tick = AxisArray.LinearAxis(gain=0.1, offset=0.0) - result = transformer(clock_tick) + result = producer(clock_tick) # Check that values are finite numbers assert np.all(np.isfinite(result.data)), "Output should contain finite numbers" @@ -75,11 +77,12 @@ def test_preserves_time_offset(self, mock_controller_class): mock_controller.position = (0, 0) mock_controller_class.return_value = mock_controller - transformer = MousePollerTransformer(MousePollerSettings()) + # Use matching fs for simple mode + producer = MousePollerProducer(MousePollerSettings(fs=20.0, n_time=1)) clock_tick = AxisArray.LinearAxis(gain=0.05, offset=1.5) - result = transformer(clock_tick) + result = producer(clock_tick) # Check that time axis offset matches input time_axis = result.axes["time"] @@ -87,12 +90,13 @@ def test_preserves_time_offset(self, mock_controller_class): assert time_axis.offset == 1.5 @patch("ezmsg.peripheraldevice.mouse.Controller") - def test_multiple_calls(self, mock_controller_class): - """Test that transformer works correctly across multiple calls.""" + def test_multiple_calls_simple_mode(self, mock_controller_class): + """Test that producer works correctly across multiple calls in simple mode.""" mock_controller = MagicMock() mock_controller_class.return_value = mock_controller - transformer = MousePollerTransformer(MousePollerSettings()) + # Use matching fs for simple mode (1/0.1 = 10 Hz) + producer = MousePollerProducer(MousePollerSettings(fs=10.0, n_time=1)) for i in range(5): # Update mock position each call @@ -100,7 +104,7 @@ def test_multiple_calls(self, mock_controller_class): clock_tick = AxisArray.LinearAxis(gain=0.1, offset=i * 0.1) - result = transformer(clock_tick) + result = producer(clock_tick) assert result.data.shape == (1, 2) assert np.all(np.isfinite(result.data)) @@ -109,6 +113,80 @@ def test_multiple_calls(self, mock_controller_class): # Data should match mock position np.testing.assert_array_equal(result.data[0], [i * 10, i * 20]) + @patch("ezmsg.peripheraldevice.mouse.Controller") + def test_threaded_mode_with_n_time_greater_than_1(self, mock_controller_class): + """Test that producer uses threaded mode when n_time > 1.""" + mock_controller = MagicMock() + mock_controller.position = (100, 200) + mock_controller_class.return_value = mock_controller + + # n_time > 1 triggers threaded mode + producer = MousePollerProducer(MousePollerSettings(fs=100.0, n_time=10)) + + clock_tick = AxisArray.LinearAxis(gain=0.1, offset=0.0) + + # Give the thread a moment to poll + result = producer(clock_tick) + time_module.sleep(0.15) # Let thread poll a few times + + result = producer(clock_tick) + + # Should have n_time samples + assert result.data.shape == (10, 2) + assert np.all(np.isfinite(result.data)) + + # Clean up thread + del producer + + @patch("ezmsg.peripheraldevice.mouse.Controller") + def test_threaded_mode_with_rate_mismatch(self, mock_controller_class): + """Test variable chunk mode (n_time=None) with fs != clock rate.""" + mock_controller = MagicMock() + mock_controller.position = (50, 75) + mock_controller_class.return_value = mock_controller + + # n_time=None, fs=100 != clock rate of 10 Hz triggers threaded mode + # n_samples = fs * clock.gain = 100 * 0.1 = 10 samples per tick + producer = MousePollerProducer(MousePollerSettings(fs=100.0, n_time=None)) + + clock_tick = AxisArray.LinearAxis(gain=0.1, offset=0.0) + + # First call initializes, give thread time to poll + result = producer(clock_tick) + time_module.sleep(0.15) # Let thread poll ~15 times at 100 Hz + + # Now get result with buffered data + result = producer(clock_tick) + + # Should have n_samples = fs * gain = 100 * 0.1 = 10 samples + assert result.data.shape == (10, 2) + assert np.all(np.isfinite(result.data)) + + # Clean up thread + del producer + + @patch("ezmsg.peripheraldevice.mouse.Controller") + def test_variable_chunk_mode_simple(self, mock_controller_class): + """Test variable chunk mode (n_time=None) with fs matching clock rate.""" + mock_controller = MagicMock() + mock_controller.position = (123, 456) + mock_controller_class.return_value = mock_controller + + # n_time=None, fs=10 matches clock rate of 10 Hz - no thread needed + # n_samples = fs * clock.gain = 10 * 0.1 = 1 sample per tick + producer = MousePollerProducer(MousePollerSettings(fs=10.0, n_time=None)) + + clock_tick = AxisArray.LinearAxis(gain=0.1, offset=0.0) + + result = producer(clock_tick) + + # Should have n_samples = 1 (simple mode, no thread) + assert result.data.shape == (1, 2) + np.testing.assert_array_equal(result.data[0], [123, 456]) + + # No thread should be running + assert producer._state.use_thread is False + class TestMouseListenerProducer: """Tests for MouseListenerProducer."""