33"""AgentFrameworkAgent wrapper for AG-UI protocol - Clean Architecture."""
44
55from collections .abc import AsyncGenerator
6- from typing import Any
6+ from typing import Any , cast
77
88from ag_ui .core import BaseEvent
99from agent_framework import AgentProtocol
@@ -22,21 +22,48 @@ class AgentConfig:
2222
2323 def __init__ (
2424 self ,
25- state_schema : dict [ str , Any ] | None = None ,
25+ state_schema : Any | None = None ,
2626 predict_state_config : dict [str , dict [str , str ]] | None = None ,
2727 require_confirmation : bool = True ,
2828 ):
2929 """Initialize agent configuration.
3030
3131 Args:
32- state_schema: Optional state schema for state management
32+ state_schema: Optional state schema for state management; accepts dict or Pydantic model/class
3333 predict_state_config: Configuration for predictive state updates
3434 require_confirmation: Whether predictive updates require confirmation
3535 """
36- self .state_schema = state_schema or {}
36+ self .state_schema = self . _normalize_state_schema ( state_schema )
3737 self .predict_state_config = predict_state_config or {}
3838 self .require_confirmation = require_confirmation
3939
40+ @staticmethod
41+ def _normalize_state_schema (state_schema : Any | None ) -> dict [str , Any ]:
42+ """Accept dict or Pydantic model/class and return a properties dict."""
43+ if state_schema is None :
44+ return {}
45+
46+ if isinstance (state_schema , dict ):
47+ return cast (dict [str , Any ], state_schema )
48+
49+ base_model_type : type [Any ] | None
50+ try :
51+ from pydantic import BaseModel as ImportedBaseModel
52+
53+ base_model_type = ImportedBaseModel
54+ except Exception : # pragma: no cover
55+ base_model_type = None
56+
57+ if base_model_type is not None and isinstance (state_schema , base_model_type ):
58+ schema_dict = state_schema .__class__ .model_json_schema ()
59+ return schema_dict .get ("properties" , {}) or {}
60+
61+ if base_model_type is not None and isinstance (state_schema , type ) and issubclass (state_schema , base_model_type ):
62+ schema_dict = state_schema .model_json_schema ()
63+ return schema_dict .get ("properties" , {}) or {}
64+
65+ return {}
66+
4067
4168class AgentFrameworkAgent :
4269 """Wraps Agent Framework agents for AG-UI protocol compatibility.
@@ -55,7 +82,7 @@ def __init__(
5582 agent : AgentProtocol ,
5683 name : str | None = None ,
5784 description : str | None = None ,
58- state_schema : dict [ str , Any ] | None = None ,
85+ state_schema : Any | None = None ,
5986 predict_state_config : dict [str , dict [str , str ]] | None = None ,
6087 require_confirmation : bool = True ,
6188 orchestrators : list [Orchestrator ] | None = None ,
@@ -67,7 +94,7 @@ def __init__(
6794 agent: The Agent Framework agent to wrap
6895 name: Optional name for the agent
6996 description: Optional description
70- state_schema: Optional state schema for state management
97+ state_schema: Optional state schema for state management; accepts dict or Pydantic model/class
7198 predict_state_config: Configuration for predictive state updates.
7299 Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}}
73100 require_confirmation: Whether predictive updates require confirmation.
0 commit comments