Skip to content

Fix - Issue #171 #235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/mcp_agent/agents/workflow/chain_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,21 @@ async def generate(
# # Get the original user message (last message in the list)
user_message = multipart_messages[-1] if multipart_messages else None

aggregator = getattr(self.context, "response_aggregator", None)

if not self.cumulative:
response: PromptMessageMultipart = await self.agents[0].generate(multipart_messages)
if aggregator:
await aggregator.add_agent_response(self.agents[0].name, response.all_text())
# Process the rest of the agents in the chain
for agent in self.agents[1:]:
next_message = Prompt.user(*response.content)
response = await agent.generate([next_message])
if aggregator:
await aggregator.add_agent_response(agent.name, response.all_text())

if aggregator and await aggregator.should_send_response():
await aggregator.get_aggregated_response()

return response

Expand All @@ -96,6 +105,8 @@ async def generate(
chain_messages = multipart_messages.copy()
chain_messages.extend(all_responses)
current_response = await agent.generate(chain_messages, request_params)
if aggregator:
await aggregator.add_agent_response(agent.name, current_response.all_text())

# Store the response
all_responses.append(current_response)
Expand All @@ -111,10 +122,16 @@ async def generate(

# For cumulative mode, return the properly formatted output with XML tags
response_text = "\n\n".join(final_results)
return PromptMessageMultipart(
final_message = PromptMessageMultipart(
role="assistant",
content=[TextContent(type="text", text=response_text)],
)
if aggregator:
await aggregator.add_agent_response(self.name, response_text)
if await aggregator.should_send_response():
await aggregator.get_aggregated_response()

return final_message

async def structured(
self,
Expand Down
37 changes: 30 additions & 7 deletions src/mcp_agent/mcp_server/agent_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,29 @@ def register_agent_tools(self, agent_name: str, agent) -> None:
)
async def send_message(message: str, ctx: MCPContext) -> str:
"""Send a message to the agent and return its response."""
# Get the agent's context
from mcp_agent.agents.workflow.chain_agent import ChainAgent

# For chain agents, handle execution without SSE aggregation
if isinstance(agent, ChainAgent):
response = await agent.send(message)

if hasattr(response, "all_text"):
return response.all_text()
elif isinstance(response, dict):
import json

return json.dumps(response)
return str(response)

# Non-chain agents use normal flow
agent_context = getattr(agent, "context", None)

# Define the function to execute
async def execute_send():
return await agent.send(message)

# Execute with bridged context
if agent_context and ctx:
return await self.with_bridged_context(agent_context, ctx, execute_send)
else:
return await execute_send()
return await self.with_bridged_context(agent_context, ctx, execute_send)
return await execute_send()

# Register a history prompt for this agent
@self.mcp_server.prompt(
Expand Down Expand Up @@ -368,7 +379,14 @@ async def _close_sse_connections(self):
except Exception as e:
logger.error(f"Error during ASGI lifespan shutdown: {e}")

async def with_bridged_context(self, agent_context, mcp_context, func, *args, **kwargs):
async def with_bridged_context(
self,
agent_context,
mcp_context,
func,
*args,
**kwargs,
):
"""
Execute a function with bridged context between MCP and agent

Expand Down Expand Up @@ -397,6 +415,9 @@ async def bridged_progress(progress, total=None) -> None:
if hasattr(agent_context, "progress_reporter"):
agent_context.progress_reporter = bridged_progress

if aggregator is not None:
agent_context.response_aggregator = aggregator

try:
# Call the function
return await func(*args, **kwargs)
Expand All @@ -408,6 +429,8 @@ async def bridged_progress(progress, total=None) -> None:
# Remove MCP context reference
if hasattr(agent_context, "mcp_context"):
delattr(agent_context, "mcp_context")
if aggregator is not None and hasattr(agent_context, "response_aggregator"):
delattr(agent_context, "response_aggregator")

async def _cleanup_stdio(self):
"""Minimal cleanup for STDIO transport to avoid keeping process alive."""
Expand Down
43 changes: 43 additions & 0 deletions src/mcp_agent/server/response_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Dict


class ChainResponseAggregator:
"""Aggregate responses for a multi-agent chain."""

def __init__(self, chain_name: str, total_agents: int) -> None:
self.chain_name = chain_name
self.total_agents = total_agents
self.agent_responses: Dict[str, Any] = {}
self.completed_agents = 0
self._response_sent = False

async def add_agent_response(self, agent_name: str, response: Any) -> None:
"""Record a response from an agent in the chain."""
self.agent_responses[agent_name] = response
self.completed_agents += 1

async def should_send_response(self) -> bool:
"""Return ``True`` if the aggregated response should be sent."""
return not self._response_sent and self.completed_agents >= self.total_agents

async def get_aggregated_response(self) -> Dict[str, Any]:
"""Return the aggregated response for the chain."""
self._response_sent = True
return {"chain": self.chain_name, "responses": self.agent_responses}


class SSEEventType(Enum):
AGENT_START = "agent_start"
AGENT_PROGRESS = "agent_progress"
AGENT_COMPLETE = "agent_complete"
CHAIN_COMPLETE = "chain_complete"
ERROR = "error"


async def send_sse_event(event_type: SSEEventType, data: Dict[str, Any], stream: Any) -> None:
"""Send an SSE event to the provided stream if possible."""
if stream is not None and hasattr(stream, "send"):
await stream.send({"event": event_type.value, "data": data})
43 changes: 43 additions & 0 deletions tests/unit/mcp_agent/server/test_response_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import importlib.util
from pathlib import Path

import pytest

MODULE_PATH = (
Path(__file__).resolve().parents[4] / "src" / "mcp_agent" / "server" / "response_aggregator.py"
)
spec = importlib.util.spec_from_file_location("response_aggregator", MODULE_PATH)
response_aggregator = importlib.util.module_from_spec(spec)
assert spec.loader
spec.loader.exec_module(response_aggregator)

ChainResponseAggregator = response_aggregator.ChainResponseAggregator
SSEEventType = response_aggregator.SSEEventType
send_sse_event = response_aggregator.send_sse_event


@pytest.mark.asyncio
async def test_chain_response_aggregator():
agg = ChainResponseAggregator("chain", 2)
await agg.add_agent_response("a1", "one")
assert not await agg.should_send_response()
await agg.add_agent_response("a2", "two")
assert await agg.should_send_response()
result = await agg.get_aggregated_response()
assert result["chain"] == "chain"
assert result["responses"] == {"a1": "one", "a2": "two"}


class _DummyStream:
def __init__(self) -> None:
self.sent = []

async def send(self, data):
self.sent.append(data)


@pytest.mark.asyncio
async def test_send_sse_event():
stream = _DummyStream()
await send_sse_event(SSEEventType.AGENT_START, {"foo": "bar"}, stream)
assert stream.sent == [{"event": "agent_start", "data": {"foo": "bar"}}]
Loading