diff --git a/python/packages/autogen-studio/autogenstudio/web/serve.py b/python/packages/autogen-studio/autogenstudio/web/serve.py index 8dbbf1823328..7e98ae0801b6 100644 --- a/python/packages/autogen-studio/autogenstudio/web/serve.py +++ b/python/packages/autogen-studio/autogenstudio/web/serve.py @@ -1,6 +1,8 @@ import os from fastapi import FastAPI +from pydantic import BaseModel +from typing import Any from ..datamodel import Response from ..teammanager import TeamManager @@ -20,8 +22,26 @@ async def predict(task: str): raise ValueError("AUTOGENSTUDIO_TEAM_FILE environment variable is not set") result_message = await team_manager.run(task=task, team_config=team_file_path) - response.data = result_message + response.data = force_model_dump(result_message) except Exception as e: response.message = str(e) response.status = False return response + +def force_model_dump(obj: Any) -> Any: + """ + Force dump all fields of a Pydantic BaseModel, even when inherited + from ABCs as BaseAgentEvent and BaseChatMessage. + """ + if isinstance(obj, BaseModel): + output = {} + for name, _field in obj.model_fields.items(): + value = getattr(obj, name) + output[name] = force_model_dump(value) + return output + elif isinstance(obj, list): + return [force_model_dump(item) for item in obj] + elif isinstance(obj, dict): + return {k: force_model_dump(v) for k, v in obj.items()} + else: + return obj \ No newline at end of file diff --git a/python/packages/autogen-studio/tests/test_serve.py b/python/packages/autogen-studio/tests/test_serve.py new file mode 100644 index 000000000000..ad8219f0181a --- /dev/null +++ b/python/packages/autogen-studio/tests/test_serve.py @@ -0,0 +1,65 @@ +import os +from fastapi.testclient import TestClient +from autogenstudio.web.serve import app +from autogen_agentchat.messages import TextMessage +from autogen_agentchat.base import TaskResult +from autogenstudio.datamodel.types import TeamResult + +client = TestClient(app) + +def test_predict_success(monkeypatch): + monkeypatch.setenv("AUTOGENSTUDIO_TEAM_FILE", "test_team_config.json") + + async def mock_run(*args, **kwargs): + assert kwargs.get('task') == 'test_task', f"Expected task='test_task', got {kwargs.get('task')}" + text_message = TextMessage( + source="agent1", + content="Mission accomplished.", + metadata={"topic": "status"} + ) + task_result = TaskResult(messages=[text_message], stop_reason="test") + team_result = TeamResult( + task_result=task_result, + usage="3 tokens", + duration=0.45 + ) + return team_result + + from autogenstudio.web.serve import team_manager + team_manager.run = mock_run + + response = client.get("/predict/test_task") + assert response.status_code == 200 + data = response.json() + assert data["status"] is True + assert data["message"] == "Task successfully completed" + # It should be able to serialize the message content + assert data["data"]['task_result']['messages'][0]['content'] == "Mission accomplished." + +def test_predict_missing_env_var(): + # Ensure environment variable is not set + if "AUTOGENSTUDIO_TEAM_FILE" in os.environ: + del os.environ["AUTOGENSTUDIO_TEAM_FILE"] + + response = client.get("/predict/test_task") + assert response.status_code == 200 + data = response.json() + assert data["status"] is False + assert "AUTOGENSTUDIO_TEAM_FILE environment variable is not set" in data["message"] + +def test_predict_team_manager_error(monkeypatch): + monkeypatch.setenv("AUTOGENSTUDIO_TEAM_FILE", "test_team_config.json") + + # Mock the team_manager.run method to raise an exception + async def mock_run(*args, **kwargs): + assert kwargs.get('task') == 'test_task', f"Expected task='test_task', got {kwargs.get('task')}" + raise Exception("Test error") + + from autogenstudio.web.serve import team_manager + team_manager.run = mock_run + + response = client.get("/predict/test_task") + assert response.status_code == 200 + data = response.json() + assert data["status"] is False + assert data["message"] == "Test error" \ No newline at end of file