Skip to content

feat: expose internals through tool decorator param StrandsContext #557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/strands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from . import agent, models, telemetry, types
from .agent.agent import Agent
from .tools.decorator import tool
from .types.tools import StrandsContext

__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry"]
__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry", "StrandsContext"]
54 changes: 47 additions & 7 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
from pydantic import BaseModel, Field, create_model
from typing_extensions import override

from ..types.tools import AgentTool, JSONSchema, ToolGenerator, ToolSpec, ToolUse
from ..types.tools import AgentTool, JSONSchema, StrandsContext, ToolGenerator, ToolSpec, ToolUse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -113,16 +113,16 @@ def _create_input_model(self) -> Type[BaseModel]:
This method analyzes the function's signature, type hints, and docstring to create a Pydantic model that can
validate input data before passing it to the function.

Special parameters like 'self', 'cls', and 'agent' are excluded from the model.
Special parameters that can be automatically injected are excluded from the model.

Returns:
A Pydantic BaseModel class customized for the function's parameters.
"""
field_definitions: dict[str, Any] = {}

for name, param in self.signature.parameters.items():
# Skip special parameters
if name in ("self", "cls", "agent"):
# Skip parameters that will be automatically injected
if self._is_special_parameter(name):
continue

# Get parameter type and default
Expand Down Expand Up @@ -252,6 +252,47 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
error_msg = str(e)
raise ValueError(f"Validation failed for input parameters: {error_msg}") from e

def inject_special_parameters(
self, validated_input: dict[str, Any], tool_use: ToolUse, invocation_state: dict[str, Any]
) -> None:
"""Inject special framework-provided parameters into the validated input.

This method automatically provides framework-level context to tools that request it
through their function signature.

Args:
validated_input: The validated input parameters (modified in place).
tool_use: The tool use request containing tool invocation details.
invocation_state: Context for the tool invocation, including agent state.
"""
# Inject StrandsContext if requested
if "strands_context" in self.signature.parameters:
strands_context: StrandsContext = {
"tool_use": tool_use,
"invocation_state": invocation_state,
}
validated_input["strands_context"] = strands_context

# Inject agent if requested (backward compatibility)
if "agent" in self.signature.parameters and "agent" in invocation_state:
validated_input["agent"] = invocation_state["agent"]

def _is_special_parameter(self, param_name: str) -> bool:
"""Check if a parameter should be automatically injected by the framework.

Special parameters include:
- Standard Python parameters: self, cls
- Framework-provided context parameters: agent, strands_context

Args:
param_name: The name of the parameter to check.

Returns:
True if the parameter should be excluded from input validation and
automatically injected during tool execution.
"""
return param_name in {"self", "cls", "agent", "strands_context"}


P = ParamSpec("P") # Captures all parameters
R = TypeVar("R") # Return type
Expand Down Expand Up @@ -402,9 +443,8 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
# Validate input against the Pydantic model
validated_input = self._metadata.validate_input(tool_input)

# Pass along the agent if provided and expected by the function
if "agent" in invocation_state and "agent" in self._metadata.signature.parameters:
validated_input["agent"] = invocation_state.get("agent")
# Inject special framework-provided parameters
self._metadata.inject_special_parameters(validated_input, tool_use, invocation_state)

# "Too few arguments" expected, hence the type ignore
if inspect.iscoroutinefunction(self._tool_func):
Expand Down
15 changes: 15 additions & 0 deletions src/strands/types/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,21 @@ class ToolChoiceTool(TypedDict):
name: str


class StrandsContext(TypedDict, total=False):
"""Context object containing framework-provided data for decorated tools.

This object provides access to framework-level information that may be useful
for tool implementations. All fields are optional to maintain backward compatibility.

Attributes:
tool_use: The complete ToolUse object containing tool invocation details.
invocation_state: Context for the tool invocation, including agent state.
"""

tool_use: ToolUse
invocation_state: dict[str, Any]


ToolChoice = Union[
dict[Literal["auto"], ToolChoiceAuto],
dict[Literal["any"], ToolChoiceAny],
Expand Down
49 changes: 48 additions & 1 deletion tests/strands/tools/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pytest

import strands
from strands.types.tools import ToolUse
from strands import Agent
from strands.types.tools import StrandsContext, ToolUse


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -1036,3 +1037,49 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None]
result = (await alist(stream))[-1]
assert result["status"] == "success"
assert "NoneType: None" in result["content"][0]["text"]


@pytest.mark.asyncio
async def test_strands_context_injection(alist):
"""Test that StrandsContext is properly injected into tools that request it."""

@strands.tool
def context_tool(message: str, agent: Agent, strands_context: StrandsContext) -> dict:
"""Tool that uses StrandsContext to access tool_use_id."""
tool_use_id = strands_context["tool_use"]["toolUseId"]
tool_name = strands_context["tool_use"]["name"]
agent_info = strands_context["invocation_state"].get("agent", "no-agent")

return {
"status": "success",
"content": [
{
"text": f"""
Tool '{tool_name}' (ID: {tool_use_id})
with agent '{agent_info}'
and injected agent '{agent}' processed: {message}
"""
}
],
}

# Test tool use with context injection
tool_use = {"toolUseId": "test-context-123", "name": "context_tool", "input": {"message": "hello world"}}
invocation_state = {"agent": "test-agent"}

stream = context_tool.stream(tool_use, invocation_state)
result = (await alist(stream))[-1]

assert result["status"] == "success"
assert result["toolUseId"] == "test-context-123"
assert "Tool 'context_tool' (ID: test-context-123)" in result["content"][0]["text"]
assert "with agent 'test-agent'" in result["content"][0]["text"]
assert "and injected agent 'test-agent'" in result["content"][0]["text"]
assert "processed: hello world" in result["content"][0]["text"]

# Verify strands_context and agent are excluded from schema
tool_spec = context_tool.tool_spec
schema_properties = tool_spec["inputSchema"]["json"].get("properties", {})
assert "message" in schema_properties
assert "strands_context" not in schema_properties
assert "agent" not in schema_properties
52 changes: 52 additions & 0 deletions tests_integ/test_strands_context_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/usr/bin/env python3
"""
Integration test for StrandsContext functionality with real agent interactions.
"""

import logging

from strands import Agent, StrandsContext, tool

logging.getLogger("strands").setLevel(logging.DEBUG)
logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()])


@tool
def tool_with_context(message: str, strands_context: StrandsContext) -> dict:
"""Tool that uses StrandsContext to access tool_use_id."""
tool_use_id = strands_context["tool_use"]["toolUseId"]
return {
"status": "success",
"content": [{"text": f"Context tool processed '{message}' with ID: {tool_use_id}"}],
}


@tool
def tool_with_agent_and_context(message: str, agent: Agent, strands_context: StrandsContext) -> dict:
"""Tool that uses both agent and StrandsContext."""
tool_use_id = strands_context["tool_use"]["toolUseId"]
agent_name = getattr(agent, "name", "unknown-agent")
return {
"status": "success",
"content": [{"text": f"Agent '{agent_name}' processed '{message}' with ID: {tool_use_id}"}],
}


def test_strands_context_integration():
"""Test StrandsContext functionality with real agent interactions."""

# Initialize agent with tools
agent = Agent(tools=[tool_with_context, tool_with_agent_and_context])

# Test tool with StrandsContext
result_with_context = agent.tool.tool_with_context(message="hello world")
assert (
"Context tool processed 'hello world' with ID: tooluse_tool_with_context_"
in result_with_context["content"][0]["text"]
)

result_with_agent_and_context = agent.tool.tool_with_agent_and_context(message="hello agent", agent=agent)
assert (
"Agent 'Strands Agents' processed 'hello agent' with ID: tooluse_tool_with_agent_and_context_"
in result_with_agent_and_context["content"][0]["text"]
)
Loading