Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions python/packages/ag-ui/agent_framework_ag_ui/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""AgentFrameworkAgent wrapper for AG-UI protocol - Clean Architecture."""

from collections.abc import AsyncGenerator
from typing import Any
from typing import Any, cast

from ag_ui.core import BaseEvent
from agent_framework import AgentProtocol
Expand All @@ -22,21 +22,48 @@ class AgentConfig:

def __init__(
self,
state_schema: dict[str, Any] | None = None,
state_schema: Any | None = None,
predict_state_config: dict[str, dict[str, str]] | None = None,
require_confirmation: bool = True,
):
"""Initialize agent configuration.

Args:
state_schema: Optional state schema for state management
state_schema: Optional state schema for state management; accepts dict or Pydantic model/class
predict_state_config: Configuration for predictive state updates
require_confirmation: Whether predictive updates require confirmation
"""
self.state_schema = state_schema or {}
self.state_schema = self._normalize_state_schema(state_schema)
self.predict_state_config = predict_state_config or {}
self.require_confirmation = require_confirmation

@staticmethod
def _normalize_state_schema(state_schema: Any | None) -> dict[str, Any]:
"""Accept dict or Pydantic model/class and return a properties dict."""
if state_schema is None:
return {}

if isinstance(state_schema, dict):
return cast(dict[str, Any], state_schema)

base_model_type: type[Any] | None
try:
from pydantic import BaseModel as ImportedBaseModel

base_model_type = ImportedBaseModel
except Exception: # pragma: no cover
base_model_type = None

if base_model_type is not None and isinstance(state_schema, base_model_type):
schema_dict = state_schema.__class__.model_json_schema()
return schema_dict.get("properties", {}) or {}

if base_model_type is not None and isinstance(state_schema, type) and issubclass(state_schema, base_model_type):
schema_dict = state_schema.model_json_schema()
return schema_dict.get("properties", {}) or {}

return {}


class AgentFrameworkAgent:
"""Wraps Agent Framework agents for AG-UI protocol compatibility.
Expand All @@ -55,7 +82,7 @@ def __init__(
agent: AgentProtocol,
name: str | None = None,
description: str | None = None,
state_schema: dict[str, Any] | None = None,
state_schema: Any | None = None,
predict_state_config: dict[str, dict[str, str]] | None = None,
require_confirmation: bool = True,
orchestrators: list[Orchestrator] | None = None,
Expand All @@ -67,7 +94,7 @@ def __init__(
agent: The Agent Framework agent to wrap
name: Optional name for the agent
description: Optional description
state_schema: Optional state schema for state management
state_schema: Optional state schema for state management; accepts dict or Pydantic model/class
predict_state_config: Configuration for predictive state updates.
Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}}
require_confirmation: Whether predictive updates require confirmation.
Expand Down
12 changes: 10 additions & 2 deletions python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

"""FastAPI endpoint creation for AG-UI agents."""

import copy
import logging
from typing import Any

Expand All @@ -19,20 +20,22 @@ def add_agent_framework_fastapi_endpoint(
app: FastAPI,
agent: AgentProtocol | AgentFrameworkAgent,
path: str = "/",
state_schema: dict[str, Any] | None = None,
state_schema: Any | None = None,
predict_state_config: dict[str, dict[str, str]] | None = None,
allow_origins: list[str] | None = None,
default_state: dict[str, Any] | None = None,
) -> None:
"""Add an AG-UI endpoint to a FastAPI app.

Args:
app: The FastAPI application
agent: The agent to expose (can be raw AgentProtocol or wrapped)
path: The endpoint path
state_schema: Optional state schema for shared state management
state_schema: Optional state schema for shared state management; accepts dict or Pydantic model/class
predict_state_config: Optional predictive state update configuration.
Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}}
allow_origins: CORS origins (not yet implemented)
default_state: Optional initial state to seed when the client does not provide state keys
"""
if isinstance(agent, AgentProtocol):
wrapped_agent = AgentFrameworkAgent(
Expand All @@ -52,6 +55,11 @@ async def agent_endpoint(request: Request): # type: ignore[misc]
"""
try:
input_data = await request.json()
if default_state:
state = input_data.setdefault("state", {})
for key, value in default_state.items():
if key not in state:
state[key] = copy.deepcopy(value)
logger.debug(
f"[{path}] Received request - Run ID: {input_data.get('run_id', 'no-run-id')}, "
f"Thread ID: {input_data.get('thread_id', 'no-thread-id')}, "
Expand Down
Loading
Loading