diff --git a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py new file mode 100644 index 000000000..ab6fb4abe --- /dev/null +++ b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py @@ -0,0 +1,71 @@ +"""Message recovery utilities for handling max token limit scenarios. + +This module provides functionality to recover and clean up incomplete messages that occur +when model responses are truncated due to maximum token limits being reached. It specifically +handles cases where tool use blocks are incomplete or malformed due to truncation. +""" + +import logging + +from ..types.content import ContentBlock, Message +from ..types.tools import ToolUse + +logger = logging.getLogger(__name__) + + +def recover_message_on_max_tokens_reached(message: Message) -> Message: + """Recover and clean up messages when max token limits are reached. + + When a model response is truncated due to maximum token limits, all tool use blocks + should be replaced with informative error messages since they may be incomplete or + unreliable. This function inspects the message content and: + + 1. Identifies all tool use blocks (regardless of validity) + 2. Replaces all tool uses with informative error messages + 3. Preserves all non-tool content blocks (text, images, etc.) + 4. Returns a cleaned message suitable for conversation history + + This recovery mechanism ensures that the conversation can continue gracefully even when + model responses are truncated, providing clear feedback about what happened and preventing + potentially incomplete or corrupted tool executions. + + Args: + message: The potentially incomplete message from the model that was truncated + due to max token limits. + + Returns: + A cleaned Message with all tool uses replaced by explanatory text content. + The returned message maintains the same role as the input message. + + Example: + If a message contains any tool use (complete or incomplete): + ``` + {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}} + ``` + + It will be replaced with: + ``` + {"text": "The selected tool calculator's tool use was incomplete due to maximum token limits being reached."} + ``` + """ + logger.info("handling max_tokens stop reason - replacing all tool uses with error messages") + + valid_content: list[ContentBlock] = [] + for content in message["content"] or []: + tool_use: ToolUse | None = content.get("toolUse") + if not tool_use: + valid_content.append(content) + continue + + # Replace all tool uses with error messages when max_tokens is reached + display_name = tool_use.get("name") or "" + logger.warning("tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name) + + valid_content.append( + { + "text": f"The selected tool {display_name}'s tool use was incomplete due " + f"to maximum token limits being reached." + } + ) + + return {"content": valid_content, "role": message["role"]} diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index ae21d4c6d..b36f73155 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -36,6 +36,7 @@ ) from ..types.streaming import Metrics, StopReason from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse +from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached from .streaming import stream_messages if TYPE_CHECKING: @@ -156,6 +157,9 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) ) + if stop_reason == "max_tokens": + message = recover_message_on_max_tokens_reached(message) + if model_invoke_span: tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason) break # Success! Break out of retry loop @@ -192,6 +196,19 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> raise e try: + # Add message in trace and mark the end of the stream messages trace + stream_trace.add_message(message) + stream_trace.end() + + # Add the response message to the conversation + agent.messages.append(message) + agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) + yield {"callback": {"message": message}} + + # Update metrics + agent.event_loop_metrics.update_usage(usage) + agent.event_loop_metrics.update_metrics(metrics) + if stop_reason == "max_tokens": """ Handle max_tokens limit reached by the model. @@ -205,21 +222,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> "Agent has reached an unrecoverable state due to max_tokens limit. " "For more information see: " "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" - ), - incomplete_message=message, + ) ) - # Add message in trace and mark the end of the stream messages trace - stream_trace.add_message(message) - stream_trace.end() - - # Add the response message to the conversation - agent.messages.append(message) - agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) - yield {"callback": {"message": message}} - - # Update metrics - agent.event_loop_metrics.update_usage(usage) - agent.event_loop_metrics.update_metrics(metrics) # If the model is requesting to use tools if stop_reason == "tool_use": diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 71ea28b9f..90f2b8d7f 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -2,8 +2,6 @@ from typing import Any -from strands.types.content import Message - class EventLoopException(Exception): """Exception raised by the event loop.""" @@ -28,14 +26,12 @@ class MaxTokensReachedException(Exception): the complexity of the response, or when the model naturally reaches its configured output limit during generation. """ - def __init__(self, message: str, incomplete_message: Message): + def __init__(self, message: str): """Initialize the exception with an error message and the incomplete message object. Args: message: The error message describing the token limit issue - incomplete_message: The valid Message object with incomplete content due to token limits """ - self.incomplete_message = incomplete_message super().__init__(message) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 3886df8b9..191ab51ba 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -305,8 +305,10 @@ async def test_event_loop_cycle_text_response_error( await alist(stream) +@patch("strands.event_loop.event_loop.recover_message_on_max_tokens_reached") @pytest.mark.asyncio async def test_event_loop_cycle_tool_result( + mock_recover_message, agent, model, system_prompt, @@ -339,6 +341,9 @@ async def test_event_loop_cycle_tool_result( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + # Verify that recover_message_on_max_tokens_reached was NOT called for tool_use stop reason + mock_recover_message.assert_not_called() + model.stream.assert_called_with( [ {"role": "user", "content": [{"text": "Hello"}]}, @@ -568,25 +573,35 @@ async def test_event_loop_cycle_max_tokens_exception( agenerator, alist, ): - """Test that max_tokens stop reason raises MaxTokensReachedException.""" + """Test that max_tokens stop reason calls _recover_message_on_max_tokens_reached then MaxTokensReachedException.""" - # Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495 - model.stream.return_value = agenerator( - [ - { - "contentBlockStart": { - "start": { - "toolUse": {}, + model.stream.side_effect = [ + agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": "asdf", + "input": {}, # empty + }, + }, }, }, - }, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "max_tokens"}}, - ] - ) + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "max_tokens"}}, + ] + ), + ] # Call event_loop_cycle, expecting it to raise MaxTokensReachedException - with pytest.raises(MaxTokensReachedException) as exc_info: + expected_message = ( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ) + with pytest.raises(MaxTokensReachedException, match=expected_message): stream = strands.event_loop.event_loop.event_loop_cycle( agent=agent, invocation_state={}, @@ -594,16 +609,8 @@ async def test_event_loop_cycle_max_tokens_exception( await alist(stream) # Verify the exception message contains the expected content - expected_message = ( - "Agent has reached an unrecoverable state due to max_tokens limit. " - "For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" - ) - assert str(exc_info.value) == expected_message - - # Verify that the message has not been appended to the messages array - assert len(agent.messages) == 1 - assert exc_info.value.incomplete_message not in agent.messages + assert len(agent.messages) == 2 + assert "tool use was incomplete due" in agent.messages[1]["content"][0]["text"] @patch("strands.event_loop.event_loop.get_tracer") diff --git a/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py new file mode 100644 index 000000000..402e90966 --- /dev/null +++ b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py @@ -0,0 +1,269 @@ +"""Tests for token limit recovery utility.""" + +from strands.event_loop._recover_message_on_max_tokens_reached import ( + recover_message_on_max_tokens_reached, +) +from strands.types.content import Message + + +def test_recover_message_on_max_tokens_reached_with_incomplete_tool_use(): + """Test recovery when incomplete tool use is present in the message.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "I'll help you with that."}, + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 2 + + # First content block should be preserved + assert result["content"][0] == {"text": "I'll help you with that."} + + # Second content block should be replaced with error message + assert "text" in result["content"][1] + assert "calculator" in result["content"][1]["text"] + assert "incomplete due to maximum token limits" in result["content"][1]["text"] + + +def test_recover_message_on_max_tokens_reached_with_missing_tool_name(): + """Test recovery when tool use has no name.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 1 + + # Content should be replaced with error message using + assert "text" in result["content"][0] + assert "" in result["content"][0]["text"] + assert "incomplete due to maximum token limits" in result["content"][0]["text"] + + +def test_recover_message_on_max_tokens_reached_with_missing_input(): + """Test recovery when tool use has no input.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "toolUseId": "123"}}, # Missing input + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 1 + + # Content should be replaced with error message + assert "text" in result["content"][0] + assert "calculator" in result["content"][0]["text"] + assert "incomplete due to maximum token limits" in result["content"][0]["text"] + + +def test_recover_message_on_max_tokens_reached_with_missing_tool_use_id(): + """Test recovery when tool use has no toolUseId.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}}}, # Missing toolUseId + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 1 + + # Content should be replaced with error message + assert "text" in result["content"][0] + assert "calculator" in result["content"][0]["text"] + assert "incomplete due to maximum token limits" in result["content"][0]["text"] + + +def test_recover_message_on_max_tokens_reached_with_valid_tool_use(): + """Test that even valid tool uses are replaced with error messages.""" + complete_message: Message = { + "role": "assistant", + "content": [ + {"text": "I'll help you with that."}, + {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}}, # Valid + ], + } + + result = recover_message_on_max_tokens_reached(complete_message) + + # Should replace even valid tool uses with error messages + assert result["role"] == "assistant" + assert len(result["content"]) == 2 + assert result["content"][0] == {"text": "I'll help you with that."} + + # Valid tool use should also be replaced with error message + assert "text" in result["content"][1] + assert "calculator" in result["content"][1]["text"] + assert "incomplete due to maximum token limits" in result["content"][1]["text"] + + +def test_recover_message_on_max_tokens_reached_with_empty_content(): + """Test handling of message with empty content.""" + empty_message: Message = {"role": "assistant", "content": []} + + result = recover_message_on_max_tokens_reached(empty_message) + + # Should return message with empty content preserved + assert result["role"] == "assistant" + assert result["content"] == [] + + +def test_recover_message_on_max_tokens_reached_with_none_content(): + """Test handling of message with None content.""" + none_content_message: Message = {"role": "assistant", "content": None} + + result = recover_message_on_max_tokens_reached(none_content_message) + + # Should return message with empty content + assert result["role"] == "assistant" + assert result["content"] == [] + + +def test_recover_message_on_max_tokens_reached_with_mixed_content(): + """Test recovery with mix of valid content and incomplete tool use.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "Let me calculate this for you."}, + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Incomplete + {"text": "And then I'll explain the result."}, + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 3 + + # First and third content blocks should be preserved + assert result["content"][0] == {"text": "Let me calculate this for you."} + assert result["content"][2] == {"text": "And then I'll explain the result."} + + # Second content block should be replaced with error message + assert "text" in result["content"][1] + assert "calculator" in result["content"][1]["text"] + assert "incomplete due to maximum token limits" in result["content"][1]["text"] + + +def test_recover_message_on_max_tokens_reached_preserves_non_tool_content(): + """Test that non-tool content is preserved as-is.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "Here's some text."}, + {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}}, + {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Incomplete + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 3 + + # First two content blocks should be preserved exactly + assert result["content"][0] == {"text": "Here's some text."} + assert result["content"][1] == {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}} + + # Third content block should be replaced with error message + assert "text" in result["content"][2] + assert "" in result["content"][2]["text"] + assert "incomplete due to maximum token limits" in result["content"][2]["text"] + + +def test_recover_message_on_max_tokens_reached_multiple_incomplete_tools(): + """Test recovery with multiple incomplete tool uses.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "input": {}}}, # Missing toolUseId + {"text": "Some text in between."}, + {"toolUse": {"name": "", "input": {}, "toolUseId": "456"}}, # Missing name + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 3 + + # First tool use should be replaced + assert "text" in result["content"][0] + assert "calculator" in result["content"][0]["text"] + assert "incomplete due to maximum token limits" in result["content"][0]["text"] + + # Text content should be preserved + assert result["content"][1] == {"text": "Some text in between."} + + # Second tool use should be replaced with + assert "text" in result["content"][2] + assert "" in result["content"][2]["text"] + assert "incomplete due to maximum token limits" in result["content"][2]["text"] + + +def test_recover_message_on_max_tokens_reached_preserves_user_role(): + """Test that the function preserves the original message role.""" + incomplete_message: Message = { + "role": "user", + "content": [ + {"toolUse": {"name": "calculator", "input": {}}}, # Missing toolUseId + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Should preserve the original role + assert result["role"] == "user" + assert len(result["content"]) == 1 + assert "text" in result["content"][0] + assert "calculator" in result["content"][0]["text"] + + +def test_recover_message_on_max_tokens_reached_with_content_without_tool_use(): + """Test handling of content blocks that don't have toolUse key.""" + message: Message = { + "role": "assistant", + "content": [ + {"text": "Regular text content."}, + {"someOtherKey": "someValue"}, # Content without toolUse + {"toolUse": {"name": "calculator"}}, # Incomplete tool use + ], + } + + result = recover_message_on_max_tokens_reached(message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 3 + + # First two content blocks should be preserved + assert result["content"][0] == {"text": "Regular text content."} + assert result["content"][1] == {"someOtherKey": "someValue"} + + # Third content block should be replaced with error message + assert "text" in result["content"][2] + assert "calculator" in result["content"][2]["text"] + assert "incomplete due to maximum token limits" in result["content"][2]["text"] diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index d9c2817b3..bf5668349 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -1,20 +1,48 @@ +import logging + import pytest +from src.strands.agent import AgentResult from strands import Agent, tool from strands.models.bedrock import BedrockModel from strands.types.exceptions import MaxTokensReachedException +logger = logging.getLogger(__name__) + @tool def story_tool(story: str) -> str: + """ + Tool that writes a story that is minimum 50,000 lines long. + """ return story -def test_context_window_overflow(): +def test_max_tokens_reached(): + """Test that MaxTokensReachedException is raised but the agent can still rerun on the second pass""" model = BedrockModel(max_tokens=100) agent = Agent(model=model, tools=[story_tool]) + # This should raise an exception with pytest.raises(MaxTokensReachedException): agent("Tell me a story!") - assert len(agent.messages) == 1 + # Validate that at least one message contains the incomplete tool use error message + expected_text = "tool use was incomplete due to maximum token limits being reached" + all_text_content = [ + content_block["text"] + for message in agent.messages + for content_block in message.get("content", []) + if "text" in content_block + ] + + assert any(expected_text in text for text in all_text_content), ( + f"Expected to find message containing '{expected_text}' in agent messages" + ) + + # Remove tools from agent and re-run with a generic question + agent.tool_registry.registry = {} + agent.tool_registry.tool_config = {} + + result: AgentResult = agent("What is 3+3") + assert result.stop_reason == "end_turn"