Skip to content

Commit 8cf8b0f

Browse files
authored
Python: Refactor ag-ui to clean up some patterns (#2363)
* Refactor ag-ui to clean up some patterns * Mypy fixes * Fix imports, typing, tests, logging. * Fix test import error * Fix imports again * Fix thread handling
1 parent 6c62431 commit 8cf8b0f

26 files changed

+1860
-1388
lines changed

python/packages/ag-ui/agent_framework_ag_ui/_agent.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""AgentFrameworkAgent wrapper for AG-UI protocol - Clean Architecture."""
44

55
from collections.abc import AsyncGenerator
6-
from typing import Any
6+
from typing import Any, cast
77

88
from ag_ui.core import BaseEvent
99
from agent_framework import AgentProtocol
@@ -22,21 +22,48 @@ class AgentConfig:
2222

2323
def __init__(
2424
self,
25-
state_schema: dict[str, Any] | None = None,
25+
state_schema: Any | None = None,
2626
predict_state_config: dict[str, dict[str, str]] | None = None,
2727
require_confirmation: bool = True,
2828
):
2929
"""Initialize agent configuration.
3030
3131
Args:
32-
state_schema: Optional state schema for state management
32+
state_schema: Optional state schema for state management; accepts dict or Pydantic model/class
3333
predict_state_config: Configuration for predictive state updates
3434
require_confirmation: Whether predictive updates require confirmation
3535
"""
36-
self.state_schema = state_schema or {}
36+
self.state_schema = self._normalize_state_schema(state_schema)
3737
self.predict_state_config = predict_state_config or {}
3838
self.require_confirmation = require_confirmation
3939

40+
@staticmethod
41+
def _normalize_state_schema(state_schema: Any | None) -> dict[str, Any]:
42+
"""Accept dict or Pydantic model/class and return a properties dict."""
43+
if state_schema is None:
44+
return {}
45+
46+
if isinstance(state_schema, dict):
47+
return cast(dict[str, Any], state_schema)
48+
49+
base_model_type: type[Any] | None
50+
try:
51+
from pydantic import BaseModel as ImportedBaseModel
52+
53+
base_model_type = ImportedBaseModel
54+
except Exception: # pragma: no cover
55+
base_model_type = None
56+
57+
if base_model_type is not None and isinstance(state_schema, base_model_type):
58+
schema_dict = state_schema.__class__.model_json_schema()
59+
return schema_dict.get("properties", {}) or {}
60+
61+
if base_model_type is not None and isinstance(state_schema, type) and issubclass(state_schema, base_model_type):
62+
schema_dict = state_schema.model_json_schema()
63+
return schema_dict.get("properties", {}) or {}
64+
65+
return {}
66+
4067

4168
class AgentFrameworkAgent:
4269
"""Wraps Agent Framework agents for AG-UI protocol compatibility.
@@ -55,7 +82,7 @@ def __init__(
5582
agent: AgentProtocol,
5683
name: str | None = None,
5784
description: str | None = None,
58-
state_schema: dict[str, Any] | None = None,
85+
state_schema: Any | None = None,
5986
predict_state_config: dict[str, dict[str, str]] | None = None,
6087
require_confirmation: bool = True,
6188
orchestrators: list[Orchestrator] | None = None,
@@ -67,7 +94,7 @@ def __init__(
6794
agent: The Agent Framework agent to wrap
6895
name: Optional name for the agent
6996
description: Optional description
70-
state_schema: Optional state schema for state management
97+
state_schema: Optional state schema for state management; accepts dict or Pydantic model/class
7198
predict_state_config: Configuration for predictive state updates.
7299
Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}}
73100
require_confirmation: Whether predictive updates require confirmation.

python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
"""FastAPI endpoint creation for AG-UI agents."""
44

5+
import copy
56
import logging
67
from typing import Any
78

@@ -19,20 +20,22 @@ def add_agent_framework_fastapi_endpoint(
1920
app: FastAPI,
2021
agent: AgentProtocol | AgentFrameworkAgent,
2122
path: str = "/",
22-
state_schema: dict[str, Any] | None = None,
23+
state_schema: Any | None = None,
2324
predict_state_config: dict[str, dict[str, str]] | None = None,
2425
allow_origins: list[str] | None = None,
26+
default_state: dict[str, Any] | None = None,
2527
) -> None:
2628
"""Add an AG-UI endpoint to a FastAPI app.
2729
2830
Args:
2931
app: The FastAPI application
3032
agent: The agent to expose (can be raw AgentProtocol or wrapped)
3133
path: The endpoint path
32-
state_schema: Optional state schema for shared state management
34+
state_schema: Optional state schema for shared state management; accepts dict or Pydantic model/class
3335
predict_state_config: Optional predictive state update configuration.
3436
Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}}
3537
allow_origins: CORS origins (not yet implemented)
38+
default_state: Optional initial state to seed when the client does not provide state keys
3639
"""
3740
if isinstance(agent, AgentProtocol):
3841
wrapped_agent = AgentFrameworkAgent(
@@ -52,6 +55,11 @@ async def agent_endpoint(request: Request): # type: ignore[misc]
5255
"""
5356
try:
5457
input_data = await request.json()
58+
if default_state:
59+
state = input_data.setdefault("state", {})
60+
for key, value in default_state.items():
61+
if key not in state:
62+
state[key] = copy.deepcopy(value)
5563
logger.debug(
5664
f"[{path}] Received request - Run ID: {input_data.get('run_id', 'no-run-id')}, "
5765
f"Thread ID: {input_data.get('thread_id', 'no-thread-id')}, "

0 commit comments

Comments
 (0)