From d4f21bf90d7b57bd863e3692a5a8431bd91daeff Mon Sep 17 00:00:00 2001 From: Mario Mol Date: Tue, 13 May 2025 00:28:26 -0300 Subject: [PATCH 1/4] Base models for AI response format. Including initial testes for serve --- .../autogenstudio/datamodel/types.py | 69 ++++++++++++++++++- .../autogen-studio/tests/test_serve.py | 53 ++++++++++++++ 2 files changed, 119 insertions(+), 3 deletions(-) create mode 100644 python/packages/autogen-studio/tests/test_serve.py diff --git a/python/packages/autogen-studio/autogenstudio/datamodel/types.py b/python/packages/autogen-studio/autogenstudio/datamodel/types.py index e36ad0e8edf8..5a23a5140336 100644 --- a/python/packages/autogen-studio/autogenstudio/datamodel/types.py +++ b/python/packages/autogen-studio/autogenstudio/datamodel/types.py @@ -1,13 +1,13 @@ # from dataclasses import Field from datetime import datetime -from typing import Any, Dict, List, Literal, Optional, Sequence +from typing import Any, Dict, List, Literal, Optional, Sequence, Union from autogen_agentchat.base import TaskResult from autogen_agentchat.messages import ChatMessage, TextMessage from autogen_core import ComponentModel from autogen_core.models import UserMessage from autogen_ext.models.openai import OpenAIChatCompletionClient -from pydantic import BaseModel, ConfigDict, SecretStr +from pydantic import BaseModel, ConfigDict, SecretStr, Field class MessageConfig(BaseModel): @@ -109,10 +109,73 @@ class SettingsConfig(BaseModel): # web request/response data models +class RequestUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + + +class FunctionCall(BaseModel): + id: str + arguments: str # Could also be Dict[str, Any] if parsed + name: str + + +class FunctionExecutionResult(BaseModel): + content: str + name: str + call_id: str + is_error: bool + + +class BaseMessage(BaseModel): + source: str + models_usage: Optional[RequestUsage] = None + metadata: Dict[str, Optional[str]] = Field(default_factory=dict) + type: str # Overridden by subclasses + + +class TextMessage(BaseMessage): + content: str + type: Literal["TextMessage"] = "TextMessage" + + +class ToolCallRequestEvent(BaseMessage): + content: List[FunctionCall] + type: Literal["ToolCallRequestEvent"] = "ToolCallRequestEvent" + + +class ToolCallExecutionEvent(BaseMessage): + content: List[FunctionExecutionResult] + type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent" + + +class ToolCallSummaryMessage(BaseMessage): + content: str + type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage" + + +MessageUnion = Union[ + TextMessage, + ToolCallRequestEvent, + ToolCallExecutionEvent, + ToolCallSummaryMessage, +] + +class TaskResult(BaseModel): + messages: List[MessageUnion] + stop_reason: Optional[str] = None + + +class TaskResponse(BaseModel): + task_result: TaskResult + usage: Optional[str] = "" + duration: float + + class Response(BaseModel): message: str status: bool - data: Optional[Any] = None + data: Optional[TaskResponse] = None class SocketMessage(BaseModel): diff --git a/python/packages/autogen-studio/tests/test_serve.py b/python/packages/autogen-studio/tests/test_serve.py new file mode 100644 index 000000000000..e7c418c4c7af --- /dev/null +++ b/python/packages/autogen-studio/tests/test_serve.py @@ -0,0 +1,53 @@ +import os +from fastapi.testclient import TestClient +from autogenstudio.web.serve import app + +client = TestClient(app) + +def test_predict_success(monkeypatch): + # Mock environment variable + monkeypatch.setenv("AUTOGENSTUDIO_TEAM_FILE", "test_team_config.json") + + # Mock the team_manager.run method + async def mock_run(*args, **kwargs): + assert kwargs.get('task') == 'test_task', f"Expected task='test_task', got {kwargs.get('task')}" + return "Test result" + + from autogenstudio.web.serve import team_manager + team_manager.run = mock_run + + response = client.get("/predict/test_task") + assert response.status_code == 200 + data = response.json() + assert data["status"] is True + assert data["message"] == "Task successfully completed" + assert data["data"] == "Test result" + +def test_predict_missing_env_var(): + # Ensure environment variable is not set + if "AUTOGENSTUDIO_TEAM_FILE" in os.environ: + del os.environ["AUTOGENSTUDIO_TEAM_FILE"] + + response = client.get("/predict/test_task") + assert response.status_code == 200 + data = response.json() + assert data["status"] is False + assert "AUTOGENSTUDIO_TEAM_FILE environment variable is not set" in data["message"] + +def test_predict_team_manager_error(monkeypatch): + # Mock environment variable + monkeypatch.setenv("AUTOGENSTUDIO_TEAM_FILE", "test_team_config.json") + + # Mock the team_manager.run method to raise an exception + async def mock_run(*args, **kwargs): + assert kwargs.get('task') == 'test_task', f"Expected task='test_task', got {kwargs.get('task')}" + raise Exception("Test error") + + from autogenstudio.web.serve import team_manager + team_manager.run = mock_run + + response = client.get("/predict/test_task") + assert response.status_code == 200 + data = response.json() + assert data["status"] is False + assert data["message"] == "Test error" \ No newline at end of file From b33629725c4a8083fec4c0a2e6478bc22655b337 Mon Sep 17 00:00:00 2001 From: Mario Mol Date: Tue, 13 May 2025 01:01:13 -0300 Subject: [PATCH 2/4] revert --- .../autogenstudio/datamodel/types.py | 69 +------------------ 1 file changed, 3 insertions(+), 66 deletions(-) diff --git a/python/packages/autogen-studio/autogenstudio/datamodel/types.py b/python/packages/autogen-studio/autogenstudio/datamodel/types.py index 5a23a5140336..e36ad0e8edf8 100644 --- a/python/packages/autogen-studio/autogenstudio/datamodel/types.py +++ b/python/packages/autogen-studio/autogenstudio/datamodel/types.py @@ -1,13 +1,13 @@ # from dataclasses import Field from datetime import datetime -from typing import Any, Dict, List, Literal, Optional, Sequence, Union +from typing import Any, Dict, List, Literal, Optional, Sequence from autogen_agentchat.base import TaskResult from autogen_agentchat.messages import ChatMessage, TextMessage from autogen_core import ComponentModel from autogen_core.models import UserMessage from autogen_ext.models.openai import OpenAIChatCompletionClient -from pydantic import BaseModel, ConfigDict, SecretStr, Field +from pydantic import BaseModel, ConfigDict, SecretStr class MessageConfig(BaseModel): @@ -109,73 +109,10 @@ class SettingsConfig(BaseModel): # web request/response data models -class RequestUsage(BaseModel): - prompt_tokens: int - completion_tokens: int - - -class FunctionCall(BaseModel): - id: str - arguments: str # Could also be Dict[str, Any] if parsed - name: str - - -class FunctionExecutionResult(BaseModel): - content: str - name: str - call_id: str - is_error: bool - - -class BaseMessage(BaseModel): - source: str - models_usage: Optional[RequestUsage] = None - metadata: Dict[str, Optional[str]] = Field(default_factory=dict) - type: str # Overridden by subclasses - - -class TextMessage(BaseMessage): - content: str - type: Literal["TextMessage"] = "TextMessage" - - -class ToolCallRequestEvent(BaseMessage): - content: List[FunctionCall] - type: Literal["ToolCallRequestEvent"] = "ToolCallRequestEvent" - - -class ToolCallExecutionEvent(BaseMessage): - content: List[FunctionExecutionResult] - type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent" - - -class ToolCallSummaryMessage(BaseMessage): - content: str - type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage" - - -MessageUnion = Union[ - TextMessage, - ToolCallRequestEvent, - ToolCallExecutionEvent, - ToolCallSummaryMessage, -] - -class TaskResult(BaseModel): - messages: List[MessageUnion] - stop_reason: Optional[str] = None - - -class TaskResponse(BaseModel): - task_result: TaskResult - usage: Optional[str] = "" - duration: float - - class Response(BaseModel): message: str status: bool - data: Optional[TaskResponse] = None + data: Optional[Any] = None class SocketMessage(BaseModel): From 979bc3817b3cebb5f27ff62d510525746f37999f Mon Sep 17 00:00:00 2001 From: Mario Mol Date: Tue, 13 May 2025 02:16:47 -0300 Subject: [PATCH 3/4] forcing model dump --- .../autogen-studio/autogenstudio/web/serve.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/python/packages/autogen-studio/autogenstudio/web/serve.py b/python/packages/autogen-studio/autogenstudio/web/serve.py index 8dbbf1823328..9b09baf200ce 100644 --- a/python/packages/autogen-studio/autogenstudio/web/serve.py +++ b/python/packages/autogen-studio/autogenstudio/web/serve.py @@ -1,6 +1,8 @@ import os from fastapi import FastAPI +from pydantic import BaseModel +from typing import Any from ..datamodel import Response from ..teammanager import TeamManager @@ -20,8 +22,26 @@ async def predict(task: str): raise ValueError("AUTOGENSTUDIO_TEAM_FILE environment variable is not set") result_message = await team_manager.run(task=task, team_config=team_file_path) - response.data = result_message + response.data = force_model_dump(result_message) except Exception as e: response.message = str(e) response.status = False return response + +def force_model_dump(obj: Any) -> Any: + """ + Force dump all fields of a Pydantic BaseModel, even when inherited + from ABCs as BaseAgentEvent and BaseChatMessage. + """ + if isinstance(obj, BaseModel): + output = {} + for name, _field in obj.__fields__.items(): + value = getattr(obj, name) + output[name] = force_model_dump(value) + return output + elif isinstance(obj, list): + return [force_model_dump(item) for item in obj] + elif isinstance(obj, dict): + return {k: force_model_dump(v) for k, v in obj.items()} + else: + return obj \ No newline at end of file From cc8c18f01080aaae9c1e46ecf8dc2921fdb619d7 Mon Sep 17 00:00:00 2001 From: Mario Mol Date: Tue, 13 May 2025 02:28:36 -0300 Subject: [PATCH 4/4] fixing unit test --- .../autogen-studio/autogenstudio/web/serve.py | 2 +- .../autogen-studio/tests/test_serve.py | 24 ++++++++++++++----- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/python/packages/autogen-studio/autogenstudio/web/serve.py b/python/packages/autogen-studio/autogenstudio/web/serve.py index 9b09baf200ce..7e98ae0801b6 100644 --- a/python/packages/autogen-studio/autogenstudio/web/serve.py +++ b/python/packages/autogen-studio/autogenstudio/web/serve.py @@ -35,7 +35,7 @@ def force_model_dump(obj: Any) -> Any: """ if isinstance(obj, BaseModel): output = {} - for name, _field in obj.__fields__.items(): + for name, _field in obj.model_fields.items(): value = getattr(obj, name) output[name] = force_model_dump(value) return output diff --git a/python/packages/autogen-studio/tests/test_serve.py b/python/packages/autogen-studio/tests/test_serve.py index e7c418c4c7af..ad8219f0181a 100644 --- a/python/packages/autogen-studio/tests/test_serve.py +++ b/python/packages/autogen-studio/tests/test_serve.py @@ -1,17 +1,29 @@ import os from fastapi.testclient import TestClient from autogenstudio.web.serve import app +from autogen_agentchat.messages import TextMessage +from autogen_agentchat.base import TaskResult +from autogenstudio.datamodel.types import TeamResult client = TestClient(app) def test_predict_success(monkeypatch): - # Mock environment variable monkeypatch.setenv("AUTOGENSTUDIO_TEAM_FILE", "test_team_config.json") - - # Mock the team_manager.run method + async def mock_run(*args, **kwargs): assert kwargs.get('task') == 'test_task', f"Expected task='test_task', got {kwargs.get('task')}" - return "Test result" + text_message = TextMessage( + source="agent1", + content="Mission accomplished.", + metadata={"topic": "status"} + ) + task_result = TaskResult(messages=[text_message], stop_reason="test") + team_result = TeamResult( + task_result=task_result, + usage="3 tokens", + duration=0.45 + ) + return team_result from autogenstudio.web.serve import team_manager team_manager.run = mock_run @@ -21,7 +33,8 @@ async def mock_run(*args, **kwargs): data = response.json() assert data["status"] is True assert data["message"] == "Task successfully completed" - assert data["data"] == "Test result" + # It should be able to serialize the message content + assert data["data"]['task_result']['messages'][0]['content'] == "Mission accomplished." def test_predict_missing_env_var(): # Ensure environment variable is not set @@ -35,7 +48,6 @@ def test_predict_missing_env_var(): assert "AUTOGENSTUDIO_TEAM_FILE environment variable is not set" in data["message"] def test_predict_team_manager_error(monkeypatch): - # Mock environment variable monkeypatch.setenv("AUTOGENSTUDIO_TEAM_FILE", "test_team_config.json") # Mock the team_manager.run method to raise an exception