diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index ffcb6a5c9..ae21d4c6d 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -28,7 +28,12 @@ from ..telemetry.tracer import get_tracer from ..tools.executor import run_tools, validate_and_prepare_tools from ..types.content import Message -from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +from ..types.exceptions import ( + ContextWindowOverflowException, + EventLoopException, + MaxTokensReachedException, + ModelThrottledException, +) from ..types.streaming import Metrics, StopReason from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse from .streaming import stream_messages @@ -187,6 +192,22 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> raise e try: + if stop_reason == "max_tokens": + """ + Handle max_tokens limit reached by the model. + + When the model reaches its maximum token limit, this represents a potentially unrecoverable + state where the model's response was truncated. By default, Strands fails hard with an + MaxTokensReachedException to maintain consistency with other failure types. + """ + raise MaxTokensReachedException( + message=( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ), + incomplete_message=message, + ) # Add message in trace and mark the end of the stream messages trace stream_trace.add_message(message) stream_trace.end() @@ -231,7 +252,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> # Don't yield or log the exception - we already did it when we # raised the exception and we don't need that duplication. raise - except ContextWindowOverflowException as e: + except (ContextWindowOverflowException, MaxTokensReachedException) as e: + # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException if cycle_span: tracer.end_span_with_error(cycle_span, str(e), e) raise e diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 0d734b762..975fca3e9 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -414,7 +414,7 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".") + raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".') content = messages["content"] output_response: dict[str, Any] | None = None diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 9b36b4244..4ea1453a4 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -631,7 +631,7 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".") + raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".') content = messages["content"] output_response: dict[str, Any] | None = None diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 4bd3fd88e..71ea28b9f 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -2,6 +2,8 @@ from typing import Any +from strands.types.content import Message + class EventLoopException(Exception): """Exception raised by the event loop.""" @@ -18,6 +20,25 @@ def __init__(self, original_exception: Exception, request_state: Any = None) -> super().__init__(str(original_exception)) +class MaxTokensReachedException(Exception): + """Exception raised when the model reaches its maximum token generation limit. + + This exception is raised when the model stops generating tokens because it has reached the maximum number of + tokens allowed for output generation. This can occur when the model's max_tokens parameter is set too low for + the complexity of the response, or when the model naturally reaches its configured output limit during generation. + """ + + def __init__(self, message: str, incomplete_message: Message): + """Initialize the exception with an error message and the incomplete message object. + + Args: + message: The error message describing the token limit issue + incomplete_message: The valid Message object with incomplete content due to token limits + """ + self.incomplete_message = incomplete_message + super().__init__(message) + + class ContextWindowOverflowException(Exception): """Exception raised when the context window is exceeded. diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 1ac2f8258..3886df8b9 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -19,7 +19,12 @@ ) from strands.telemetry.metrics import EventLoopMetrics from strands.tools.registry import ToolRegistry -from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +from strands.types.exceptions import ( + ContextWindowOverflowException, + EventLoopException, + MaxTokensReachedException, + ModelThrottledException, +) from tests.fixtures.mock_hook_provider import MockHookProvider @@ -556,6 +561,51 @@ async def test_event_loop_tracing_with_model_error( mock_tracer.end_span_with_error.assert_called_once_with(model_span, "Input too long", model.stream.side_effect) +@pytest.mark.asyncio +async def test_event_loop_cycle_max_tokens_exception( + agent, + model, + agenerator, + alist, +): + """Test that max_tokens stop reason raises MaxTokensReachedException.""" + + # Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495 + model.stream.return_value = agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": {}, + }, + }, + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "max_tokens"}}, + ] + ) + + # Call event_loop_cycle, expecting it to raise MaxTokensReachedException + with pytest.raises(MaxTokensReachedException) as exc_info: + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await alist(stream) + + # Verify the exception message contains the expected content + expected_message = ( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ) + assert str(exc_info.value) == expected_message + + # Verify that the message has not been appended to the messages array + assert len(agent.messages) == 1 + assert exc_info.value.incomplete_message not in agent.messages + + @patch("strands.event_loop.event_loop.get_tracer") @pytest.mark.asyncio async def test_event_loop_tracing_with_tool_execution( diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py new file mode 100644 index 000000000..d9c2817b3 --- /dev/null +++ b/tests_integ/test_max_tokens_reached.py @@ -0,0 +1,20 @@ +import pytest + +from strands import Agent, tool +from strands.models.bedrock import BedrockModel +from strands.types.exceptions import MaxTokensReachedException + + +@tool +def story_tool(story: str) -> str: + return story + + +def test_context_window_overflow(): + model = BedrockModel(max_tokens=100) + agent = Agent(model=model, tools=[story_tool]) + + with pytest.raises(MaxTokensReachedException): + agent("Tell me a story!") + + assert len(agent.messages) == 1