diff --git a/src/sysls/strategy/__init__.py b/src/sysls/strategy/__init__.py new file mode 100644 index 0000000..b569a83 --- /dev/null +++ b/src/sysls/strategy/__init__.py @@ -0,0 +1 @@ +"""Strategy and risk framework for sysls.""" diff --git a/src/sysls/strategy/base.py b/src/sysls/strategy/base.py new file mode 100644 index 0000000..d1f0dd5 --- /dev/null +++ b/src/sysls/strategy/base.py @@ -0,0 +1,322 @@ +"""Strategy abstract base class and context for the sysls framework. + +The Strategy ABC is the main user extension point. Users subclass it to +implement trading strategies. The strategy receives market data, manages +internal state, generates signals, and can request orders. + +The StrategyContext provides strategies with access to framework services +(event bus, clock) and convenience methods for common operations. + +Example usage:: + + class MyStrategy(Strategy): + async def on_market_data(self, event: MarketDataEvent) -> None: + if some_condition(event): + await self.emit_signal( + instrument=event.instrument, + direction=SignalDirection.LONG, + strength=0.8, + ) +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +import structlog + +from sysls.core.events import ( + FillEvent, + MarketDataEvent, + OrderSubmitted, + PositionEvent, + SignalDirection, + SignalEvent, +) +from sysls.core.types import ( + OrderRequest, + OrderType, + Side, + TimeInForce, + generate_order_id, +) + +if TYPE_CHECKING: + from decimal import Decimal + + from sysls.core.bus import EventBus + from sysls.core.clock import Clock + from sysls.core.types import Instrument + +logger: structlog.stdlib.BoundLogger = structlog.get_logger() + + +class StrategyContext: + """Context provided to strategies for interacting with the framework. + + Provides access to the event bus, clock, and convenience methods + for common operations like emitting signals and requesting orders. + Strategies should use the context instead of directly accessing + framework internals. + + Args: + bus: The event bus for publishing/subscribing events. + clock: The clock for getting current time. + """ + + def __init__(self, bus: EventBus, clock: Clock) -> None: + self._bus = bus + self._clock = clock + + @property + def bus(self) -> EventBus: + """The event bus.""" + return self._bus + + @property + def clock(self) -> Clock: + """The clock.""" + return self._clock + + async def emit_signal( + self, + strategy_id: str, + instrument: Instrument, + direction: SignalDirection, + strength: float = 1.0, + metadata: dict[str, str] | None = None, + ) -> None: + """Emit a SignalEvent on the bus. + + Convenience method that constructs and publishes a SignalEvent. + + Args: + strategy_id: ID of the strategy emitting the signal. + instrument: Target instrument for the signal. + direction: Signal direction (LONG, SHORT, FLAT). + strength: Signal strength/conviction, typically in [-1.0, 1.0]. + metadata: Optional key-value metadata. + """ + event = SignalEvent( + strategy_id=strategy_id, + instrument=instrument, + direction=direction, + strength=strength, + metadata=metadata or {}, + source=f"strategy:{strategy_id}", + ) + await self._bus.publish(event) + logger.debug( + "signal_emitted", + strategy_id=strategy_id, + instrument=str(instrument), + direction=direction.value, + strength=strength, + ) + + async def request_order( + self, + instrument: Instrument, + side: Side, + quantity: Decimal, + order_type: OrderType = OrderType.MARKET, + price: Decimal | None = None, + time_in_force: TimeInForce = TimeInForce.GTC, + ) -> OrderRequest: + """Create an OrderRequest and publish it as an event. + + Creates the OrderRequest, publishes an OrderSubmitted event on the bus, + and returns the OrderRequest for tracking. + + NOTE: This does NOT submit through the OMS. The engine or an order + manager subscribes to these events and routes them through the OMS. + For Phase 3, this simply creates the request and emits an event. + + Args: + instrument: The instrument to trade. + side: Buy or sell. + quantity: Order quantity (always positive). + order_type: Market, limit, stop, etc. + price: Limit price, required for LIMIT and STOP_LIMIT orders. + time_in_force: How long the order remains active. + + Returns: + The created OrderRequest for tracking. + """ + order_id = generate_order_id() + request = OrderRequest( + order_id=order_id, + instrument=instrument, + side=side, + order_type=order_type, + quantity=quantity, + price=price, + time_in_force=time_in_force, + ) + event = OrderSubmitted( + order_id=order_id, + instrument=instrument, + side=side, + quantity=quantity, + price=price, + source="strategy_context", + ) + await self._bus.publish(event) + logger.debug( + "order_requested", + order_id=order_id, + instrument=str(instrument), + side=side.value, + quantity=str(quantity), + order_type=order_type.value, + ) + return request + + +class Strategy(ABC): + """Abstract base class for trading strategies. + + Users subclass Strategy and implement the abstract methods to create + trading strategies. The engine calls lifecycle methods in this order: + + 1. ``__init__`` -- set up parameters (before engine start) + 2. ``on_start`` -- called once when the engine starts (bus is running) + 3. ``on_market_data`` -- called on every market data event for subscribed instruments + 4. ``on_fill`` -- called on every fill for orders this strategy submitted + 5. ``on_position`` -- called on every position change for relevant instruments + 6. ``on_stop`` -- called once when the engine stops + + Strategies access the event bus and clock through the StrategyContext + provided at initialization. + + Args: + strategy_id: Unique identifier for this strategy instance. + context: StrategyContext providing bus, clock, and helper methods. + instruments: List of instruments this strategy trades. + params: Optional strategy-specific parameters dict. + """ + + def __init__( + self, + strategy_id: str, + context: StrategyContext, + instruments: list[Instrument], + params: dict[str, Any] | None = None, + ) -> None: + self._strategy_id = strategy_id + self._context = context + self._instruments = list(instruments) + self._params: dict[str, Any] = params if params is not None else {} + self._log: structlog.stdlib.BoundLogger = structlog.get_logger( + strategy_id=strategy_id, + ) + + # --- Properties --- + + @property + def strategy_id(self) -> str: + """The strategy's unique identifier.""" + return self._strategy_id + + @property + def instruments(self) -> list[Instrument]: + """Instruments this strategy is registered for.""" + return list(self._instruments) + + @property + def params(self) -> dict[str, Any]: + """Strategy parameters.""" + return dict(self._params) + + # --- Abstract methods (users MUST implement) --- + + @abstractmethod + async def on_market_data(self, event: MarketDataEvent) -> None: + """Called on every market data event for subscribed instruments. + + This is the main entry point for strategy logic. Analyze the + incoming data and optionally emit signals or request orders. + + Args: + event: The market data event to process. + """ + + # --- Optional lifecycle hooks (default no-op) --- + + async def on_start(self) -> None: # noqa: B027 + """Called once when the engine starts. Override for initialization.""" + + async def on_stop(self) -> None: # noqa: B027 + """Called once when the engine stops. Override for cleanup.""" + + async def on_fill(self, event: FillEvent) -> None: # noqa: B027 + """Called on every fill for orders this strategy submitted. + + Args: + event: The fill event to process. + """ + + async def on_position(self, event: PositionEvent) -> None: # noqa: B027 + """Called on every position change for relevant instruments. + + Args: + event: The position event to process. + """ + + # --- Concrete helper methods --- + + async def emit_signal( + self, + instrument: Instrument, + direction: SignalDirection, + strength: float = 1.0, + metadata: dict[str, str] | None = None, + ) -> None: + """Emit a signal through the context. + + Convenience method that delegates to ``StrategyContext.emit_signal``. + + Args: + instrument: Target instrument for the signal. + direction: Signal direction (LONG, SHORT, FLAT). + strength: Signal strength/conviction, typically in [-1.0, 1.0]. + metadata: Optional key-value metadata. + """ + await self._context.emit_signal( + strategy_id=self._strategy_id, + instrument=instrument, + direction=direction, + strength=strength, + metadata=metadata, + ) + + async def request_order( + self, + instrument: Instrument, + side: Side, + quantity: Decimal, + order_type: OrderType = OrderType.MARKET, + price: Decimal | None = None, + ) -> OrderRequest: + """Request an order through the context. + + Convenience method that delegates to ``StrategyContext.request_order``. + + Args: + instrument: The instrument to trade. + side: Buy or sell. + quantity: Order quantity (always positive). + order_type: Market, limit, stop, etc. + price: Limit price, required for LIMIT and STOP_LIMIT orders. + + Returns: + The created OrderRequest for tracking. + """ + return await self._context.request_order( + instrument=instrument, + side=side, + quantity=quantity, + order_type=order_type, + price=price, + ) diff --git a/src/sysls/strategy/signal.py b/src/sysls/strategy/signal.py new file mode 100644 index 0000000..234558d --- /dev/null +++ b/src/sysls/strategy/signal.py @@ -0,0 +1,351 @@ +"""Signal types and combinators for the sysls strategy framework. + +Signals represent trading intentions -- the output of strategy analysis. +They express a directional opinion with a strength/conviction level. +This module provides: + +- ``Signal``: Frozen Pydantic model representing a trading signal. +- ``SignalBook``: Mutable container tracking the latest signal per instrument. +- Combinator functions for combining multiple signals (average, majority, weighted). +- Conversion utilities between ``Signal`` models and ``SignalEvent`` bus events. + +Example usage:: + + signal = Signal( + instrument=nvda, + direction=SignalDirection.LONG, + strength=0.8, + strategy_id="momentum", + ) + + book = SignalBook(max_age_seconds=300) + book.update(signal) + + combined = combine_signals_average([sig1, sig2, sig3], instrument=nvda) +""" + +from __future__ import annotations + +import time + +from pydantic import BaseModel, Field, model_validator + +from sysls.core.events import SignalDirection, SignalEvent +from sysls.core.types import Instrument # noqa: TC001 + + +class Signal(BaseModel, frozen=True): + """A trading signal with instrument, direction, and strength. + + Signals are the output of strategy analysis. They express a directional + opinion with a strength/conviction level. Multiple signals can be + combined using the combinator functions. + + Attributes: + instrument: Target instrument. + direction: Signal direction (LONG, SHORT, FLAT). + strength: Signal strength in [-1.0, 1.0]. Positive = long conviction, + negative = short conviction, 0 = flat/neutral. Values outside this + range are clamped. + strategy_id: ID of the strategy that generated this signal. + timestamp_ns: When the signal was generated (ns since epoch). + metadata: Optional key-value metadata. + """ + + instrument: Instrument + direction: SignalDirection + strength: float = 1.0 + strategy_id: str = "" + timestamp_ns: int = Field(default_factory=lambda: int(time.time() * 1_000_000_000)) + metadata: dict[str, str] = Field(default_factory=dict) + + @model_validator(mode="after") + def _clamp_strength(self) -> Signal: + """Clamp strength to [-1.0, 1.0].""" + clamped = max(-1.0, min(1.0, self.strength)) + if clamped != self.strength: + # Use object.__setattr__ because the model is frozen + object.__setattr__(self, "strength", clamped) + return self + + +class SignalBook: + """Tracks the latest signal per instrument. + + The SignalBook is a mutable container that stores the most recent + signal for each instrument. It supports iteration, lookups, and + bulk operations. + + Args: + max_age_seconds: Optional maximum age in seconds. Signals older + than this are considered stale and filtered from active signals. + """ + + def __init__(self, max_age_seconds: float | None = None) -> None: + self._signals: dict[Instrument, Signal] = {} + self._max_age_seconds = max_age_seconds + + def update(self, signal: Signal) -> None: + """Update the signal for an instrument (replaces previous). + + Args: + signal: The new signal to store. + """ + self._signals[signal.instrument] = signal + + def get(self, instrument: Instrument) -> Signal | None: + """Get the latest signal for an instrument. + + Args: + instrument: The instrument to look up. + + Returns: + The latest signal, or None if no signal exists for this instrument. + """ + signal = self._signals.get(instrument) + if signal is not None and self._is_stale(signal): + return None + return signal + + def remove(self, instrument: Instrument) -> None: + """Remove the signal for an instrument. + + Args: + instrument: The instrument whose signal to remove. + """ + self._signals.pop(instrument, None) + + def clear(self) -> None: + """Remove all signals.""" + self._signals.clear() + + @property + def active_signals(self) -> dict[Instrument, Signal]: + """All current signals (excluding stale ones if max_age set).""" + if self._max_age_seconds is None: + return dict(self._signals) + return {inst: sig for inst, sig in self._signals.items() if not self._is_stale(sig)} + + @property + def instruments(self) -> list[Instrument]: + """Instruments with active signals.""" + return list(self.active_signals.keys()) + + def __len__(self) -> int: + """Number of active signals.""" + return len(self.active_signals) + + def __contains__(self, instrument: Instrument) -> bool: + """Check if an instrument has an active signal.""" + return self.get(instrument) is not None + + def _is_stale(self, signal: Signal) -> bool: + """Check if a signal is older than max_age_seconds. + + Args: + signal: The signal to check. + + Returns: + True if the signal is stale, False otherwise. + """ + if self._max_age_seconds is None: + return False + now_ns = int(time.time() * 1_000_000_000) + age_seconds = (now_ns - signal.timestamp_ns) / 1_000_000_000 + return age_seconds > self._max_age_seconds + + +# --------------------------------------------------------------------------- +# Combinator functions +# --------------------------------------------------------------------------- + + +def combine_signals_average(signals: list[Signal], instrument: Instrument) -> Signal: + """Combine signals by averaging their strengths. + + All signals should relate to the same instrument. The resulting direction + is determined by the sign of the average strength. + + Args: + signals: List of signals to combine. + instrument: The target instrument. + + Returns: + A new Signal with averaged strength. + + Raises: + ValueError: If signals is empty. + """ + if not signals: + raise ValueError("Cannot combine empty list of signals.") + + avg_strength = sum(s.strength for s in signals) / len(signals) + direction = _direction_from_strength(avg_strength) + return Signal( + instrument=instrument, + direction=direction, + strength=avg_strength, + ) + + +def combine_signals_majority(signals: list[Signal], instrument: Instrument) -> Signal: + """Combine signals by majority vote on direction. + + Counts LONG vs SHORT vs FLAT. The majority direction wins. + Strength is the proportion of votes for the winning direction. + + Args: + signals: List of signals to combine. + instrument: The target instrument. + + Returns: + A new Signal with majority-voted direction. + + Raises: + ValueError: If signals is empty. + """ + if not signals: + raise ValueError("Cannot combine empty list of signals.") + + counts: dict[SignalDirection, int] = { + SignalDirection.LONG: 0, + SignalDirection.SHORT: 0, + SignalDirection.FLAT: 0, + } + for s in signals: + counts[s.direction] += 1 + + # Find the direction(s) with the maximum count + max_count = max(counts.values()) + # In case of tie, prefer FLAT as the conservative choice + if counts[SignalDirection.FLAT] == max_count: + winner = SignalDirection.FLAT + elif counts[SignalDirection.LONG] == max_count: + winner = SignalDirection.LONG + else: + winner = SignalDirection.SHORT + + strength = max_count / len(signals) + # Map strength to the appropriate sign + if winner == SignalDirection.SHORT: + strength = -strength + elif winner == SignalDirection.FLAT: + strength = 0.0 + + return Signal( + instrument=instrument, + direction=winner, + strength=strength, + ) + + +def combine_signals_weighted( + signals: list[Signal], + weights: list[float], + instrument: Instrument, +) -> Signal: + """Combine signals with explicit weights. + + Computes a weighted average of signal strengths. Weights are + normalized to sum to 1. + + Args: + signals: List of signals to combine. + weights: Weight for each signal (must be same length as signals). + instrument: The target instrument. + + Returns: + A new Signal with weighted average strength. + + Raises: + ValueError: If signals is empty or lengths don't match. + """ + if not signals: + raise ValueError("Cannot combine empty list of signals.") + if len(signals) != len(weights): + raise ValueError( + f"signals and weights must have same length: {len(signals)} != {len(weights)}" + ) + + total_weight = sum(weights) + if total_weight == 0: + return Signal( + instrument=instrument, + direction=SignalDirection.FLAT, + strength=0.0, + ) + + normalized = [w / total_weight for w in weights] + weighted_strength = sum(s.strength * w for s, w in zip(signals, normalized, strict=True)) + direction = _direction_from_strength(weighted_strength) + + return Signal( + instrument=instrument, + direction=direction, + strength=weighted_strength, + ) + + +# --------------------------------------------------------------------------- +# Conversion utilities +# --------------------------------------------------------------------------- + + +def signal_from_event(event: SignalEvent) -> Signal: + """Convert a SignalEvent to a Signal model. + + Args: + event: The SignalEvent to convert. + + Returns: + A Signal model with the same data. + """ + return Signal( + instrument=event.instrument, + direction=event.direction, + strength=event.strength, + strategy_id=event.strategy_id, + timestamp_ns=event.timestamp_ns, + metadata=dict(event.metadata), + ) + + +def signal_to_event(signal: Signal, source: str | None = None) -> SignalEvent: + """Convert a Signal model to a SignalEvent for bus publishing. + + Args: + signal: The Signal model to convert. + source: Optional source identifier for the event. + + Returns: + A SignalEvent suitable for publishing on the event bus. + """ + return SignalEvent( + strategy_id=signal.strategy_id, + instrument=signal.instrument, + direction=signal.direction, + strength=signal.strength, + metadata=dict(signal.metadata), + source=source, + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _direction_from_strength(strength: float) -> SignalDirection: + """Determine direction from a numeric strength value. + + Args: + strength: The strength value. + + Returns: + LONG if positive, SHORT if negative, FLAT if zero. + """ + if strength > 0: + return SignalDirection.LONG + if strength < 0: + return SignalDirection.SHORT + return SignalDirection.FLAT diff --git a/tests/strategy/__init__.py b/tests/strategy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/strategy/test_base.py b/tests/strategy/test_base.py new file mode 100644 index 0000000..cbde25e --- /dev/null +++ b/tests/strategy/test_base.py @@ -0,0 +1,521 @@ +"""Tests for sysls.strategy.base module.""" + +from __future__ import annotations + +import asyncio +from decimal import Decimal +from typing import Any + +import pytest + +from sysls.core.bus import EventBus +from sysls.core.clock import LiveClock +from sysls.core.events import ( + BarEvent, + FillEvent, + MarketDataEvent, + OrderSubmitted, + PositionEvent, + SignalDirection, + SignalEvent, +) +from sysls.core.types import ( + AssetClass, + Instrument, + OrderStatus, + OrderType, + Side, + Venue, +) +from sysls.strategy.base import Strategy, StrategyContext + +# --------------------------------------------------------------------------- +# Test fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def instrument() -> Instrument: + """Provide a standard test instrument.""" + return Instrument( + symbol="NVDA", + asset_class=AssetClass.EQUITY, + venue=Venue.TASTYTRADE, + ) + + +@pytest.fixture() +def instrument_btc() -> Instrument: + """Provide a second test instrument.""" + return Instrument( + symbol="BTC-USDT-PERP", + asset_class=AssetClass.CRYPTO_PERP, + venue=Venue.CCXT, + exchange="binance", + currency="USDT", + ) + + +@pytest.fixture() +def bus() -> EventBus: + """Provide a fresh EventBus instance.""" + return EventBus() + + +@pytest.fixture() +def clock() -> LiveClock: + """Provide a LiveClock instance.""" + return LiveClock() + + +@pytest.fixture() +def context(bus: EventBus, clock: LiveClock) -> StrategyContext: + """Provide a StrategyContext.""" + return StrategyContext(bus=bus, clock=clock) + + +# --------------------------------------------------------------------------- +# Concrete test strategy +# --------------------------------------------------------------------------- + + +class _TestStrategy(Strategy): + """Concrete strategy subclass that records calls for testing.""" + + def __init__( + self, + strategy_id: str, + context: StrategyContext, + instruments: list[Instrument], + params: dict[str, Any] | None = None, + ) -> None: + super().__init__(strategy_id, context, instruments, params) + self.market_data_events: list[MarketDataEvent] = [] + self.fill_events: list[FillEvent] = [] + self.position_events: list[PositionEvent] = [] + self.started: bool = False + self.stopped: bool = False + + async def on_market_data(self, event: MarketDataEvent) -> None: + """Record market data events.""" + self.market_data_events.append(event) + + async def on_start(self) -> None: + """Record start.""" + self.started = True + + async def on_stop(self) -> None: + """Record stop.""" + self.stopped = True + + async def on_fill(self, event: FillEvent) -> None: + """Record fills.""" + self.fill_events.append(event) + + async def on_position(self, event: PositionEvent) -> None: + """Record position events.""" + self.position_events.append(event) + + +# --------------------------------------------------------------------------- +# StrategyContext tests +# --------------------------------------------------------------------------- + + +class TestStrategyContext: + """Tests for StrategyContext.""" + + def test_context_properties(self, bus: EventBus, clock: LiveClock) -> None: + """Context exposes bus and clock via properties.""" + ctx = StrategyContext(bus=bus, clock=clock) + assert ctx.bus is bus + assert ctx.clock is clock + + @pytest.mark.asyncio + async def test_context_emit_signal( + self, + bus: EventBus, + clock: LiveClock, + instrument: Instrument, + ) -> None: + """emit_signal publishes a SignalEvent on the bus.""" + ctx = StrategyContext(bus=bus, clock=clock) + received: list[SignalEvent] = [] + + async def handler(event: SignalEvent) -> None: + received.append(event) + + bus.subscribe(SignalEvent, handler) + await bus.start() + try: + await ctx.emit_signal( + strategy_id="test-strat", + instrument=instrument, + direction=SignalDirection.LONG, + strength=0.75, + metadata={"reason": "breakout"}, + ) + # Allow dispatch + await asyncio.sleep(0.05) + finally: + await bus.stop() + + assert len(received) == 1 + event = received[0] + assert event.strategy_id == "test-strat" + assert event.instrument == instrument + assert event.direction == SignalDirection.LONG + assert event.strength == 0.75 + assert event.metadata == {"reason": "breakout"} + assert event.source == "strategy:test-strat" + + @pytest.mark.asyncio + async def test_context_request_order( + self, + bus: EventBus, + clock: LiveClock, + instrument: Instrument, + ) -> None: + """request_order creates an OrderRequest and publishes OrderSubmitted.""" + ctx = StrategyContext(bus=bus, clock=clock) + submitted: list[OrderSubmitted] = [] + + async def handler(event: OrderSubmitted) -> None: + submitted.append(event) + + bus.subscribe(OrderSubmitted, handler) + await bus.start() + try: + request = await ctx.request_order( + instrument=instrument, + side=Side.BUY, + quantity=Decimal("10"), + order_type=OrderType.LIMIT, + price=Decimal("150.00"), + ) + await asyncio.sleep(0.05) + finally: + await bus.stop() + + # Verify OrderRequest returned + assert request.instrument == instrument + assert request.side == Side.BUY + assert request.quantity == Decimal("10") + assert request.order_type == OrderType.LIMIT + assert request.price == Decimal("150.00") + + # Verify OrderSubmitted event published + assert len(submitted) == 1 + event = submitted[0] + assert event.order_id == request.order_id + assert event.instrument == instrument + assert event.side == Side.BUY + assert event.quantity == Decimal("10") + assert event.price == Decimal("150.00") + + +# --------------------------------------------------------------------------- +# Strategy ABC tests +# --------------------------------------------------------------------------- + + +class TestStrategyInit: + """Tests for Strategy initialization.""" + + def test_strategy_init_stores_attributes( + self, + context: StrategyContext, + instrument: Instrument, + ) -> None: + """Strategy stores strategy_id, instruments, and params.""" + strat = _TestStrategy( + strategy_id="my-strat", + context=context, + instruments=[instrument], + params={"lookback": 20}, + ) + assert strat.strategy_id == "my-strat" + assert strat.instruments == [instrument] + assert strat.params == {"lookback": 20} + + def test_strategy_default_params_empty_dict( + self, + context: StrategyContext, + instrument: Instrument, + ) -> None: + """When params is None, defaults to empty dict.""" + strat = _TestStrategy( + strategy_id="s1", + context=context, + instruments=[instrument], + ) + assert strat.params == {} + + def test_strategy_strategy_id_property( + self, + context: StrategyContext, + instrument: Instrument, + ) -> None: + """strategy_id property returns the ID.""" + strat = _TestStrategy("abc-123", context, [instrument]) + assert strat.strategy_id == "abc-123" + + def test_strategy_instruments_property( + self, + context: StrategyContext, + instrument: Instrument, + instrument_btc: Instrument, + ) -> None: + """instruments property returns a copy of the instruments list.""" + instruments = [instrument, instrument_btc] + strat = _TestStrategy("s1", context, instruments) + result = strat.instruments + assert result == instruments + # Should be a copy, not the same list + assert result is not strat._instruments + + +class TestStrategyLifecycle: + """Tests for strategy lifecycle hooks.""" + + @pytest.mark.asyncio + async def test_strategy_lifecycle_on_start_default_noop( + self, + context: StrategyContext, + instrument: Instrument, + ) -> None: + """Default on_start is a no-op (doesn't raise).""" + + class _MinimalStrategy(Strategy): + async def on_market_data(self, event: MarketDataEvent) -> None: + pass + + strat = _MinimalStrategy("s1", context, [instrument]) + # Should not raise + await strat.on_start() + + @pytest.mark.asyncio + async def test_strategy_lifecycle_on_stop_default_noop( + self, + context: StrategyContext, + instrument: Instrument, + ) -> None: + """Default on_stop is a no-op (doesn't raise).""" + + class _MinimalStrategy(Strategy): + async def on_market_data(self, event: MarketDataEvent) -> None: + pass + + strat = _MinimalStrategy("s1", context, [instrument]) + await strat.on_stop() + + @pytest.mark.asyncio + async def test_strategy_lifecycle_on_fill_default_noop( + self, + context: StrategyContext, + instrument: Instrument, + ) -> None: + """Default on_fill is a no-op (doesn't raise).""" + + class _MinimalStrategy(Strategy): + async def on_market_data(self, event: MarketDataEvent) -> None: + pass + + strat = _MinimalStrategy("s1", context, [instrument]) + fill = FillEvent( + order_id="ord-1", + instrument=instrument, + side=Side.BUY, + fill_price=Decimal("100"), + fill_quantity=Decimal("5"), + cumulative_quantity=Decimal("5"), + order_status=OrderStatus.FILLED, + ) + await strat.on_fill(fill) + + @pytest.mark.asyncio + async def test_strategy_lifecycle_on_position_default_noop( + self, + context: StrategyContext, + instrument: Instrument, + ) -> None: + """Default on_position is a no-op (doesn't raise).""" + + class _MinimalStrategy(Strategy): + async def on_market_data(self, event: MarketDataEvent) -> None: + pass + + strat = _MinimalStrategy("s1", context, [instrument]) + pos = PositionEvent( + instrument=instrument, + quantity=Decimal("10"), + avg_price=Decimal("100"), + ) + await strat.on_position(pos) + + +class TestStrategyAbstract: + """Tests for abstract method enforcement.""" + + def test_strategy_on_market_data_abstract(self) -> None: + """Cannot instantiate Strategy without implementing on_market_data.""" + with pytest.raises(TypeError, match="on_market_data"): + Strategy( # type: ignore[abstract] + strategy_id="s1", + context=None, # type: ignore[arg-type] + instruments=[], + ) + + +class TestStrategyConvenienceMethods: + """Tests for Strategy convenience methods.""" + + @pytest.mark.asyncio + async def test_strategy_emit_signal_convenience( + self, + bus: EventBus, + clock: LiveClock, + instrument: Instrument, + ) -> None: + """Strategy.emit_signal delegates to context.emit_signal.""" + ctx = StrategyContext(bus=bus, clock=clock) + strat = _TestStrategy("my-strat", ctx, [instrument]) + + received: list[SignalEvent] = [] + + async def handler(event: SignalEvent) -> None: + received.append(event) + + bus.subscribe(SignalEvent, handler) + await bus.start() + try: + await strat.emit_signal( + instrument=instrument, + direction=SignalDirection.SHORT, + strength=0.5, + metadata={"indicator": "rsi"}, + ) + await asyncio.sleep(0.05) + finally: + await bus.stop() + + assert len(received) == 1 + event = received[0] + assert event.strategy_id == "my-strat" + assert event.direction == SignalDirection.SHORT + assert event.strength == 0.5 + assert event.metadata == {"indicator": "rsi"} + + @pytest.mark.asyncio + async def test_strategy_request_order_convenience( + self, + bus: EventBus, + clock: LiveClock, + instrument: Instrument, + ) -> None: + """Strategy.request_order delegates to context.request_order.""" + ctx = StrategyContext(bus=bus, clock=clock) + strat = _TestStrategy("my-strat", ctx, [instrument]) + + submitted: list[OrderSubmitted] = [] + + async def handler(event: OrderSubmitted) -> None: + submitted.append(event) + + bus.subscribe(OrderSubmitted, handler) + await bus.start() + try: + request = await strat.request_order( + instrument=instrument, + side=Side.SELL, + quantity=Decimal("5"), + order_type=OrderType.MARKET, + ) + await asyncio.sleep(0.05) + finally: + await bus.stop() + + assert request.instrument == instrument + assert request.side == Side.SELL + assert request.quantity == Decimal("5") + assert len(submitted) == 1 + + +class TestConcreteStrategy: + """Tests for the concrete _TestStrategy receiving events.""" + + @pytest.mark.asyncio + async def test_concrete_strategy_on_market_data( + self, + context: StrategyContext, + instrument: Instrument, + ) -> None: + """Concrete strategy receives and records market data events.""" + strat = _TestStrategy("s1", context, [instrument]) + bar = BarEvent( + instrument=instrument, + open=Decimal("100"), + high=Decimal("105"), + low=Decimal("99"), + close=Decimal("103"), + volume=Decimal("1000"), + bar_start_ns=0, + bar_end_ns=60_000_000_000, + ) + await strat.on_market_data(bar) + assert len(strat.market_data_events) == 1 + assert strat.market_data_events[0] is bar + + @pytest.mark.asyncio + async def test_concrete_strategy_on_start_on_stop( + self, + context: StrategyContext, + instrument: Instrument, + ) -> None: + """Concrete strategy records start/stop lifecycle calls.""" + strat = _TestStrategy("s1", context, [instrument]) + assert not strat.started + assert not strat.stopped + await strat.on_start() + assert strat.started + await strat.on_stop() + assert strat.stopped + + @pytest.mark.asyncio + async def test_concrete_strategy_on_fill( + self, + context: StrategyContext, + instrument: Instrument, + ) -> None: + """Concrete strategy records fill events.""" + strat = _TestStrategy("s1", context, [instrument]) + fill = FillEvent( + order_id="ord-1", + instrument=instrument, + side=Side.BUY, + fill_price=Decimal("150"), + fill_quantity=Decimal("10"), + cumulative_quantity=Decimal("10"), + order_status=OrderStatus.FILLED, + ) + await strat.on_fill(fill) + assert len(strat.fill_events) == 1 + assert strat.fill_events[0] is fill + + @pytest.mark.asyncio + async def test_concrete_strategy_on_position( + self, + context: StrategyContext, + instrument: Instrument, + ) -> None: + """Concrete strategy records position events.""" + strat = _TestStrategy("s1", context, [instrument]) + pos = PositionEvent( + instrument=instrument, + quantity=Decimal("20"), + avg_price=Decimal("100"), + realized_pnl=Decimal("50"), + ) + await strat.on_position(pos) + assert len(strat.position_events) == 1 + assert strat.position_events[0] is pos diff --git a/tests/strategy/test_signal.py b/tests/strategy/test_signal.py new file mode 100644 index 0000000..410de7f --- /dev/null +++ b/tests/strategy/test_signal.py @@ -0,0 +1,442 @@ +"""Tests for sysls.strategy.signal module.""" + +from __future__ import annotations + +import time + +import pytest + +from sysls.core.events import SignalDirection, SignalEvent +from sysls.core.types import AssetClass, Instrument, Venue +from sysls.strategy.signal import ( + Signal, + SignalBook, + combine_signals_average, + combine_signals_majority, + combine_signals_weighted, + signal_from_event, + signal_to_event, +) + +# --------------------------------------------------------------------------- +# Test fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def instrument() -> Instrument: + """Provide a standard test instrument.""" + return Instrument( + symbol="NVDA", + asset_class=AssetClass.EQUITY, + venue=Venue.TASTYTRADE, + ) + + +@pytest.fixture() +def instrument_btc() -> Instrument: + """Provide a second test instrument.""" + return Instrument( + symbol="BTC-USDT-PERP", + asset_class=AssetClass.CRYPTO_PERP, + venue=Venue.CCXT, + exchange="binance", + currency="USDT", + ) + + +# --------------------------------------------------------------------------- +# Signal model tests +# --------------------------------------------------------------------------- + + +class TestSignal: + """Tests for the Signal model.""" + + def test_signal_creation(self, instrument: Instrument) -> None: + """Signal can be created with basic fields.""" + sig = Signal( + instrument=instrument, + direction=SignalDirection.LONG, + strength=0.8, + strategy_id="momentum", + ) + assert sig.instrument == instrument + assert sig.direction == SignalDirection.LONG + assert sig.strength == 0.8 + assert sig.strategy_id == "momentum" + assert sig.timestamp_ns > 0 + + def test_signal_strength_clamped_above(self, instrument: Instrument) -> None: + """Strength values above 1.0 are clamped to 1.0.""" + sig = Signal( + instrument=instrument, + direction=SignalDirection.LONG, + strength=5.0, + ) + assert sig.strength == 1.0 + + def test_signal_strength_clamped_below(self, instrument: Instrument) -> None: + """Strength values below -1.0 are clamped to -1.0.""" + sig = Signal( + instrument=instrument, + direction=SignalDirection.SHORT, + strength=-3.0, + ) + assert sig.strength == -1.0 + + def test_signal_strength_within_range_unchanged(self, instrument: Instrument) -> None: + """Strength values within [-1.0, 1.0] are not modified.""" + sig = Signal( + instrument=instrument, + direction=SignalDirection.LONG, + strength=0.5, + ) + assert sig.strength == 0.5 + + def test_signal_frozen(self, instrument: Instrument) -> None: + """Signal is immutable (frozen).""" + sig = Signal( + instrument=instrument, + direction=SignalDirection.LONG, + strength=0.5, + ) + with pytest.raises(Exception): # noqa: B017 + sig.strength = 0.9 # type: ignore[misc] + + def test_signal_default_metadata_empty(self, instrument: Instrument) -> None: + """Default metadata is an empty dict.""" + sig = Signal( + instrument=instrument, + direction=SignalDirection.FLAT, + ) + assert sig.metadata == {} + + def test_signal_default_strategy_id_empty(self, instrument: Instrument) -> None: + """Default strategy_id is an empty string.""" + sig = Signal( + instrument=instrument, + direction=SignalDirection.FLAT, + ) + assert sig.strategy_id == "" + + +# --------------------------------------------------------------------------- +# SignalBook tests +# --------------------------------------------------------------------------- + + +class TestSignalBook: + """Tests for the SignalBook container.""" + + def test_signal_book_update_and_get(self, instrument: Instrument) -> None: + """Can update and retrieve a signal.""" + book = SignalBook() + sig = Signal( + instrument=instrument, + direction=SignalDirection.LONG, + strength=0.7, + ) + book.update(sig) + result = book.get(instrument) + assert result is sig + + def test_signal_book_update_replaces(self, instrument: Instrument) -> None: + """Updating a signal for the same instrument replaces the previous one.""" + book = SignalBook() + sig1 = Signal( + instrument=instrument, + direction=SignalDirection.LONG, + strength=0.5, + ) + sig2 = Signal( + instrument=instrument, + direction=SignalDirection.SHORT, + strength=-0.8, + ) + book.update(sig1) + book.update(sig2) + result = book.get(instrument) + assert result is sig2 + + def test_signal_book_get_missing_returns_none(self, instrument: Instrument) -> None: + """Getting a signal for an unknown instrument returns None.""" + book = SignalBook() + assert book.get(instrument) is None + + def test_signal_book_remove(self, instrument: Instrument) -> None: + """Can remove a signal for an instrument.""" + book = SignalBook() + sig = Signal( + instrument=instrument, + direction=SignalDirection.LONG, + strength=0.5, + ) + book.update(sig) + book.remove(instrument) + assert book.get(instrument) is None + + def test_signal_book_remove_missing_no_error(self, instrument: Instrument) -> None: + """Removing a non-existent signal does not raise.""" + book = SignalBook() + book.remove(instrument) # Should not raise + + def test_signal_book_clear(self, instrument: Instrument, instrument_btc: Instrument) -> None: + """Clear removes all signals.""" + book = SignalBook() + book.update(Signal(instrument=instrument, direction=SignalDirection.LONG)) + book.update(Signal(instrument=instrument_btc, direction=SignalDirection.SHORT)) + assert len(book) == 2 + book.clear() + assert len(book) == 0 + + def test_signal_book_contains(self, instrument: Instrument) -> None: + """in-operator checks for active signal presence.""" + book = SignalBook() + assert instrument not in book + book.update(Signal(instrument=instrument, direction=SignalDirection.LONG)) + assert instrument in book + + def test_signal_book_len(self, instrument: Instrument, instrument_btc: Instrument) -> None: + """len returns the number of active signals.""" + book = SignalBook() + assert len(book) == 0 + book.update(Signal(instrument=instrument, direction=SignalDirection.LONG)) + assert len(book) == 1 + book.update(Signal(instrument=instrument_btc, direction=SignalDirection.SHORT)) + assert len(book) == 2 + + def test_signal_book_active_signals_filters_stale( + self, instrument: Instrument, instrument_btc: Instrument + ) -> None: + """Stale signals are filtered from active_signals when max_age is set.""" + book = SignalBook(max_age_seconds=1.0) + + # Create a stale signal (timestamp far in the past) + stale_ts = int((time.time() - 10) * 1_000_000_000) + stale_sig = Signal( + instrument=instrument, + direction=SignalDirection.LONG, + strength=0.5, + timestamp_ns=stale_ts, + ) + # Create a fresh signal + fresh_sig = Signal( + instrument=instrument_btc, + direction=SignalDirection.SHORT, + strength=-0.3, + ) + + book.update(stale_sig) + book.update(fresh_sig) + + active = book.active_signals + assert instrument not in active + assert instrument_btc in active + assert len(book) == 1 + + def test_signal_book_instruments( + self, instrument: Instrument, instrument_btc: Instrument + ) -> None: + """instruments property returns list of instruments with active signals.""" + book = SignalBook() + book.update(Signal(instrument=instrument, direction=SignalDirection.LONG)) + book.update(Signal(instrument=instrument_btc, direction=SignalDirection.SHORT)) + instruments = book.instruments + assert set(instruments) == {instrument, instrument_btc} + + def test_signal_book_no_max_age_returns_all(self, instrument: Instrument) -> None: + """Without max_age, all signals are considered active regardless of timestamp.""" + book = SignalBook() + old_ts = int((time.time() - 100_000) * 1_000_000_000) + sig = Signal( + instrument=instrument, + direction=SignalDirection.LONG, + strength=0.5, + timestamp_ns=old_ts, + ) + book.update(sig) + assert instrument in book + assert len(book) == 1 + + +# --------------------------------------------------------------------------- +# Combinator tests +# --------------------------------------------------------------------------- + + +class TestCombineSignalsAverage: + """Tests for combine_signals_average.""" + + def test_combine_signals_average(self, instrument: Instrument) -> None: + """Average of same-direction signals gives correct result.""" + signals = [ + Signal(instrument=instrument, direction=SignalDirection.LONG, strength=0.8), + Signal(instrument=instrument, direction=SignalDirection.LONG, strength=0.6), + Signal(instrument=instrument, direction=SignalDirection.LONG, strength=0.4), + ] + result = combine_signals_average(signals, instrument) + assert result.direction == SignalDirection.LONG + assert abs(result.strength - 0.6) < 1e-10 + + def test_combine_signals_average_mixed_directions(self, instrument: Instrument) -> None: + """Average of mixed-direction signals can produce any direction.""" + signals = [ + Signal(instrument=instrument, direction=SignalDirection.LONG, strength=0.5), + Signal(instrument=instrument, direction=SignalDirection.SHORT, strength=-0.8), + Signal(instrument=instrument, direction=SignalDirection.LONG, strength=0.1), + ] + result = combine_signals_average(signals, instrument) + # (0.5 + -0.8 + 0.1) / 3 = -0.2 / 3 = -0.0667 + assert result.direction == SignalDirection.SHORT + assert result.strength < 0 + + def test_combine_signals_average_empty_raises(self, instrument: Instrument) -> None: + """Combining an empty list raises ValueError.""" + with pytest.raises(ValueError, match="empty"): + combine_signals_average([], instrument) + + +class TestCombineSignalsMajority: + """Tests for combine_signals_majority.""" + + def test_combine_signals_majority(self, instrument: Instrument) -> None: + """Majority vote selects the most common direction.""" + signals = [ + Signal(instrument=instrument, direction=SignalDirection.LONG, strength=0.3), + Signal(instrument=instrument, direction=SignalDirection.LONG, strength=0.5), + Signal(instrument=instrument, direction=SignalDirection.SHORT, strength=-0.9), + ] + result = combine_signals_majority(signals, instrument) + assert result.direction == SignalDirection.LONG + # 2 out of 3 voted LONG + assert abs(result.strength - 2 / 3) < 1e-10 + + def test_combine_signals_majority_tie(self, instrument: Instrument) -> None: + """In a tie, FLAT is preferred as the conservative choice.""" + signals = [ + Signal(instrument=instrument, direction=SignalDirection.LONG, strength=0.5), + Signal(instrument=instrument, direction=SignalDirection.SHORT, strength=-0.5), + Signal(instrument=instrument, direction=SignalDirection.FLAT, strength=0.0), + ] + result = combine_signals_majority(signals, instrument) + # All three have count=1, tie -> FLAT preferred + assert result.direction == SignalDirection.FLAT + assert result.strength == 0.0 + + def test_combine_signals_majority_empty_raises(self, instrument: Instrument) -> None: + """Combining an empty list raises ValueError.""" + with pytest.raises(ValueError, match="empty"): + combine_signals_majority([], instrument) + + def test_combine_signals_majority_short_wins(self, instrument: Instrument) -> None: + """SHORT majority produces negative strength.""" + signals = [ + Signal(instrument=instrument, direction=SignalDirection.SHORT, strength=-0.5), + Signal(instrument=instrument, direction=SignalDirection.SHORT, strength=-0.8), + Signal(instrument=instrument, direction=SignalDirection.LONG, strength=0.3), + ] + result = combine_signals_majority(signals, instrument) + assert result.direction == SignalDirection.SHORT + assert result.strength < 0 + + +class TestCombineSignalsWeighted: + """Tests for combine_signals_weighted.""" + + def test_combine_signals_weighted(self, instrument: Instrument) -> None: + """Weighted combination uses normalized weights.""" + signals = [ + Signal(instrument=instrument, direction=SignalDirection.LONG, strength=1.0), + Signal(instrument=instrument, direction=SignalDirection.SHORT, strength=-1.0), + ] + # Weight the LONG signal 3x more than SHORT + result = combine_signals_weighted(signals, [3.0, 1.0], instrument) + # (1.0 * 0.75) + (-1.0 * 0.25) = 0.5 + assert result.direction == SignalDirection.LONG + assert abs(result.strength - 0.5) < 1e-10 + + def test_combine_signals_weighted_mismatched_lengths_raises( + self, instrument: Instrument + ) -> None: + """Mismatched signals and weights lengths raises ValueError.""" + signals = [ + Signal(instrument=instrument, direction=SignalDirection.LONG, strength=0.5), + ] + with pytest.raises(ValueError, match="same length"): + combine_signals_weighted(signals, [1.0, 2.0], instrument) + + def test_combine_signals_weighted_empty_raises(self, instrument: Instrument) -> None: + """Combining an empty list raises ValueError.""" + with pytest.raises(ValueError, match="empty"): + combine_signals_weighted([], [], instrument) + + def test_combine_signals_weighted_zero_weights(self, instrument: Instrument) -> None: + """All-zero weights produce a FLAT signal.""" + signals = [ + Signal(instrument=instrument, direction=SignalDirection.LONG, strength=1.0), + ] + result = combine_signals_weighted(signals, [0.0], instrument) + assert result.direction == SignalDirection.FLAT + assert result.strength == 0.0 + + +# --------------------------------------------------------------------------- +# Conversion tests +# --------------------------------------------------------------------------- + + +class TestSignalConversion: + """Tests for signal_from_event and signal_to_event.""" + + def test_signal_from_event(self, instrument: Instrument) -> None: + """Can convert a SignalEvent to a Signal model.""" + event = SignalEvent( + strategy_id="my-strat", + instrument=instrument, + direction=SignalDirection.LONG, + strength=0.75, + metadata={"key": "value"}, + ) + sig = signal_from_event(event) + assert sig.instrument == instrument + assert sig.direction == SignalDirection.LONG + assert sig.strength == 0.75 + assert sig.strategy_id == "my-strat" + assert sig.timestamp_ns == event.timestamp_ns + assert sig.metadata == {"key": "value"} + + def test_signal_to_event(self, instrument: Instrument) -> None: + """Can convert a Signal model to a SignalEvent.""" + sig = Signal( + instrument=instrument, + direction=SignalDirection.SHORT, + strength=-0.5, + strategy_id="mean-rev", + metadata={"indicator": "bollinger"}, + ) + event = signal_to_event(sig, source="test") + assert isinstance(event, SignalEvent) + assert event.instrument == instrument + assert event.direction == SignalDirection.SHORT + assert event.strength == -0.5 + assert event.strategy_id == "mean-rev" + assert event.metadata == {"indicator": "bollinger"} + assert event.source == "test" + + def test_signal_roundtrip(self, instrument: Instrument) -> None: + """Signal -> SignalEvent -> Signal preserves key fields.""" + original = Signal( + instrument=instrument, + direction=SignalDirection.LONG, + strength=0.9, + strategy_id="trend", + metadata={"tf": "1h"}, + ) + event = signal_to_event(original) + recovered = signal_from_event(event) + assert recovered.instrument == original.instrument + assert recovered.direction == original.direction + assert recovered.strength == original.strength + assert recovered.strategy_id == original.strategy_id + assert recovered.metadata == original.metadata