Skip to content

Commit 4fcc5a4

Browse files
authored
Python: propagate as_tool() kwargs. Add sample for runtime context with as_tool kwargs and middleware. (#2311)
* as tool kwargs * simplify
1 parent 79bb870 commit 4fcc5a4

File tree

7 files changed

+829
-13
lines changed

7 files changed

+829
-13
lines changed

python/packages/core/agent_framework/_agents.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -454,13 +454,16 @@ async def agent_wrapper(**kwargs: Any) -> str:
454454
# Extract the input from kwargs using the specified arg_name
455455
input_text = kwargs.get(arg_name, "")
456456

457+
# Forward all kwargs except the arg_name to support runtime context propagation
458+
forwarded_kwargs = {k: v for k, v in kwargs.items() if k != arg_name}
459+
457460
if stream_callback is None:
458461
# Use non-streaming mode
459-
return (await self.run(input_text)).text
462+
return (await self.run(input_text, **forwarded_kwargs)).text
460463

461464
# Use streaming mode - accumulate updates and create final response
462465
response_updates: list[AgentRunResponseUpdate] = []
463-
async for update in self.run_stream(input_text):
466+
async for update in self.run_stream(input_text, **forwarded_kwargs):
464467
response_updates.append(update)
465468
if is_async_callback:
466469
await stream_callback(update) # type: ignore[misc]
@@ -470,12 +473,14 @@ async def agent_wrapper(**kwargs: Any) -> str:
470473
# Create final text from accumulated updates
471474
return AgentRunResponse.from_agent_run_response_updates(response_updates).text
472475

473-
return AIFunction(
476+
agent_tool: AIFunction[BaseModel, str] = AIFunction(
474477
name=tool_name,
475478
description=tool_description,
476479
func=agent_wrapper,
477480
input_model=input_model, # type: ignore
478481
)
482+
agent_tool._forward_runtime_kwargs = True # type: ignore
483+
return agent_tool
479484

480485
def _normalize_messages(
481486
self,
@@ -868,7 +873,9 @@ async def run(
868873
user=user,
869874
**(additional_chat_options or {}),
870875
)
871-
response = await self.chat_client.get_response(messages=thread_messages, chat_options=co, **kwargs)
876+
# Filter chat_options from kwargs to prevent duplicate keyword argument
877+
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
878+
response = await self.chat_client.get_response(messages=thread_messages, chat_options=co, **filtered_kwargs)
872879

873880
await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
874881

@@ -1000,9 +1007,11 @@ async def run_stream(
10001007
**(additional_chat_options or {}),
10011008
)
10021009

1010+
# Filter chat_options from kwargs to prevent duplicate keyword argument
1011+
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
10031012
response_updates: list[ChatResponseUpdate] = []
10041013
async for update in self.chat_client.get_streaming_response(
1005-
messages=thread_messages, chat_options=co, **kwargs
1014+
messages=thread_messages, chat_options=co, **filtered_kwargs
10061015
):
10071016
response_updates.append(update)
10081017

python/packages/core/agent_framework/_tools.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,7 @@ def __init__(
627627
self.invocation_exception_count = 0
628628
self._invocation_duration_histogram = _default_histogram()
629629
self.type: Literal["ai_function"] = "ai_function"
630+
self._forward_runtime_kwargs: bool = False
630631

631632
@property
632633
def declaration_only(self) -> bool:
@@ -728,11 +729,16 @@ async def invoke(
728729
global OBSERVABILITY_SETTINGS
729730
from .observability import OBSERVABILITY_SETTINGS
730731

731-
tool_call_id = kwargs.pop("tool_call_id", None)
732+
original_kwargs = dict(kwargs)
733+
tool_call_id = original_kwargs.pop("tool_call_id", None)
732734
if arguments is not None:
733735
if not isinstance(arguments, self.input_model):
734736
raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}")
735737
kwargs = arguments.model_dump(exclude_none=True)
738+
if getattr(self, "_forward_runtime_kwargs", False) and original_kwargs:
739+
kwargs.update(original_kwargs)
740+
else:
741+
kwargs = original_kwargs
736742
if not OBSERVABILITY_SETTINGS.ENABLED: # type: ignore[name-defined]
737743
logger.info(f"Function name: {self.name}")
738744
logger.debug(f"Function arguments: {kwargs}")
@@ -1272,15 +1278,20 @@ async def _auto_invoke_function(
12721278

12731279
parsed_args: dict[str, Any] = dict(function_call_content.parse_arguments() or {})
12741280

1275-
# Merge with user-supplied args; right-hand side dominates, so parsed args win on conflicts.
1276-
merged_args: dict[str, Any] = (custom_args or {}) | parsed_args
1281+
# Filter out internal framework kwargs before passing to tools.
1282+
runtime_kwargs: dict[str, Any] = {
1283+
key: value
1284+
for key, value in (custom_args or {}).items()
1285+
if key not in {"_function_middleware_pipeline", "middleware"}
1286+
}
12771287
try:
1278-
args = tool.input_model.model_validate(merged_args)
1288+
args = tool.input_model.model_validate(parsed_args)
12791289
except ValidationError as exc:
12801290
message = "Error: Argument parsing failed."
12811291
if config.include_detailed_errors:
12821292
message = f"{message} Exception: {exc}"
12831293
return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)
1294+
12841295
if not middleware_pipeline or (
12851296
not hasattr(middleware_pipeline, "has_middlewares") and not middleware_pipeline.has_middlewares
12861297
):
@@ -1289,7 +1300,8 @@ async def _auto_invoke_function(
12891300
function_result = await tool.invoke(
12901301
arguments=args,
12911302
tool_call_id=function_call_content.call_id,
1292-
) # type: ignore[arg-type]
1303+
**runtime_kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {},
1304+
)
12931305
return FunctionResultContent(
12941306
call_id=function_call_content.call_id,
12951307
result=function_result,
@@ -1305,13 +1317,14 @@ async def _auto_invoke_function(
13051317
middleware_context = FunctionInvocationContext(
13061318
function=tool,
13071319
arguments=args,
1308-
kwargs=custom_args or {},
1320+
kwargs=runtime_kwargs.copy(),
13091321
)
13101322

13111323
async def final_function_handler(context_obj: Any) -> Any:
13121324
return await tool.invoke(
13131325
arguments=context_obj.arguments,
13141326
tool_call_id=function_call_content.call_id,
1327+
**context_obj.kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {},
13151328
)
13161329

13171330
try:

python/packages/core/agent_framework/observability.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,7 @@ async def trace_run(
11041104
if not OBSERVABILITY_SETTINGS.ENABLED:
11051105
# If model diagnostics are not enabled, just return the completion
11061106
return await run_func(self, messages=messages, thread=thread, **kwargs)
1107+
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
11071108
attributes = _get_span_attributes(
11081109
operation_name=OtelAttr.AGENT_INVOKE_OPERATION,
11091110
provider_name=provider_name,
@@ -1112,7 +1113,7 @@ async def trace_run(
11121113
agent_description=self.description,
11131114
thread_id=thread.service_thread_id if thread else None,
11141115
chat_options=getattr(self, "chat_options", None),
1115-
**kwargs,
1116+
**filtered_kwargs,
11161117
)
11171118
with _get_span(attributes=attributes, span_name_attribute=OtelAttr.AGENT_NAME) as span:
11181119
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages:
@@ -1173,6 +1174,7 @@ async def trace_run_streaming(
11731174

11741175
all_updates: list["AgentRunResponseUpdate"] = []
11751176

1177+
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
11761178
attributes = _get_span_attributes(
11771179
operation_name=OtelAttr.AGENT_INVOKE_OPERATION,
11781180
provider_name=provider_name,
@@ -1181,7 +1183,7 @@ async def trace_run_streaming(
11811183
agent_description=self.description,
11821184
thread_id=thread.service_thread_id if thread else None,
11831185
chat_options=getattr(self, "chat_options", None),
1184-
**kwargs,
1186+
**filtered_kwargs,
11851187
)
11861188
with _get_span(attributes=attributes, span_name_attribute=OtelAttr.AGENT_NAME) as span:
11871189
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages:

0 commit comments

Comments
 (0)