diff --git a/code_puppy/agents/base_agent.py b/code_puppy/agents/base_agent.py index 86ff147de..4ef1d227d 100644 --- a/code_puppy/agents/base_agent.py +++ b/code_puppy/agents/base_agent.py @@ -37,6 +37,12 @@ UsageLimitExceeded, UsageLimits, ) + +from code_puppy.llm_retry import ( + LLMRetryConfig, + RetryExhaustedError, + llm_run_with_retry, +) from pydantic_ai.durable_exec.dbos import DBOSAgent from pydantic_ai.messages import ( ModelMessage, @@ -1890,6 +1896,7 @@ async def run_agent_task(): ) usage_limits = UsageLimits(request_limit=get_message_limit()) + retry_config = LLMRetryConfig() # Handle MCP servers - add them temporarily when using DBOS if ( @@ -1905,12 +1912,15 @@ async def run_agent_task(): try: # Set the workflow ID for DBOS context so DBOS and Code Puppy ID match with SetWorkflowID(group_id): - result_ = await pydantic_agent.run( - prompt_payload, - message_history=self.get_message_history(), - usage_limits=usage_limits, - event_stream_handler=event_stream_handler, - **kwargs, + result_ = await llm_run_with_retry( + lambda: pydantic_agent.run( + prompt_payload, + message_history=self.get_message_history(), + usage_limits=usage_limits, + event_stream_handler=event_stream_handler, + **kwargs, + ), + config=retry_config, ) return result_ finally: @@ -1918,22 +1928,28 @@ async def run_agent_task(): pydantic_agent._toolsets = original_toolsets elif get_use_dbos(): with SetWorkflowID(group_id): - result_ = await pydantic_agent.run( + result_ = await llm_run_with_retry( + lambda: pydantic_agent.run( + prompt_payload, + message_history=self.get_message_history(), + usage_limits=usage_limits, + event_stream_handler=event_stream_handler, + **kwargs, + ), + config=retry_config, + ) + return result_ + else: + # Non-DBOS path (MCP servers are already included) + result_ = await llm_run_with_retry( + lambda: pydantic_agent.run( prompt_payload, message_history=self.get_message_history(), usage_limits=usage_limits, event_stream_handler=event_stream_handler, **kwargs, - ) - return result_ - else: - # Non-DBOS path (MCP servers are already included) - result_ = await pydantic_agent.run( - prompt_payload, - message_history=self.get_message_history(), - usage_limits=usage_limits, - event_stream_handler=event_stream_handler, - **kwargs, + ), + config=retry_config, ) return result_ except* UsageLimitExceeded as ule: @@ -1942,6 +1958,16 @@ async def run_agent_task(): "The agent has reached its usage limit. You can ask it to continue by saying 'please continue' or similar.", group_id=group_id, ) + except* RetryExhaustedError as retry_error: + emit_info( + f"API request failed after retries: {str(retry_error)}", + group_id=group_id, + ) + emit_info( + "The API may be experiencing high load. Try again in a moment, " + "or switch models with /model.", + group_id=group_id, + ) except* mcp.shared.exceptions.McpError as mcp_error: emit_info(f"MCP server error: {str(mcp_error)}", group_id=group_id) emit_info(f"{str(mcp_error)}", group_id=group_id) diff --git a/code_puppy/callbacks.py b/code_puppy/callbacks.py index 047a70a02..c0f3b6733 100644 --- a/code_puppy/callbacks.py +++ b/code_puppy/callbacks.py @@ -34,6 +34,8 @@ "register_model_providers", "message_history_processor_start", "message_history_processor_end", + "api_retry_start", + "api_retry_end", ] CallbackFunc = Callable[..., Any] @@ -68,6 +70,8 @@ "register_model_providers": [], "message_history_processor_start": [], "message_history_processor_end": [], + "api_retry_start": [], + "api_retry_end": [], } logger = logging.getLogger(__name__) diff --git a/code_puppy/llm_retry.py b/code_puppy/llm_retry.py new file mode 100644 index 000000000..c39f09e84 --- /dev/null +++ b/code_puppy/llm_retry.py @@ -0,0 +1,365 @@ +""" +LLM API retry engine with exponential backoff, jitter, and gateway awareness. + +Wraps pydantic_ai agent.run() calls with retry logic for transient API +failures (429, 529, 5xx, network errors). + +Usage: + from code_puppy.llm_retry import llm_run_with_retry, LLMRetryConfig + + result = await llm_run_with_retry( + lambda: pydantic_agent.run(prompt, message_history=history, ...), + config=LLMRetryConfig(), + ) +""" + +import asyncio +import logging +import os +import random +from dataclasses import dataclass, field +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- +_DEFAULT_MAX_RETRIES = 10 +_BASE_DELAY_MS = 500 +_MAX_DELAY_MS = 32_000 +_MAX_CONSECUTIVE_OVERLOADS = 3 + + +# --------------------------------------------------------------------------- +# Config & errors +# --------------------------------------------------------------------------- +def _resolve_max_retries() -> int: + env = os.environ.get("PUPPY_MAX_LLM_RETRIES") + if env is not None: + try: + value = int(env) + if value < 0: + logger.warning( + "PUPPY_MAX_LLM_RETRIES=%r must be >= 0, using default %d", + env, + _DEFAULT_MAX_RETRIES, + ) + return _DEFAULT_MAX_RETRIES + return value + except ValueError: + logger.warning( + "PUPPY_MAX_LLM_RETRIES=%r is not a valid integer, using default %d", + env, + _DEFAULT_MAX_RETRIES, + ) + return _DEFAULT_MAX_RETRIES + + +@dataclass +class LLMRetryConfig: + """Configuration for the LLM retry engine.""" + + max_retries: int = field(default_factory=_resolve_max_retries) + cancel_event: Optional[asyncio.Event] = None + + +class RetryExhaustedError(Exception): + """All retry attempts failed.""" + + def __init__(self, message: str, original_error: Exception): + super().__init__(message) + self.original_error = original_error + + +# --------------------------------------------------------------------------- +# Error introspection helpers +# --------------------------------------------------------------------------- +def _get_status_code(error: Exception) -> Optional[int]: + """Extract HTTP status code from various error types.""" + # anthropic SDK / pydantic_ai errors + if hasattr(error, "status_code"): + return error.status_code + # Some errors wrap an HTTP response + resp = getattr(error, "response", None) + if resp is not None and hasattr(resp, "status_code"): + return resp.status_code + return None + + +def _get_retry_after(error: Exception) -> Optional[float]: + """Extract Retry-After header value in seconds, or None.""" + # Try direct headers attribute (anthropic SDK errors) + headers = getattr(error, "headers", None) + # Fall back to response headers + if headers is None: + resp = getattr(error, "response", None) + headers = getattr(resp, "headers", None) + if not headers: + return None + + val = None + if hasattr(headers, "get"): + val = headers.get("retry-after") or headers.get("Retry-After") + if val is None: + return None + try: + return float(val) + except (ValueError, TypeError): + return None + + +def _get_x_should_retry(error: Exception) -> Optional[bool]: + """Check x-should-retry header. Returns True, False, or None if absent.""" + headers = getattr(error, "headers", None) + if headers is None: + resp = getattr(error, "response", None) + headers = getattr(resp, "headers", None) + if not headers or not hasattr(headers, "get"): + return None + val = headers.get("x-should-retry") + if val == "true": + return True + if val == "false": + return False + return None + + +def _is_overloaded(error: Exception) -> bool: + """True for 529 or body-level overloaded_error.""" + if _get_status_code(error) == 529: + return True + return "overloaded_error" in str(error).lower() + + +# --------------------------------------------------------------------------- +# Retryability decision +# --------------------------------------------------------------------------- +def is_retryable(error: Exception) -> bool: + """Determine whether an LLM API error should be retried. + + Returns: + True if the error is transient and the request should be retried. + """ + # Never retry cancellation + if isinstance(error, (asyncio.CancelledError, KeyboardInterrupt)): + return False + + # x-should-retry header is authoritative when present + hint = _get_x_should_retry(error) + if hint is False: + return False + if hint is True: + return True + + # Overloaded errors (529 or body-level) + if _is_overloaded(error): + return True + + # Network-level errors are always retryable — check the error itself and + # walk the full __cause__/__context__ chain. The real chain in production is: + # pydantic_ai ModelAPIError → SDK APIConnectionError → httpx ConnectError → OSError + # None of the intermediate types inherit from Python's ConnectionError, so we + # must walk all the way down to find the stdlib error at the root. + _network_types = (asyncio.TimeoutError, ConnectionError, OSError) + if isinstance(error, _network_types): + return True + exc: BaseException | None = error + for _ in range(10): # bounded walk to prevent infinite loops + exc = getattr(exc, "__cause__", None) or getattr(exc, "__context__", None) + if exc is None: + break + if isinstance(exc, _network_types): + return True + + # Streaming errors from pydantic_ai — only retry the specific transient + # message that pydantic_ai raises when a streamed response terminates early. + error_msg = str(error).lower() + if "streamed response ended" in error_msg: + return True + # Schema/validation errors are always fatal — never retry + if "schema" in error_msg or "validation" in error_msg: + return False + + status = _get_status_code(error) + if status is None: + # No status code and not a recognized transient pattern — don't retry + # blindly. Only specifically-identified patterns above are retried. + return False + + if status == 408: + return True # Request Timeout + if status == 409: + return True # Conflict + if status == 429: + return True # Rate Limit + if status == 401: + return True # Unauthorized (token may need refresh) + if status >= 500: + return True # Server errors + + # 400 (non-overflow), 402, 403, 404, 422, etc. — fatal + return False + + +# --------------------------------------------------------------------------- +# Backoff formula +# --------------------------------------------------------------------------- +def _compute_backoff(attempt: int, retry_after_secs: Optional[float] = None) -> float: + """Compute retry delay in seconds. + + Uses exponential backoff with up to 25% jitter. + Server-provided Retry-After header takes absolute priority. + + Args: + attempt: 1-based attempt number. + retry_after_secs: Value from Retry-After header, if present. + + Returns: + Delay in seconds. + """ + if retry_after_secs is not None and retry_after_secs > 0: + return retry_after_secs + + base = min(_BASE_DELAY_MS * (2 ** (attempt - 1)), _MAX_DELAY_MS) / 1000.0 + jitter = random.random() * 0.25 * base + return base + jitter + + +# --------------------------------------------------------------------------- +# Abort-aware sleep +# --------------------------------------------------------------------------- +async def _cancellable_sleep( + seconds: float, cancel_event: Optional[asyncio.Event] +) -> None: + """Sleep that aborts immediately if cancel_event is set.""" + if cancel_event is None: + await asyncio.sleep(seconds) + return + + if cancel_event.is_set(): + raise asyncio.CancelledError("LLM retry sleep interrupted by cancel event") + + try: + await asyncio.wait_for(cancel_event.wait(), timeout=seconds) + # If we get here, the event was set during the wait + raise asyncio.CancelledError("LLM retry sleep interrupted by cancel event") + except asyncio.TimeoutError: + pass # Normal: full duration elapsed without cancellation + + +# --------------------------------------------------------------------------- +# Main retry loop +# --------------------------------------------------------------------------- +async def llm_run_with_retry( + coro_factory: Callable[[], Any], + config: Optional[LLMRetryConfig] = None, +) -> Any: + """Execute an LLM API call with retry logic. + + Wraps a coroutine factory (typically ``lambda: agent.run(...)``) with + production-grade retry handling for transient API failures. + + Args: + coro_factory: Callable that returns a fresh coroutine for each attempt. + Must create a new coroutine on every call (use a lambda). + config: Retry configuration. Uses defaults if None. + + Returns: + The successful result of coro_factory(). + + Raises: + RetryExhaustedError: All retry attempts failed. + asyncio.CancelledError: If cancelled during retry sleep. + """ + # Lazy import to avoid circular dependency + from code_puppy.callbacks import _trigger_callbacks + + if config is None: + config = LLMRetryConfig() + + max_retries = config.max_retries + overload_hits = 0 + last_error: Optional[Exception] = None + + # 1 initial attempt + max_retries retries = max_retries + 1 total attempts + for attempt in range(1, max_retries + 2): + try: + result = await coro_factory() + + # If we recovered after retries, fire the callback + if attempt > 1: + await _trigger_callbacks( + "api_retry_end", + total_attempts=attempt, + ) + + return result + + except (asyncio.CancelledError, KeyboardInterrupt): + raise # Never swallow cancellation + + except Exception as error: + last_error = error + status = _get_status_code(error) + + logger.warning( + "LLM API error on attempt %d/%d: %s: %s (status=%s)", + attempt, + max_retries + 1, + type(error).__name__, + error, + status, + ) + + # Track consecutive overloads — short-circuit early when the + # model is clearly overloaded rather than burning all retries + if _is_overloaded(error): + overload_hits += 1 + if overload_hits >= _MAX_CONSECUTIVE_OVERLOADS: + raise RetryExhaustedError( + f"API returned {_MAX_CONSECUTIVE_OVERLOADS} consecutive " + f"overloaded errors", + error, + ) from error + else: + overload_hits = 0 + + # Retries exhausted + if attempt > max_retries: + raise RetryExhaustedError( + f"LLM API call failed after {max_retries} retries: {error}", + error, + ) from error + + # Non-retryable → fail immediately + if not is_retryable(error): + raise + + # Compute delay + retry_after = _get_retry_after(error) + delay_secs = _compute_backoff(attempt, retry_after) + + # Notify plugins + await _trigger_callbacks( + "api_retry_start", + error=error, + attempt=attempt, + delay_ms=int(delay_secs * 1000), + max_retries=max_retries, + ) + + logger.info( + "Retrying LLM API call in %.1fs (attempt %d/%d)", + delay_secs, + attempt, + max_retries + 1, + ) + + await _cancellable_sleep(delay_secs, config.cancel_event) + + # Unreachable, but satisfies type checker + raise RetryExhaustedError( + f"LLM API call failed: {last_error}", + last_error, # type: ignore[arg-type] + ) diff --git a/tests/test_llm_retry.py b/tests/test_llm_retry.py new file mode 100644 index 000000000..6e113de93 --- /dev/null +++ b/tests/test_llm_retry.py @@ -0,0 +1,467 @@ +"""Tests for code_puppy/llm_retry.py — LLM API retry engine.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from code_puppy.llm_retry import ( + LLMRetryConfig, + RetryExhaustedError, + _cancellable_sleep, + _compute_backoff, + _get_retry_after, + _get_status_code, + _get_x_should_retry, + _is_overloaded, + is_retryable, + llm_run_with_retry, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def _make_api_error(status_code: int, message: str = "error", headers=None): + """Create a mock API error with a status code and optional headers.""" + err = Exception(message) + err.status_code = status_code + err.headers = headers or {} + return err + + +def _make_overloaded_error(): + """Create a 529 overloaded error.""" + return _make_api_error(529, '{"type":"overloaded_error"}') + + +# --------------------------------------------------------------------------- +# _get_status_code +# --------------------------------------------------------------------------- +class TestGetStatusCode: + def test_direct_attribute(self): + err = _make_api_error(429) + assert _get_status_code(err) == 429 + + def test_response_attribute(self): + err = Exception("fail") + err.response = MagicMock(status_code=503) + assert _get_status_code(err) == 503 + + def test_no_status(self): + assert _get_status_code(Exception("boom")) is None + + +# --------------------------------------------------------------------------- +# _get_retry_after +# --------------------------------------------------------------------------- +class TestGetRetryAfter: + def test_direct_headers(self): + err = _make_api_error(429, headers={"retry-after": "5"}) + assert _get_retry_after(err) == 5.0 + + def test_response_headers(self): + err = Exception("fail") + err.headers = None + err.response = MagicMock(headers={"Retry-After": "10"}) + assert _get_retry_after(err) == 10.0 + + def test_none_when_absent(self): + assert _get_retry_after(Exception("no headers")) is None + + def test_invalid_value(self): + err = _make_api_error(429, headers={"retry-after": "not-a-number"}) + assert _get_retry_after(err) is None + + +# --------------------------------------------------------------------------- +# _get_x_should_retry +# --------------------------------------------------------------------------- +class TestGetXShouldRetry: + def test_true(self): + err = _make_api_error(500, headers={"x-should-retry": "true"}) + assert _get_x_should_retry(err) is True + + def test_false(self): + err = _make_api_error(400, headers={"x-should-retry": "false"}) + assert _get_x_should_retry(err) is False + + def test_absent(self): + err = _make_api_error(500, headers={}) + assert _get_x_should_retry(err) is None + + +# --------------------------------------------------------------------------- +# _is_overloaded +# --------------------------------------------------------------------------- +class TestIsOverloaded: + def test_529_status(self): + assert _is_overloaded(_make_api_error(529)) is True + + def test_overloaded_in_body(self): + err = Exception('{"type":"overloaded_error","message":"busy"}') + assert _is_overloaded(err) is True + + def test_not_overloaded(self): + assert _is_overloaded(_make_api_error(429)) is False + + +# --------------------------------------------------------------------------- +# _compute_backoff +# --------------------------------------------------------------------------- +class TestComputeBackoff: + def test_retry_after_takes_priority(self): + assert _compute_backoff(1, retry_after_secs=5.0) == 5.0 + + def test_attempt_1_is_500ms(self): + # Base is 500ms = 0.5s, jitter adds up to 25% + delay = _compute_backoff(1) + assert 0.5 <= delay <= 0.625 + + def test_attempt_2_is_1s(self): + delay = _compute_backoff(2) + assert 1.0 <= delay <= 1.25 + + def test_attempt_7_capped_at_32s(self): + delay = _compute_backoff(7) + assert 32.0 <= delay <= 40.0 + + def test_attempt_10_still_capped(self): + delay = _compute_backoff(10) + assert 32.0 <= delay <= 40.0 + + def test_zero_retry_after_ignored(self): + delay = _compute_backoff(1, retry_after_secs=0.0) + # 0.0 is not > 0, so falls through to computed backoff + assert 0.5 <= delay <= 0.625 + + def test_negative_retry_after_ignored(self): + delay = _compute_backoff(1, retry_after_secs=-1.0) + assert 0.5 <= delay <= 0.625 + + +# --------------------------------------------------------------------------- +# is_retryable +# --------------------------------------------------------------------------- +class TestIsRetryable: + def test_429_retryable(self): + assert is_retryable(_make_api_error(429)) is True + + def test_529_retryable(self): + assert is_retryable(_make_api_error(529)) is True + + def test_503_retryable(self): + assert is_retryable(_make_api_error(503)) is True + + def test_500_retryable(self): + assert is_retryable(_make_api_error(500)) is True + + def test_408_retryable(self): + assert is_retryable(_make_api_error(408)) is True + + def test_409_retryable(self): + assert is_retryable(_make_api_error(409)) is True + + def test_401_retryable(self): + assert is_retryable(_make_api_error(401)) is True + + def test_400_not_retryable(self): + assert is_retryable(_make_api_error(400)) is False + + def test_402_not_retryable(self): + assert is_retryable(_make_api_error(402)) is False + + def test_404_not_retryable(self): + assert is_retryable(_make_api_error(404)) is False + + def test_422_not_retryable(self): + assert is_retryable(_make_api_error(422)) is False + + def test_cancelled_error_not_retryable(self): + assert is_retryable(asyncio.CancelledError()) is False + + def test_keyboard_interrupt_not_retryable(self): + assert is_retryable(KeyboardInterrupt()) is False + + def test_timeout_error_retryable(self): + assert is_retryable(asyncio.TimeoutError()) is True + + def test_connection_error_retryable(self): + assert is_retryable(ConnectionError("reset")) is True + + def test_os_error_retryable(self): + assert is_retryable(OSError("network unreachable")) is True + + def test_wrapped_connection_error_retryable(self): + """pydantic_ai wraps ConnectionError in ModelAPIError (RuntimeError). + Our engine must detect the __cause__ chain.""" + wrapper = RuntimeError("Connection error.") + wrapper.__cause__ = ConnectionError("nodename nor servname provided") + assert is_retryable(wrapper) is True + + def test_wrapped_os_error_retryable(self): + """Wrapped OSError via __context__ should also be retryable.""" + wrapper = RuntimeError("Network failure") + wrapper.__context__ = OSError("network unreachable") + assert is_retryable(wrapper) is True + + def test_deep_chain_connection_error_retryable(self): + """Real chain: ModelAPIError -> APIConnectionError -> httpx.ConnectError -> OSError. + Must walk the full chain to find the stdlib OSError at the root.""" + os_error = OSError("nodename nor servname provided") + httpx_error = Exception("Connect error") # simulates httpx.ConnectError + httpx_error.__cause__ = os_error + sdk_error = Exception("Connection error") # simulates SDK APIConnectionError + sdk_error.__cause__ = httpx_error + model_error = RuntimeError("Connection error.") # simulates ModelAPIError + model_error.__cause__ = sdk_error + assert is_retryable(model_error) is True + + def test_x_should_retry_false_overrides(self): + err = _make_api_error(500, headers={"x-should-retry": "false"}) + assert is_retryable(err) is False + + def test_x_should_retry_true_overrides(self): + err = _make_api_error(400, headers={"x-should-retry": "true"}) + assert is_retryable(err) is True + + def test_unknown_error_no_status_not_retryable(self): + """Unknown errors with no status code are NOT retried.""" + assert is_retryable(Exception("mysterious error")) is False + + def test_streaming_error_retryable(self): + """Transient streaming errors (pydantic_ai) are retried.""" + err = Exception("Streamed response ended without content") + assert is_retryable(err) is True + err2 = Exception("Streamed response ended without content or tool calls") + assert is_retryable(err2) is True + + def test_validation_error_not_retryable(self): + """Schema/validation errors are fatal — never retried.""" + assert is_retryable(Exception("Schema validation failed")) is False + assert is_retryable(Exception("Response validation error")) is False + + def test_overloaded_body_retryable(self): + err = Exception('{"type":"overloaded_error"}') + assert is_retryable(err) is True + + +# --------------------------------------------------------------------------- +# _cancellable_sleep +# --------------------------------------------------------------------------- +class TestCancellableSleep: + @pytest.mark.asyncio + async def test_normal_sleep(self): + """Sleep completes normally without cancel event.""" + await _cancellable_sleep(0.01, cancel_event=None) + + @pytest.mark.asyncio + async def test_cancel_interrupts(self): + """Setting the event during sleep raises CancelledError.""" + event = asyncio.Event() + + async def set_after_delay(): + await asyncio.sleep(0.01) + event.set() + + asyncio.create_task(set_after_delay()) + with pytest.raises(asyncio.CancelledError): + await _cancellable_sleep(10.0, cancel_event=event) + + @pytest.mark.asyncio + async def test_already_set_event(self): + """If event is already set, raises immediately.""" + event = asyncio.Event() + event.set() + with pytest.raises(asyncio.CancelledError): + await _cancellable_sleep(10.0, cancel_event=event) + + +# --------------------------------------------------------------------------- +# llm_run_with_retry +# --------------------------------------------------------------------------- +class TestLLMRunWithRetry: + @pytest.mark.asyncio + async def test_success_first_try(self): + """Succeeds on first attempt — no retries needed.""" + factory = AsyncMock(return_value="result") + result = await llm_run_with_retry(factory, config=LLMRetryConfig(max_retries=3)) + assert result == "result" + assert factory.call_count == 1 + + @pytest.mark.asyncio + async def test_retry_then_succeed(self): + """Fails with 429, then succeeds on second attempt.""" + call_count = 0 + + async def factory(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise _make_api_error(429) + return "recovered" + + with patch("code_puppy.llm_retry._compute_backoff", return_value=0.001): + result = await llm_run_with_retry( + factory, config=LLMRetryConfig(max_retries=3) + ) + assert result == "recovered" + assert call_count == 2 + + @pytest.mark.asyncio + async def test_retry_exhaustion(self): + """All retries fail — raises RetryExhaustedError.""" + + async def factory(): + raise _make_api_error(500) + + with patch("code_puppy.llm_retry._compute_backoff", return_value=0.001): + with pytest.raises(RetryExhaustedError) as exc_info: + await llm_run_with_retry(factory, config=LLMRetryConfig(max_retries=2)) + assert exc_info.value.original_error.status_code == 500 + + @pytest.mark.asyncio + async def test_fatal_error_not_retried(self): + """Non-retryable error raises immediately without retry.""" + call_count = 0 + + async def factory(): + nonlocal call_count + call_count += 1 + raise _make_api_error(400, "bad request") + + with pytest.raises(Exception, match="bad request"): + await llm_run_with_retry(factory, config=LLMRetryConfig(max_retries=5)) + assert call_count == 1 # No retry — failed on first attempt + + @pytest.mark.asyncio + async def test_cancelled_error_propagates(self): + """CancelledError is never swallowed.""" + + async def factory(): + raise asyncio.CancelledError() + + with pytest.raises(asyncio.CancelledError): + await llm_run_with_retry(factory, config=LLMRetryConfig(max_retries=5)) + + @pytest.mark.asyncio + async def test_consecutive_529_exhaustion(self): + """3 consecutive 529s raises RetryExhaustedError early.""" + + async def factory(): + raise _make_overloaded_error() + + with patch("code_puppy.llm_retry._compute_backoff", return_value=0.001): + with pytest.raises(RetryExhaustedError, match="overloaded"): + await llm_run_with_retry( + factory, + config=LLMRetryConfig(max_retries=10), + ) + + @pytest.mark.asyncio + async def test_overload_counter_resets(self): + """Non-overload errors reset the consecutive overload counter.""" + call_count = 0 + + async def factory(): + nonlocal call_count + call_count += 1 + if call_count <= 2: + raise _make_overloaded_error() # 2 overloads + if call_count == 3: + raise _make_api_error(500) # resets counter + if call_count <= 5: + raise _make_overloaded_error() # 2 more overloads + return "ok" + + with patch("code_puppy.llm_retry._compute_backoff", return_value=0.001): + result = await llm_run_with_retry( + factory, config=LLMRetryConfig(max_retries=10) + ) + assert result == "ok" + assert call_count == 6 + + @pytest.mark.asyncio + async def test_retry_after_header_respected(self): + """Retry-After header value flows into backoff calculation.""" + call_count = 0 + + async def factory(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise _make_api_error(429, headers={"retry-after": "1.5"}) + return "ok" + + with patch( + "code_puppy.llm_retry._compute_backoff", return_value=0.001 + ) as backoff: + result = await llm_run_with_retry( + factory, config=LLMRetryConfig(max_retries=3) + ) + backoff.assert_called_once_with(1, 1.5) + assert result == "ok" + assert call_count == 2 + + @pytest.mark.asyncio + async def test_api_retry_callbacks_fired(self): + """api_retry_start callback is triggered before each retry sleep.""" + call_count = 0 + callback_calls = [] + + async def factory(): + nonlocal call_count + call_count += 1 + if call_count <= 2: + raise _make_api_error(503) + return "ok" + + async def mock_trigger(phase, **kwargs): + callback_calls.append((phase, kwargs)) + return [] + + with ( + patch("code_puppy.llm_retry._compute_backoff", return_value=0.001), + patch( + "code_puppy.callbacks._trigger_callbacks", + side_effect=mock_trigger, + ), + ): + result = await llm_run_with_retry( + factory, config=LLMRetryConfig(max_retries=5) + ) + + assert result == "ok" + # 2 retries → 2 api_retry_start + 1 api_retry_end + start_calls = [c for c in callback_calls if c[0] == "api_retry_start"] + end_calls = [c for c in callback_calls if c[0] == "api_retry_end"] + assert len(start_calls) == 2 + assert len(end_calls) == 1 + + @pytest.mark.asyncio + async def test_config_from_env(self): + """PUPPY_MAX_LLM_RETRIES env var overrides default.""" + with patch.dict("os.environ", {"PUPPY_MAX_LLM_RETRIES": "2"}): + config = LLMRetryConfig() + assert config.max_retries == 2 + + @pytest.mark.asyncio + async def test_config_invalid_env(self): + """Invalid env var falls back to default.""" + with patch.dict("os.environ", {"PUPPY_MAX_LLM_RETRIES": "not-a-number"}): + config = LLMRetryConfig() + assert config.max_retries == 10 + + @pytest.mark.asyncio + async def test_config_negative_env(self): + """Negative env var falls back to default.""" + with patch.dict("os.environ", {"PUPPY_MAX_LLM_RETRIES": "-3"}): + config = LLMRetryConfig() + assert config.max_retries == 10 + + @pytest.mark.asyncio + async def test_default_config(self): + """Default config uses sensible defaults.""" + config = LLMRetryConfig() + assert config.max_retries == 10 + assert config.cancel_event is None