Skip to content

AutoGen Studio: Serve responding complete messages #6520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion python/packages/autogen-studio/autogenstudio/web/serve.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os

from fastapi import FastAPI
from pydantic import BaseModel
from typing import Any

from ..datamodel import Response
from ..teammanager import TeamManager
Expand All @@ -20,8 +22,26 @@ async def predict(task: str):
raise ValueError("AUTOGENSTUDIO_TEAM_FILE environment variable is not set")

result_message = await team_manager.run(task=task, team_config=team_file_path)
response.data = result_message
response.data = force_model_dump(result_message)
except Exception as e:
response.message = str(e)
response.status = False
return response

def force_model_dump(obj: Any) -> Any:
"""
Force dump all fields of a Pydantic BaseModel, even when inherited
from ABCs as BaseAgentEvent and BaseChatMessage.
"""
if isinstance(obj, BaseModel):
output = {}
for name, _field in obj.model_fields.items():
value = getattr(obj, name)
output[name] = force_model_dump(value)
return output
elif isinstance(obj, list):
return [force_model_dump(item) for item in obj]
elif isinstance(obj, dict):
return {k: force_model_dump(v) for k, v in obj.items()}
else:
return obj
65 changes: 65 additions & 0 deletions python/packages/autogen-studio/tests/test_serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
from fastapi.testclient import TestClient
from autogenstudio.web.serve import app
from autogen_agentchat.messages import TextMessage
from autogen_agentchat.base import TaskResult
from autogenstudio.datamodel.types import TeamResult

client = TestClient(app)

def test_predict_success(monkeypatch):
monkeypatch.setenv("AUTOGENSTUDIO_TEAM_FILE", "test_team_config.json")

async def mock_run(*args, **kwargs):
assert kwargs.get('task') == 'test_task', f"Expected task='test_task', got {kwargs.get('task')}"
text_message = TextMessage(
source="agent1",
content="Mission accomplished.",
metadata={"topic": "status"}
)
task_result = TaskResult(messages=[text_message], stop_reason="test")
team_result = TeamResult(
task_result=task_result,
usage="3 tokens",
duration=0.45
)
return team_result

from autogenstudio.web.serve import team_manager
team_manager.run = mock_run

response = client.get("/predict/test_task")
assert response.status_code == 200
data = response.json()
assert data["status"] is True
assert data["message"] == "Task successfully completed"
# It should be able to serialize the message content
assert data["data"]['task_result']['messages'][0]['content'] == "Mission accomplished."

def test_predict_missing_env_var():
# Ensure environment variable is not set
if "AUTOGENSTUDIO_TEAM_FILE" in os.environ:
del os.environ["AUTOGENSTUDIO_TEAM_FILE"]

response = client.get("/predict/test_task")
assert response.status_code == 200
data = response.json()
assert data["status"] is False
assert "AUTOGENSTUDIO_TEAM_FILE environment variable is not set" in data["message"]

def test_predict_team_manager_error(monkeypatch):
monkeypatch.setenv("AUTOGENSTUDIO_TEAM_FILE", "test_team_config.json")

# Mock the team_manager.run method to raise an exception
async def mock_run(*args, **kwargs):
assert kwargs.get('task') == 'test_task', f"Expected task='test_task', got {kwargs.get('task')}"
raise Exception("Test error")

from autogenstudio.web.serve import team_manager
team_manager.run = mock_run

response = client.get("/predict/test_task")
assert response.status_code == 200
data = response.json()
assert data["status"] is False
assert data["message"] == "Test error"
Loading