From b410dc936f621503d2a5bf1ff29887fc334a959a Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Fri, 21 Nov 2025 11:39:44 +0900 Subject: [PATCH 1/6] Refactor ag-ui to clean up some patterns --- .../ag-ui/agent_framework_ag_ui/_agent.py | 36 +- .../ag-ui/agent_framework_ag_ui/_endpoint.py | 12 +- .../ag-ui/agent_framework_ag_ui/_events.py | 992 ++++++++---------- .../_orchestration/__init__.py | 16 + .../_orchestration/message_hygiene.py | 178 ++++ .../_orchestration/state_manager.py | 102 ++ .../_orchestration/tooling.py | 82 ++ .../agent_framework_ag_ui/_orchestrators.py | 494 ++------- .../tests/test_agent_wrapper_comprehensive.py | 23 + python/packages/ag-ui/tests/test_endpoint.py | 27 + .../ag-ui/tests/test_message_hygiene.py | 51 + .../ag-ui/tests/test_state_manager.py | 49 + python/packages/ag-ui/tests/test_tooling.py | 34 + 13 files changed, 1145 insertions(+), 951 deletions(-) create mode 100644 python/packages/ag-ui/agent_framework_ag_ui/_orchestration/__init__.py create mode 100644 python/packages/ag-ui/agent_framework_ag_ui/_orchestration/message_hygiene.py create mode 100644 python/packages/ag-ui/agent_framework_ag_ui/_orchestration/state_manager.py create mode 100644 python/packages/ag-ui/agent_framework_ag_ui/_orchestration/tooling.py create mode 100644 python/packages/ag-ui/tests/test_message_hygiene.py create mode 100644 python/packages/ag-ui/tests/test_state_manager.py create mode 100644 python/packages/ag-ui/tests/test_tooling.py diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py index 298c0acfe9..9aa54925c2 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py @@ -3,7 +3,7 @@ """AgentFrameworkAgent wrapper for AG-UI protocol - Clean Architecture.""" from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, cast from ag_ui.core import BaseEvent from agent_framework import AgentProtocol @@ -22,21 +22,45 @@ class AgentConfig: def __init__( self, - state_schema: dict[str, Any] | None = None, + state_schema: Any | None = None, predict_state_config: dict[str, dict[str, str]] | None = None, require_confirmation: bool = True, ): """Initialize agent configuration. Args: - state_schema: Optional state schema for state management + state_schema: Optional state schema for state management; accepts dict or Pydantic model/class predict_state_config: Configuration for predictive state updates require_confirmation: Whether predictive updates require confirmation """ - self.state_schema = state_schema or {} + self.state_schema = self._normalize_state_schema(state_schema) self.predict_state_config = predict_state_config or {} self.require_confirmation = require_confirmation + @staticmethod + def _normalize_state_schema(state_schema: Any | None) -> dict[str, Any]: + """Accept dict or Pydantic model/class and return a properties dict.""" + if state_schema is None: + return {} + + if isinstance(state_schema, dict): + return cast(dict[str, Any], state_schema) + + try: + from pydantic import BaseModel + except Exception: + BaseModel = None # type: ignore # noqa: N806 + + if BaseModel and isinstance(state_schema, BaseModel): + schema_dict = state_schema.__class__.model_json_schema() + return schema_dict.get("properties", {}) or {} + + if BaseModel and isinstance(state_schema, type) and issubclass(state_schema, BaseModel): + schema_dict = state_schema.model_json_schema() + return schema_dict.get("properties", {}) or {} + + return {} + class AgentFrameworkAgent: """Wraps Agent Framework agents for AG-UI protocol compatibility. @@ -55,7 +79,7 @@ def __init__( agent: AgentProtocol, name: str | None = None, description: str | None = None, - state_schema: dict[str, Any] | None = None, + state_schema: Any | None = None, predict_state_config: dict[str, dict[str, str]] | None = None, require_confirmation: bool = True, orchestrators: list[Orchestrator] | None = None, @@ -67,7 +91,7 @@ def __init__( agent: The Agent Framework agent to wrap name: Optional name for the agent description: Optional description - state_schema: Optional state schema for state management + state_schema: Optional state schema for state management; accepts dict or Pydantic model/class predict_state_config: Configuration for predictive state updates. Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}} require_confirmation: Whether predictive updates require confirmation. diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py b/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py index d1baad5561..eedf88db14 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py @@ -2,6 +2,7 @@ """FastAPI endpoint creation for AG-UI agents.""" +import copy import logging from typing import Any @@ -19,9 +20,10 @@ def add_agent_framework_fastapi_endpoint( app: FastAPI, agent: AgentProtocol | AgentFrameworkAgent, path: str = "/", - state_schema: dict[str, Any] | None = None, + state_schema: Any | None = None, predict_state_config: dict[str, dict[str, str]] | None = None, allow_origins: list[str] | None = None, + default_state: dict[str, Any] | None = None, ) -> None: """Add an AG-UI endpoint to a FastAPI app. @@ -29,10 +31,11 @@ def add_agent_framework_fastapi_endpoint( app: The FastAPI application agent: The agent to expose (can be raw AgentProtocol or wrapped) path: The endpoint path - state_schema: Optional state schema for shared state management + state_schema: Optional state schema for shared state management; accepts dict or Pydantic model/class predict_state_config: Optional predictive state update configuration. Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}} allow_origins: CORS origins (not yet implemented) + default_state: Optional initial state to seed when the client does not provide state keys """ if isinstance(agent, AgentProtocol): wrapped_agent = AgentFrameworkAgent( @@ -52,6 +55,11 @@ async def agent_endpoint(request: Request): # type: ignore[misc] """ try: input_data = await request.json() + if default_state: + state = input_data.setdefault("state", {}) + for key, value in default_state.items(): + if key not in state: + state[key] = copy.deepcopy(value) logger.debug( f"[{path}] Received request - Run ID: {input_data.get('run_id', 'no-run-id')}, " f"Thread ID: {input_data.get('thread_id', 'no-thread-id')}, " diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_events.py b/python/packages/ag-ui/agent_framework_ag_ui/_events.py index 8aec59d52c..d5679f42b5 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_events.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_events.py @@ -100,578 +100,510 @@ async def from_agent_run_update(self, update: AgentRunResponseUpdate) -> list[Ba """ events: list[BaseEvent] = [] - logger.info(f"Processing AgentRunUpdate with {len(update.contents)} content items") + logger.info("Processing AgentRunUpdate with %s content items", len(update.contents)) for idx, content in enumerate(update.contents): - logger.info(f" Content {idx}: type={type(content).__name__}") + logger.info(" Content %s: type=%s", idx, type(content).__name__) if isinstance(content, TextContent): - logger.info( - f" TextContent found: text_length={len(content.text)}, text_preview='{content.text[:100]}'" - ) - logger.info( - f" Flags: skip_text_content={self.skip_text_content}, should_stop_after_confirm={self.should_stop_after_confirm}" - ) + events.extend(self._handle_text_content(content)) + elif isinstance(content, FunctionCallContent): + events.extend(self._handle_function_call_content(content)) + elif isinstance(content, FunctionResultContent): + events.extend(self._handle_function_result_content(content)) + elif isinstance(content, FunctionApprovalRequestContent): + events.extend(self._handle_function_approval_request_content(content)) - # Skip text content if using structured outputs (it's just the JSON) - if self.skip_text_content: - logger.info(" SKIPPING TextContent: skip_text_content is True") - continue + return events - # Skip text content if we're about to emit confirm_changes - # The summary should only appear after user confirms - if self.should_stop_after_confirm: - logger.info(" SKIPPING TextContent: waiting for confirm_changes response") - # Save the summary text to show after confirmation - self.suppressed_summary += content.text - logger.info(f" Suppressed summary now has {len(self.suppressed_summary)} chars") - continue + def _handle_text_content(self, content: TextContent) -> list[BaseEvent]: + events: list[BaseEvent] = [] + logger.info(" TextContent found: text_length=%s, text_preview='%s'", len(content.text), content.text[:100]) + logger.info( + " Flags: skip_text_content=%s, should_stop_after_confirm=%s", + self.skip_text_content, + self.should_stop_after_confirm, + ) - if not self.current_message_id: - self.current_message_id = generate_event_id() - start_event = TextMessageStartEvent( - message_id=self.current_message_id, - role="assistant", - ) - logger.info(f" EMITTING TextMessageStartEvent with message_id={self.current_message_id}") - events.append(start_event) + if self.skip_text_content: + logger.info(" SKIPPING TextContent: skip_text_content is True") + return events + + if self.should_stop_after_confirm: + logger.info(" SKIPPING TextContent: waiting for confirm_changes response") + self.suppressed_summary += content.text + logger.info(" Suppressed summary now has %s chars", len(self.suppressed_summary)) + return events + + if not self.current_message_id: + self.current_message_id = generate_event_id() + start_event = TextMessageStartEvent( + message_id=self.current_message_id, + role="assistant", + ) + logger.info(" EMITTING TextMessageStartEvent with message_id=%s", self.current_message_id) + events.append(start_event) + + event = TextMessageContentEvent( + message_id=self.current_message_id, + delta=content.text, + ) + self.accumulated_text_content += content.text + logger.info(" EMITTING TextMessageContentEvent with delta: '%s'", content.text) + events.append(event) + return events - event = TextMessageContentEvent( - message_id=self.current_message_id, - delta=content.text, - ) - # Accumulate text content for final MessagesSnapshotEvent - self.accumulated_text_content += content.text - logger.info(f" EMITTING TextMessageContentEvent with delta: '{content.text}'") - events.append(event) + def _handle_function_call_content(self, content: FunctionCallContent) -> list[BaseEvent]: + events: list[BaseEvent] = [] + if content.name: + logger.debug("Tool call: %s (call_id: %s)", content.name, content.call_id) + + if not content.name and not content.call_id and not self.current_tool_call_name: + args_preview = str(content.arguments)[:50] if content.arguments else "None" + logger.warning("FunctionCallContent missing name and call_id. Args: %s", args_preview) + + tool_call_id = self._coalesce_tool_call_id(content) + if content.name: + self.current_tool_call_id = tool_call_id + self.current_tool_call_name = content.name + + tool_start_event = ToolCallStartEvent( + tool_call_id=tool_call_id, + tool_call_name=content.name, + parent_message_id=self.current_message_id, + ) + logger.info("Emitting ToolCallStartEvent with name='%s', id='%s'", content.name, tool_call_id) + events.append(tool_start_event) + + self.pending_tool_calls.append( + { + "id": tool_call_id, + "type": "function", + "function": { + "name": content.name, + "arguments": "", + }, + } + ) + elif tool_call_id: + self.current_tool_call_id = tool_call_id + + if content.arguments: + delta_str = content.arguments if isinstance(content.arguments, str) else json.dumps(content.arguments) + logger.info("Emitting ToolCallArgsEvent with delta: %r..., id='%s'", delta_str, tool_call_id) + args_event = ToolCallArgsEvent( + tool_call_id=tool_call_id, + delta=delta_str, + ) + events.append(args_event) + + for tool_call in self.pending_tool_calls: + if tool_call["id"] == tool_call_id: + tool_call["function"]["arguments"] += delta_str + break + + events.extend(self._emit_predictive_state_deltas(delta_str)) + events.extend(self._legacy_predictive_state(content)) - elif isinstance(content, FunctionCallContent): - # Log tool calls for debugging - if content.name: - logger.debug(f"Tool call: {content.name} (call_id: {content.call_id})") - - if not content.name and not content.call_id and not self.current_tool_call_name: - args_preview = str(content.arguments)[:50] if content.arguments else "None" - logger.warning(f"FunctionCallContent missing name and call_id. Args: {args_preview}") - - # Get or use existing tool call ID - all chunks of same tool call share the same call_id - # Important: the first chunk might have name but no call_id yet - if content.call_id: - tool_call_id = content.call_id - elif self.current_tool_call_id: - tool_call_id = self.current_tool_call_id - else: - # Generate a new ID for this tool call - tool_call_id = ( - generate_event_id() - ) # Handle streaming tool calls - name comes in first chunk, arguments in subsequent chunks - if content.name: - # This is a new tool call or the first chunk with the name - self.current_tool_call_id = tool_call_id - self.current_tool_call_name = content.name - - tool_start_event = ToolCallStartEvent( - tool_call_id=tool_call_id, - tool_call_name=content.name, - parent_message_id=self.current_message_id, - ) - logger.info(f"Emitting ToolCallStartEvent with name='{content.name}', id='{tool_call_id}'") - events.append(tool_start_event) - - # Track tool call for MessagesSnapshotEvent - # Initialize a new tool call entry - self.pending_tool_calls.append( - { - "id": tool_call_id, - "type": "function", - "function": { - "name": content.name, - "arguments": "", # Will accumulate as we get argument chunks - }, - } - ) - else: - # Subsequent chunk without name - update our tracked ID if needed - if tool_call_id: - self.current_tool_call_id = tool_call_id - - # Emit arguments if present - if content.arguments: - # content.arguments is already a JSON string from the LLM for streaming calls - # For non-streaming it could be a dict, so we need to handle both - if isinstance(content.arguments, str): - delta_str = content.arguments - else: - # If it's a dict, convert to JSON - delta_str = json.dumps(content.arguments) - - logger.info(f"Emitting ToolCallArgsEvent with delta: {delta_str!r}..., id='{tool_call_id}'") - args_event = ToolCallArgsEvent( - tool_call_id=tool_call_id, - delta=delta_str, - ) - events.append(args_event) - - # Accumulate arguments for MessagesSnapshotEvent - if self.pending_tool_calls: - # Find the matching tool call and append the delta - for tool_call in self.pending_tool_calls: - if tool_call["id"] == tool_call_id: - tool_call["function"]["arguments"] += delta_str - break - - # Predictive state updates - accumulate streaming arguments and emit deltas - # Use current_tool_call_name since content.name is only present on first chunk - if self.current_tool_call_name and self.predict_state_config: - # Accumulate the argument string - if isinstance(content.arguments, str): - self.streaming_tool_args += content.arguments - else: - self.streaming_tool_args += json.dumps(content.arguments) - - logger.debug( - f"Predictive state: accumulated {len(self.streaming_tool_args)} chars for tool '{self.current_tool_call_name}'" - ) + return events - # Try to parse accumulated arguments (may be incomplete JSON) - # We use a lenient approach: try standard parsing first, then try to extract partial values - parsed_args = None - try: - parsed_args = json.loads(self.streaming_tool_args) - except json.JSONDecodeError: - # JSON is incomplete - try to extract partial string values - # For streaming "document" field, we can extract: {"document": "text... - # Look for pattern: {"field": "value (incomplete) - for state_key, config in self.predict_state_config.items(): - if config["tool"] == self.current_tool_call_name: - tool_arg_name = config["tool_argument"] - - # Try to extract partial string value for this argument - # Pattern: "argument_name": "partial text - pattern = rf'"{re.escape(tool_arg_name)}":\s*"([^"]*)' - match = re.search(pattern, self.streaming_tool_args) - - if match: - partial_value = match.group(1) - # Unescape common sequences - partial_value = ( - partial_value.replace("\\n", "\n").replace('\\"', '"').replace("\\\\", "\\") - ) - - # Emit delta if we have new content - if ( - state_key not in self.last_emitted_state - or self.last_emitted_state[state_key] != partial_value - ): - state_delta_event = StateDeltaEvent( - delta=[ - { - "op": "replace", - "path": f"/{state_key}", - "value": partial_value, - } - ], - ) - - self.state_delta_count += 1 - if self.state_delta_count % 10 == 1: - value_preview = ( - str(partial_value)[:100] + "..." - if len(str(partial_value)) > 100 - else str(partial_value) - ) - logger.info( - f"StateDeltaEvent #{self.state_delta_count} for '{state_key}': " - f"op=replace, path=/{state_key}, value={value_preview}" - ) - elif self.state_delta_count % 100 == 0: - logger.info(f"StateDeltaEvent #{self.state_delta_count} emitted") - - events.append(state_delta_event) - self.last_emitted_state[state_key] = partial_value - self.pending_state_updates[state_key] = partial_value - - # If we successfully parsed complete JSON, process it - if parsed_args: - # Check if this tool matches any predictive state config - for state_key, config in self.predict_state_config.items(): - if config["tool"] == self.current_tool_call_name: - tool_arg_name = config["tool_argument"] - - # Extract the state value - if tool_arg_name == "*": - state_value = parsed_args - elif tool_arg_name in parsed_args: - state_value = parsed_args[tool_arg_name] - else: - continue - - # Only emit if state has changed from last emission - if ( - state_key not in self.last_emitted_state - or self.last_emitted_state[state_key] != state_value - ): - # Emit StateDeltaEvent for real-time UI updates (JSON Patch format) - state_delta_event = StateDeltaEvent( - delta=[ - { - "op": "replace", # Use replace since field exists in schema - "path": f"/{state_key}", # JSON Pointer path with leading slash - "value": state_value, - } - ], - ) - - # Increment counter and log every 10th emission with sample data - self.state_delta_count += 1 - if self.state_delta_count % 10 == 1: # Log 1st, 11th, 21st, etc. - value_preview = ( - str(state_value)[:100] + "..." - if len(str(state_value)) > 100 - else str(state_value) - ) - logger.info( - f"StateDeltaEvent #{self.state_delta_count} for '{state_key}': " - f"op=replace, path=/{state_key}, value={value_preview}" - ) - elif self.state_delta_count % 100 == 0: # Also log every 100th - logger.info(f"StateDeltaEvent #{self.state_delta_count} emitted") - - events.append(state_delta_event) - - # Track what we emitted - self.last_emitted_state[state_key] = state_value - self.pending_state_updates[state_key] = state_value - - # Legacy predictive state check (for when arguments are complete) - if content.name and content.arguments: - parsed_args = content.parse_arguments() - - if parsed_args: - logger.info(f"Checking predict_state_config: {self.predict_state_config}") - for state_key, config in self.predict_state_config.items(): - logger.info(f"Checking state_key='{state_key}', config={config}") - if config["tool"] == content.name: - tool_arg_name = config["tool_argument"] - logger.info( - f"MATCHED tool '{content.name}' for state key '{state_key}', arg='{tool_arg_name}'" - ) - - # If tool_argument is "*", use all arguments as the state value - if tool_arg_name == "*": - state_value = parsed_args - logger.info(f"Using all args as state value, keys: {list(state_value.keys())}") - elif tool_arg_name in parsed_args: - state_value = parsed_args[tool_arg_name] - logger.info(f"Using specific arg '{tool_arg_name}' as state value") - else: - logger.warning(f"Tool argument '{tool_arg_name}' not found in parsed args") - continue - - # Emit predictive delta (JSON Patch format) - state_delta_event = StateDeltaEvent( - delta=[ - { - "op": "replace", # Use replace since field exists in schema - "path": f"/{state_key}", # JSON Pointer path with leading slash - "value": state_value, - } - ], - ) - logger.info( - f"Emitting StateDeltaEvent for key '{state_key}', value type: {type(state_value)}" - ) - events.append(state_delta_event) - - # Track pending update for later snapshot - self.pending_state_updates[state_key] = state_value - - # Note: ToolCallEndEvent is emitted when we receive FunctionResultContent, - # not here during streaming, since we don't know when the stream is complete + def _coalesce_tool_call_id(self, content: FunctionCallContent) -> str: + if content.call_id: + return content.call_id + if self.current_tool_call_id: + return self.current_tool_call_id + return generate_event_id() - elif isinstance(content, FunctionResultContent): - # First emit ToolCallEndEvent to close the tool call - if content.call_id: - end_event = ToolCallEndEvent( - tool_call_id=content.call_id, - ) - logger.info(f"Emitting ToolCallEndEvent for completed tool call '{content.call_id}'") - events.append(end_event) - self.tool_calls_ended.add(content.call_id) # Track that we emitted end event + def _emit_predictive_state_deltas(self, argument_chunk: str) -> list[BaseEvent]: + events: list[BaseEvent] = [] + if not self.current_tool_call_name or not self.predict_state_config: + return events + + self.streaming_tool_args += argument_chunk + logger.debug( + "Predictive state: accumulated %s chars for tool '%s'", + len(self.streaming_tool_args), + self.current_tool_call_name, + ) - # Log total StateDeltaEvent count for this tool call - if self.state_delta_count > 0: - logger.info( - f"Tool call '{content.call_id}' complete: emitted {self.state_delta_count} StateDeltaEvents total" + parsed_args = None + try: + parsed_args = json.loads(self.streaming_tool_args) + except json.JSONDecodeError: + for state_key, config in self.predict_state_config.items(): + if config["tool"] != self.current_tool_call_name: + continue + tool_arg_name = config["tool_argument"] + pattern = rf'"{re.escape(tool_arg_name)}":\s*"([^"]*)' + match = re.search(pattern, self.streaming_tool_args) + + if match: + partial_value = match.group(1).replace("\\n", "\n").replace('\\"', '"').replace("\\\\", "\\") + + if state_key not in self.last_emitted_state or self.last_emitted_state[state_key] != partial_value: + state_delta_event = StateDeltaEvent( + delta=[ + { + "op": "replace", + "path": f"/{state_key}", + "value": partial_value, + } + ], ) - # Reset streaming accumulator and counter for next tool call - self.streaming_tool_args = "" - self.state_delta_count = 0 + self.state_delta_count += 1 + if self.state_delta_count % 10 == 1: + value_preview = ( + str(partial_value)[:100] + "..." + if len(str(partial_value)) > 100 + else str(partial_value) + ) + logger.info( + "StateDeltaEvent #%s for '%s': op=replace, path=/%s, value=%s", + self.state_delta_count, + state_key, + state_key, + value_preview, + ) + elif self.state_delta_count % 100 == 0: + logger.info("StateDeltaEvent #%s emitted", self.state_delta_count) + + events.append(state_delta_event) + self.last_emitted_state[state_key] = partial_value + self.pending_state_updates[state_key] = partial_value - # Tool result - emit ToolCallResultEvent - result_message_id = generate_event_id() + if parsed_args: + for state_key, config in self.predict_state_config.items(): + if config["tool"] != self.current_tool_call_name: + continue + tool_arg_name = config["tool_argument"] - # Preserve structured data for backend tool rendering - # Serialize dicts to JSON string, otherwise convert to string - if isinstance(content.result, dict): - result_content = json.dumps(content.result) # type: ignore[arg-type] - elif content.result is not None: - result_content = str(content.result) + if tool_arg_name == "*": + state_value = parsed_args + elif tool_arg_name in parsed_args: + state_value = parsed_args[tool_arg_name] else: - result_content = "" + continue - result_event = ToolCallResultEvent( - message_id=result_message_id, - tool_call_id=content.call_id, - content=result_content, - role="tool", - ) - events.append(result_event) + if state_key not in self.last_emitted_state or self.last_emitted_state[state_key] != state_value: + state_delta_event = StateDeltaEvent( + delta=[ + { + "op": "replace", + "path": f"/{state_key}", + "value": state_value, + } + ], + ) - # Track tool result for MessagesSnapshotEvent - # AG-UI protocol expects: { role: "tool", toolCallId: ..., content: ... } - # Use camelCase for Pydantic's alias_generator=to_camel - self.tool_results.append( + self.state_delta_count += 1 + if self.state_delta_count % 10 == 1: + value_preview = ( + str(state_value)[:100] + "..." if len(str(state_value)) > 100 else str(state_value) + ) + logger.info( + "StateDeltaEvent #%s for '%s': op=replace, path=/%s, value=%s", + self.state_delta_count, + state_key, + state_key, + value_preview, + ) + elif self.state_delta_count % 100 == 0: + logger.info("StateDeltaEvent #%s emitted", self.state_delta_count) + + events.append(state_delta_event) + self.last_emitted_state[state_key] = state_value + self.pending_state_updates[state_key] = state_value + return events + + def _legacy_predictive_state(self, content: FunctionCallContent) -> list[BaseEvent]: + events: list[BaseEvent] = [] + if not (content.name and content.arguments): + return events + parsed_args = content.parse_arguments() + if not parsed_args: + return events + + logger.info("Checking predict_state_config: %s", self.predict_state_config) + for state_key, config in self.predict_state_config.items(): + logger.info("Checking state_key='%s', config=%s", state_key, config) + if config["tool"] != content.name: + continue + tool_arg_name = config["tool_argument"] + logger.info( + "MATCHED tool '%s' for state key '%s', arg='%s'", + content.name, + state_key, + tool_arg_name, + ) + + if tool_arg_name == "*": + state_value = parsed_args + logger.info("Using all args as state value, keys: %s", list(state_value.keys())) + elif tool_arg_name in parsed_args: + state_value = parsed_args[tool_arg_name] + logger.info("Using specific arg '%s' as state value", tool_arg_name) + else: + logger.warning("Tool argument '%s' not found in parsed args", tool_arg_name) + continue + + state_delta_event = StateDeltaEvent( + delta=[ { - "id": result_message_id, - "role": "tool", - "toolCallId": content.call_id, - "content": result_content, + "op": "replace", + "path": f"/{state_key}", + "value": state_value, } - ) + ], + ) + logger.info("Emitting StateDeltaEvent for key '%s', value type: %s", state_key, type(state_value)) # type: ignore + events.append(state_delta_event) + self.pending_state_updates[state_key] = state_value + return events - # Emit MessagesSnapshotEvent with the complete conversation including tool calls and results - # This is required for CopilotKit's useCopilotAction to detect tool result - # HOWEVER: Skip this for predictive tools when require_confirmation=False, because - # the agent will generate a follow-up text message and we'll emit a complete snapshot at the end. - # Emitting here would create an incomplete snapshot that gets replaced, causing UI flicker. - should_emit_snapshot = self.pending_tool_calls and self.tool_results - - # Check if this is a predictive tool that will have a follow-up message - is_predictive_without_confirmation = False - if should_emit_snapshot and self.current_tool_call_name and self.predict_state_config: - for state_key, config in self.predict_state_config.items(): - if config["tool"] == self.current_tool_call_name and not self.require_confirmation: - is_predictive_without_confirmation = True - logger.info( - f"Skipping intermediate MessagesSnapshotEvent for predictive tool '{self.current_tool_call_name}' " - "- will emit complete snapshot after follow-up message" - ) - break + def _handle_function_result_content(self, content: FunctionResultContent) -> list[BaseEvent]: + events: list[BaseEvent] = [] + if content.call_id: + end_event = ToolCallEndEvent( + tool_call_id=content.call_id, + ) + logger.info("Emitting ToolCallEndEvent for completed tool call '%s'", content.call_id) + events.append(end_event) + self.tool_calls_ended.add(content.call_id) + + if self.state_delta_count > 0: + logger.info( + "Tool call '%s' complete: emitted %s StateDeltaEvents total", + content.call_id, + self.state_delta_count, + ) - if should_emit_snapshot and not is_predictive_without_confirmation: - # Import message adapter - from ._message_adapters import agent_framework_messages_to_agui + self.streaming_tool_args = "" + self.state_delta_count = 0 + + result_message_id = generate_event_id() + if isinstance(content.result, dict): + result_content = json.dumps(content.result) # type: ignore[arg-type] + elif content.result is not None: + result_content = str(content.result) + else: + result_content = "" + + result_event = ToolCallResultEvent( + message_id=result_message_id, + tool_call_id=content.call_id, + content=result_content, + role="tool", + ) + events.append(result_event) + + self.tool_results.append( + { + "id": result_message_id, + "role": "tool", + "toolCallId": content.call_id, + "content": result_content, + } + ) - # Build assistant message with tool_calls - assistant_message = { - "id": generate_event_id(), - "role": "assistant", - "tool_calls": self.pending_tool_calls.copy(), # Copy the accumulated tool calls - } + events.extend(self._emit_snapshot_for_tool_result()) + events.extend(self._emit_state_snapshot_and_confirmation()) - # Convert Agent Framework messages to AG-UI format (adds required 'id' field) - converted_input_messages = agent_framework_messages_to_agui(self.input_messages) + return events - # Build complete messages array: input messages + assistant message + tool results - all_messages = converted_input_messages + [assistant_message] + self.tool_results.copy() + def _emit_snapshot_for_tool_result(self) -> list[BaseEvent]: + events: list[BaseEvent] = [] + should_emit_snapshot = self.pending_tool_calls and self.tool_results - # Emit MessagesSnapshotEvent using the proper event type - # Note: messages are dict[str, Any] but Pydantic will validate them as Message types - messages_snapshot_event = MessagesSnapshotEvent( - type=EventType.MESSAGES_SNAPSHOT, - messages=all_messages, # type: ignore[arg-type] + is_predictive_without_confirmation = False + if should_emit_snapshot and self.current_tool_call_name and self.predict_state_config: + for _, config in self.predict_state_config.items(): + if config["tool"] == self.current_tool_call_name and not self.require_confirmation: + is_predictive_without_confirmation = True + logger.info( + "Skipping intermediate MessagesSnapshotEvent for predictive tool '%s' - delaying until summary", + self.current_tool_call_name, ) - logger.info(f"Emitting MessagesSnapshotEvent with {len(all_messages)} messages") - events.append(messages_snapshot_event) - - # After tool execution, emit StateSnapshotEvent if we have pending state updates - if self.pending_state_updates: - # Update the current state with pending updates - for key, value in self.pending_state_updates.items(): - self.current_state[key] = value - - # Log the state structure for debugging - logger.info(f"Emitting StateSnapshotEvent with keys: {list(self.current_state.keys())}") - if "recipe" in self.current_state: - recipe = self.current_state["recipe"] - logger.info( - f"Recipe fields: title={recipe.get('title')}, " - f"skill_level={recipe.get('skill_level')}, " - f"ingredients_count={len(recipe.get('ingredients', []))}, " - f"instructions_count={len(recipe.get('instructions', []))}" - ) + break + + if should_emit_snapshot and not is_predictive_without_confirmation: + from ._message_adapters import agent_framework_messages_to_agui + + assistant_message = { + "id": generate_event_id(), + "role": "assistant", + "tool_calls": self.pending_tool_calls.copy(), + } + converted_input_messages = agent_framework_messages_to_agui(self.input_messages) + all_messages = converted_input_messages + [assistant_message] + self.tool_results.copy() + + messages_snapshot_event = MessagesSnapshotEvent( + type=EventType.MESSAGES_SNAPSHOT, + messages=all_messages, # type: ignore[arg-type] + ) + logger.info("Emitting MessagesSnapshotEvent with %s messages", len(all_messages)) + events.append(messages_snapshot_event) + return events - # Emit complete state snapshot - state_snapshot_event = StateSnapshotEvent( - snapshot=self.current_state, - ) - events.append(state_snapshot_event) - - # Check if this was a predictive state update tool (e.g., write_document_local) - # If so, emit a confirm_changes tool call for the UI modal - tool_was_predictive = False - logger.debug( - f"Checking predictive state: current_tool='{self.current_tool_call_name}', " - f"predict_config={list(self.predict_state_config.keys()) if self.predict_state_config else 'None'}" + def _emit_state_snapshot_and_confirmation(self) -> list[BaseEvent]: + events: list[BaseEvent] = [] + if self.pending_state_updates: + for key, value in self.pending_state_updates.items(): + self.current_state[key] = value + + logger.info("Emitting StateSnapshotEvent with keys: %s", list(self.current_state.keys())) + if "recipe" in self.current_state: + recipe = self.current_state["recipe"] + logger.info( + "Recipe fields: title=%s, skill_level=%s, ingredients_count=%s, instructions_count=%s", + recipe.get("title"), + recipe.get("skill_level"), + len(recipe.get("ingredients", [])), + len(recipe.get("instructions", [])), + ) + + state_snapshot_event = StateSnapshotEvent( + snapshot=self.current_state, + ) + events.append(state_snapshot_event) + + tool_was_predictive = False + logger.debug( + "Checking predictive state: current_tool='%s', predict_config=%s", + self.current_tool_call_name, + list(self.predict_state_config.keys()) if self.predict_state_config else "None", + ) + for state_key, config in self.predict_state_config.items(): + if self.current_tool_call_name and config["tool"] == self.current_tool_call_name: + logger.info( + "Tool '%s' matches predictive config for state key '%s'", + self.current_tool_call_name, + state_key, ) - for state_key, config in self.predict_state_config.items(): - # Check if this tool call matches a predictive config - # We need to match against self.current_tool_call_name - if self.current_tool_call_name and config["tool"] == self.current_tool_call_name: - logger.info( - f"Tool '{self.current_tool_call_name}' matches predictive config for state key '{state_key}'" - ) - tool_was_predictive = True - break + tool_was_predictive = True + break - if tool_was_predictive and self.require_confirmation: - # Emit confirm_changes tool call sequence - confirm_call_id = generate_event_id() + if tool_was_predictive and self.require_confirmation: + events.extend(self._emit_confirm_changes_tool_call()) + elif tool_was_predictive: + logger.info("Skipping confirm_changes - require_confirmation is False") - logger.info("Emitting confirm_changes tool call for predictive update") + self.pending_state_updates.clear() + self.last_emitted_state.clear() + self.current_tool_call_name = None + return events - # Track confirm_changes tool call for MessagesSnapshotEvent (so it persists after RUN_FINISHED) - self.pending_tool_calls.append( - { - "id": confirm_call_id, - "type": "function", - "function": { - "name": "confirm_changes", - "arguments": "{}", - }, - } - ) + def _emit_confirm_changes_tool_call(self) -> list[BaseEvent]: + events: list[BaseEvent] = [] + confirm_call_id = generate_event_id() + logger.info("Emitting confirm_changes tool call for predictive update") + + self.pending_tool_calls.append( + { + "id": confirm_call_id, + "type": "function", + "function": { + "name": "confirm_changes", + "arguments": "{}", + }, + } + ) - # Start the confirm_changes tool call - confirm_start = ToolCallStartEvent( - tool_call_id=confirm_call_id, - tool_call_name="confirm_changes", - ) - events.append(confirm_start) + confirm_start = ToolCallStartEvent( + tool_call_id=confirm_call_id, + tool_call_name="confirm_changes", + ) + events.append(confirm_start) - # Empty args for confirm_changes - confirm_args = ToolCallArgsEvent( - tool_call_id=confirm_call_id, - delta="{}", - ) - events.append(confirm_args) + confirm_args = ToolCallArgsEvent( + tool_call_id=confirm_call_id, + delta="{}", + ) + events.append(confirm_args) - # End the confirm_changes tool call - confirm_end = ToolCallEndEvent( - tool_call_id=confirm_call_id, - ) - events.append(confirm_end) - - # Emit MessagesSnapshotEvent so confirm_changes persists after RUN_FINISHED - # Import message adapter - from ._message_adapters import agent_framework_messages_to_agui - - # Build assistant message with pending confirm_changes tool call - assistant_message = { - "id": generate_event_id(), - "role": "assistant", - "tool_calls": self.pending_tool_calls.copy(), # Includes confirm_changes - } - - # Convert Agent Framework messages to AG-UI format (adds required 'id' field) - converted_input_messages = agent_framework_messages_to_agui(self.input_messages) - - # Build complete messages array: input messages + assistant message + any tool results - all_messages = converted_input_messages + [assistant_message] + self.tool_results.copy() - - # Emit MessagesSnapshotEvent - # Note: messages are dict[str, Any] but Pydantic will validate them as Message types - messages_snapshot_event = MessagesSnapshotEvent( - type=EventType.MESSAGES_SNAPSHOT, - messages=all_messages, # type: ignore[arg-type] - ) - logger.info( - f"Emitting MessagesSnapshotEvent for confirm_changes with {len(all_messages)} messages" - ) - events.append(messages_snapshot_event) + confirm_end = ToolCallEndEvent( + tool_call_id=confirm_call_id, + ) + events.append(confirm_end) - # Set flag to stop the run after this - we're waiting for user response - self.should_stop_after_confirm = True - logger.info("Set flag to stop run after confirm_changes") - elif tool_was_predictive: - logger.info("Skipping confirm_changes - require_confirmation is False") + from ._message_adapters import agent_framework_messages_to_agui - # Clear pending updates and reset tool name tracker - self.pending_state_updates.clear() - self.last_emitted_state.clear() - self.current_tool_call_name = None # Reset for next tool call + assistant_message = { + "id": generate_event_id(), + "role": "assistant", + "tool_calls": self.pending_tool_calls.copy(), + } - elif isinstance(content, FunctionApprovalRequestContent): - # Human in the loop - function approval request - logger.info("=== FUNCTION APPROVAL REQUEST ===") - logger.info(f" Function: {content.function_call.name}") - logger.info(f" Call ID: {content.function_call.call_id}") - - # Parse the arguments to extract state for predictive UI updates - parsed_args = content.function_call.parse_arguments() - logger.info(f" Parsed args keys: {list(parsed_args.keys()) if parsed_args else 'None'}") - - # Check if this matches our predict_state_config and emit state - if parsed_args and self.predict_state_config: - logger.info(f" Checking predict_state_config: {self.predict_state_config}") - for state_key, config in self.predict_state_config.items(): - if config["tool"] == content.function_call.name: - tool_arg_name = config["tool_argument"] - logger.info( - f" MATCHED tool '{content.function_call.name}' for state key '{state_key}', arg='{tool_arg_name}'" - ) + converted_input_messages = agent_framework_messages_to_agui(self.input_messages) + all_messages = converted_input_messages + [assistant_message] + self.tool_results.copy() - # Extract the state value - if tool_arg_name == "*": - state_value = parsed_args - elif tool_arg_name in parsed_args: - state_value = parsed_args[tool_arg_name] - else: - logger.warning(f" Tool argument '{tool_arg_name}' not found in parsed args") - continue - - # Update current state - self.current_state[state_key] = state_value - logger.info( - f"Emitting StateSnapshotEvent for key '{state_key}', value type: {type(state_value)}" - ) + messages_snapshot_event = MessagesSnapshotEvent( + type=EventType.MESSAGES_SNAPSHOT, + messages=all_messages, # type: ignore[arg-type] + ) + logger.info("Emitting MessagesSnapshotEvent for confirm_changes with %s messages", len(all_messages)) + events.append(messages_snapshot_event) - # Emit state snapshot - state_snapshot = StateSnapshotEvent( - snapshot=self.current_state, - ) - events.append(state_snapshot) + self.should_stop_after_confirm = True + logger.info("Set flag to stop run after confirm_changes") + return events - # The tool call has been streamed already (Start/Args events) - # Now we need to close it with an End event before the agent waits for approval - if content.function_call.call_id: - end_event = ToolCallEndEvent( - tool_call_id=content.function_call.call_id, - ) - logger.info( - f"Emitting ToolCallEndEvent for approval-required tool '{content.function_call.call_id}'" - ) - events.append(end_event) - self.tool_calls_ended.add(content.function_call.call_id) # Track that we emitted end event - - # Emit custom event for approval request - # Note: In AG-UI protocol, the frontend handles interrupts automatically - # when it sees a tool call with the configured name (via predict_state_config) - # This custom event is for additional metadata if needed - approval_event = CustomEvent( - name="function_approval_request", - value={ - "id": content.id, - "function_call": { - "call_id": content.function_call.call_id, - "name": content.function_call.name, - "arguments": content.function_call.parse_arguments(), - }, - }, + def _handle_function_approval_request_content(self, content: FunctionApprovalRequestContent) -> list[BaseEvent]: + events: list[BaseEvent] = [] + logger.info("=== FUNCTION APPROVAL REQUEST ===") + logger.info(" Function: %s", content.function_call.name) + logger.info(" Call ID: %s", content.function_call.call_id) + + parsed_args = content.function_call.parse_arguments() + logger.info(" Parsed args keys: %s", list(parsed_args.keys()) if parsed_args else "None") + + if parsed_args and self.predict_state_config: + logger.info(" Checking predict_state_config: %s", self.predict_state_config) + for state_key, config in self.predict_state_config.items(): + if config["tool"] != content.function_call.name: + continue + tool_arg_name = config["tool_argument"] + logger.info( + " MATCHED tool '%s' for state key '%s', arg='%s'", + content.function_call.name, + state_key, + tool_arg_name, ) - logger.info(f"Emitting function_approval_request custom event for '{content.function_call.name}'") - events.append(approval_event) + if tool_arg_name == "*": + state_value = parsed_args + elif tool_arg_name in parsed_args: + state_value = parsed_args[tool_arg_name] + else: + logger.warning(" Tool argument '%s' not found in parsed args", tool_arg_name) + continue + + self.current_state[state_key] = state_value + logger.info("Emitting StateSnapshotEvent for key '%s', value type: %s", state_key, type(state_value)) # type: ignore + state_snapshot = StateSnapshotEvent( + snapshot=self.current_state, + ) + events.append(state_snapshot) + + if content.function_call.call_id: + end_event = ToolCallEndEvent( + tool_call_id=content.function_call.call_id, + ) + logger.info("Emitting ToolCallEndEvent for approval-required tool '%s'", content.function_call.call_id) + events.append(end_event) + self.tool_calls_ended.add(content.function_call.call_id) + + approval_event = CustomEvent( + name="function_approval_request", + value={ + "id": content.id, + "function_call": { + "call_id": content.function_call.call_id, + "name": content.function_call.name, + "arguments": content.function_call.parse_arguments(), + }, + }, + ) + logger.info("Emitting function_approval_request custom event for '%s'", content.function_call.name) + events.append(approval_event) return events def create_run_started_event(self) -> RunStartedEvent: diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/__init__.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/__init__.py new file mode 100644 index 0000000000..acec1bdf9b --- /dev/null +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Orchestration helpers broken into focused modules.""" + +from .message_hygiene import deduplicate_messages, sanitize_tool_history +from .state_manager import StateManager +from .tooling import collect_server_tools, merge_tools, register_additional_client_tools + +__all__ = [ + "StateManager", + "sanitize_tool_history", + "deduplicate_messages", + "collect_server_tools", + "register_additional_client_tools", + "merge_tools", +] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/message_hygiene.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/message_hygiene.py new file mode 100644 index 0000000000..d70d24f2f8 --- /dev/null +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/message_hygiene.py @@ -0,0 +1,178 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Message hygiene utilities for orchestrators.""" + +import json +import logging +from typing import Any + +from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent + +logger = logging.getLogger(__name__) + + +def sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: + """Normalize tool ordering and inject synthetic results for AG-UI edge cases.""" + sanitized: list[ChatMessage] = [] + pending_tool_call_ids: set[str] | None = None + pending_confirm_changes_id: str | None = None + + for msg in messages: + role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + + if role_value == "assistant": + tool_ids = { + str(content.call_id) + for content in msg.contents or [] + if isinstance(content, FunctionCallContent) and content.call_id + } + confirm_changes_call = None + for content in msg.contents or []: + if isinstance(content, FunctionCallContent) and content.name == "confirm_changes": + confirm_changes_call = content + break + + sanitized.append(msg) + pending_tool_call_ids = tool_ids if tool_ids else None + pending_confirm_changes_id = ( + str(confirm_changes_call.call_id) if confirm_changes_call and confirm_changes_call.call_id else None + ) + continue + + if role_value == "user": + if pending_confirm_changes_id: + user_text = "" + for content in msg.contents or []: + if isinstance(content, TextContent): + user_text = content.text + break + + try: + parsed = json.loads(user_text) + if "accepted" in parsed: + logger.info( + "Injecting synthetic tool result for confirm_changes call_id=%s", + pending_confirm_changes_id, + ) + synthetic_result = ChatMessage( + role="tool", + contents=[ + FunctionResultContent( + call_id=pending_confirm_changes_id, + result="Confirmed" if parsed.get("accepted") else "Rejected", + ) + ], + ) + sanitized.append(synthetic_result) + if pending_tool_call_ids: + pending_tool_call_ids.discard(pending_confirm_changes_id) + pending_confirm_changes_id = None + continue + except (json.JSONDecodeError, KeyError) as exc: + logger.debug("Could not parse user message as confirm_changes response: %s", exc) + + if pending_tool_call_ids: + logger.info( + "User message arrived with %s pending tool calls - injecting synthetic results", + len(pending_tool_call_ids), + ) + for pending_call_id in pending_tool_call_ids: + logger.info("Injecting synthetic tool result for pending call_id=%s", pending_call_id) + synthetic_result = ChatMessage( + role="tool", + contents=[ + FunctionResultContent( + call_id=pending_call_id, + result="Tool execution skipped - user provided follow-up message", + ) + ], + ) + sanitized.append(synthetic_result) + pending_tool_call_ids = None + pending_confirm_changes_id = None + + sanitized.append(msg) + pending_confirm_changes_id = None + continue + + if role_value == "tool": + if not pending_tool_call_ids: + continue + keep = False + for content in msg.contents or []: + if isinstance(content, FunctionResultContent): + call_id = str(content.call_id) + if call_id in pending_tool_call_ids: + keep = True + if call_id == pending_confirm_changes_id: + pending_confirm_changes_id = None + break + if keep: + sanitized.append(msg) + continue + + sanitized.append(msg) + pending_tool_call_ids = None + pending_confirm_changes_id = None + + return sanitized + + +def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: + """Remove duplicate messages while preserving order.""" + seen_keys: dict[Any, int] = {} + unique_messages: list[ChatMessage] = [] + + for idx, msg in enumerate(messages): + role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + + if role_value == "tool" and msg.contents and isinstance(msg.contents[0], FunctionResultContent): + call_id = str(msg.contents[0].call_id) + key: Any = (role_value, call_id) + + if key in seen_keys: + existing_idx = seen_keys[key] + existing_msg = unique_messages[existing_idx] + + existing_result = None + if existing_msg.contents and isinstance(existing_msg.contents[0], FunctionResultContent): + existing_result = existing_msg.contents[0].result + new_result = msg.contents[0].result + + if (not existing_result or existing_result == "") and new_result: + logger.info("Replacing empty tool result at index %s with data from index %s", existing_idx, idx) + unique_messages[existing_idx] = msg + else: + logger.info("Skipping duplicate tool result at index %s: call_id=%s", idx, call_id) + continue + + seen_keys[key] = len(unique_messages) + unique_messages.append(msg) + + elif ( + role_value == "assistant" and msg.contents and any(isinstance(c, FunctionCallContent) for c in msg.contents) + ): + tool_call_ids = tuple( + sorted(str(c.call_id) for c in msg.contents if isinstance(c, FunctionCallContent) and c.call_id) + ) + key = (role_value, tool_call_ids) + + if key in seen_keys: + logger.info("Skipping duplicate assistant tool call at index %s", idx) + continue + + seen_keys[key] = len(unique_messages) + unique_messages.append(msg) + + else: + content_str = str([str(c) for c in msg.contents]) if msg.contents else "" + key = (role_value, hash(content_str)) + + if key in seen_keys: + logger.info("Skipping duplicate message at index %s: role=%s", idx, role_value) + continue + + seen_keys[key] = len(unique_messages) + unique_messages.append(msg) + + return unique_messages diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/state_manager.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/state_manager.py new file mode 100644 index 0000000000..45c16afef4 --- /dev/null +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/state_manager.py @@ -0,0 +1,102 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""State orchestration utilities.""" + +import json +from typing import Any + +from ag_ui.core import CustomEvent, EventType +from agent_framework import ChatMessage, TextContent + + +class StateManager: + """Coordinates state defaults, snapshots, and structured updates.""" + + def __init__( + self, + state_schema: dict[str, Any] | None, + predict_state_config: dict[str, dict[str, str]] | None, + require_confirmation: bool, + ) -> None: + self.state_schema = state_schema or {} + self.predict_state_config = predict_state_config or {} + self.require_confirmation = require_confirmation + self.current_state: dict[str, Any] = {} + + def initialize(self, initial_state: dict[str, Any] | None) -> dict[str, Any]: + """Initialize state with schema defaults.""" + self.current_state = (initial_state or {}).copy() + self._apply_schema_defaults() + return self.current_state + + def predict_state_event(self) -> CustomEvent | None: + """Create predict-state custom event when configured.""" + if not self.predict_state_config: + return None + + predict_state_value = [ + { + "state_key": state_key, + "tool": config["tool"], + "tool_argument": config["tool_argument"], + } + for state_key, config in self.predict_state_config.items() + ] + + return CustomEvent( + type=EventType.CUSTOM, + name="PredictState", + value=predict_state_value, + ) + + def initial_snapshot_event(self, event_bridge: Any) -> Any: + """Emit initial snapshot when schema and state present.""" + if not self.state_schema: + return None + self._apply_schema_defaults() + return event_bridge.create_state_snapshot_event(self.current_state) + + def state_context_message(self, is_new_user_turn: bool, conversation_has_tool_calls: bool) -> ChatMessage | None: + """Inject state context only when starting a new user turn.""" + if not self.current_state or not self.state_schema: + return None + if not is_new_user_turn or conversation_has_tool_calls: + return None + + state_json = json.dumps(self.current_state, indent=2) + return ChatMessage( + role="system", + contents=[ + TextContent( + text=( + "Current state of the application:\n" + f"{state_json}\n\n" + "When modifying state, you MUST include ALL existing data plus your changes.\n" + "For example, if adding one new item to a list, include ALL existing items PLUS the one new item.\n" + "Never replace existing data - always preserve and append or merge." + ) + ) + ], + ) + + def extract_state_updates(self, response_dict: dict[str, Any]) -> dict[str, Any]: + """Extract state updates from structured response payloads.""" + if self.state_schema: + return {key: response_dict[key] for key in self.state_schema.keys() if key in response_dict} + return {k: v for k, v in response_dict.items() if k != "message"} + + def apply_state_updates(self, updates: dict[str, Any]) -> None: + """Merge state updates into current state.""" + if not updates: + return + self.current_state.update(updates) + + def _apply_schema_defaults(self) -> None: + """Fill missing state fields based on schema hints.""" + for key, schema in self.state_schema.items(): + if key in self.current_state: + continue + if isinstance(schema, dict) and schema.get("type") == "array": # type: ignore + self.current_state[key] = [] + else: + self.current_state[key] = {} diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/tooling.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/tooling.py new file mode 100644 index 0000000000..3c59ee0440 --- /dev/null +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/tooling.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tool handling helpers.""" + +import logging +from typing import Any + +from agent_framework import BaseChatClient, ChatAgent + +logger = logging.getLogger(__name__) + + +def collect_server_tools(agent: Any) -> list[Any]: + """Collect server tools from ChatAgent or duck-typed agent.""" + if isinstance(agent, ChatAgent): + tools_from_agent = agent.chat_options.tools + server_tools = list(tools_from_agent) if tools_from_agent else [] + logger.info("[TOOLS] Agent has %s configured tools", len(server_tools)) + for tool in server_tools: + tool_name = getattr(tool, "name", "unknown") + approval_mode = getattr(tool, "approval_mode", None) + logger.info("[TOOLS] - %s: approval_mode=%s", tool_name, approval_mode) + return server_tools + + try: + chat_options_attr = getattr(agent, "chat_options", None) + if chat_options_attr is not None: + return getattr(chat_options_attr, "tools", None) or [] + except AttributeError: + return [] + return [] + + +def register_additional_client_tools(agent: Any, client_tools: list[Any] | None) -> None: + """Register client tools as additional declaration-only tools to avoid server execution.""" + if not client_tools: + return + + if isinstance(agent, ChatAgent): + chat_client = agent.chat_client + if isinstance(chat_client, BaseChatClient) and chat_client.function_invocation_configuration is not None: + chat_client.function_invocation_configuration.additional_tools = client_tools + logger.debug("[TOOLS] Registered %s client tools as additional_tools (declaration-only)", len(client_tools)) + return + + try: + chat_client_attr = getattr(agent, "chat_client", None) + if chat_client_attr is not None: + fic = getattr(chat_client_attr, "function_invocation_configuration", None) + if fic is not None: + fic.additional_tools = client_tools # type: ignore[attr-defined] + logger.debug( + "[TOOLS] Registered %s client tools as additional_tools (declaration-only)", len(client_tools) + ) + except AttributeError: + return + + +def merge_tools(server_tools: list[Any], client_tools: list[Any] | None) -> list[Any] | None: + """Combine server and client tools without overriding server metadata.""" + if not client_tools: + logger.info("[TOOLS] No client tools - not passing tools= parameter (using agent's configured tools)") + return None + + server_tool_names = {getattr(tool, "name", None) for tool in server_tools} + unique_client_tools = [tool for tool in client_tools if getattr(tool, "name", None) not in server_tool_names] + + if not unique_client_tools: + logger.info("[TOOLS] All client tools duplicate server tools - not passing tools= parameter") + return None + + combined_tools: list[Any] = [] + if server_tools: + combined_tools.extend(server_tools) + combined_tools.extend(unique_client_tools) + logger.info( + "[TOOLS] Passing tools= parameter with %s tools (%s server + %s unique client)", + len(combined_tools), + len(server_tools), + len(unique_client_tools), + ) + return combined_tools diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index 6da46d819f..81b3a75fbe 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -21,7 +21,6 @@ AgentProtocol, AgentThread, ChatAgent, - ChatMessage, FunctionCallContent, FunctionResultContent, TextContent, @@ -271,144 +270,30 @@ async def run( AG-UI events """ from ._events import AgentFrameworkEventBridge + from ._message_adapters import agui_messages_to_snapshot_format + from ._orchestration import ( + StateManager, + collect_server_tools, + deduplicate_messages, + merge_tools, + register_additional_client_tools, + sanitize_tool_history, + ) - logger.info(f"Starting default agent run for thread_id={context.thread_id}, run_id={context.run_id}") - - # Initialize state tracking - initial_state = context.input_data.get("state", {}) - current_state: dict[str, Any] = initial_state.copy() if initial_state else {} + logger.info("Starting default agent run for thread_id=%s, run_id=%s", context.thread_id, context.run_id) - # Check if agent uses structured outputs (response_format) - # Use isinstance to narrow type for proper attribute access response_format = None if isinstance(context.agent, ChatAgent): response_format = context.agent.chat_options.response_format skip_text_content = response_format is not None - # Sanitizer: ensure tool results only follow assistant tool calls - # Also inject synthetic tool results for confirm_changes - def sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: - sanitized: list[ChatMessage] = [] - pending_tool_call_ids: set[str] | None = None - pending_confirm_changes_id: str | None = None - - for msg in messages: - role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - - if role_value == "assistant": - tool_ids = { - str(content.call_id) - for content in msg.contents or [] - if isinstance(content, FunctionCallContent) and content.call_id - } - # Check for confirm_changes tool call - confirm_changes_call = None - for content in msg.contents or []: - if isinstance(content, FunctionCallContent) and content.name == "confirm_changes": - confirm_changes_call = content - break - - sanitized.append(msg) - pending_tool_call_ids = tool_ids if tool_ids else None - pending_confirm_changes_id = ( - str(confirm_changes_call.call_id) - if confirm_changes_call and confirm_changes_call.call_id - else None - ) - continue - - if role_value == "user": - # Check if this user message is a confirm_changes response (JSON with "accepted" field) - # This must be checked BEFORE injecting synthetic results for pending tool calls - if pending_confirm_changes_id: - user_text = "" - for content in msg.contents or []: - if isinstance(content, TextContent): - user_text = content.text - break - - try: - parsed = json.loads(user_text) - if "accepted" in parsed: - # This is a confirm_changes response - inject synthetic tool result - logger.info( - f"Injecting synthetic tool result for confirm_changes call_id={pending_confirm_changes_id}" - ) - synthetic_result = ChatMessage( - role="tool", - contents=[ - FunctionResultContent( - call_id=pending_confirm_changes_id, - result="Confirmed" if parsed.get("accepted") else "Rejected", - ) - ], - ) - sanitized.append(synthetic_result) - if pending_tool_call_ids: - pending_tool_call_ids.discard(pending_confirm_changes_id) - pending_confirm_changes_id = None - # Don't add the user message to sanitized - it's been converted to tool result - continue - except (json.JSONDecodeError, KeyError) as e: - # Failed to parse user message as confirm_changes response; continue normal processing - logger.debug(f"Could not parse user message as confirm_changes response: {e}") - - # Before processing user message, check if there are pending tool calls without results - # This happens when assistant made multiple tool calls but only some got results - # This is checked AFTER confirm_changes special handling above - if pending_tool_call_ids: - logger.info( - f"User message arrived with {len(pending_tool_call_ids)} pending tool calls - injecting synthetic results" - ) - for pending_call_id in pending_tool_call_ids: - logger.info(f"Injecting synthetic tool result for pending call_id={pending_call_id}") - synthetic_result = ChatMessage( - role="tool", - contents=[ - FunctionResultContent( - call_id=pending_call_id, - result="Tool execution skipped - user provided follow-up message", - ) - ], - ) - sanitized.append(synthetic_result) - pending_tool_call_ids = None - pending_confirm_changes_id = None - - # Normal user message processing - sanitized.append(msg) - pending_confirm_changes_id = None - continue - - if role_value == "tool": - if not pending_tool_call_ids: - continue - keep = False - for content in msg.contents or []: - if isinstance(content, FunctionResultContent): - call_id = str(content.call_id) - if call_id in pending_tool_call_ids: - keep = True - # Note: We do NOT remove call_id from pending here. - # This allows duplicate tool results to pass through sanitization - # so the deduplicator can choose the best one (prefer non-empty results). - # We only clear pending_tool_call_ids when a user message arrives. - if call_id == pending_confirm_changes_id: - # For confirm_changes specifically, we do want to clear it - # since we only expect one response - pending_confirm_changes_id = None - break - if keep: - sanitized.append(msg) - continue - - sanitized.append(msg) - pending_tool_call_ids = None - pending_confirm_changes_id = None - - return sanitized - - # Create event bridge + state_manager = StateManager( + state_schema=context.config.state_schema, + predict_state_config=context.config.predict_state_config, + require_confirmation=context.config.require_confirmation, + ) + current_state = state_manager.initialize(context.input_data.get("state", {})) + event_bridge = AgentFrameworkEventBridge( run_id=context.run_id, thread_id=context.thread_id, @@ -421,42 +306,19 @@ def sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: yield event_bridge.create_run_started_event() - # Emit PredictState custom event if we have predictive state config - if context.config.predict_state_config: - from ag_ui.core import CustomEvent, EventType - - predict_state_value = [ - { - "state_key": state_key, - "tool": config["tool"], - "tool_argument": config["tool_argument"], - } - for state_key, config in context.config.predict_state_config.items() - ] - - yield CustomEvent( - type=EventType.CUSTOM, - name="PredictState", - value=predict_state_value, - ) + predict_event = state_manager.predict_state_event() + if predict_event: + yield predict_event - # If we have a state schema, ensure we emit initial state snapshot - if context.config.state_schema: - # Initialize missing state fields with appropriate empty values based on schema type - for key, schema in context.config.state_schema.items(): - if key not in current_state: - # Default to empty object; use empty array if schema specifies "array" type - current_state[key] = [] if isinstance(schema, dict) and schema.get("type") == "array" else {} # type: ignore - yield event_bridge.create_state_snapshot_event(current_state) + snapshot_event = state_manager.initial_snapshot_event(event_bridge) + if snapshot_event: + yield snapshot_event - # Create thread for context tracking thread = AgentThread() thread.metadata = { # type: ignore[attr-defined] "ag_ui_thread_id": context.thread_id, "ag_ui_run_id": context.run_id, } - - # Inject current state into thread metadata so agent can access it if current_state: thread.metadata["current_state"] = current_state # type: ignore[attr-defined] @@ -466,99 +328,29 @@ def sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: yield event_bridge.create_run_finished_event() return - logger.info(f"Received {len(raw_messages)} raw messages from client") + logger.info("Received %s raw messages from client", len(raw_messages)) for i, msg in enumerate(raw_messages): role = msg.role.value if hasattr(msg.role, "value") else str(msg.role) msg_id = getattr(msg, "message_id", None) - logger.info(f" Raw message {i}: role={role}, id={msg_id}") + logger.info(" Raw message %s: role=%s, id=%s", i, role, msg_id) if hasattr(msg, "contents") and msg.contents: for j, content in enumerate(msg.contents): content_type = type(content).__name__ if isinstance(content, TextContent): - logger.debug(f" Content {j}: {content_type} - {content.text}") + logger.debug(" Content %s: %s - %s", j, content_type, content.text) elif isinstance(content, FunctionCallContent): - logger.debug(f" Content {j}: {content_type} - {content.name}({content.arguments})") + logger.debug(" Content %s: %s - %s(%s)", j, content_type, content.name, content.arguments) elif isinstance(content, FunctionResultContent): logger.debug( - f" Content {j}: {content_type} - call_id={content.call_id}, result={content.result}" + " Content %s: %s - call_id=%s, result=%s", + j, + content_type, + content.call_id, + content.result, ) else: - logger.debug(f" Content {j}: {content_type} - {content}") - - # After getting sanitized_messages, deduplicate them - def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: - """Remove duplicate messages while preserving order. - - For tool results with the same call_id, prefer the one with actual data. - """ - seen_keys: dict[Any, int] = {} # key -> index in unique_messages (key can be various tuple types) - unique_messages: list[ChatMessage] = [] - - for idx, msg in enumerate(messages): - role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - - # For tool messages, use call_id as unique key - if role_value == "tool" and msg.contents and isinstance(msg.contents[0], FunctionResultContent): - call_id = str(msg.contents[0].call_id) - key: Any = (role_value, call_id) - - # Check if we already have this tool result - if key in seen_keys: - existing_idx = seen_keys[key] - existing_msg = unique_messages[existing_idx] - - # Compare results - prefer non-empty over empty - existing_result = None - if existing_msg.contents and isinstance(existing_msg.contents[0], FunctionResultContent): - existing_result = existing_msg.contents[0].result - new_result = msg.contents[0].result - - # Replace if existing is empty/None and new has data - if (not existing_result or existing_result == "") and new_result: - logger.info( - f"Replacing empty tool result at index {existing_idx} with data from index {idx}" - ) - unique_messages[existing_idx] = msg - else: - logger.info(f"Skipping duplicate tool result at index {idx}: call_id={call_id}") - continue - - seen_keys[key] = len(unique_messages) - unique_messages.append(msg) - - elif ( - role_value == "assistant" - and msg.contents - and any(isinstance(c, FunctionCallContent) for c in msg.contents) - ): - # For assistant messages with tool_calls, use the tool call IDs - tool_call_ids = tuple( - sorted(str(c.call_id) for c in msg.contents if isinstance(c, FunctionCallContent) and c.call_id) - ) - key = (role_value, tool_call_ids) - - if key in seen_keys: - logger.info(f"Skipping duplicate assistant tool call at index {idx}") - continue - - seen_keys[key] = len(unique_messages) - unique_messages.append(msg) - - else: - # For other messages (system, user, assistant without tools), hash the content - content_str = str([str(c) for c in msg.contents]) if msg.contents else "" - key = (role_value, hash(content_str)) - - if key in seen_keys: - logger.info(f"Skipping duplicate message at index {idx}: role={role_value}") - continue - - seen_keys[key] = len(unique_messages) - unique_messages.append(msg) - - return unique_messages + logger.debug(" Content %s: %s - %s", j, content_type, content) - # Then use it: sanitized_messages = sanitize_tool_history(raw_messages) provider_messages = deduplicate_messages(sanitized_messages) @@ -567,189 +359,90 @@ def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: yield event_bridge.create_run_finished_event() return - logger.info(f"Processing {len(provider_messages)} provider messages after sanitization/deduplication") + logger.info("Processing %s provider messages after sanitization/deduplication", len(provider_messages)) for i, msg in enumerate(provider_messages): role = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - logger.info(f" Message {i}: role={role}") + logger.info(" Message %s: role=%s", i, role) if hasattr(msg, "contents") and msg.contents: for j, content in enumerate(msg.contents): content_type = type(content).__name__ if isinstance(content, TextContent): - logger.info(f" Content {j}: {content_type} - {content.text}") + logger.info(" Content %s: %s - %s", j, content_type, content.text) elif isinstance(content, FunctionCallContent): - logger.info(f" Content {j}: {content_type} - {content.name}({content.arguments})") + logger.info(" Content %s: %s - %s(%s)", j, content_type, content.name, content.arguments) elif isinstance(content, FunctionResultContent): logger.info( - f" Content {j}: {content_type} - call_id={content.call_id}, result={content.result}" + " Content %s: %s - call_id=%s, result=%s", + j, + content_type, + content.call_id, + content.result, ) else: - logger.info(f" Content {j}: {content_type} - {content}") + logger.info(" Content %s: %s - %s", j, content_type, content) - # NOTE: For AG-UI, the client sends the full conversation history on each request. - # We should NOT add to thread.on_new_messages() as that would cause duplication. - # Instead, we pass messages directly to the agent via messages_to_run. - - # Inject current state as system message context if we have state and this is a new user turn messages_to_run: list[Any] = [] - - # Check if the last message is from the user (new turn) vs assistant/tool (mid-execution) is_new_user_turn = False if provider_messages: last_msg = provider_messages[-1] - is_new_user_turn = last_msg.role.value == "user" + role_value = last_msg.role.value if hasattr(last_msg.role, "value") else str(last_msg.role) + is_new_user_turn = role_value == "user" - # Check if conversation has tool calls (indicates mid-execution) conversation_has_tool_calls = False for msg in provider_messages: - if msg.role.value == "assistant" and hasattr(msg, "contents") and msg.contents: + role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + if role_value == "assistant" and hasattr(msg, "contents") and msg.contents: if any(isinstance(content, FunctionCallContent) for content in msg.contents): conversation_has_tool_calls = True break - # Only inject state context on new user turns AND when conversation doesn't have tool calls - # (tool calls indicate we're mid-execution, so state context was already injected) - if current_state and context.config.state_schema and is_new_user_turn and not conversation_has_tool_calls: - state_json = json.dumps(current_state, indent=2) - state_context_msg = ChatMessage( - role="system", - contents=[ - TextContent( - text=f"""Current state of the application: - {state_json} - - When modifying state, you MUST include ALL existing data plus your changes. - For example, if adding one new item to a list, include ALL existing items PLUS the one new item. - Never replace existing data - always preserve and append or merge.""" - ) - ], - ) + state_context_msg = state_manager.state_context_message( + is_new_user_turn=is_new_user_turn, conversation_has_tool_calls=conversation_has_tool_calls + ) + if state_context_msg: messages_to_run.append(state_context_msg) - # Add all provider messages to messages_to_run - # AG-UI sends full conversation history on each request, so we pass it directly to the agent messages_to_run.extend(provider_messages) - # Handle client tools for hybrid execution - # Client sends tool metadata, server merges with its own tools. - # Client tools have func=None (declaration-only), so @use_function_invocation - # will return the function call without executing (passes back to client). - from agent_framework import BaseChatClient - client_tools = convert_agui_tools_to_agent_framework(context.input_data.get("tools")) - logger.info(f"[TOOLS] Client sent {len(client_tools) if client_tools else 0} tools") + logger.info("[TOOLS] Client sent %s tools", len(client_tools) if client_tools else 0) if client_tools: for tool in client_tools: tool_name = getattr(tool, "name", "unknown") declaration_only = getattr(tool, "declaration_only", None) - logger.info(f"[TOOLS] - Client tool: {tool_name}, declaration_only={declaration_only}") + logger.info("[TOOLS] - Client tool: %s, declaration_only=%s", tool_name, declaration_only) - # Extract server tools - use type narrowing when possible - server_tools: list[Any] = [] - if isinstance(context.agent, ChatAgent): - tools_from_agent = context.agent.chat_options.tools - server_tools = list(tools_from_agent) if tools_from_agent else [] - logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools") - for tool in server_tools: - tool_name = getattr(tool, "name", "unknown") - approval_mode = getattr(tool, "approval_mode", None) - logger.info(f"[TOOLS] - {tool_name}: approval_mode={approval_mode}") - else: - # AgentProtocol allows duck-typed implementations - fallback to attribute access - # This supports test mocks and custom agent implementations - try: - chat_options_attr = getattr(context.agent, "chat_options", None) - if chat_options_attr is not None: - server_tools = getattr(chat_options_attr, "tools", None) or [] - except AttributeError: - pass - - # Register client tools as additional (declaration-only) so they are not executed on server - if client_tools: - if isinstance(context.agent, ChatAgent): - # Type-safe path for ChatAgent - chat_client = context.agent.chat_client - if ( - isinstance(chat_client, BaseChatClient) - and chat_client.function_invocation_configuration is not None - ): - chat_client.function_invocation_configuration.additional_tools = client_tools - logger.debug( - f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)" - ) - else: - # Fallback for AgentProtocol implementations (test mocks, custom agents) - try: - chat_client_attr = getattr(context.agent, "chat_client", None) - if chat_client_attr is not None: - fic = getattr(chat_client_attr, "function_invocation_configuration", None) - if fic is not None: - fic.additional_tools = client_tools # type: ignore[attr-defined] - logger.debug( - f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)" - ) - except AttributeError: - pass - - # For tools parameter: only pass if we have client tools to add - # If we pass tools=, it overrides the agent's configured tools and loses metadata like approval_mode - # So only pass tools when we need to add client tools on top of server tools - # IMPORTANT: Don't include client tools that duplicate server tools (same name) - tools_param = None - if client_tools: - # Get server tool names - server_tool_names = {getattr(tool, "name", None) for tool in server_tools} - - # Filter out client tools that duplicate server tools - unique_client_tools = [ - tool for tool in client_tools if getattr(tool, "name", None) not in server_tool_names - ] - - if unique_client_tools: - combined_tools: list[Any] = [] - if server_tools: - combined_tools.extend(server_tools) - combined_tools.extend(unique_client_tools) - tools_param = combined_tools - logger.info( - f"[TOOLS] Passing tools= parameter with {len(combined_tools)} tools ({len(server_tools)} server + {len(unique_client_tools)} unique client)" - ) - else: - logger.info("[TOOLS] All client tools duplicate server tools - not passing tools= parameter") - else: - logger.info("[TOOLS] No client tools - not passing tools= parameter (using agent's configured tools)") + server_tools = collect_server_tools(context.agent) + register_additional_client_tools(context.agent, client_tools) + tools_param = merge_tools(server_tools, client_tools) - # Collect all updates to get the final structured output all_updates: list[Any] = [] update_count = 0 async for update in context.agent.run_stream(messages_to_run, thread=thread, tools=tools_param): update_count += 1 - logger.info(f"[STREAM] Received update #{update_count} from agent") + logger.info("[STREAM] Received update #%s from agent", update_count) all_updates.append(update) events = await event_bridge.from_agent_run_update(update) - logger.info(f"[STREAM] Update #{update_count} produced {len(events)} events") + logger.info("[STREAM] Update #%s produced %s events", update_count, len(events)) for event in events: - logger.info(f"[STREAM] Yielding event: {type(event).__name__}") + logger.info("[STREAM] Yielding event: %s", type(event).__name__) yield event - logger.info(f"[STREAM] Agent stream completed. Total updates: {update_count}") + logger.info("[STREAM] Agent stream completed. Total updates: %s", update_count) - # After agent completes, check if we should stop (waiting for user to confirm changes) if event_bridge.should_stop_after_confirm: logger.info("Stopping run after confirm_changes - waiting for user response") yield event_bridge.create_run_finished_event() return - # Check if there are pending tool calls (declaration-only tools that weren't executed) - # These need ToolCallEndEvent to signal the client to execute them - # Only emit for tool calls that haven't already had ToolCallEndEvent emitted - # (approval-required tools already had their end event emitted) if event_bridge.pending_tool_calls: pending_without_end = [ tc for tc in event_bridge.pending_tool_calls if tc.get("id") not in event_bridge.tool_calls_ended ] if pending_without_end: logger.info( - f"Found {len(pending_without_end)} pending tool calls without end event - emitting ToolCallEndEvent" + "Found %s pending tool calls without end event - emitting ToolCallEndEvent", + len(pending_without_end), ) for tool_call in pending_without_end: tool_call_id = tool_call.get("id") @@ -757,79 +450,56 @@ def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: from ag_ui.core import ToolCallEndEvent end_event = ToolCallEndEvent(tool_call_id=tool_call_id) - logger.info(f"Emitting ToolCallEndEvent for declaration-only tool call '{tool_call_id}'") + logger.info( + "Emitting ToolCallEndEvent for declaration-only tool call '%s'", + tool_call_id, + ) yield end_event - # After streaming completes, check if agent has response_format and extract structured output if all_updates and response_format: from agent_framework import AgentRunResponse from pydantic import BaseModel - logger.info(f"Processing structured output, update count: {len(all_updates)}") - - # Convert streaming updates to final response to get the structured output + logger.info("Processing structured output, update count: %s", len(all_updates)) final_response = AgentRunResponse.from_agent_run_response_updates( all_updates, output_format_type=response_format ) if final_response.value and isinstance(final_response.value, BaseModel): - # Convert Pydantic model to dict response_dict = final_response.value.model_dump(mode="json", exclude_none=True) - logger.info(f"Received structured output: {list(response_dict.keys())}") - - # Extract state fields based on state_schema - state_updates: dict[str, Any] = {} - - if context.config.state_schema: - # Use state_schema to determine which fields are state - for state_key in context.config.state_schema.keys(): - if state_key in response_dict: - state_updates[state_key] = response_dict[state_key] - else: - # No schema: treat all non-message fields as state - state_updates = {k: v for k, v in response_dict.items() if k != "message"} + logger.info("Received structured output: %s", list(response_dict.keys())) - # Apply state updates if any found + state_updates = state_manager.extract_state_updates(response_dict) if state_updates: - current_state.update(state_updates) - - # Emit StateSnapshotEvent with the updated state + state_manager.apply_state_updates(state_updates) state_snapshot = event_bridge.create_state_snapshot_event(current_state) yield state_snapshot - logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}") + logger.info("Emitted StateSnapshotEvent with updates: %s", list(state_updates.keys())) - # If there's a message field, emit it as chat text if "message" in response_dict and response_dict["message"]: message_id = generate_event_id() yield TextMessageStartEvent(message_id=message_id, role="assistant") yield TextMessageContentEvent(message_id=message_id, delta=response_dict["message"]) yield TextMessageEndEvent(message_id=message_id) - logger.info(f"Emitted conversational message: {response_dict['message'][:100]}...") + logger.info("Emitted conversational message: %s...", response_dict["message"][:100]) - logger.info(f"[FINALIZE] Checking for unclosed message. current_message_id={event_bridge.current_message_id}") + logger.info("[FINALIZE] Checking for unclosed message. current_message_id=%s", event_bridge.current_message_id) if event_bridge.current_message_id: - logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") + logger.info( + "[FINALIZE] Emitting TextMessageEndEvent for message_id=%s", + event_bridge.current_message_id, + ) yield event_bridge.create_message_end_event(event_bridge.current_message_id) - # Emit MessagesSnapshotEvent to persist the final assistant text message - from ._message_adapters import agui_messages_to_snapshot_format - - # Build the final assistant message with accumulated text content assistant_text_message = { "id": event_bridge.current_message_id, "role": "assistant", "content": event_bridge.accumulated_text_content, } - # Convert input messages to snapshot format (normalize content structure) - # event_bridge.input_messages are already in AG-UI format, just need normalization converted_input_messages = agui_messages_to_snapshot_format(event_bridge.input_messages) - - # Build complete messages array - # Include: input messages + any pending tool calls/results + final text message all_messages = converted_input_messages.copy() - # Add assistant message with tool calls if any if event_bridge.pending_tool_calls: tool_call_message = { "id": generate_event_id(), @@ -838,18 +508,16 @@ def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: } all_messages.append(tool_call_message) - # Add tool results if any all_messages.extend(event_bridge.tool_results.copy()) - - # Add final text message all_messages.append(assistant_text_message) messages_snapshot = MessagesSnapshotEvent( messages=all_messages, # type: ignore[arg-type] ) logger.info( - f"[FINALIZE] Emitting MessagesSnapshotEvent with {len(all_messages)} messages " - f"(text content length: {len(event_bridge.accumulated_text_content)})" + "[FINALIZE] Emitting MessagesSnapshotEvent with %s messages (text content length: %s)", + len(all_messages), + len(event_bridge.accumulated_text_content), ) yield messages_snapshot else: @@ -857,7 +525,7 @@ def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: logger.info("[FINALIZE] Emitting RUN_FINISHED event") yield event_bridge.create_run_finished_event() - logger.info(f"Completed agent run for thread_id={context.thread_id}, run_id={context.run_id}") + logger.info("Completed agent run for thread_id=%s, run_id=%s", context.thread_id, context.run_id) __all__ = [ diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index dbf0160ae6..743f3a218a 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -7,6 +7,7 @@ import pytest from agent_framework import ChatAgent, TextContent from agent_framework._types import ChatResponseUpdate +from pydantic import BaseModel async def test_agent_initialization_basic(): @@ -56,6 +57,28 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): assert wrapper.config.predict_state_config == predict_config +async def test_agent_initialization_with_pydantic_state_schema(): + """Test agent initialization when state_schema is provided as Pydantic model/class.""" + from agent_framework.ag_ui import AgentFrameworkAgent + + class MockChatClient: + async def get_streaming_response(self, messages, chat_options, **kwargs): + yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + + class MyState(BaseModel): + document: str + tags: list[str] = [] + + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + + wrapper_class_schema = AgentFrameworkAgent(agent=agent, state_schema=MyState) + wrapper_instance_schema = AgentFrameworkAgent(agent=agent, state_schema=MyState(document="hi")) + + expected_properties = MyState.model_json_schema().get("properties", {}) + assert wrapper_class_schema.config.state_schema == expected_properties + assert wrapper_instance_schema.config.state_schema == expected_properties + + async def test_run_started_event_emission(): """Test RunStartedEvent is emitted at start of run.""" from agent_framework.ag_ui import AgentFrameworkAgent diff --git a/python/packages/ag-ui/tests/test_endpoint.py b/python/packages/ag-ui/tests/test_endpoint.py index 1ae364f818..93cd271881 100644 --- a/python/packages/ag-ui/tests/test_endpoint.py +++ b/python/packages/ag-ui/tests/test_endpoint.py @@ -70,6 +70,33 @@ async def test_endpoint_with_state_schema(): assert response.status_code == 200 +async def test_endpoint_with_default_state_seed(): + """Test endpoint seeds default state when client omits it.""" + app = FastAPI() + agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) + state_schema = {"proverbs": {"type": "array"}} + default_state = {"proverbs": ["Keep the original."]} + + add_agent_framework_fastapi_endpoint( + app, + agent, + path="/default-state", + state_schema=state_schema, + default_state=default_state, + ) + + client = TestClient(app) + response = client.post("/default-state", json={"messages": [{"role": "user", "content": "Hello"}]}) + + assert response.status_code == 200 + + content = response.content.decode("utf-8") + lines = [line for line in content.split("\n") if line.startswith("data: ")] + snapshots = [json.loads(line[6:]) for line in lines if json.loads(line[6:]).get("type") == "STATE_SNAPSHOT"] + assert snapshots, "Expected a STATE_SNAPSHOT event" + assert snapshots[0]["snapshot"]["proverbs"] == default_state["proverbs"] + + async def test_endpoint_with_predict_state_config(): """Test endpoint with predict_state_config parameter.""" app = FastAPI() diff --git a/python/packages/ag-ui/tests/test_message_hygiene.py b/python/packages/ag-ui/tests/test_message_hygiene.py new file mode 100644 index 0000000000..d5dad600ff --- /dev/null +++ b/python/packages/ag-ui/tests/test_message_hygiene.py @@ -0,0 +1,51 @@ +from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent + +from agent_framework_ag_ui._orchestration.message_hygiene import ( + deduplicate_messages, + sanitize_tool_history, +) + + +def test_sanitize_tool_history_injects_confirm_changes_result() -> None: + messages = [ + ChatMessage( + role="assistant", + contents=[ + FunctionCallContent( + name="confirm_changes", + call_id="call_confirm_123", + arguments='{"changes": "test"}', + ) + ], + ), + ChatMessage( + role="user", + contents=[TextContent(text='{"accepted": true}')], + ), + ] + + sanitized = sanitize_tool_history(messages) + + tool_messages = [ + msg for msg in sanitized if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "tool" + ] + assert len(tool_messages) == 1 + assert str(tool_messages[0].contents[0].call_id) == "call_confirm_123" + assert tool_messages[0].contents[0].result == "Confirmed" + + +def test_deduplicate_messages_prefers_non_empty_tool_results() -> None: + messages = [ + ChatMessage( + role="tool", + contents=[FunctionResultContent(call_id="call1", result="")], + ), + ChatMessage( + role="tool", + contents=[FunctionResultContent(call_id="call1", result="result data")], + ), + ] + + deduped = deduplicate_messages(messages) + assert len(deduped) == 1 + assert deduped[0].contents[0].result == "result data" diff --git a/python/packages/ag-ui/tests/test_state_manager.py b/python/packages/ag-ui/tests/test_state_manager.py new file mode 100644 index 0000000000..79d959113d --- /dev/null +++ b/python/packages/ag-ui/tests/test_state_manager.py @@ -0,0 +1,49 @@ +from ag_ui.core import CustomEvent, EventType +from agent_framework import ChatMessage, TextContent + +from agent_framework_ag_ui._events import AgentFrameworkEventBridge +from agent_framework_ag_ui._orchestration.state_manager import StateManager + + +def test_state_manager_initializes_defaults_and_snapshot() -> None: + state_manager = StateManager( + state_schema={"items": {"type": "array"}, "metadata": {"type": "object"}}, + predict_state_config=None, + require_confirmation=True, + ) + current_state = state_manager.initialize({"metadata": {"a": 1}}) + bridge = AgentFrameworkEventBridge(run_id="run", thread_id="thread", current_state=current_state) + + snapshot_event = state_manager.initial_snapshot_event(bridge) + assert snapshot_event is not None + assert snapshot_event.snapshot["items"] == [] + assert snapshot_event.snapshot["metadata"] == {"a": 1} + + +def test_state_manager_predict_state_event_shape() -> None: + state_manager = StateManager( + state_schema=None, + predict_state_config={"doc": {"tool": "write_document_local", "tool_argument": "document"}}, + require_confirmation=True, + ) + predict_event = state_manager.predict_state_event() + assert isinstance(predict_event, CustomEvent) + assert predict_event.type == EventType.CUSTOM + assert predict_event.name == "PredictState" + assert predict_event.value[0]["state_key"] == "doc" + + +def test_state_context_only_when_new_user_turn() -> None: + state_manager = StateManager( + state_schema={"items": {"type": "array"}}, + predict_state_config=None, + require_confirmation=True, + ) + state_manager.initialize({"items": [1]}) + + assert state_manager.state_context_message(is_new_user_turn=False, conversation_has_tool_calls=False) is None + + message = state_manager.state_context_message(is_new_user_turn=True, conversation_has_tool_calls=False) + assert isinstance(message, ChatMessage) + assert isinstance(message.contents[0], TextContent) + assert "Current state of the application" in message.contents[0].text diff --git a/python/packages/ag-ui/tests/test_tooling.py b/python/packages/ag-ui/tests/test_tooling.py new file mode 100644 index 0000000000..b27063f726 --- /dev/null +++ b/python/packages/ag-ui/tests/test_tooling.py @@ -0,0 +1,34 @@ +from types import SimpleNamespace + +from agent_framework_ag_ui._orchestration.tooling import merge_tools, register_additional_client_tools + + +class DummyTool: + def __init__(self, name: str) -> None: + self.name = name + self.declaration_only = True + + +def test_merge_tools_filters_duplicates() -> None: + server = [DummyTool("a"), DummyTool("b")] + client = [DummyTool("b"), DummyTool("c")] + + merged = merge_tools(server, client) + + assert merged is not None + names = [getattr(t, "name", None) for t in merged] + assert names == ["a", "b", "c"] + + +def test_register_additional_client_tools_assigns_when_configured() -> None: + class Fic: + def __init__(self) -> None: + self.additional_tools = None + + holder = SimpleNamespace(function_invocation_configuration=Fic()) + agent = SimpleNamespace(chat_client=holder) + + tools = [DummyTool("x")] + register_additional_client_tools(agent, tools) + + assert holder.function_invocation_configuration.additional_tools == tools From 7ff358b126d5ad5bc5029b6a693af15ed373d9cb Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Fri, 21 Nov 2025 12:08:36 +0900 Subject: [PATCH 2/6] Mypy fixes --- .../packages/ag-ui/agent_framework_ag_ui/_agent.py | 13 ++++++++----- .../packages/ag-ui/agent_framework_ag_ui/_events.py | 2 ++ python/packages/ag-ui/tests/test_ag_ui_client.py | 2 ++ .../packages/ag-ui/tests/test_event_converters.py | 2 ++ python/packages/ag-ui/tests/test_message_hygiene.py | 2 ++ python/packages/ag-ui/tests/test_orchestrators.py | 2 ++ python/packages/ag-ui/tests/test_state_manager.py | 2 ++ python/packages/ag-ui/tests/test_tooling.py | 2 ++ 8 files changed, 22 insertions(+), 5 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py index 9aa54925c2..23860150be 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py @@ -46,16 +46,19 @@ def _normalize_state_schema(state_schema: Any | None) -> dict[str, Any]: if isinstance(state_schema, dict): return cast(dict[str, Any], state_schema) + base_model_type: type[Any] | None try: - from pydantic import BaseModel - except Exception: - BaseModel = None # type: ignore # noqa: N806 + from pydantic import BaseModel as ImportedBaseModel - if BaseModel and isinstance(state_schema, BaseModel): + base_model_type = ImportedBaseModel + except Exception: # pragma: no cover + base_model_type = None + + if base_model_type is not None and isinstance(state_schema, base_model_type): schema_dict = state_schema.__class__.model_json_schema() return schema_dict.get("properties", {}) or {} - if BaseModel and isinstance(state_schema, type) and issubclass(state_schema, BaseModel): + if base_model_type is not None and isinstance(state_schema, type) and issubclass(state_schema, base_model_type): schema_dict = state_schema.model_json_schema() return schema_dict.get("properties", {}) or {} diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_events.py b/python/packages/ag-ui/agent_framework_ag_ui/_events.py index d5679f42b5..abf335dc83 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_events.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_events.py @@ -335,6 +335,7 @@ def _legacy_predictive_state(self, content: FunctionCallContent) -> list[BaseEve tool_arg_name, ) + state_value: Any if tool_arg_name == "*": state_value = parsed_args logger.info("Using all args as state value, keys: %s", list(state_value.keys())) @@ -568,6 +569,7 @@ def _handle_function_approval_request_content(self, content: FunctionApprovalReq tool_arg_name, ) + state_value: Any if tool_arg_name == "*": state_value = parsed_args elif tool_arg_name in parsed_args: diff --git a/python/packages/ag-ui/tests/test_ag_ui_client.py b/python/packages/ag-ui/tests/test_ag_ui_client.py index cfececd771..742b656369 100644 --- a/python/packages/ag-ui/tests/test_ag_ui_client.py +++ b/python/packages/ag-ui/tests/test_ag_ui_client.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft. All rights reserved. + """Tests for AGUIChatClient.""" import json diff --git a/python/packages/ag-ui/tests/test_event_converters.py b/python/packages/ag-ui/tests/test_event_converters.py index d05b1fe720..ff4d2ddc91 100644 --- a/python/packages/ag-ui/tests/test_event_converters.py +++ b/python/packages/ag-ui/tests/test_event_converters.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft. All rights reserved. + """Tests for AG-UI event converter.""" from agent_framework import FinishReason, Role diff --git a/python/packages/ag-ui/tests/test_message_hygiene.py b/python/packages/ag-ui/tests/test_message_hygiene.py index d5dad600ff..14066ce89c 100644 --- a/python/packages/ag-ui/tests/test_message_hygiene.py +++ b/python/packages/ag-ui/tests/test_message_hygiene.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft. All rights reserved. + from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent from agent_framework_ag_ui._orchestration.message_hygiene import ( diff --git a/python/packages/ag-ui/tests/test_orchestrators.py b/python/packages/ag-ui/tests/test_orchestrators.py index a400e78458..74a2083e32 100644 --- a/python/packages/ag-ui/tests/test_orchestrators.py +++ b/python/packages/ag-ui/tests/test_orchestrators.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft. All rights reserved. + """Tests for AG-UI orchestrators.""" from collections.abc import AsyncGenerator diff --git a/python/packages/ag-ui/tests/test_state_manager.py b/python/packages/ag-ui/tests/test_state_manager.py index 79d959113d..ce964d784d 100644 --- a/python/packages/ag-ui/tests/test_state_manager.py +++ b/python/packages/ag-ui/tests/test_state_manager.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft. All rights reserved. + from ag_ui.core import CustomEvent, EventType from agent_framework import ChatMessage, TextContent diff --git a/python/packages/ag-ui/tests/test_tooling.py b/python/packages/ag-ui/tests/test_tooling.py index b27063f726..dfd89c0148 100644 --- a/python/packages/ag-ui/tests/test_tooling.py +++ b/python/packages/ag-ui/tests/test_tooling.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft. All rights reserved. + from types import SimpleNamespace from agent_framework_ag_ui._orchestration.tooling import merge_tools, register_additional_client_tools From 228114f45e250eef6719308bc810b12d309bcf28 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Mon, 24 Nov 2025 11:40:17 +0900 Subject: [PATCH 3/6] Fix imports, typing, tests, logging. --- .../ag-ui/agent_framework_ag_ui/_events.py | 105 ++++--- .../_orchestration/__init__.py | 15 - ...message_hygiene.py => _message_hygiene.py} | 18 +- .../{state_manager.py => _state_manager.py} | 0 .../{tooling.py => _tooling.py} | 14 +- .../agent_framework_ag_ui/_orchestrators.py | 77 +++-- python/packages/ag-ui/tests/__init__.py | 1 + python/packages/ag-ui/tests/_test_stubs.py | 138 +++++++++ .../packages/ag-ui/tests/test_ag_ui_client.py | 166 ++++++---- .../tests/test_agent_wrapper_comprehensive.py | 293 ++++++++++-------- .../tests/test_backend_tool_rendering.py | 11 +- ...t_confirmation_strategies_comprehensive.py | 54 ++-- .../ag-ui/tests/test_document_writer_flow.py | 32 +- python/packages/ag-ui/tests/test_endpoint.py | 44 ++- .../ag-ui/tests/test_message_hygiene.py | 2 +- .../tests/test_orchestrators_coverage.py | 251 ++++++++------- .../packages/ag-ui/tests/test_shared_state.py | 33 +- .../ag-ui/tests/test_state_manager.py | 2 +- .../ag-ui/tests/test_structured_output.py | 109 ++++--- python/packages/ag-ui/tests/test_tooling.py | 2 +- python/packages/ag-ui/tests/test_utils.py | 16 +- .../packages/core/agent_framework/_agents.py | 12 +- 22 files changed, 810 insertions(+), 585 deletions(-) rename python/packages/ag-ui/agent_framework_ag_ui/_orchestration/{message_hygiene.py => _message_hygiene.py} (88%) rename python/packages/ag-ui/agent_framework_ag_ui/_orchestration/{state_manager.py => _state_manager.py} (100%) rename python/packages/ag-ui/agent_framework_ag_ui/_orchestration/{tooling.py => _tooling.py} (82%) create mode 100644 python/packages/ag-ui/tests/__init__.py create mode 100644 python/packages/ag-ui/tests/_test_stubs.py diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_events.py b/python/packages/ag-ui/agent_framework_ag_ui/_events.py index abf335dc83..f5a6c77029 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_events.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_events.py @@ -5,6 +5,7 @@ import json import logging import re +from copy import deepcopy from typing import Any from ag_ui.core import ( @@ -100,9 +101,9 @@ async def from_agent_run_update(self, update: AgentRunResponseUpdate) -> list[Ba """ events: list[BaseEvent] = [] - logger.info("Processing AgentRunUpdate with %s content items", len(update.contents)) + logger.info(f"Processing AgentRunUpdate with {len(update.contents)} content items") for idx, content in enumerate(update.contents): - logger.info(" Content %s: type=%s", idx, type(content).__name__) + logger.info(f" Content {idx}: type={type(content).__name__}") if isinstance(content, TextContent): events.extend(self._handle_text_content(content)) elif isinstance(content, FunctionCallContent): @@ -116,7 +117,7 @@ async def from_agent_run_update(self, update: AgentRunResponseUpdate) -> list[Ba def _handle_text_content(self, content: TextContent) -> list[BaseEvent]: events: list[BaseEvent] = [] - logger.info(" TextContent found: text_length=%s, text_preview='%s'", len(content.text), content.text[:100]) + logger.info(f" TextContent found: length={len(content.text)}") logger.info( " Flags: skip_text_content=%s, should_stop_after_confirm=%s", self.skip_text_content, @@ -130,7 +131,7 @@ def _handle_text_content(self, content: TextContent) -> list[BaseEvent]: if self.should_stop_after_confirm: logger.info(" SKIPPING TextContent: waiting for confirm_changes response") self.suppressed_summary += content.text - logger.info(" Suppressed summary now has %s chars", len(self.suppressed_summary)) + logger.info(f" Suppressed summary length={len(self.suppressed_summary)}") return events if not self.current_message_id: @@ -139,7 +140,7 @@ def _handle_text_content(self, content: TextContent) -> list[BaseEvent]: message_id=self.current_message_id, role="assistant", ) - logger.info(" EMITTING TextMessageStartEvent with message_id=%s", self.current_message_id) + logger.info(f" EMITTING TextMessageStartEvent with message_id={self.current_message_id}") events.append(start_event) event = TextMessageContentEvent( @@ -147,20 +148,23 @@ def _handle_text_content(self, content: TextContent) -> list[BaseEvent]: delta=content.text, ) self.accumulated_text_content += content.text - logger.info(" EMITTING TextMessageContentEvent with delta: '%s'", content.text) + logger.info(f" EMITTING TextMessageContentEvent with text_len={len(content.text)}") events.append(event) return events def _handle_function_call_content(self, content: FunctionCallContent) -> list[BaseEvent]: events: list[BaseEvent] = [] if content.name: - logger.debug("Tool call: %s (call_id: %s)", content.name, content.call_id) + logger.debug(f"Tool call: {content.name} (call_id: {content.call_id})") if not content.name and not content.call_id and not self.current_tool_call_name: - args_preview = str(content.arguments)[:50] if content.arguments else "None" - logger.warning("FunctionCallContent missing name and call_id. Args: %s", args_preview) + args_length = len(str(content.arguments)) if content.arguments else 0 + logger.warning(f"FunctionCallContent missing name and call_id. args_length={args_length}") tool_call_id = self._coalesce_tool_call_id(content) + if content.name and tool_call_id != self.current_tool_call_id: + self.streaming_tool_args = "" + self.state_delta_count = 0 if content.name: self.current_tool_call_id = tool_call_id self.current_tool_call_name = content.name @@ -170,7 +174,7 @@ def _handle_function_call_content(self, content: FunctionCallContent) -> list[Ba tool_call_name=content.name, parent_message_id=self.current_message_id, ) - logger.info("Emitting ToolCallStartEvent with name='%s', id='%s'", content.name, tool_call_id) + logger.info(f"Emitting ToolCallStartEvent with name='{content.name}', id='{tool_call_id}'") events.append(tool_start_event) self.pending_tool_calls.append( @@ -188,7 +192,7 @@ def _handle_function_call_content(self, content: FunctionCallContent) -> list[Ba if content.arguments: delta_str = content.arguments if isinstance(content.arguments, str) else json.dumps(content.arguments) - logger.info("Emitting ToolCallArgsEvent with delta: %r..., id='%s'", delta_str, tool_call_id) + logger.info(f"Emitting ToolCallArgsEvent with delta_length={len(delta_str)}, id='{tool_call_id}'") args_event = ToolCallArgsEvent( tool_call_id=tool_call_id, delta=delta_str, @@ -251,20 +255,15 @@ def _emit_predictive_state_deltas(self, argument_chunk: str) -> list[BaseEvent]: self.state_delta_count += 1 if self.state_delta_count % 10 == 1: - value_preview = ( - str(partial_value)[:100] + "..." - if len(str(partial_value)) > 100 - else str(partial_value) - ) logger.info( - "StateDeltaEvent #%s for '%s': op=replace, path=/%s, value=%s", + "StateDeltaEvent #%s for '%s': op=replace, path=/%s, value_length=%s", self.state_delta_count, state_key, state_key, - value_preview, + len(str(partial_value)), ) elif self.state_delta_count % 100 == 0: - logger.info("StateDeltaEvent #%s emitted", self.state_delta_count) + logger.info(f"StateDeltaEvent #{self.state_delta_count} emitted") events.append(state_delta_event) self.last_emitted_state[state_key] = partial_value @@ -296,18 +295,15 @@ def _emit_predictive_state_deltas(self, argument_chunk: str) -> list[BaseEvent]: self.state_delta_count += 1 if self.state_delta_count % 10 == 1: - value_preview = ( - str(state_value)[:100] + "..." if len(str(state_value)) > 100 else str(state_value) - ) logger.info( - "StateDeltaEvent #%s for '%s': op=replace, path=/%s, value=%s", + "StateDeltaEvent #%s for '%s': op=replace, path=/%s, value_length=%s", self.state_delta_count, state_key, state_key, - value_preview, + len(str(state_value)), ) elif self.state_delta_count % 100 == 0: - logger.info("StateDeltaEvent #%s emitted", self.state_delta_count) + logger.info(f"StateDeltaEvent #{self.state_delta_count} emitted") events.append(state_delta_event) self.last_emitted_state[state_key] = state_value @@ -322,28 +318,34 @@ def _legacy_predictive_state(self, content: FunctionCallContent) -> list[BaseEve if not parsed_args: return events - logger.info("Checking predict_state_config: %s", self.predict_state_config) + logger.info( + "Checking predict_state_config keys: %s", + list(self.predict_state_config.keys()) if self.predict_state_config else "None", + ) for state_key, config in self.predict_state_config.items(): - logger.info("Checking state_key='%s', config=%s", state_key, config) + logger.info(f"Checking state_key='{state_key}'") if config["tool"] != content.name: continue tool_arg_name = config["tool_argument"] - logger.info( - "MATCHED tool '%s' for state key '%s', arg='%s'", - content.name, - state_key, - tool_arg_name, - ) + logger.info(f"MATCHED tool '{content.name}' for state key '{state_key}', arg='{tool_arg_name}'") state_value: Any if tool_arg_name == "*": state_value = parsed_args - logger.info("Using all args as state value, keys: %s", list(state_value.keys())) + logger.info(f"Using all args as state value, keys: {list(state_value.keys())}") elif tool_arg_name in parsed_args: state_value = parsed_args[tool_arg_name] - logger.info("Using specific arg '%s' as state value", tool_arg_name) + logger.info(f"Using specific arg '{tool_arg_name}' as state value") else: - logger.warning("Tool argument '%s' not found in parsed args", tool_arg_name) + logger.warning(f"Tool argument '{tool_arg_name}' not found in parsed args") + continue + + previous_value = self.last_emitted_state.get(state_key, object()) + if previous_value == state_value: + logger.info( + "Skipping duplicate StateDeltaEvent for key '%s' - value unchanged", + state_key, + ) continue state_delta_event = StateDeltaEvent( @@ -355,9 +357,10 @@ def _legacy_predictive_state(self, content: FunctionCallContent) -> list[BaseEve } ], ) - logger.info("Emitting StateDeltaEvent for key '%s', value type: %s", state_key, type(state_value)) # type: ignore + logger.info(f"Emitting StateDeltaEvent for key '{state_key}', value type: {type(state_value)}") # type: ignore events.append(state_delta_event) self.pending_state_updates[state_key] = state_value + self.last_emitted_state[state_key] = state_value return events def _handle_function_result_content(self, content: FunctionResultContent) -> list[BaseEvent]: @@ -366,7 +369,7 @@ def _handle_function_result_content(self, content: FunctionResultContent) -> lis end_event = ToolCallEndEvent( tool_call_id=content.call_id, ) - logger.info("Emitting ToolCallEndEvent for completed tool call '%s'", content.call_id) + logger.info(f"Emitting ToolCallEndEvent for completed tool call '{content.call_id}'") events.append(end_event) self.tool_calls_ended.add(content.call_id) @@ -440,7 +443,7 @@ def _emit_snapshot_for_tool_result(self) -> list[BaseEvent]: type=EventType.MESSAGES_SNAPSHOT, messages=all_messages, # type: ignore[arg-type] ) - logger.info("Emitting MessagesSnapshotEvent with %s messages", len(all_messages)) + logger.info(f"Emitting MessagesSnapshotEvent with {len(all_messages)} messages") events.append(messages_snapshot_event) return events @@ -450,7 +453,7 @@ def _emit_state_snapshot_and_confirmation(self) -> list[BaseEvent]: for key, value in self.pending_state_updates.items(): self.current_state[key] = value - logger.info("Emitting StateSnapshotEvent with keys: %s", list(self.current_state.keys())) + logger.info(f"Emitting StateSnapshotEvent with keys: {list(self.current_state.keys())}") if "recipe" in self.current_state: recipe = self.current_state["recipe"] logger.info( @@ -488,7 +491,7 @@ def _emit_state_snapshot_and_confirmation(self) -> list[BaseEvent]: logger.info("Skipping confirm_changes - require_confirmation is False") self.pending_state_updates.clear() - self.last_emitted_state.clear() + self.last_emitted_state = deepcopy(self.current_state) self.current_tool_call_name = None return events @@ -540,7 +543,7 @@ def _emit_confirm_changes_tool_call(self) -> list[BaseEvent]: type=EventType.MESSAGES_SNAPSHOT, messages=all_messages, # type: ignore[arg-type] ) - logger.info("Emitting MessagesSnapshotEvent for confirm_changes with %s messages", len(all_messages)) + logger.info(f"Emitting MessagesSnapshotEvent for confirm_changes with {len(all_messages)} messages") events.append(messages_snapshot_event) self.should_stop_after_confirm = True @@ -550,14 +553,18 @@ def _emit_confirm_changes_tool_call(self) -> list[BaseEvent]: def _handle_function_approval_request_content(self, content: FunctionApprovalRequestContent) -> list[BaseEvent]: events: list[BaseEvent] = [] logger.info("=== FUNCTION APPROVAL REQUEST ===") - logger.info(" Function: %s", content.function_call.name) - logger.info(" Call ID: %s", content.function_call.call_id) + logger.info(f" Function: {content.function_call.name}") + logger.info(f" Call ID: {content.function_call.call_id}") parsed_args = content.function_call.parse_arguments() - logger.info(" Parsed args keys: %s", list(parsed_args.keys()) if parsed_args else "None") + parsed_arg_keys = list(parsed_args.keys()) if parsed_args else "None" + logger.info(f" Parsed args keys: {parsed_arg_keys}") if parsed_args and self.predict_state_config: - logger.info(" Checking predict_state_config: %s", self.predict_state_config) + logger.info( + " Checking predict_state_config keys: %s", + list(self.predict_state_config.keys()) if self.predict_state_config else "None", + ) for state_key, config in self.predict_state_config.items(): if config["tool"] != content.function_call.name: continue @@ -575,7 +582,7 @@ def _handle_function_approval_request_content(self, content: FunctionApprovalReq elif tool_arg_name in parsed_args: state_value = parsed_args[tool_arg_name] else: - logger.warning(" Tool argument '%s' not found in parsed args", tool_arg_name) + logger.warning(f" Tool argument '{tool_arg_name}' not found in parsed args") continue self.current_state[state_key] = state_value @@ -589,7 +596,7 @@ def _handle_function_approval_request_content(self, content: FunctionApprovalReq end_event = ToolCallEndEvent( tool_call_id=content.function_call.call_id, ) - logger.info("Emitting ToolCallEndEvent for approval-required tool '%s'", content.function_call.call_id) + logger.info(f"Emitting ToolCallEndEvent for approval-required tool '{content.function_call.call_id}'") events.append(end_event) self.tool_calls_ended.add(content.function_call.call_id) @@ -604,7 +611,7 @@ def _handle_function_approval_request_content(self, content: FunctionApprovalReq }, }, ) - logger.info("Emitting function_approval_request custom event for '%s'", content.function_call.name) + logger.info(f"Emitting function_approval_request custom event for '{content.function_call.name}'") events.append(approval_event) return events diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/__init__.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/__init__.py index acec1bdf9b..2a50eae894 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/__init__.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/__init__.py @@ -1,16 +1 @@ # Copyright (c) Microsoft. All rights reserved. - -"""Orchestration helpers broken into focused modules.""" - -from .message_hygiene import deduplicate_messages, sanitize_tool_history -from .state_manager import StateManager -from .tooling import collect_server_tools, merge_tools, register_additional_client_tools - -__all__ = [ - "StateManager", - "sanitize_tool_history", - "deduplicate_messages", - "collect_server_tools", - "register_additional_client_tools", - "merge_tools", -] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/message_hygiene.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py similarity index 88% rename from python/packages/ag-ui/agent_framework_ag_ui/_orchestration/message_hygiene.py rename to python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py index d70d24f2f8..97c990781b 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/message_hygiene.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py @@ -51,8 +51,7 @@ def sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: parsed = json.loads(user_text) if "accepted" in parsed: logger.info( - "Injecting synthetic tool result for confirm_changes call_id=%s", - pending_confirm_changes_id, + f"Injecting synthetic tool result for confirm_changes call_id={pending_confirm_changes_id}" ) synthetic_result = ChatMessage( role="tool", @@ -69,15 +68,14 @@ def sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: pending_confirm_changes_id = None continue except (json.JSONDecodeError, KeyError) as exc: - logger.debug("Could not parse user message as confirm_changes response: %s", exc) + logger.debug("Could not parse user message as confirm_changes response: %s", type(exc).__name__) if pending_tool_call_ids: logger.info( - "User message arrived with %s pending tool calls - injecting synthetic results", - len(pending_tool_call_ids), + f"User message arrived with {len(pending_tool_call_ids)} pending tool calls - injecting synthetic results" ) for pending_call_id in pending_tool_call_ids: - logger.info("Injecting synthetic tool result for pending call_id=%s", pending_call_id) + logger.info(f"Injecting synthetic tool result for pending call_id={pending_call_id}") synthetic_result = ChatMessage( role="tool", contents=[ @@ -140,10 +138,10 @@ def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: new_result = msg.contents[0].result if (not existing_result or existing_result == "") and new_result: - logger.info("Replacing empty tool result at index %s with data from index %s", existing_idx, idx) + logger.info(f"Replacing empty tool result at index {existing_idx} with data from index {idx}") unique_messages[existing_idx] = msg else: - logger.info("Skipping duplicate tool result at index %s: call_id=%s", idx, call_id) + logger.info(f"Skipping duplicate tool result at index {idx}: call_id={call_id}") continue seen_keys[key] = len(unique_messages) @@ -158,7 +156,7 @@ def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: key = (role_value, tool_call_ids) if key in seen_keys: - logger.info("Skipping duplicate assistant tool call at index %s", idx) + logger.info(f"Skipping duplicate assistant tool call at index {idx}") continue seen_keys[key] = len(unique_messages) @@ -169,7 +167,7 @@ def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: key = (role_value, hash(content_str)) if key in seen_keys: - logger.info("Skipping duplicate message at index %s: role=%s", idx, role_value) + logger.info(f"Skipping duplicate message at index {idx}: role={role_value}") continue seen_keys[key] = len(unique_messages) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/state_manager.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py similarity index 100% rename from python/packages/ag-ui/agent_framework_ag_ui/_orchestration/state_manager.py rename to python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/tooling.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py similarity index 82% rename from python/packages/ag-ui/agent_framework_ag_ui/_orchestration/tooling.py rename to python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py index 3c59ee0440..977c276627 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/tooling.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py @@ -15,11 +15,11 @@ def collect_server_tools(agent: Any) -> list[Any]: if isinstance(agent, ChatAgent): tools_from_agent = agent.chat_options.tools server_tools = list(tools_from_agent) if tools_from_agent else [] - logger.info("[TOOLS] Agent has %s configured tools", len(server_tools)) + logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools") for tool in server_tools: tool_name = getattr(tool, "name", "unknown") approval_mode = getattr(tool, "approval_mode", None) - logger.info("[TOOLS] - %s: approval_mode=%s", tool_name, approval_mode) + logger.info(f"[TOOLS] - {tool_name}: approval_mode={approval_mode}") return server_tools try: @@ -40,7 +40,7 @@ def register_additional_client_tools(agent: Any, client_tools: list[Any] | None) chat_client = agent.chat_client if isinstance(chat_client, BaseChatClient) and chat_client.function_invocation_configuration is not None: chat_client.function_invocation_configuration.additional_tools = client_tools - logger.debug("[TOOLS] Registered %s client tools as additional_tools (declaration-only)", len(client_tools)) + logger.debug(f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)") return try: @@ -50,7 +50,7 @@ def register_additional_client_tools(agent: Any, client_tools: list[Any] | None) if fic is not None: fic.additional_tools = client_tools # type: ignore[attr-defined] logger.debug( - "[TOOLS] Registered %s client tools as additional_tools (declaration-only)", len(client_tools) + f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)" ) except AttributeError: return @@ -74,9 +74,7 @@ def merge_tools(server_tools: list[Any], client_tools: list[Any] | None) -> list combined_tools.extend(server_tools) combined_tools.extend(unique_client_tools) logger.info( - "[TOOLS] Passing tools= parameter with %s tools (%s server + %s unique client)", - len(combined_tools), - len(server_tools), - len(unique_client_tools), + f"[TOOLS] Passing tools= parameter with {len(combined_tools)} tools " + f"({len(server_tools)} server + {len(unique_client_tools)} unique client)" ) return combined_tools diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index 81b3a75fbe..39051c042c 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -271,16 +271,15 @@ async def run( """ from ._events import AgentFrameworkEventBridge from ._message_adapters import agui_messages_to_snapshot_format - from ._orchestration import ( - StateManager, + from ._orchestration._message_hygiene import deduplicate_messages, sanitize_tool_history + from ._orchestration._state_manager import StateManager + from ._orchestration._tooling import ( collect_server_tools, - deduplicate_messages, merge_tools, register_additional_client_tools, - sanitize_tool_history, ) - logger.info("Starting default agent run for thread_id=%s, run_id=%s", context.thread_id, context.run_id) + logger.info(f"Starting default agent run for thread_id={context.thread_id}, run_id={context.run_id}") response_format = None if isinstance(context.agent, ChatAgent): @@ -328,28 +327,32 @@ async def run( yield event_bridge.create_run_finished_event() return - logger.info("Received %s raw messages from client", len(raw_messages)) + logger.info(f"Received {len(raw_messages)} raw messages from client") for i, msg in enumerate(raw_messages): role = msg.role.value if hasattr(msg.role, "value") else str(msg.role) msg_id = getattr(msg, "message_id", None) - logger.info(" Raw message %s: role=%s, id=%s", i, role, msg_id) + logger.info(f" Raw message {i}: role={role}, id={msg_id}") if hasattr(msg, "contents") and msg.contents: for j, content in enumerate(msg.contents): content_type = type(content).__name__ if isinstance(content, TextContent): - logger.debug(" Content %s: %s - %s", j, content_type, content.text) + logger.debug(" Content %s: %s - text_length=%s", j, content_type, len(content.text)) elif isinstance(content, FunctionCallContent): - logger.debug(" Content %s: %s - %s(%s)", j, content_type, content.name, content.arguments) + arg_length = len(str(content.arguments)) if content.arguments else 0 + logger.debug( + " Content %s: %s - %s args_length=%s", j, content_type, content.name, arg_length + ) elif isinstance(content, FunctionResultContent): + result_preview = type(content.result).__name__ if content.result is not None else "None" logger.debug( - " Content %s: %s - call_id=%s, result=%s", + " Content %s: %s - call_id=%s, result_type=%s", j, content_type, content.call_id, - content.result, + result_preview, ) else: - logger.debug(" Content %s: %s - %s", j, content_type, content) + logger.debug(f" Content {j}: {content_type}") sanitized_messages = sanitize_tool_history(raw_messages) provider_messages = deduplicate_messages(sanitized_messages) @@ -359,27 +362,29 @@ async def run( yield event_bridge.create_run_finished_event() return - logger.info("Processing %s provider messages after sanitization/deduplication", len(provider_messages)) + logger.info(f"Processing {len(provider_messages)} provider messages after sanitization/deduplication") for i, msg in enumerate(provider_messages): role = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - logger.info(" Message %s: role=%s", i, role) + logger.info(f" Message {i}: role={role}") if hasattr(msg, "contents") and msg.contents: for j, content in enumerate(msg.contents): content_type = type(content).__name__ if isinstance(content, TextContent): - logger.info(" Content %s: %s - %s", j, content_type, content.text) + logger.info(f" Content {j}: {content_type} - text_length={len(content.text)}") elif isinstance(content, FunctionCallContent): - logger.info(" Content %s: %s - %s(%s)", j, content_type, content.name, content.arguments) + arg_length = len(str(content.arguments)) if content.arguments else 0 + logger.info(" Content %s: %s - %s args_length=%s", j, content_type, content.name, arg_length) elif isinstance(content, FunctionResultContent): + result_preview = type(content.result).__name__ if content.result is not None else "None" logger.info( - " Content %s: %s - call_id=%s, result=%s", + " Content %s: %s - call_id=%s, result_type=%s", j, content_type, content.call_id, - content.result, + result_preview, ) else: - logger.info(" Content %s: %s - %s", j, content_type, content) + logger.info(f" Content {j}: {content_type}") messages_to_run: list[Any] = [] is_new_user_turn = False @@ -405,12 +410,12 @@ async def run( messages_to_run.extend(provider_messages) client_tools = convert_agui_tools_to_agent_framework(context.input_data.get("tools")) - logger.info("[TOOLS] Client sent %s tools", len(client_tools) if client_tools else 0) + logger.info(f"[TOOLS] Client sent {len(client_tools) if client_tools else 0} tools") if client_tools: for tool in client_tools: tool_name = getattr(tool, "name", "unknown") declaration_only = getattr(tool, "declaration_only", None) - logger.info("[TOOLS] - Client tool: %s, declaration_only=%s", tool_name, declaration_only) + logger.info(f"[TOOLS] - Client tool: {tool_name}, declaration_only={declaration_only}") server_tools = collect_server_tools(context.agent) register_additional_client_tools(context.agent, client_tools) @@ -420,15 +425,15 @@ async def run( update_count = 0 async for update in context.agent.run_stream(messages_to_run, thread=thread, tools=tools_param): update_count += 1 - logger.info("[STREAM] Received update #%s from agent", update_count) + logger.info(f"[STREAM] Received update #{update_count} from agent") all_updates.append(update) events = await event_bridge.from_agent_run_update(update) - logger.info("[STREAM] Update #%s produced %s events", update_count, len(events)) + logger.info(f"[STREAM] Update #{update_count} produced {len(events)} events") for event in events: - logger.info("[STREAM] Yielding event: %s", type(event).__name__) + logger.info(f"[STREAM] Yielding event: {type(event).__name__}") yield event - logger.info("[STREAM] Agent stream completed. Total updates: %s", update_count) + logger.info(f"[STREAM] Agent stream completed. Total updates: {update_count}") if event_bridge.should_stop_after_confirm: logger.info("Stopping run after confirm_changes - waiting for user response") @@ -450,45 +455,39 @@ async def run( from ag_ui.core import ToolCallEndEvent end_event = ToolCallEndEvent(tool_call_id=tool_call_id) - logger.info( - "Emitting ToolCallEndEvent for declaration-only tool call '%s'", - tool_call_id, - ) + logger.info(f"Emitting ToolCallEndEvent for declaration-only tool call '{tool_call_id}'") yield end_event if all_updates and response_format: from agent_framework import AgentRunResponse from pydantic import BaseModel - logger.info("Processing structured output, update count: %s", len(all_updates)) + logger.info(f"Processing structured output, update count: {len(all_updates)}") final_response = AgentRunResponse.from_agent_run_response_updates( all_updates, output_format_type=response_format ) if final_response.value and isinstance(final_response.value, BaseModel): response_dict = final_response.value.model_dump(mode="json", exclude_none=True) - logger.info("Received structured output: %s", list(response_dict.keys())) + logger.info(f"Received structured output keys: {list(response_dict.keys())}") state_updates = state_manager.extract_state_updates(response_dict) if state_updates: state_manager.apply_state_updates(state_updates) state_snapshot = event_bridge.create_state_snapshot_event(current_state) yield state_snapshot - logger.info("Emitted StateSnapshotEvent with updates: %s", list(state_updates.keys())) + logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}") if "message" in response_dict and response_dict["message"]: message_id = generate_event_id() yield TextMessageStartEvent(message_id=message_id, role="assistant") yield TextMessageContentEvent(message_id=message_id, delta=response_dict["message"]) yield TextMessageEndEvent(message_id=message_id) - logger.info("Emitted conversational message: %s...", response_dict["message"][:100]) + logger.info(f"Emitted conversational message with length={len(response_dict['message'])}") - logger.info("[FINALIZE] Checking for unclosed message. current_message_id=%s", event_bridge.current_message_id) + logger.info(f"[FINALIZE] Checking for unclosed message. current_message_id={event_bridge.current_message_id}") if event_bridge.current_message_id: - logger.info( - "[FINALIZE] Emitting TextMessageEndEvent for message_id=%s", - event_bridge.current_message_id, - ) + logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") yield event_bridge.create_message_end_event(event_bridge.current_message_id) assistant_text_message = { @@ -525,7 +524,7 @@ async def run( logger.info("[FINALIZE] Emitting RUN_FINISHED event") yield event_bridge.create_run_finished_event() - logger.info("Completed agent run for thread_id=%s, run_id=%s", context.thread_id, context.run_id) + logger.info(f"Completed agent run for thread_id={context.thread_id}, run_id={context.run_id}") __all__ = [ diff --git a/python/packages/ag-ui/tests/__init__.py b/python/packages/ag-ui/tests/__init__.py new file mode 100644 index 0000000000..2a50eae894 --- /dev/null +++ b/python/packages/ag-ui/tests/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Microsoft. All rights reserved. diff --git a/python/packages/ag-ui/tests/_test_stubs.py b/python/packages/ag-ui/tests/_test_stubs.py new file mode 100644 index 0000000000..bfb528511e --- /dev/null +++ b/python/packages/ag-ui/tests/_test_stubs.py @@ -0,0 +1,138 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Shared test stubs for AG-UI tests.""" + +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, MutableSequence +from types import SimpleNamespace +from typing import Any + +from agent_framework import ( + AgentProtocol, + AgentRunResponse, + AgentRunResponseUpdate, + AgentThread, + ChatMessage, + ChatOptions, + TextContent, +) +from agent_framework._clients import BaseChatClient +from agent_framework._types import ChatResponse, ChatResponseUpdate + +from agent_framework_ag_ui._orchestrators import ExecutionContext + +StreamFn = Callable[..., AsyncIterator[ChatResponseUpdate]] +ResponseFn = Callable[..., Awaitable[ChatResponse]] + + +class StreamingChatClientStub(BaseChatClient): + """Typed streaming stub that satisfies ChatClientProtocol.""" + + def __init__(self, stream_fn: StreamFn, response_fn: ResponseFn | None = None) -> None: + super().__init__() + self._stream_fn = stream_fn + self._response_fn = response_fn + + async def _inner_get_streaming_response( + self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + async for update in self._stream_fn(messages, chat_options, **kwargs): + yield update + + async def _inner_get_response( + self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> ChatResponse: + if self._response_fn is not None: + return await self._response_fn(messages, chat_options, **kwargs) + + contents: list[Any] = [] + async for update in self._stream_fn(messages, chat_options, **kwargs): + contents.extend(update.contents) + + return ChatResponse( + messages=[ChatMessage(role="assistant", contents=contents)], + response_id="stub-response", + ) + + +def stream_from_updates(updates: list[ChatResponseUpdate]) -> StreamFn: + """Create a stream function that yields from a static list of updates.""" + + async def _stream( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + for update in updates: + yield update + + return _stream + + +class StubAgent(AgentProtocol): + """Minimal AgentProtocol stub for orchestrator tests.""" + + def __init__( + self, + updates: list[AgentRunResponseUpdate] | None = None, + *, + agent_id: str = "stub-agent", + agent_name: str | None = "stub-agent", + chat_options: Any | None = None, + chat_client: Any | None = None, + ) -> None: + self._id = agent_id + self._name = agent_name + self._description = "stub agent" + self.updates = updates or [AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant")] + self.chat_options = chat_options or SimpleNamespace(tools=None, response_format=None) + self.chat_client = chat_client or SimpleNamespace(function_invocation_configuration=None) + self.messages_received: list[Any] = [] + self.tools_received: list[Any] | None = None + + @property + def id(self) -> str: + return self._id + + @property + def name(self) -> str | None: + return self._name + + @property + def display_name(self) -> str: + return self._name or self._id + + @property + def description(self) -> str | None: + return self._description + + async def run( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AgentRunResponse: + return AgentRunResponse(messages=[], response_id="stub-response") + + def run_stream( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AsyncIterable[AgentRunResponseUpdate]: + async def _stream() -> AsyncIterator[AgentRunResponseUpdate]: + self.messages_received = [] if messages is None else list(messages) # type: ignore[arg-type] + self.tools_received = kwargs.get("tools") + for update in self.updates: + yield update + + return _stream() + + def get_new_thread(self, **kwargs: Any) -> AgentThread: + return AgentThread() + + +class TestExecutionContext(ExecutionContext): + """ExecutionContext helper that allows setting messages for tests.""" + + def set_messages(self, messages: list[ChatMessage]) -> None: + self._messages = messages diff --git a/python/packages/ag-ui/tests/test_ag_ui_client.py b/python/packages/ag-ui/tests/test_ag_ui_client.py index 742b656369..09570c1be4 100644 --- a/python/packages/ag-ui/tests/test_ag_ui_client.py +++ b/python/packages/ag-ui/tests/test_ag_ui_client.py @@ -3,10 +3,59 @@ """Tests for AGUIChatClient.""" import json - -from agent_framework import ChatMessage, ChatOptions, FunctionCallContent, Role, ai_function +from collections.abc import AsyncGenerator, AsyncIterable, MutableSequence +from typing import Any + +from agent_framework import ( + ChatMessage, + ChatOptions, + ChatResponseUpdate, + FunctionCallContent, + Role, + TextContent, + ai_function, +) +from agent_framework._types import ChatResponse +from pytest import MonkeyPatch from agent_framework_ag_ui._client import AGUIChatClient, ServerFunctionCallContent +from agent_framework_ag_ui._http_service import AGUIHttpService + + +class TestableAGUIChatClient(AGUIChatClient): + """Testable wrapper exposing protected helpers.""" + + @property + def http_service(self) -> AGUIHttpService: + """Expose http service for monkeypatching.""" + return self._http_service + + def extract_state_from_messages( + self, messages: list[ChatMessage] + ) -> tuple[list[ChatMessage], dict[str, Any] | None]: + """Expose state extraction helper.""" + return self._extract_state_from_messages(messages) + + def convert_messages_to_agui_format(self, messages: list[ChatMessage]) -> list[dict[str, Any]]: + """Expose message conversion helper.""" + return self._convert_messages_to_agui_format(messages) + + def get_thread_id(self, chat_options: ChatOptions) -> str: + """Expose thread id helper.""" + return self._get_thread_id(chat_options) + + async def inner_get_streaming_response( + self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions + ) -> AsyncIterable[ChatResponseUpdate]: + """Proxy to protected streaming call.""" + async for update in self._inner_get_streaming_response(messages=messages, chat_options=chat_options): + yield update + + async def inner_get_response( + self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions + ) -> ChatResponse: + """Proxy to protected response call.""" + return await self._inner_get_response(messages=messages, chat_options=chat_options) class TestAGUIChatClient: @@ -14,25 +63,25 @@ class TestAGUIChatClient: async def test_client_initialization(self) -> None: """Test client initialization.""" - client = AGUIChatClient(endpoint="http://localhost:8888/") + client = TestableAGUIChatClient(endpoint="http://localhost:8888/") - assert client._http_service is not None - assert client._http_service.endpoint.startswith("http://localhost:8888") + assert client.http_service is not None + assert client.http_service.endpoint.startswith("http://localhost:8888") async def test_client_context_manager(self) -> None: """Test client as async context manager.""" - async with AGUIChatClient(endpoint="http://localhost:8888/") as client: + async with TestableAGUIChatClient(endpoint="http://localhost:8888/") as client: assert client is not None async def test_extract_state_from_messages_no_state(self) -> None: """Test state extraction when no state is present.""" - client = AGUIChatClient(endpoint="http://localhost:8888/") + client = TestableAGUIChatClient(endpoint="http://localhost:8888/") messages = [ ChatMessage(role="user", text="Hello"), ChatMessage(role="assistant", text="Hi there"), ] - result_messages, state = client._extract_state_from_messages(messages) + result_messages, state = client.extract_state_from_messages(messages) assert result_messages == messages assert state is None @@ -41,7 +90,7 @@ async def test_extract_state_from_messages_with_state(self) -> None: """Test state extraction from last message.""" import base64 - client = AGUIChatClient(endpoint="http://localhost:8888/") + client = TestableAGUIChatClient(endpoint="http://localhost:8888/") state_data = {"key": "value", "count": 42} state_json = json.dumps(state_data) @@ -57,7 +106,7 @@ async def test_extract_state_from_messages_with_state(self) -> None: ), ] - result_messages, state = client._extract_state_from_messages(messages) + result_messages, state = client.extract_state_from_messages(messages) assert len(result_messages) == 1 assert result_messages[0].text == "Hello" @@ -67,7 +116,7 @@ async def test_extract_state_invalid_json(self) -> None: """Test state extraction with invalid JSON.""" import base64 - client = AGUIChatClient(endpoint="http://localhost:8888/") + client = TestableAGUIChatClient(endpoint="http://localhost:8888/") invalid_json = "not valid json" state_b64 = base64.b64encode(invalid_json.encode("utf-8")).decode("utf-8") @@ -81,20 +130,20 @@ async def test_extract_state_invalid_json(self) -> None: ), ] - result_messages, state = client._extract_state_from_messages(messages) + result_messages, state = client.extract_state_from_messages(messages) assert result_messages == messages assert state is None async def test_convert_messages_to_agui_format(self) -> None: """Test message conversion to AG-UI format.""" - client = AGUIChatClient(endpoint="http://localhost:8888/") + client = TestableAGUIChatClient(endpoint="http://localhost:8888/") messages = [ ChatMessage(role=Role.USER, text="What is the weather?"), ChatMessage(role=Role.ASSISTANT, text="Let me check.", message_id="msg_123"), ] - agui_messages = client._convert_messages_to_agui_format(messages) + agui_messages = client.convert_messages_to_agui_format(messages) assert len(agui_messages) == 2 assert agui_messages[0]["role"] == "user" @@ -105,24 +154,24 @@ async def test_convert_messages_to_agui_format(self) -> None: async def test_get_thread_id_from_metadata(self) -> None: """Test thread ID extraction from metadata.""" - client = AGUIChatClient(endpoint="http://localhost:8888/") + client = TestableAGUIChatClient(endpoint="http://localhost:8888/") chat_options = ChatOptions(metadata={"thread_id": "existing_thread_123"}) - thread_id = client._get_thread_id(chat_options) + thread_id = client.get_thread_id(chat_options) assert thread_id == "existing_thread_123" async def test_get_thread_id_generation(self) -> None: """Test automatic thread ID generation.""" - client = AGUIChatClient(endpoint="http://localhost:8888/") + client = TestableAGUIChatClient(endpoint="http://localhost:8888/") chat_options = ChatOptions() - thread_id = client._get_thread_id(chat_options) + thread_id = client.get_thread_id(chat_options) assert thread_id.startswith("thread_") assert len(thread_id) > 7 - async def test_get_streaming_response(self, monkeypatch) -> None: + async def test_get_streaming_response(self, monkeypatch: MonkeyPatch) -> None: """Test streaming response method.""" mock_events = [ {"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"}, @@ -131,26 +180,32 @@ async def test_get_streaming_response(self, monkeypatch) -> None: {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"}, ] - async def mock_post_run(*args, **kwargs): + async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]: for event in mock_events: yield event - client = AGUIChatClient(endpoint="http://localhost:8888/") - monkeypatch.setattr(client._http_service, "post_run", mock_post_run) + client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [ChatMessage(role="user", text="Test message")] chat_options = ChatOptions() - updates = [] - async for update in client._inner_get_streaming_response(messages=messages, chat_options=chat_options): + updates: list[ChatResponseUpdate] = [] + async for update in client.inner_get_streaming_response(messages=messages, chat_options=chat_options): updates.append(update) assert len(updates) == 4 + assert updates[0].additional_properties is not None assert updates[0].additional_properties["thread_id"] == "thread_1" - assert updates[1].contents[0].text == "Hello" - assert updates[2].contents[0].text == " world" - async def test_get_response_non_streaming(self, monkeypatch) -> None: + first_content = updates[1].contents[0] + second_content = updates[2].contents[0] + assert isinstance(first_content, TextContent) + assert isinstance(second_content, TextContent) + assert first_content.text == "Hello" + assert second_content.text == " world" + + async def test_get_response_non_streaming(self, monkeypatch: MonkeyPatch) -> None: """Test non-streaming response method.""" mock_events = [ {"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"}, @@ -158,23 +213,23 @@ async def test_get_response_non_streaming(self, monkeypatch) -> None: {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"}, ] - async def mock_post_run(*args, **kwargs): + async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]: for event in mock_events: yield event - client = AGUIChatClient(endpoint="http://localhost:8888/") - monkeypatch.setattr(client._http_service, "post_run", mock_post_run) + client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [ChatMessage(role="user", text="Test message")] chat_options = ChatOptions() - response = await client._inner_get_response(messages=messages, chat_options=chat_options) + response = await client.inner_get_response(messages=messages, chat_options=chat_options) assert response is not None assert len(response.messages) > 0 assert "Complete response" in response.text - async def test_tool_handling(self, monkeypatch) -> None: + async def test_tool_handling(self, monkeypatch: MonkeyPatch) -> None: """Test that client tool metadata is sent to server. Client tool metadata (name, description, schema) is sent to server for planning. @@ -193,28 +248,29 @@ def test_tool(param: str) -> str: {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"}, ] - async def mock_post_run(*args, **kwargs): + async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]: # Client tool metadata should be sent to server - tools = kwargs.get("tools") + tools: list[dict[str, Any]] | None = kwargs.get("tools") assert tools is not None assert len(tools) == 1 - assert tools[0]["name"] == "test_tool" - assert tools[0]["description"] == "Test tool." - assert "parameters" in tools[0] + tool_entry = tools[0] + assert tool_entry["name"] == "test_tool" + assert tool_entry["description"] == "Test tool." + assert "parameters" in tool_entry for event in mock_events: yield event - client = AGUIChatClient(endpoint="http://localhost:8888/") - monkeypatch.setattr(client._http_service, "post_run", mock_post_run) + client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [ChatMessage(role="user", text="Test with tools")] chat_options = ChatOptions(tools=[test_tool]) - response = await client._inner_get_response(messages=messages, chat_options=chat_options) + response = await client.inner_get_response(messages=messages, chat_options=chat_options) assert response is not None - async def test_server_tool_calls_unwrapped_after_invocation(self, monkeypatch) -> None: + async def test_server_tool_calls_unwrapped_after_invocation(self, monkeypatch: MonkeyPatch) -> None: """Ensure server-side tool calls are exposed as FunctionCallContent after processing.""" mock_events = [ @@ -224,17 +280,17 @@ async def test_server_tool_calls_unwrapped_after_invocation(self, monkeypatch) - {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"}, ] - async def mock_post_run(*args, **kwargs): + async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]: for event in mock_events: yield event - client = AGUIChatClient(endpoint="http://localhost:8888/") - monkeypatch.setattr(client._http_service, "post_run", mock_post_run) + client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [ChatMessage(role="user", text="Test server tool execution")] chat_options = ChatOptions() - updates = [] + updates: list[ChatResponseUpdate] = [] async for update in client.get_streaming_response(messages, chat_options=chat_options): updates.append(update) @@ -247,7 +303,7 @@ async def mock_post_run(*args, **kwargs): isinstance(content, ServerFunctionCallContent) for update in updates for content in update.contents ) - async def test_server_tool_calls_not_executed_locally(self, monkeypatch) -> None: + async def test_server_tool_calls_not_executed_locally(self, monkeypatch: MonkeyPatch) -> None: """Server tools should not trigger local function invocation even when client tools exist.""" @ai_function @@ -262,18 +318,18 @@ def client_tool() -> str: {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"}, ] - async def mock_post_run(*args, **kwargs): + async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]: for event in mock_events: yield event - async def fake_auto_invoke(*args, **kwargs): + async def fake_auto_invoke(*args: object, **kwargs: Any) -> None: function_call = kwargs.get("function_call_content") or args[0] raise AssertionError(f"Unexpected local execution of server tool: {getattr(function_call, 'name', '?')}") monkeypatch.setattr("agent_framework._tools._auto_invoke_function", fake_auto_invoke) - client = AGUIChatClient(endpoint="http://localhost:8888/") - monkeypatch.setattr(client._http_service, "post_run", mock_post_run) + client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [ChatMessage(role="user", text="Test server tool execution")] chat_options = ChatOptions(tool_choice="auto", tools=[client_tool]) @@ -281,7 +337,7 @@ async def fake_auto_invoke(*args, **kwargs): async for _ in client.get_streaming_response(messages, chat_options=chat_options): pass - async def test_state_transmission(self, monkeypatch) -> None: + async def test_state_transmission(self, monkeypatch: MonkeyPatch) -> None: """Test state is properly transmitted to server.""" import base64 @@ -304,16 +360,16 @@ async def test_state_transmission(self, monkeypatch) -> None: {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"}, ] - async def mock_post_run(*args, **kwargs): + async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]: assert kwargs.get("state") == state_data for event in mock_events: yield event - client = AGUIChatClient(endpoint="http://localhost:8888/") - monkeypatch.setattr(client._http_service, "post_run", mock_post_run) + client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + monkeypatch.setattr(client.http_service, "post_run", mock_post_run) chat_options = ChatOptions() - response = await client._inner_get_response(messages=messages, chat_options=chat_options) + response = await client.inner_get_response(messages=messages, chat_options=chat_options) assert response is not None diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index 743f3a218a..015bbdfc61 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -3,22 +3,27 @@ """Comprehensive tests for AgentFrameworkAgent (_agent.py).""" import json +from collections.abc import AsyncIterator, MutableSequence +from typing import Any import pytest -from agent_framework import ChatAgent, TextContent +from agent_framework import ChatAgent, ChatMessage, ChatOptions, TextContent from agent_framework._types import ChatResponseUpdate from pydantic import BaseModel +from ._test_stubs import StreamingChatClientStub + async def test_agent_initialization_basic(): """Test basic agent initialization without state schema.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) assert wrapper.name == "test_agent" @@ -31,12 +36,13 @@ async def test_agent_initialization_with_state_schema(): """Test agent initialization with state_schema.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) - state_schema = {"document": {"type": "string"}} + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) + state_schema: dict[str, dict[str, Any]] = {"document": {"type": "string"}} wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema) assert wrapper.config.state_schema == state_schema @@ -46,11 +52,12 @@ async def test_agent_initialization_with_predict_state_config(): """Test agent initialization with predict_state_config.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) predict_config = {"document": {"tool": "write_doc", "tool_argument": "content"}} wrapper = AgentFrameworkAgent(agent=agent, predict_state_config=predict_config) @@ -61,15 +68,16 @@ async def test_agent_initialization_with_pydantic_state_schema(): """Test agent initialization when state_schema is provided as Pydantic model/class.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) class MyState(BaseModel): document: str tags: list[str] = [] - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper_class_schema = AgentFrameworkAgent(agent=agent, state_schema=MyState) wrapper_instance_schema = AgentFrameworkAgent(agent=agent, state_schema=MyState(document="hi")) @@ -83,16 +91,17 @@ async def test_run_started_event_emission(): """Test RunStartedEvent is emitted at start of run.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) input_data = {"messages": [{"role": "user", "content": "Hi"}]} - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -106,11 +115,12 @@ async def test_predict_state_custom_event_emission(): """Test PredictState CustomEvent is emitted when predict_state_config is present.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) predict_config = { "document": {"tool": "write_doc", "tool_argument": "content"}, "summary": {"tool": "summarize", "tool_argument": "text"}, @@ -119,7 +129,7 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): input_data = {"messages": [{"role": "user", "content": "Hi"}]} - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -137,11 +147,12 @@ async def test_initial_state_snapshot_with_schema(): """Test initial StateSnapshotEvent emission when state_schema present.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) state_schema = {"document": {"type": "string"}} wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema) @@ -150,7 +161,7 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): "state": {"document": "Initial content"}, } - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -166,17 +177,18 @@ async def test_state_initialization_object_type(): """Test state initialization with object type in schema.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) - state_schema = {"recipe": {"type": "object", "properties": {}}} + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) + state_schema: dict[str, dict[str, Any]] = {"recipe": {"type": "object", "properties": {}}} wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema) input_data = {"messages": [{"role": "user", "content": "Hi"}]} - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -192,17 +204,18 @@ async def test_state_initialization_array_type(): """Test state initialization with array type in schema.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) - state_schema = {"steps": {"type": "array", "items": {}}} + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) + state_schema: dict[str, dict[str, Any]] = {"steps": {"type": "array", "items": {}}} wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema) input_data = {"messages": [{"role": "user", "content": "Hi"}]} - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -218,16 +231,17 @@ async def test_run_finished_event_emission(): """Test RunFinishedEvent is emitted at end of run.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) input_data = {"messages": [{"role": "user", "content": "Hi"}]} - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -239,11 +253,12 @@ async def test_tool_result_confirm_changes_accepted(): """Test confirm_changes tool result handling when accepted.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="Document updated")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="Document updated")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent( agent=agent, state_schema={"document": {"type": "string"}}, @@ -251,8 +266,8 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): ) # Simulate tool result message with acceptance - tool_result = {"accepted": True, "steps": []} - input_data = { + tool_result: dict[str, Any] = {"accepted": True, "steps": []} + input_data: dict[str, Any] = { "messages": [ { "role": "tool", # Tool result from UI @@ -263,7 +278,7 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): "state": {"document": "Updated content"}, } - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -285,16 +300,17 @@ async def test_tool_result_confirm_changes_rejected(): """Test confirm_changes tool result handling when rejected.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="OK")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="OK")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) # Simulate tool result message with rejection - tool_result = {"accepted": False, "steps": []} - input_data = { + tool_result: dict[str, Any] = {"accepted": False, "steps": []} + input_data: dict[str, Any] = { "messages": [ { "role": "tool", @@ -304,7 +320,7 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): ], } - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -318,22 +334,23 @@ async def test_tool_result_function_approval_accepted(): """Test function approval tool result when steps are accepted.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="OK")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="OK")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) # Simulate tool result with multiple steps - tool_result = { + tool_result: dict[str, Any] = { "accepted": True, "steps": [ {"id": "step1", "description": "Send email", "status": "enabled"}, {"id": "step2", "description": "Create calendar event", "status": "enabled"}, ], } - input_data = { + input_data: dict[str, Any] = { "messages": [ { "role": "tool", @@ -343,7 +360,7 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): ], } - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -363,19 +380,20 @@ async def test_tool_result_function_approval_rejected(): """Test function approval tool result when rejected.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="OK")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="OK")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) # Simulate tool result rejection with steps - tool_result = { + tool_result: dict[str, Any] = { "accepted": False, "steps": [{"id": "step1", "description": "Send email", "status": "disabled"}], } - input_data = { + input_data: dict[str, Any] = { "messages": [ { "role": "tool", @@ -385,7 +403,7 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): ], } - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -399,17 +417,16 @@ async def test_thread_metadata_tracking(): """Test that thread metadata includes ag_ui_thread_id and ag_ui_run_id.""" from agent_framework.ag_ui import AgentFrameworkAgent - thread_metadata = {} + thread_metadata: dict[str, Any] = {} - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - # Capture thread metadata from kwargs - nonlocal thread_metadata - if "thread" in kwargs: - thread_metadata = kwargs["thread"].metadata - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + if "thread" in kwargs: + thread_metadata.update(kwargs["thread"].metadata) + yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) input_data = { @@ -418,28 +435,28 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): "run_id": "test_run_456", } - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) - # Check thread metadata was set - # Note: This test may need adjustment based on actual thread passing mechanism + assert thread_metadata.get("ag_ui_thread_id") == "test_thread_123" + assert thread_metadata.get("ag_ui_run_id") == "test_run_456" async def test_state_context_injection(): """Test that current state is injected into thread metadata.""" from agent_framework.ag_ui import AgentFrameworkAgent - thread_metadata = {} + thread_metadata: dict[str, Any] = {} - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - # Track if state context message was added - nonlocal thread_metadata - # In actual implementation, thread is passed and state is in metadata - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + if "thread" in kwargs: + thread_metadata.update(kwargs["thread"].metadata) + yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent( agent=agent, state_schema={"document": {"type": "string"}}, @@ -450,27 +467,28 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): "state": {"document": "Test content"}, } - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) - # State should be injected - this is validated by agent execution flow + assert thread_metadata.get("current_state") == {"document": "Test content"} async def test_no_messages_provided(): """Test handling when no messages are provided.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) - input_data = {"messages": []} + input_data: dict[str, Any] = {"messages": []} - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -484,16 +502,17 @@ async def test_message_end_event_emission(): """Test TextMessageEndEvent is emitted for assistant messages.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="Hello world")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="Hello world")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) - input_data = {"messages": [{"role": "user", "content": "Hi"}]} + input_data: dict[str, Any] = {"messages": [{"role": "user", "content": "Hi"}]} - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -511,19 +530,20 @@ async def test_error_handling_with_exception(): """Test that exceptions during agent execution are re-raised.""" from agent_framework.ag_ui import AgentFrameworkAgent - class FailingChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - if False: - yield - raise RuntimeError("Simulated failure") + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + if False: + yield ChatResponseUpdate(contents=[]) + raise RuntimeError("Simulated failure") - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=FailingChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) - input_data = {"messages": [{"role": "user", "content": "Hi"}]} + input_data: dict[str, Any] = {"messages": [{"role": "user", "content": "Hi"}]} with pytest.raises(RuntimeError, match="Simulated failure"): - async for event in wrapper.run_agent(input_data): + async for _ in wrapper.run_agent(input_data): pass @@ -531,18 +551,18 @@ async def test_json_decode_error_in_tool_result(): """Test handling of orphaned tool result - should be sanitized out.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - # Should not be called since orphaned tool result is dropped - if False: - yield - raise AssertionError("ChatClient should not be called with orphaned tool result") + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + if False: + yield ChatResponseUpdate(contents=[]) + raise AssertionError("ChatClient should not be called with orphaned tool result") - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) # Send invalid JSON as tool result without preceding tool call - input_data = { + input_data: dict[str, Any] = { "messages": [ { "role": "tool", @@ -552,7 +572,7 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): ], } - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -568,11 +588,12 @@ async def test_suppressed_summary_with_document_state(): """Test suppressed summary uses document state for confirmation message.""" from agent_framework.ag_ui import AgentFrameworkAgent, DocumentWriterConfirmationStrategy - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="Response")]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="Response")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent( agent=agent, state_schema={"document": {"type": "string"}}, @@ -581,8 +602,8 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): ) # Simulate confirmation with document state - tool_result = {"accepted": True, "steps": []} - input_data = { + tool_result: dict[str, Any] = {"accepted": True, "steps": []} + input_data: dict[str, Any] = { "messages": [ { "role": "tool", @@ -593,7 +614,7 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): "state": {"document": "This is the beginning of a document. It contains important information."}, } - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) diff --git a/python/packages/ag-ui/tests/test_backend_tool_rendering.py b/python/packages/ag-ui/tests/test_backend_tool_rendering.py index fbd27ee8bb..6fefc14665 100644 --- a/python/packages/ag-ui/tests/test_backend_tool_rendering.py +++ b/python/packages/ag-ui/tests/test_backend_tool_rendering.py @@ -2,6 +2,8 @@ """Tests for backend tool rendering.""" +from typing import cast + from ag_ui.core import ( TextMessageContentEvent, TextMessageStartEvent, @@ -119,6 +121,9 @@ async def test_multiple_tool_results(): assert isinstance(events[end_idx], ToolCallEndEvent) assert isinstance(events[result_idx], ToolCallResultEvent) - assert events[end_idx].tool_call_id == f"tool-{i + 1}" - assert events[result_idx].tool_call_id == f"tool-{i + 1}" - assert f"Result {i + 1}" in events[result_idx].content + end_event = cast(ToolCallEndEvent, events[end_idx]) + result_event = cast(ToolCallResultEvent, events[result_idx]) + + assert end_event.tool_call_id == f"tool-{i + 1}" + assert result_event.tool_call_id == f"tool-{i + 1}" + assert f"Result {i + 1}" in result_event.content diff --git a/python/packages/ag-ui/tests/test_confirmation_strategies_comprehensive.py b/python/packages/ag-ui/tests/test_confirmation_strategies_comprehensive.py index 205182d58d..ab355d8995 100644 --- a/python/packages/ag-ui/tests/test_confirmation_strategies_comprehensive.py +++ b/python/packages/ag-ui/tests/test_confirmation_strategies_comprehensive.py @@ -14,7 +14,7 @@ @pytest.fixture -def sample_steps(): +def sample_steps() -> list[dict[str, str]]: """Sample steps for testing approval messages.""" return [ {"description": "Step 1: Do something", "status": "enabled"}, @@ -24,7 +24,7 @@ def sample_steps(): @pytest.fixture -def all_enabled_steps(): +def all_enabled_steps() -> list[dict[str, str]]: """All steps enabled.""" return [ {"description": "Task A", "status": "enabled"}, @@ -34,7 +34,7 @@ def all_enabled_steps(): @pytest.fixture -def empty_steps(): +def empty_steps() -> list[dict[str, str]]: """Empty steps list.""" return [] @@ -42,7 +42,7 @@ def empty_steps(): class TestDefaultConfirmationStrategy: """Tests for DefaultConfirmationStrategy.""" - def test_on_approval_accepted_with_enabled_steps(self, sample_steps): + def test_on_approval_accepted_with_enabled_steps(self, sample_steps: list[dict[str, str]]) -> None: strategy = DefaultConfirmationStrategy() message = strategy.on_approval_accepted(sample_steps) @@ -52,7 +52,7 @@ def test_on_approval_accepted_with_enabled_steps(self, sample_steps): assert "Step 3" not in message # Disabled step shouldn't appear assert "All steps completed successfully!" in message - def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps): + def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps: list[dict[str, str]]) -> None: strategy = DefaultConfirmationStrategy() message = strategy.on_approval_accepted(all_enabled_steps) @@ -61,28 +61,28 @@ def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps): assert "Task B" in message assert "Task C" in message - def test_on_approval_accepted_with_empty_steps(self, empty_steps): + def test_on_approval_accepted_with_empty_steps(self, empty_steps: list[dict[str, str]]) -> None: strategy = DefaultConfirmationStrategy() message = strategy.on_approval_accepted(empty_steps) assert "Executing 0 approved steps" in message assert "All steps completed successfully!" in message - def test_on_approval_rejected(self, sample_steps): + def test_on_approval_rejected(self, sample_steps: list[dict[str, str]]) -> None: strategy = DefaultConfirmationStrategy() message = strategy.on_approval_rejected(sample_steps) assert "No problem!" in message assert "What would you like me to change" in message - def test_on_state_confirmed(self): + def test_on_state_confirmed(self) -> None: strategy = DefaultConfirmationStrategy() message = strategy.on_state_confirmed() assert "Changes confirmed" in message assert "successfully" in message - def test_on_state_rejected(self): + def test_on_state_rejected(self) -> None: strategy = DefaultConfirmationStrategy() message = strategy.on_state_rejected() @@ -93,7 +93,7 @@ def test_on_state_rejected(self): class TestTaskPlannerConfirmationStrategy: """Tests for TaskPlannerConfirmationStrategy.""" - def test_on_approval_accepted_with_enabled_steps(self, sample_steps): + def test_on_approval_accepted_with_enabled_steps(self, sample_steps: list[dict[str, str]]) -> None: strategy = TaskPlannerConfirmationStrategy() message = strategy.on_approval_accepted(sample_steps) @@ -103,7 +103,7 @@ def test_on_approval_accepted_with_enabled_steps(self, sample_steps): assert "Step 3" not in message assert "All tasks completed successfully!" in message - def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps): + def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps: list[dict[str, str]]) -> None: strategy = TaskPlannerConfirmationStrategy() message = strategy.on_approval_accepted(all_enabled_steps) @@ -112,28 +112,28 @@ def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps): assert "2. Task B" in message assert "3. Task C" in message - def test_on_approval_accepted_with_empty_steps(self, empty_steps): + def test_on_approval_accepted_with_empty_steps(self, empty_steps: list[dict[str, str]]) -> None: strategy = TaskPlannerConfirmationStrategy() message = strategy.on_approval_accepted(empty_steps) assert "Executing your requested tasks" in message assert "All tasks completed successfully!" in message - def test_on_approval_rejected(self, sample_steps): + def test_on_approval_rejected(self, sample_steps: list[dict[str, str]]) -> None: strategy = TaskPlannerConfirmationStrategy() message = strategy.on_approval_rejected(sample_steps) assert "No problem!" in message assert "revise the plan" in message - def test_on_state_confirmed(self): + def test_on_state_confirmed(self) -> None: strategy = TaskPlannerConfirmationStrategy() message = strategy.on_state_confirmed() assert "Tasks confirmed" in message assert "ready to execute" in message - def test_on_state_rejected(self): + def test_on_state_rejected(self) -> None: strategy = TaskPlannerConfirmationStrategy() message = strategy.on_state_rejected() @@ -144,7 +144,7 @@ def test_on_state_rejected(self): class TestRecipeConfirmationStrategy: """Tests for RecipeConfirmationStrategy.""" - def test_on_approval_accepted_with_enabled_steps(self, sample_steps): + def test_on_approval_accepted_with_enabled_steps(self, sample_steps: list[dict[str, str]]) -> None: strategy = RecipeConfirmationStrategy() message = strategy.on_approval_accepted(sample_steps) @@ -154,7 +154,7 @@ def test_on_approval_accepted_with_enabled_steps(self, sample_steps): assert "Step 3" not in message assert "Recipe updated successfully!" in message - def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps): + def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps: list[dict[str, str]]) -> None: strategy = RecipeConfirmationStrategy() message = strategy.on_approval_accepted(all_enabled_steps) @@ -163,28 +163,28 @@ def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps): assert "2. Task B" in message assert "3. Task C" in message - def test_on_approval_accepted_with_empty_steps(self, empty_steps): + def test_on_approval_accepted_with_empty_steps(self, empty_steps: list[dict[str, str]]) -> None: strategy = RecipeConfirmationStrategy() message = strategy.on_approval_accepted(empty_steps) assert "Updating your recipe" in message assert "Recipe updated successfully!" in message - def test_on_approval_rejected(self, sample_steps): + def test_on_approval_rejected(self, sample_steps: list[dict[str, str]]) -> None: strategy = RecipeConfirmationStrategy() message = strategy.on_approval_rejected(sample_steps) assert "No problem!" in message assert "ingredients or steps" in message - def test_on_state_confirmed(self): + def test_on_state_confirmed(self) -> None: strategy = RecipeConfirmationStrategy() message = strategy.on_state_confirmed() assert "Recipe changes applied" in message assert "successfully" in message - def test_on_state_rejected(self): + def test_on_state_rejected(self) -> None: strategy = RecipeConfirmationStrategy() message = strategy.on_state_rejected() @@ -195,7 +195,7 @@ def test_on_state_rejected(self): class TestDocumentWriterConfirmationStrategy: """Tests for DocumentWriterConfirmationStrategy.""" - def test_on_approval_accepted_with_enabled_steps(self, sample_steps): + def test_on_approval_accepted_with_enabled_steps(self, sample_steps: list[dict[str, str]]) -> None: strategy = DocumentWriterConfirmationStrategy() message = strategy.on_approval_accepted(sample_steps) @@ -205,7 +205,7 @@ def test_on_approval_accepted_with_enabled_steps(self, sample_steps): assert "Step 3" not in message assert "Document updated successfully!" in message - def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps): + def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps: list[dict[str, str]]) -> None: strategy = DocumentWriterConfirmationStrategy() message = strategy.on_approval_accepted(all_enabled_steps) @@ -214,27 +214,27 @@ def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps): assert "2. Task B" in message assert "3. Task C" in message - def test_on_approval_accepted_with_empty_steps(self, empty_steps): + def test_on_approval_accepted_with_empty_steps(self, empty_steps: list[dict[str, str]]) -> None: strategy = DocumentWriterConfirmationStrategy() message = strategy.on_approval_accepted(empty_steps) assert "Applying your edits" in message assert "Document updated successfully!" in message - def test_on_approval_rejected(self, sample_steps): + def test_on_approval_rejected(self, sample_steps: list[dict[str, str]]) -> None: strategy = DocumentWriterConfirmationStrategy() message = strategy.on_approval_rejected(sample_steps) assert "No problem!" in message assert "keep or modify" in message - def test_on_state_confirmed(self): + def test_on_state_confirmed(self) -> None: strategy = DocumentWriterConfirmationStrategy() message = strategy.on_state_confirmed() assert "Document edits applied!" in message - def test_on_state_rejected(self): + def test_on_state_rejected(self) -> None: strategy = DocumentWriterConfirmationStrategy() message = strategy.on_state_rejected() diff --git a/python/packages/ag-ui/tests/test_document_writer_flow.py b/python/packages/ag-ui/tests/test_document_writer_flow.py index d46b9bf7a0..1ea164beef 100644 --- a/python/packages/ag-ui/tests/test_document_writer_flow.py +++ b/python/packages/ag-ui/tests/test_document_writer_flow.py @@ -2,7 +2,7 @@ """Tests for document writer predictive state flow with confirm_changes.""" -from ag_ui.core import EventType +from ag_ui.core import EventType, StateDeltaEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallStartEvent from agent_framework import FunctionCallContent, FunctionResultContent, TextContent from agent_framework._types import AgentRunResponseUpdate @@ -35,16 +35,12 @@ async def test_streaming_document_with_state_deltas(): assert any(e.type == EventType.TOOL_CALL_ARGS for e in events1) # Second chunk - incomplete JSON, should try partial extraction - tool_call_chunk2 = FunctionCallContent( - call_id="call_123", - name=None, # Name only in first chunk - arguments=" upon a time", - ) + tool_call_chunk2 = FunctionCallContent(call_id="call_123", name="write_document_local", arguments=" upon a time") update2 = AgentRunResponseUpdate(contents=[tool_call_chunk2]) events2 = await bridge.from_agent_run_update(update2) # Should emit StateDeltaEvent with partial document - state_deltas = [e for e in events2 if e.type == EventType.STATE_DELTA] + state_deltas = [e for e in events2 if isinstance(e, StateDeltaEvent)] assert len(state_deltas) >= 1 # Check JSON Patch format @@ -62,7 +58,7 @@ async def test_confirm_changes_emission(): "document": {"tool": "write_document_local", "tool_argument": "document"}, } - current_state = {} + current_state: dict[str, str] = {} bridge = AgentFrameworkEventBridge( run_id="test_run", @@ -90,15 +86,13 @@ async def test_confirm_changes_emission(): assert any(e.type == EventType.STATE_SNAPSHOT for e in events) # Check for confirm_changes tool call - confirm_starts = [ - e for e in events if e.type == EventType.TOOL_CALL_START and e.tool_call_name == "confirm_changes" - ] + confirm_starts = [e for e in events if isinstance(e, ToolCallStartEvent) and e.tool_call_name == "confirm_changes"] assert len(confirm_starts) == 1 - confirm_args = [e for e in events if e.type == EventType.TOOL_CALL_ARGS and e.delta == "{}"] + confirm_args = [e for e in events if isinstance(e, ToolCallArgsEvent) and e.delta == "{}"] assert len(confirm_args) >= 1 - confirm_ends = [e for e in events if e.type == EventType.TOOL_CALL_END] + confirm_ends = [e for e in events if isinstance(e, ToolCallEndEvent)] # At least 2: one for write_document_local, one for confirm_changes assert len(confirm_ends) >= 2 @@ -141,7 +135,7 @@ async def test_no_confirm_for_non_predictive_tools(): "document": {"tool": "write_document_local", "tool_argument": "document"}, } - current_state = {} + current_state: dict[str, str] = {} bridge = AgentFrameworkEventBridge( run_id="test_run", @@ -162,9 +156,7 @@ async def test_no_confirm_for_non_predictive_tools(): events = await bridge.from_agent_run_update(update) # Should NOT have confirm_changes - confirm_starts = [ - e for e in events if e.type == EventType.TOOL_CALL_START and e.tool_call_name == "confirm_changes" - ] + confirm_starts = [e for e in events if isinstance(e, ToolCallStartEvent) and e.tool_call_name == "confirm_changes"] assert len(confirm_starts) == 0 # Stop flag should NOT be set @@ -193,14 +185,14 @@ async def test_state_delta_deduplication(): events1 = await bridge.from_agent_run_update(update1) # Count state deltas - state_deltas_1 = [e for e in events1 if e.type == EventType.STATE_DELTA] + state_deltas_1 = [e for e in events1 if isinstance(e, StateDeltaEvent)] assert len(state_deltas_1) >= 1 # Second tool call with SAME document (shouldn't emit new delta) bridge.current_tool_call_name = "write_document_local" tool_call2 = FunctionCallContent( call_id="call_2", - name=None, + name="write_document_local", arguments='{"document":"Same text"}', # Identical content ) update2 = AgentRunResponseUpdate(contents=[tool_call2]) @@ -234,7 +226,7 @@ async def test_predict_state_config_multiple_fields(): events = await bridge.from_agent_run_update(update) # Should emit StateDeltaEvent for both fields - state_deltas = [e for e in events if e.type == EventType.STATE_DELTA] + state_deltas = [e for e in events if isinstance(e, StateDeltaEvent)] assert len(state_deltas) >= 2 # Check both fields are present diff --git a/python/packages/ag-ui/tests/test_endpoint.py b/python/packages/ag-ui/tests/test_endpoint.py index 93cd271881..829662ab38 100644 --- a/python/packages/ag-ui/tests/test_endpoint.py +++ b/python/packages/ag-ui/tests/test_endpoint.py @@ -3,7 +3,6 @@ """Tests for FastAPI endpoint creation (_endpoint.py).""" import json -from typing import Any from agent_framework import ChatAgent, TextContent from agent_framework._types import ChatResponseUpdate @@ -13,22 +12,19 @@ from agent_framework_ag_ui._agent import AgentFrameworkAgent from agent_framework_ag_ui._endpoint import add_agent_framework_fastapi_endpoint +from ._test_stubs import StreamingChatClientStub, stream_from_updates -class MockChatClient: - """Mock chat client for testing.""" - def __init__(self, response_text: str = "Test response"): - self.response_text = response_text - - async def get_streaming_response(self, messages: list[Any], chat_options: Any, **kwargs: Any): - """Mock streaming response.""" - yield ChatResponseUpdate(contents=[TextContent(text=self.response_text)]) +def build_chat_client(response_text: str = "Test response") -> StreamingChatClientStub: + """Create a typed chat client stub for endpoint tests.""" + updates = [ChatResponseUpdate(contents=[TextContent(text=response_text)])] + return StreamingChatClientStub(stream_from_updates(updates)) async def test_add_endpoint_with_agent_protocol(): """Test adding endpoint with raw AgentProtocol.""" app = FastAPI() - agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) add_agent_framework_fastapi_endpoint(app, agent, path="/test-agent") @@ -42,7 +38,7 @@ async def test_add_endpoint_with_agent_protocol(): async def test_add_endpoint_with_wrapped_agent(): """Test adding endpoint with pre-wrapped AgentFrameworkAgent.""" app = FastAPI() - agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) wrapped_agent = AgentFrameworkAgent(agent=agent, name="wrapped") add_agent_framework_fastapi_endpoint(app, wrapped_agent, path="/wrapped-agent") @@ -57,7 +53,7 @@ async def test_add_endpoint_with_wrapped_agent(): async def test_endpoint_with_state_schema(): """Test endpoint with state_schema parameter.""" app = FastAPI() - agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) state_schema = {"document": {"type": "string"}} add_agent_framework_fastapi_endpoint(app, agent, path="/stateful", state_schema=state_schema) @@ -73,7 +69,7 @@ async def test_endpoint_with_state_schema(): async def test_endpoint_with_default_state_seed(): """Test endpoint seeds default state when client omits it.""" app = FastAPI() - agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) state_schema = {"proverbs": {"type": "array"}} default_state = {"proverbs": ["Keep the original."]} @@ -100,7 +96,7 @@ async def test_endpoint_with_default_state_seed(): async def test_endpoint_with_predict_state_config(): """Test endpoint with predict_state_config parameter.""" app = FastAPI() - agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) predict_config = {"document": {"tool": "write_doc", "tool_argument": "content"}} add_agent_framework_fastapi_endpoint(app, agent, path="/predictive", predict_state_config=predict_config) @@ -114,7 +110,7 @@ async def test_endpoint_with_predict_state_config(): async def test_endpoint_request_logging(): """Test that endpoint logs request details.""" app = FastAPI() - agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) add_agent_framework_fastapi_endpoint(app, agent, path="/logged") @@ -134,7 +130,7 @@ async def test_endpoint_request_logging(): async def test_endpoint_event_streaming(): """Test that endpoint streams events correctly.""" app = FastAPI() - agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient("Streamed response")) + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client("Streamed response")) add_agent_framework_fastapi_endpoint(app, agent, path="/stream") @@ -168,14 +164,14 @@ async def test_endpoint_event_streaming(): async def test_endpoint_error_handling(): """Test endpoint error handling during request parsing.""" app = FastAPI() - agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) add_agent_framework_fastapi_endpoint(app, agent, path="/failing") client = TestClient(app) # Send invalid JSON to trigger parsing error before streaming - response = client.post("/failing", data="invalid json", headers={"content-type": "application/json"}) + response = client.post("/failing", data=b"invalid json", headers={"content-type": "application/json"}) # type: ignore # The exception handler catches it and returns JSON error assert response.status_code == 200 @@ -187,8 +183,8 @@ async def test_endpoint_error_handling(): async def test_endpoint_multiple_paths(): """Test adding multiple endpoints with different paths.""" app = FastAPI() - agent1 = ChatAgent(name="agent1", instructions="First agent", chat_client=MockChatClient("Response 1")) - agent2 = ChatAgent(name="agent2", instructions="Second agent", chat_client=MockChatClient("Response 2")) + agent1 = ChatAgent(name="agent1", instructions="First agent", chat_client=build_chat_client("Response 1")) + agent2 = ChatAgent(name="agent2", instructions="Second agent", chat_client=build_chat_client("Response 2")) add_agent_framework_fastapi_endpoint(app, agent1, path="/agent1") add_agent_framework_fastapi_endpoint(app, agent2, path="/agent2") @@ -205,7 +201,7 @@ async def test_endpoint_multiple_paths(): async def test_endpoint_default_path(): """Test endpoint with default path.""" app = FastAPI() - agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) add_agent_framework_fastapi_endpoint(app, agent) @@ -218,7 +214,7 @@ async def test_endpoint_default_path(): async def test_endpoint_response_headers(): """Test that endpoint sets correct response headers.""" app = FastAPI() - agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) add_agent_framework_fastapi_endpoint(app, agent, path="/headers") @@ -234,7 +230,7 @@ async def test_endpoint_response_headers(): async def test_endpoint_empty_messages(): """Test endpoint with empty messages list.""" app = FastAPI() - agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) add_agent_framework_fastapi_endpoint(app, agent, path="/empty") @@ -247,7 +243,7 @@ async def test_endpoint_empty_messages(): async def test_endpoint_complex_input(): """Test endpoint with complex input data.""" app = FastAPI() - agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) add_agent_framework_fastapi_endpoint(app, agent, path="/complex") diff --git a/python/packages/ag-ui/tests/test_message_hygiene.py b/python/packages/ag-ui/tests/test_message_hygiene.py index 14066ce89c..ba775fa7d9 100644 --- a/python/packages/ag-ui/tests/test_message_hygiene.py +++ b/python/packages/ag-ui/tests/test_message_hygiene.py @@ -2,7 +2,7 @@ from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent -from agent_framework_ag_ui._orchestration.message_hygiene import ( +from agent_framework_ag_ui._orchestration._message_hygiene import ( deduplicate_messages, sanitize_tool_history, ) diff --git a/python/packages/ag-ui/tests/test_orchestrators_coverage.py b/python/packages/ag-ui/tests/test_orchestrators_coverage.py index 81e41dee5f..cdf39a3e88 100644 --- a/python/packages/ag-ui/tests/test_orchestrators_coverage.py +++ b/python/packages/ag-ui/tests/test_orchestrators_coverage.py @@ -15,11 +15,9 @@ from pydantic import BaseModel from agent_framework_ag_ui._agent import AgentConfig -from agent_framework_ag_ui._orchestrators import ( - DefaultOrchestrator, - ExecutionContext, - HumanInTheLoopOrchestrator, -) +from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, HumanInTheLoopOrchestrator + +from ._test_stubs import StubAgent, TestExecutionContext @ai_function(approval_mode="always_require") @@ -28,34 +26,14 @@ def approval_tool(param: str) -> str: return f"executed: {param}" -class MockAgent: - """Mock agent for testing.""" - - def __init__(self, updates: list[AgentRunResponseUpdate] | None = None) -> None: - self.updates = updates or [AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant")] - self.chat_options = SimpleNamespace(tools=[approval_tool], response_format=None) - self.chat_client = SimpleNamespace(function_invocation_configuration=None) - self.messages_received: list[Any] = [] - self.tools_received: list[Any] | None = None - - async def run_stream( - self, - messages: list[Any], - *, - thread: Any = None, - tools: list[Any] | None = None, - ) -> AsyncGenerator[AgentRunResponseUpdate, None]: - self.messages_received = messages - self.tools_received = tools - for update in self.updates: - yield update +DEFAULT_CHAT_OPTIONS = SimpleNamespace(tools=[approval_tool], response_format=None) async def test_human_in_the_loop_json_decode_error() -> None: """Test HumanInTheLoopOrchestrator handles invalid JSON in tool result.""" orchestrator = HumanInTheLoopOrchestrator() - input_data = { + input_data: dict[str, Any] = { "messages": [ { "role": "tool", @@ -72,21 +50,25 @@ async def test_human_in_the_loop_json_decode_error() -> None: ) ] - context = ExecutionContext( + agent = StubAgent( + chat_options=SimpleNamespace(tools=[approval_tool], response_format=None), + updates=[AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant")], + ) + context = TestExecutionContext( input_data=input_data, - agent=MockAgent(), + agent=agent, config=AgentConfig(), ) - context._messages = messages + context.set_messages(messages) assert orchestrator.can_handle(context) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) # Should emit RunErrorEvent for invalid JSON - error_events = [e for e in events if e.type == "RUN_ERROR"] + error_events: list[Any] = [e for e in events if e.type == "RUN_ERROR"] assert len(error_events) == 1 assert "Invalid tool result format" in error_events[0].message @@ -118,18 +100,20 @@ async def test_sanitize_tool_history_confirm_changes() -> None: orchestrator = DefaultOrchestrator() # Use pre-constructed ChatMessage objects to bypass message adapter - input_data = {"messages": []} + input_data: dict[str, Any] = {"messages": []} - agent = MockAgent() - context = ExecutionContext( + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + ) + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(), ) # Override the messages property to use our pre-constructed messages - context._messages = messages + context.set_messages(messages) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) @@ -162,16 +146,18 @@ async def test_sanitize_tool_history_orphaned_tool_result() -> None: ] orchestrator = DefaultOrchestrator() - input_data = {"messages": []} - agent = MockAgent() - context = ExecutionContext( + input_data: dict[str, Any] = {"messages": []} + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + ) + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(), ) - context._messages = messages + context.set_messages(messages) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) @@ -188,7 +174,7 @@ async def test_orphaned_tool_result_sanitization() -> None: """Test that orphaned tool results are filtered out.""" orchestrator = DefaultOrchestrator() - input_data = { + input_data: dict[str, Any] = { "messages": [ { "role": "tool", @@ -201,14 +187,16 @@ async def test_orphaned_tool_result_sanitization() -> None: ], } - agent = MockAgent() - context = ExecutionContext( + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + ) + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(), ) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) @@ -241,16 +229,18 @@ async def test_deduplicate_messages_empty_tool_results() -> None: ] orchestrator = DefaultOrchestrator() - input_data = {"messages": []} - agent = MockAgent() - context = ExecutionContext( + input_data: dict[str, Any] = {"messages": []} + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + ) + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(), ) - context._messages = messages + context.set_messages(messages) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) @@ -284,16 +274,18 @@ async def test_deduplicate_messages_duplicate_assistant_tool_calls() -> None: ] orchestrator = DefaultOrchestrator() - input_data = {"messages": []} - agent = MockAgent() - context = ExecutionContext( + input_data: dict[str, Any] = {"messages": []} + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + ) + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(), ) - context._messages = messages + context.set_messages(messages) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) @@ -326,16 +318,18 @@ async def test_deduplicate_messages_duplicate_system_messages() -> None: ] orchestrator = DefaultOrchestrator() - input_data = {"messages": []} - agent = MockAgent() - context = ExecutionContext( + input_data: dict[str, Any] = {"messages": []} + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + ) + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(), ) - context._messages = messages + context.set_messages(messages) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) @@ -354,7 +348,7 @@ async def test_state_context_injection() -> None: """Test state context message injection for first request.""" orchestrator = DefaultOrchestrator() - input_data = { + input_data: dict[str, Any] = { "messages": [ { "role": "user", @@ -364,14 +358,16 @@ async def test_state_context_injection() -> None: "state": {"items": ["apple", "banana"]}, } - agent = MockAgent() - context = ExecutionContext( + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + ) + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(state_schema={"items": {"type": "array"}}), ) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) @@ -406,16 +402,18 @@ async def test_no_state_context_injection_with_tool_calls() -> None: ] orchestrator = DefaultOrchestrator() - input_data = {"messages": [], "state": {"weather": "sunny"}} - agent = MockAgent() - context = ExecutionContext( + input_data: dict[str, Any] = {"messages": [], "state": {"weather": "sunny"}} + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + ) + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(state_schema={"weather": {"type": "string"}}), ) - context._messages = messages + context.set_messages(messages) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) @@ -437,7 +435,7 @@ class RecipeState(BaseModel): orchestrator = DefaultOrchestrator() - input_data = { + input_data: dict[str, Any] = { "messages": [ { "role": "user", @@ -447,32 +445,33 @@ class RecipeState(BaseModel): } # Agent with structured output - agent = MockAgent( + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, updates=[ AgentRunResponseUpdate( contents=[TextContent(text='{"ingredients": ["tomato"], "message": "Added tomato"}')], role="assistant", ) - ] + ], ) agent.chat_options.response_format = RecipeState - context = ExecutionContext( + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(state_schema={"ingredients": {"type": "array"}}), ) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) # Should emit StateSnapshotEvent with ingredients - state_events = [e for e in events if e.type == "STATE_SNAPSHOT"] + state_events: list[Any] = [e for e in events if e.type == "STATE_SNAPSHOT"] assert len(state_events) >= 1 # Should emit TextMessage with message field - text_content_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"] + text_content_events: list[Any] = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"] assert len(text_content_events) >= 1 assert any("Added tomato" in e.delta for e in text_content_events) @@ -487,7 +486,7 @@ def get_weather(location: str) -> str: orchestrator = DefaultOrchestrator() - input_data = { + input_data: dict[str, Any] = { "messages": [ { "role": "user", @@ -507,16 +506,18 @@ def get_weather(location: str) -> str: ], } - agent = MockAgent() + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + ) agent.chat_options.tools = [get_weather] - context = ExecutionContext( + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(), ) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) @@ -534,7 +535,7 @@ def server_tool() -> str: orchestrator = DefaultOrchestrator() - input_data = { + input_data: dict[str, Any] = { "messages": [ { "role": "user", @@ -554,16 +555,18 @@ def server_tool() -> str: ], } - agent = MockAgent() + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + ) agent.chat_options.tools = [server_tool] - context = ExecutionContext( + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(), ) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) @@ -578,16 +581,18 @@ async def test_empty_messages_handling() -> None: """Test orchestrator handles empty message list gracefully.""" orchestrator = DefaultOrchestrator() - input_data = {"messages": []} + input_data: dict[str, Any] = {"messages": []} - agent = MockAgent() - context = ExecutionContext( + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + ) + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(), ) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) @@ -603,7 +608,7 @@ async def test_all_messages_filtered_handling() -> None: """Test orchestrator handles case where all messages are filtered out.""" orchestrator = DefaultOrchestrator() - input_data = { + input_data: dict[str, Any] = { "messages": [ { "role": "tool", @@ -612,14 +617,16 @@ async def test_all_messages_filtered_handling() -> None: ] } - agent = MockAgent() - context = ExecutionContext( + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + ) + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(), ) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) @@ -651,16 +658,18 @@ async def test_confirm_changes_with_invalid_json_fallback() -> None: ] orchestrator = DefaultOrchestrator() - input_data = {"messages": []} - agent = MockAgent() - context = ExecutionContext( + input_data: dict[str, Any] = {"messages": []} + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + ) + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(), ) - context._messages = messages + context.set_messages(messages) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) @@ -689,16 +698,18 @@ async def test_tool_result_kept_when_call_id_matches() -> None: ] orchestrator = DefaultOrchestrator() - input_data = {"messages": []} - agent = MockAgent() - context = ExecutionContext( + input_data: dict[str, Any] = {"messages": []} + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + ) + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(), ) - context._messages = messages + context.set_messages(messages) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) @@ -738,16 +749,16 @@ async def run_stream( messages = [ChatMessage(role="user", contents=[TextContent(text="Hello")])] orchestrator = DefaultOrchestrator() - input_data = {"messages": []} + input_data: dict[str, Any] = {"messages": []} agent = CustomAgent() - context = ExecutionContext( + context = TestExecutionContext( input_data=input_data, agent=agent, # type: ignore config=AgentConfig(), ) - context._messages = messages + context.set_messages(messages) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) @@ -762,21 +773,23 @@ async def test_initial_state_snapshot_with_array_schema() -> None: messages = [ChatMessage(role="user", contents=[TextContent(text="Hello")])] orchestrator = DefaultOrchestrator() - input_data = {"messages": [], "state": {}} - agent = MockAgent() - context = ExecutionContext( + input_data: dict[str, Any] = {"messages": [], "state": {}} + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + ) + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(state_schema={"items": {"type": "array"}}), ) - context._messages = messages + context.set_messages(messages) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) # Should emit state snapshot with empty array for items - state_events = [e for e in events if e.type == "STATE_SNAPSHOT"] + state_events: list[Any] = [e for e in events if e.type == "STATE_SNAPSHOT"] assert len(state_events) >= 1 @@ -791,19 +804,21 @@ class OutputModel(BaseModel): messages = [ChatMessage(role="user", contents=[TextContent(text="Hello")])] orchestrator = DefaultOrchestrator() - input_data = {"messages": []} + input_data: dict[str, Any] = {"messages": []} - agent = MockAgent() + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + ) agent.chat_options.response_format = OutputModel - context = ExecutionContext( + context = TestExecutionContext( input_data=input_data, agent=agent, config=AgentConfig(), ) - context._messages = messages + context.set_messages(messages) - events = [] + events: list[Any] = [] async for event in orchestrator.run(context): events.append(event) diff --git a/python/packages/ag-ui/tests/test_shared_state.py b/python/packages/ag-ui/tests/test_shared_state.py index 578d48ecd0..bfdc64081d 100644 --- a/python/packages/ag-ui/tests/test_shared_state.py +++ b/python/packages/ag-ui/tests/test_shared_state.py @@ -2,6 +2,8 @@ """Tests for shared state management.""" +from typing import Any + import pytest from ag_ui.core import StateSnapshotEvent from agent_framework import ChatAgent, TextContent @@ -10,20 +12,15 @@ from agent_framework_ag_ui._agent import AgentFrameworkAgent from agent_framework_ag_ui._events import AgentFrameworkEventBridge +from ._test_stubs import StreamingChatClientStub, stream_from_updates + @pytest.fixture -def mock_agent(): +def mock_agent() -> ChatAgent: """Create a mock agent for testing.""" - - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="Hello!")]) - - return ChatAgent( - name="test_agent", - instructions="Test agent", - chat_client=MockChatClient(), - ) + updates = [ChatResponseUpdate(contents=[TextContent(text="Hello!")])] + chat_client = StreamingChatClientStub(stream_from_updates(updates)) + return ChatAgent(name="test_agent", instructions="Test agent", chat_client=chat_client) def test_state_snapshot_event(): @@ -65,9 +62,9 @@ def test_state_delta_event(): assert event.delta[1]["op"] == "replace" -async def test_agent_with_initial_state(mock_agent): +async def test_agent_with_initial_state(mock_agent: ChatAgent) -> None: """Test agent emits state snapshot when initial state provided.""" - state_schema = {"recipe": {"type": "object", "properties": {"name": {"type": "string"}}}} + state_schema: dict[str, Any] = {"recipe": {"type": "object", "properties": {"name": {"type": "string"}}}} agent = AgentFrameworkAgent( agent=mock_agent, @@ -76,12 +73,12 @@ async def test_agent_with_initial_state(mock_agent): initial_state = {"recipe": {"name": "Test Recipe"}} - input_data = { + input_data: dict[str, Any] = { "messages": [{"role": "user", "content": "Hello"}], "state": initial_state, } - events = [] + events: list[Any] = [] async for event in agent.run_agent(input_data): events.append(event) @@ -91,16 +88,16 @@ async def test_agent_with_initial_state(mock_agent): assert snapshot_events[0].snapshot == initial_state -async def test_agent_without_state_schema(mock_agent): +async def test_agent_without_state_schema(mock_agent: ChatAgent) -> None: """Test agent doesn't emit state events without state schema.""" agent = AgentFrameworkAgent(agent=mock_agent) - input_data = { + input_data: dict[str, Any] = { "messages": [{"role": "user", "content": "Hello"}], "state": {"some": "state"}, } - events = [] + events: list[Any] = [] async for event in agent.run_agent(input_data): events.append(event) diff --git a/python/packages/ag-ui/tests/test_state_manager.py b/python/packages/ag-ui/tests/test_state_manager.py index ce964d784d..bc0a7b6a19 100644 --- a/python/packages/ag-ui/tests/test_state_manager.py +++ b/python/packages/ag-ui/tests/test_state_manager.py @@ -4,7 +4,7 @@ from agent_framework import ChatMessage, TextContent from agent_framework_ag_ui._events import AgentFrameworkEventBridge -from agent_framework_ag_ui._orchestration.state_manager import StateManager +from agent_framework_ag_ui._orchestration._state_manager import StateManager def test_state_manager_initializes_defaults_and_snapshot() -> None: diff --git a/python/packages/ag-ui/tests/test_structured_output.py b/python/packages/ag-ui/tests/test_structured_output.py index 10307356a5..26742b7e63 100644 --- a/python/packages/ag-ui/tests/test_structured_output.py +++ b/python/packages/ag-ui/tests/test_structured_output.py @@ -3,12 +3,15 @@ """Tests for structured output handling in _agent.py.""" import json +from collections.abc import AsyncIterator, MutableSequence from typing import Any -from agent_framework import ChatAgent, ChatOptions, TextContent +from agent_framework import ChatAgent, ChatMessage, ChatOptions, TextContent from agent_framework._types import ChatResponseUpdate from pydantic import BaseModel +from ._test_stubs import StreamingChatClientStub, stream_from_updates + class RecipeOutput(BaseModel): """Test Pydantic model for recipe output.""" @@ -34,14 +37,14 @@ async def test_structured_output_with_recipe(): """Test structured output processing with recipe state.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - # Simulate structured output - yield ChatResponseUpdate( - contents=[TextContent(text='{"recipe": {"name": "Pasta"}, "message": "Here is your recipe"}')] - ) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate( + contents=[TextContent(text='{"recipe": {"name": "Pasta"}, "message": "Here is your recipe"}')] + ) - agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) agent.chat_options = ChatOptions(response_format=RecipeOutput) wrapper = AgentFrameworkAgent( @@ -51,7 +54,7 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): input_data = {"messages": [{"role": "user", "content": "Make pasta"}]} - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -72,17 +75,18 @@ async def test_structured_output_with_steps(): """Test structured output processing with steps state.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - steps_data = { - "steps": [ - {"id": "1", "description": "Step 1", "status": "pending"}, - {"id": "2", "description": "Step 2", "status": "pending"}, - ] - } - yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(steps_data))]) - - agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient()) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + steps_data = { + "steps": [ + {"id": "1", "description": "Step 1", "status": "pending"}, + {"id": "2", "description": "Step 2", "status": "pending"}, + ] + } + yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(steps_data))]) + + agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) agent.chat_options = ChatOptions(response_format=StepsOutput) wrapper = AgentFrameworkAgent( @@ -92,7 +96,7 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): input_data = {"messages": [{"role": "user", "content": "Do steps"}]} - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -111,12 +115,13 @@ async def test_structured_output_with_no_schema_match(): """Test structured output when response fields don't match state_schema keys.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - # Response has "data" field but schema expects "result" field - yield ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}}')]) + updates = [ + ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}}')]), + ] - agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent( + name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_from_updates(updates)) + ) agent.chat_options = ChatOptions(response_format=GenericOutput) wrapper = AgentFrameworkAgent( @@ -126,7 +131,7 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): input_data = {"messages": [{"role": "user", "content": "Generate data"}]} - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -146,11 +151,12 @@ class DataOutput(BaseModel): data: dict[str, Any] info: str - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}, "info": "processed"}')]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}, "info": "processed"}')]) - agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) agent.chat_options = ChatOptions(response_format=DataOutput) wrapper = AgentFrameworkAgent( @@ -160,7 +166,7 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): input_data = {"messages": [{"role": "user", "content": "Generate data"}]} - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -177,18 +183,20 @@ async def test_no_structured_output_when_no_response_format(): """Test that structured output path is skipped when no response_format.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - yield ChatResponseUpdate(contents=[TextContent(text="Regular text")]) + updates = [ChatResponseUpdate(contents=[TextContent(text="Regular text")])] - agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent( + name="test", + instructions="Test", + chat_client=StreamingChatClientStub(stream_from_updates(updates)), + ) # No response_format set wrapper = AgentFrameworkAgent(agent=agent) input_data = {"messages": [{"role": "user", "content": "Hi"}]} - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -202,12 +210,13 @@ async def test_structured_output_with_message_field(): """Test structured output that includes a message field.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - output_data = {"recipe": {"name": "Salad"}, "message": "Fresh salad recipe ready"} - yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(output_data))]) + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + output_data = {"recipe": {"name": "Salad"}, "message": "Fresh salad recipe ready"} + yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(output_data))]) - agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) agent.chat_options = ChatOptions(response_format=RecipeOutput) wrapper = AgentFrameworkAgent( @@ -217,7 +226,7 @@ async def get_streaming_response(self, messages, chat_options, **kwargs): input_data = {"messages": [{"role": "user", "content": "Make salad"}]} - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) @@ -236,20 +245,20 @@ async def test_empty_updates_no_structured_processing(): """Test that empty updates don't trigger structured output processing.""" from agent_framework.ag_ui import AgentFrameworkAgent - class MockChatClient: - async def get_streaming_response(self, messages, chat_options, **kwargs): - # Return nothing - if False: - yield + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + if False: + yield ChatResponseUpdate(contents=[]) - agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient()) + agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) agent.chat_options = ChatOptions(response_format=RecipeOutput) wrapper = AgentFrameworkAgent(agent=agent) input_data = {"messages": [{"role": "user", "content": "Test"}]} - events = [] + events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) diff --git a/python/packages/ag-ui/tests/test_tooling.py b/python/packages/ag-ui/tests/test_tooling.py index dfd89c0148..b802d654c6 100644 --- a/python/packages/ag-ui/tests/test_tooling.py +++ b/python/packages/ag-ui/tests/test_tooling.py @@ -2,7 +2,7 @@ from types import SimpleNamespace -from agent_framework_ag_ui._orchestration.tooling import merge_tools, register_additional_client_tools +from agent_framework_ag_ui._orchestration._tooling import merge_tools, register_additional_client_tools class DummyTool: diff --git a/python/packages/ag-ui/tests/test_utils.py b/python/packages/ag-ui/tests/test_utils.py index e4324ab187..4a6d0360bd 100644 --- a/python/packages/ag-ui/tests/test_utils.py +++ b/python/packages/ag-ui/tests/test_utils.py @@ -20,8 +20,8 @@ def test_generate_event_id(): def test_merge_state(): """Test state merging.""" - current = {"a": 1, "b": 2} - update = {"b": 3, "c": 4} + current: dict[str, int] = {"a": 1, "b": 2} + update: dict[str, int] = {"b": 3, "c": 4} result = merge_state(current, update) @@ -32,8 +32,8 @@ def test_merge_state(): def test_merge_state_empty_update(): """Test merging with empty update.""" - current = {"x": 10, "y": 20} - update = {} + current: dict[str, int] = {"x": 10, "y": 20} + update: dict[str, int] = {} result = merge_state(current, update) @@ -43,8 +43,8 @@ def test_merge_state_empty_update(): def test_merge_state_empty_current(): """Test merging with empty current state.""" - current = {} - update = {"a": 1, "b": 2} + current: dict[str, int] = {} + update: dict[str, int] = {"a": 1, "b": 2} result = merge_state(current, update) @@ -53,8 +53,8 @@ def test_merge_state_empty_current(): def test_merge_state_deep_copy(): """Test that merge_state creates a deep copy preventing mutation of original.""" - current = {"recipe": {"name": "Cake", "ingredients": ["flour", "sugar"]}} - update = {"other": "value"} + current: dict[str, dict[str, object]] = {"recipe": {"name": "Cake", "ingredients": ["flour", "sugar"]}} + update: dict[str, str] = {"other": "value"} result = merge_state(current, update) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 591d255490..f9519acade 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -875,7 +875,12 @@ async def run( ) # Filter chat_options from kwargs to prevent duplicate keyword argument filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"} - response = await self.chat_client.get_response(messages=thread_messages, chat_options=co, **filtered_kwargs) + response = await self.chat_client.get_response( + messages=thread_messages, + chat_options=co, + thread=thread, + **filtered_kwargs, + ) await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id) @@ -1011,7 +1016,10 @@ async def run_stream( filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"} response_updates: list[ChatResponseUpdate] = [] async for update in self.chat_client.get_streaming_response( - messages=thread_messages, chat_options=co, **filtered_kwargs + messages=thread_messages, + chat_options=co, + thread=thread, + **filtered_kwargs, ): response_updates.append(update) From b55238f1f4503743b0db93309087c93be6534841 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Mon, 24 Nov 2025 12:08:46 +0900 Subject: [PATCH 4/6] Fix test import error --- python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py | 2 +- python/packages/ag-ui/tests/test_endpoint.py | 2 +- python/packages/ag-ui/tests/test_orchestrators_coverage.py | 2 +- python/packages/ag-ui/tests/test_shared_state.py | 2 +- python/packages/ag-ui/tests/test_structured_output.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index 015bbdfc61..76086dca79 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -11,7 +11,7 @@ from agent_framework._types import ChatResponseUpdate from pydantic import BaseModel -from ._test_stubs import StreamingChatClientStub +from tests._test_stubs import StreamingChatClientStub async def test_agent_initialization_basic(): diff --git a/python/packages/ag-ui/tests/test_endpoint.py b/python/packages/ag-ui/tests/test_endpoint.py index 829662ab38..c413562a3a 100644 --- a/python/packages/ag-ui/tests/test_endpoint.py +++ b/python/packages/ag-ui/tests/test_endpoint.py @@ -12,7 +12,7 @@ from agent_framework_ag_ui._agent import AgentFrameworkAgent from agent_framework_ag_ui._endpoint import add_agent_framework_fastapi_endpoint -from ._test_stubs import StreamingChatClientStub, stream_from_updates +from tests._test_stubs import StreamingChatClientStub, stream_from_updates def build_chat_client(response_text: str = "Test response") -> StreamingChatClientStub: diff --git a/python/packages/ag-ui/tests/test_orchestrators_coverage.py b/python/packages/ag-ui/tests/test_orchestrators_coverage.py index cdf39a3e88..c5db90d372 100644 --- a/python/packages/ag-ui/tests/test_orchestrators_coverage.py +++ b/python/packages/ag-ui/tests/test_orchestrators_coverage.py @@ -17,7 +17,7 @@ from agent_framework_ag_ui._agent import AgentConfig from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, HumanInTheLoopOrchestrator -from ._test_stubs import StubAgent, TestExecutionContext +from tests._test_stubs import StubAgent, TestExecutionContext @ai_function(approval_mode="always_require") diff --git a/python/packages/ag-ui/tests/test_shared_state.py b/python/packages/ag-ui/tests/test_shared_state.py index bfdc64081d..f99074f7eb 100644 --- a/python/packages/ag-ui/tests/test_shared_state.py +++ b/python/packages/ag-ui/tests/test_shared_state.py @@ -12,7 +12,7 @@ from agent_framework_ag_ui._agent import AgentFrameworkAgent from agent_framework_ag_ui._events import AgentFrameworkEventBridge -from ._test_stubs import StreamingChatClientStub, stream_from_updates +from tests._test_stubs import StreamingChatClientStub, stream_from_updates @pytest.fixture diff --git a/python/packages/ag-ui/tests/test_structured_output.py b/python/packages/ag-ui/tests/test_structured_output.py index 26742b7e63..cf7d0e868c 100644 --- a/python/packages/ag-ui/tests/test_structured_output.py +++ b/python/packages/ag-ui/tests/test_structured_output.py @@ -10,7 +10,7 @@ from agent_framework._types import ChatResponseUpdate from pydantic import BaseModel -from ._test_stubs import StreamingChatClientStub, stream_from_updates +from tests._test_stubs import StreamingChatClientStub, stream_from_updates class RecipeOutput(BaseModel): From d18c5dd1e4cd1afd1d1ac96a86ef19d8383f713e Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Mon, 24 Nov 2025 12:37:34 +0900 Subject: [PATCH 5/6] Fix imports again --- .../packages/ag-ui/tests/test_agent_wrapper_comprehensive.py | 5 ++++- python/packages/ag-ui/tests/test_endpoint.py | 5 ++++- .../ag-ui/tests/{_test_stubs.py => test_helpers_ag_ui.py} | 0 python/packages/ag-ui/tests/test_orchestrators_coverage.py | 5 ++++- python/packages/ag-ui/tests/test_shared_state.py | 5 ++++- python/packages/ag-ui/tests/test_structured_output.py | 5 ++++- 6 files changed, 20 insertions(+), 5 deletions(-) rename python/packages/ag-ui/tests/{_test_stubs.py => test_helpers_ag_ui.py} (100%) diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index 76086dca79..3b4e59c9e4 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -3,7 +3,9 @@ """Comprehensive tests for AgentFrameworkAgent (_agent.py).""" import json +import sys from collections.abc import AsyncIterator, MutableSequence +from pathlib import Path from typing import Any import pytest @@ -11,7 +13,8 @@ from agent_framework._types import ChatResponseUpdate from pydantic import BaseModel -from tests._test_stubs import StreamingChatClientStub +sys.path.insert(0, str(Path(__file__).parent)) +from test_helpers_ag_ui import StreamingChatClientStub async def test_agent_initialization_basic(): diff --git a/python/packages/ag-ui/tests/test_endpoint.py b/python/packages/ag-ui/tests/test_endpoint.py index c413562a3a..36c9e3bc32 100644 --- a/python/packages/ag-ui/tests/test_endpoint.py +++ b/python/packages/ag-ui/tests/test_endpoint.py @@ -3,6 +3,8 @@ """Tests for FastAPI endpoint creation (_endpoint.py).""" import json +import sys +from pathlib import Path from agent_framework import ChatAgent, TextContent from agent_framework._types import ChatResponseUpdate @@ -12,7 +14,8 @@ from agent_framework_ag_ui._agent import AgentFrameworkAgent from agent_framework_ag_ui._endpoint import add_agent_framework_fastapi_endpoint -from tests._test_stubs import StreamingChatClientStub, stream_from_updates +sys.path.insert(0, str(Path(__file__).parent)) +from test_helpers_ag_ui import StreamingChatClientStub, stream_from_updates def build_chat_client(response_text: str = "Test response") -> StreamingChatClientStub: diff --git a/python/packages/ag-ui/tests/_test_stubs.py b/python/packages/ag-ui/tests/test_helpers_ag_ui.py similarity index 100% rename from python/packages/ag-ui/tests/_test_stubs.py rename to python/packages/ag-ui/tests/test_helpers_ag_ui.py diff --git a/python/packages/ag-ui/tests/test_orchestrators_coverage.py b/python/packages/ag-ui/tests/test_orchestrators_coverage.py index c5db90d372..a9c63b8a81 100644 --- a/python/packages/ag-ui/tests/test_orchestrators_coverage.py +++ b/python/packages/ag-ui/tests/test_orchestrators_coverage.py @@ -2,7 +2,9 @@ """Comprehensive tests for orchestrator coverage.""" +import sys from collections.abc import AsyncGenerator +from pathlib import Path from types import SimpleNamespace from typing import Any @@ -17,7 +19,8 @@ from agent_framework_ag_ui._agent import AgentConfig from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, HumanInTheLoopOrchestrator -from tests._test_stubs import StubAgent, TestExecutionContext +sys.path.insert(0, str(Path(__file__).parent)) +from test_helpers_ag_ui import StubAgent, TestExecutionContext @ai_function(approval_mode="always_require") diff --git a/python/packages/ag-ui/tests/test_shared_state.py b/python/packages/ag-ui/tests/test_shared_state.py index f99074f7eb..36f80b9d47 100644 --- a/python/packages/ag-ui/tests/test_shared_state.py +++ b/python/packages/ag-ui/tests/test_shared_state.py @@ -2,6 +2,8 @@ """Tests for shared state management.""" +import sys +from pathlib import Path from typing import Any import pytest @@ -12,7 +14,8 @@ from agent_framework_ag_ui._agent import AgentFrameworkAgent from agent_framework_ag_ui._events import AgentFrameworkEventBridge -from tests._test_stubs import StreamingChatClientStub, stream_from_updates +sys.path.insert(0, str(Path(__file__).parent)) +from test_helpers_ag_ui import StreamingChatClientStub, stream_from_updates @pytest.fixture diff --git a/python/packages/ag-ui/tests/test_structured_output.py b/python/packages/ag-ui/tests/test_structured_output.py index cf7d0e868c..c5f9719938 100644 --- a/python/packages/ag-ui/tests/test_structured_output.py +++ b/python/packages/ag-ui/tests/test_structured_output.py @@ -3,14 +3,17 @@ """Tests for structured output handling in _agent.py.""" import json +import sys from collections.abc import AsyncIterator, MutableSequence +from pathlib import Path from typing import Any from agent_framework import ChatAgent, ChatMessage, ChatOptions, TextContent from agent_framework._types import ChatResponseUpdate from pydantic import BaseModel -from tests._test_stubs import StreamingChatClientStub, stream_from_updates +sys.path.insert(0, str(Path(__file__).parent)) +from test_helpers_ag_ui import StreamingChatClientStub, stream_from_updates class RecipeOutput(BaseModel): From 4986f850e3e3bbe39208c2fb6e34edcf5d952901 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Thu, 27 Nov 2025 10:56:13 +0900 Subject: [PATCH 6/6] Fix thread handling --- .../agent_framework_ag_ui/_orchestrators.py | 20 ++++++++++++++++++- .../tests/test_agent_wrapper_comprehensive.py | 13 +++++++----- .../ag-ui/tests/test_orchestrators.py | 1 + .../tests/test_orchestrators_coverage.py | 1 + .../packages/core/agent_framework/_agents.py | 2 -- 5 files changed, 29 insertions(+), 8 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index 39051c042c..654498e371 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -423,7 +423,25 @@ async def run( all_updates: list[Any] = [] update_count = 0 - async for update in context.agent.run_stream(messages_to_run, thread=thread, tools=tools_param): + # Prepare metadata for chat client (Azure requires string values) + safe_metadata: dict[str, Any] = {} + thread_metadata = getattr(thread, "metadata", None) + if thread_metadata: + for key, value in thread_metadata.items(): + value_str = value if isinstance(value, str) else json.dumps(value) + if len(value_str) > 512: + value_str = value_str[:512] + safe_metadata[key] = value_str + + run_kwargs: dict[str, Any] = { + "thread": thread, + "tools": tools_param, + "metadata": safe_metadata, + } + if safe_metadata: + run_kwargs["store"] = True + + async for update in context.agent.run_stream(messages_to_run, **run_kwargs): update_count += 1 logger.info(f"[STREAM] Received update #{update_count} from agent") all_updates.append(update) diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index 3b4e59c9e4..beb6f8af2c 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -425,8 +425,8 @@ async def test_thread_metadata_tracking(): async def stream_fn( messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - if "thread" in kwargs: - thread_metadata.update(kwargs["thread"].metadata) + if chat_options.metadata: + thread_metadata.update(chat_options.metadata) yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) @@ -455,8 +455,8 @@ async def test_state_context_injection(): async def stream_fn( messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - if "thread" in kwargs: - thread_metadata.update(kwargs["thread"].metadata) + if chat_options.metadata: + thread_metadata.update(chat_options.metadata) yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) @@ -474,7 +474,10 @@ async def stream_fn( async for event in wrapper.run_agent(input_data): events.append(event) - assert thread_metadata.get("current_state") == {"document": "Test content"} + current_state = thread_metadata.get("current_state") + if isinstance(current_state, str): + current_state = json.loads(current_state) + assert current_state == {"document": "Test content"} async def test_no_messages_provided(): diff --git a/python/packages/ag-ui/tests/test_orchestrators.py b/python/packages/ag-ui/tests/test_orchestrators.py index 74a2083e32..10843a259c 100644 --- a/python/packages/ag-ui/tests/test_orchestrators.py +++ b/python/packages/ag-ui/tests/test_orchestrators.py @@ -36,6 +36,7 @@ async def run_stream( *, thread: Any, tools: list[Any] | None = None, + **kwargs: Any, ) -> AsyncGenerator[AgentRunResponseUpdate, None]: self.seen_tools = tools yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant") diff --git a/python/packages/ag-ui/tests/test_orchestrators_coverage.py b/python/packages/ag-ui/tests/test_orchestrators_coverage.py index a9c63b8a81..1da11bffbc 100644 --- a/python/packages/ag-ui/tests/test_orchestrators_coverage.py +++ b/python/packages/ag-ui/tests/test_orchestrators_coverage.py @@ -743,6 +743,7 @@ async def run_stream( *, thread: Any = None, tools: list[Any] | None = None, + **kwargs: Any, ) -> AsyncGenerator[AgentRunResponseUpdate, None]: self.messages_received = messages yield AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant") diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 5f50c0e0e3..7ded0af8d5 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -879,7 +879,6 @@ async def run( response = await self.chat_client.get_response( messages=thread_messages, chat_options=co, - thread=thread, **filtered_kwargs, ) @@ -1020,7 +1019,6 @@ async def run_stream( async for update in self.chat_client.get_streaming_response( messages=thread_messages, chat_options=co, - thread=thread, **filtered_kwargs, ): response_updates.append(update)