Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/mouse_poll.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def configure(self) -> None:

def network(self) -> ez.NetworkDefinition:
return (
(self.CLOCK.OUTPUT_SIGNAL, self.MOUSE.INPUT_SIGNAL),
(self.CLOCK.OUTPUT_SIGNAL, self.MOUSE.INPUT_CLOCK),
(self.MOUSE.OUTPUT_SIGNAL, self.LOG.INPUT),
)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"ezmsg>=3.6.0",
"ezmsg-baseproc>=1.1.0",
"ezmsg-baseproc>=1.2.1",
"numpy>=1.26.0",
"pynput>=1.8.1",
]
Expand Down
4 changes: 2 additions & 2 deletions src/ezmsg/peripheraldevice/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
MouseListenerSettings,
MouseListenerState,
MousePoller,
MousePollerProducer,
MousePollerSettings,
MousePollerState,
MousePollerTransformer,
)

__all__ = [
Expand All @@ -16,8 +16,8 @@
"MouseListenerSettings",
"MouseListenerState",
"MouseListener",
"MousePollerProducer",
"MousePollerSettings",
"MousePollerState",
"MousePollerTransformer",
"MousePoller",
]
205 changes: 155 additions & 50 deletions src/ezmsg/peripheraldevice/mouse.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,144 @@
"""Mouse input via pynput."""

import queue
import threading
import time
import time as time_module
from collections import deque

import ezmsg.core as ez
import numpy as np
from ezmsg.baseproc import (
BaseClockDrivenProducer,
BaseClockDrivenUnit,
BaseProducerUnit,
BaseStatefulProducer,
BaseStatefulTransformer,
BaseTransformerUnit,
ClockDrivenSettings,
ClockDrivenState,
processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray, replace
from pynput.mouse import Controller, Listener

# =============================================================================
# Polled Mouse Transformer (takes LinearAxis from Clock, like Counter)
# Polled Mouse Producer (takes LinearAxis from Clock)
# =============================================================================


class MousePollerSettings(ez.Settings):
"""Settings for MousePollerTransformer."""
class MousePollerSettings(ClockDrivenSettings):
"""Settings for MousePollerProducer."""

pass
fs: float = 60.0
"""
Output sample rate in Hz (only used when n_time is None).

When n_time is specified, the effective poll rate is derived from the clock:
poll_rate = n_time / clock.gain (i.e., n_time samples per tick).

When n_time is None, fs determines both the poll rate and output sample rate,
with n_samples = fs * clock.gain per tick.
"""

n_time: int | None = None
"""
Samples per block.
- If specified: fixed chunk size, poll rate derived from clock timing
- If None: derived from fs * clock.gain, poll rate is fs
"""


@processor_state
class MousePollerState:
"""State for MousePollerTransformer."""
class MousePollerState(ClockDrivenState):
"""State for MousePollerProducer."""

controller: Controller | None = None
template: AxisArray | None = None

# Background polling state
poll_thread: threading.Thread | None = None
poll_buffer: deque | None = None
stop_event: threading.Event | None = None
last_position: tuple[float, float] = (0.0, 0.0)
use_thread: bool = False
poll_rate: float = 0.0


class MousePollerTransformer(
BaseStatefulTransformer[
MousePollerSettings,
AxisArray.LinearAxis,
AxisArray,
MousePollerState,
]
):
class MousePollerProducer(BaseClockDrivenProducer[MousePollerSettings, MousePollerState]):
"""
Reads current mouse position when triggered by clock tick.
Reads mouse position, optionally with high-rate background polling.

Takes LinearAxis input (from Clock) and outputs mouse positions.

Takes LinearAxis input (from Clock) and outputs the current mouse position
as a single sample with the clock's timestamp.
Behavior depends on settings:

**Fixed chunk mode (n_time is set):**
- Each tick produces exactly n_time samples
- Poll rate = n_time * clock_rate (derived from clock timing)
- settings.fs is ignored
- Thread used when n_time > 1

**Variable chunk mode (n_time is None):**
- Each tick produces fs * clock.gain samples (with fractional tracking)
- Poll rate = fs (from settings)
- Thread used when fs != clock_rate

Input: LinearAxis (from Clock - provides timing info)
Output: AxisArray with shape (1, 2) - single sample with x, y channels
Output: AxisArray with shape (n_samples, 2) - x, y channels
"""

def _reset_state(self, message: AxisArray.LinearAxis) -> None:
"""Initialize mouse controller."""
self._state.controller = Controller()
def _hash_message(self, message: AxisArray.LinearAxis) -> int:
"""
Hash based on clock gain to detect rate changes.

# Pre-construct template AxisArray
Returns different hash when clock rate changes significantly,
triggering state reset and potential thread restart.
"""
# Quantize gain to avoid floating point noise triggering resets
quantized_gain = round(message.gain * 1e6)
return hash(quantized_gain)

def _reset_state(self, time_axis: AxisArray.LinearAxis) -> None:
"""Initialize mouse controller and optionally start polling thread."""
# Stop any existing polling thread
self._stop_poll_thread()

self._state.controller = Controller()
self._state.last_position = self._state.controller.position

clock_rate = 1.0 / time_axis.gain if time_axis.gain > 0 else float("inf")

if self.settings.n_time is not None:
# Fixed chunk mode: poll rate derived from clock timing
# Need n_time samples per tick, so poll at n_time * clock_rate
self._state.poll_rate = self.settings.n_time * clock_rate
need_thread = self.settings.n_time > 1
else:
# Variable chunk mode: poll rate is fs from settings
# n_samples = fs * clock.gain (handled by base class)
self._state.poll_rate = self.settings.fs
# Need thread if fs != clock_rate (meaning n_samples != 1)
need_thread = not np.isclose(self.settings.fs, clock_rate, rtol=0.01)

self._state.use_thread = need_thread

if self._state.use_thread:
# Start background polling thread
buffer_size = max(int(self._state.poll_rate * 10), 1000)
self._state.poll_buffer = deque(maxlen=buffer_size)
self._state.stop_event = threading.Event()
self._state.poll_thread = threading.Thread(
target=self._poll_loop,
daemon=True,
)
self._state.poll_thread.start()

# Pre-construct template AxisArray (shape will be updated in _produce)
n_time_for_template = self.settings.n_time if self.settings.n_time is not None else 1
self._state.template = AxisArray(
data=np.zeros((1, 2), dtype=np.float64),
data=np.zeros((n_time_for_template, 2), dtype=np.float64),
dims=["time", "ch"],
axes={
"time": AxisArray.LinearAxis(
unit="s",
gain=message.gain,
offset=message.offset,
),
"time": time_axis,
"ch": AxisArray.CoordinateAxis(
data=np.array(["x", "y"]),
dims=["ch"],
Expand All @@ -74,36 +147,68 @@ def _reset_state(self, message: AxisArray.LinearAxis) -> None:
key="mouse",
)

def _process(self, message: AxisArray.LinearAxis) -> AxisArray:
"""Read current mouse position and return as AxisArray."""
pos = self._state.controller.position

# Create output with single sample
data = np.array([[pos[0], pos[1]]], dtype=np.float64)
time_axis = replace(
self._state.template.axes["time"],
offset=message.offset,
)
def _poll_loop(self) -> None:
"""Background thread that polls mouse at poll_rate."""
interval = 1.0 / self._state.poll_rate
last_poll = time_module.perf_counter() - interval

while not self._state.stop_event.is_set():
sleep_time = (last_poll + interval) - time_module.perf_counter()
if sleep_time > 0:
time_module.sleep(sleep_time)
pos = self._state.controller.position
self._state.poll_buffer.append(pos)
self._state.last_position = pos
last_poll = time_module.perf_counter()

def _stop_poll_thread(self) -> None:
"""Stop the background polling thread if running."""
if self._state.stop_event is not None:
self._state.stop_event.set()
if self._state.poll_thread is not None:
self._state.poll_thread.join(timeout=1.0)
self._state.poll_thread = None
self._state.stop_event = None
self._state.poll_buffer = None

def _produce(self, n_samples: int, time_axis: AxisArray.LinearAxis) -> AxisArray:
"""Generate mouse position data."""
if self._state.use_thread and self._state.poll_buffer is not None:
# Get samples from buffer
positions = []
for _ in range(n_samples):
if self._state.poll_buffer:
pos = self._state.poll_buffer.popleft()
self._state.last_position = pos
else:
# Buffer empty - hold last known position
pos = self._state.last_position
positions.append([pos[0], pos[1]])
data = np.array(positions, dtype=np.float64)
else:
# Simple single-poll mode
pos = self._state.controller.position
data = np.array([[pos[0], pos[1]]], dtype=np.float64)

return replace(
self._state.template,
data=data,
axes={"time": time_axis, "ch": self._state.template.axes["ch"]},
axes={**self._state.template.axes, "time": time_axis},
)

def __del__(self) -> None:
"""Stop polling thread on destruction."""
if hasattr(self, "_state"):
self._stop_poll_thread()


class MousePoller(
BaseTransformerUnit[
MousePollerSettings,
AxisArray.LinearAxis,
AxisArray,
MousePollerTransformer,
]
):
class MousePoller(BaseClockDrivenUnit[MousePollerSettings, MousePollerProducer]):
"""
Unit for reading mouse position from Clock input.

Receives LinearAxis from Clock and outputs current mouse position.
Receives LinearAxis from Clock and outputs mouse positions.
Supports both simple polling (one sample per tick) and high-rate
background polling with buffering.
"""

SETTINGS = MousePollerSettings
Expand Down
Loading