diff --git a/README.md b/README.md
index 1c8cd61b..a7c14f0a 100644
--- a/README.md
+++ b/README.md
@@ -71,6 +71,8 @@ pip install -e .
playwright install
```
+MCPMark defaults to the built-in orchestration agent (`MCPMarkAgent`). To experiment with the ReAct-style agent, pass `--agent react` to `pipeline.py` (other settings stay the same).
+
Docker
```bash
./build-docker.sh
diff --git a/pipeline.py b/pipeline.py
index b6f442ef..8fe78fc7 100644
--- a/pipeline.py
+++ b/pipeline.py
@@ -14,6 +14,7 @@
from src.logger import get_logger
from src.evaluator import MCPEvaluator
+from src.agents import AGENT_REGISTRY
from src.factory import MCPServiceFactory
from src.model_config import ModelConfig
@@ -41,6 +42,13 @@ def main():
required=True,
help="Comma-separated list of models to evaluate (e.g., 'o3,k2,gpt-4.1')",
)
+
+ parser.add_argument(
+ "--agent",
+ default="mcpmark",
+ choices=sorted(AGENT_REGISTRY.keys()),
+ help="Agent implementation to use (default: mcpmark)",
+ )
parser.add_argument(
"--tasks",
default="all",
@@ -138,6 +146,7 @@ def main():
exp_name=run_exp_name,
output_dir=run_output_dir,
reasoning_effort=args.reasoning_effort,
+ agent_name=args.agent,
)
pipeline.run_evaluation(args.tasks)
diff --git a/src/agents/__init__.py b/src/agents/__init__.py
index ccbf7149..ea1e057a 100644
--- a/src/agents/__init__.py
+++ b/src/agents/__init__.py
@@ -2,10 +2,17 @@
MCPMark Agent Module
====================
-Provides a unified agent implementation using LiteLLM for model interactions
-and minimal MCP server management.
+Provides agent implementations and registry for MCPMark.
"""
+from .base_agent import BaseMCPAgent
from .mcpmark_agent import MCPMarkAgent
+from .react_agent import ReActAgent
+
+AGENT_REGISTRY = {
+ "mcpmark": MCPMarkAgent,
+ "react": ReActAgent,
+}
+
+__all__ = ["BaseMCPAgent", "MCPMarkAgent", "ReActAgent", "AGENT_REGISTRY"]
-__all__ = ["MCPMarkAgent"]
\ No newline at end of file
diff --git a/src/agents/base_agent.py b/src/agents/base_agent.py
new file mode 100644
index 00000000..81fb87a7
--- /dev/null
+++ b/src/agents/base_agent.py
@@ -0,0 +1,465 @@
+"""Shared base agent functionality for MCPMark agents."""
+
+from __future__ import annotations
+
+import asyncio
+import copy
+import json
+import uuid
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional, Callable
+
+from src.logger import get_logger
+from .mcp import MCPStdioServer, MCPHttpServer
+from .utils import TokenUsageTracker
+
+logger = get_logger(__name__)
+
+
+class BaseMCPAgent(ABC):
+ """Base class with shared functionality for MCPMark agents."""
+
+ STDIO_SERVICES = ["notion", "filesystem", "playwright", "playwright_webarena", "postgres"]
+ HTTP_SERVICES = ["github"]
+ DEFAULT_TIMEOUT = 600
+
+ CLAUDE_THINKING_BUDGETS = {
+ "low": 1024,
+ "medium": 2048,
+ "high": 4096,
+ }
+
+ def __init__(
+ self,
+ litellm_input_model_name: str,
+ api_key: str,
+ base_url: str,
+ mcp_service: str,
+ timeout: int = DEFAULT_TIMEOUT,
+ service_config: Optional[Dict[str, Any]] = None,
+ service_config_provider: Optional[Callable[[], Dict[str, Any]]] = None,
+ reasoning_effort: Optional[str] = "default",
+ ):
+ self.litellm_input_model_name = litellm_input_model_name
+ self.api_key = api_key
+ self.base_url = base_url
+ self.mcp_service = mcp_service
+ self.timeout = timeout
+ self.service_config = service_config or {}
+ self._service_config_provider = service_config_provider
+ self.reasoning_effort = reasoning_effort or "default"
+
+ self.is_claude = self._is_anthropic_model(litellm_input_model_name)
+ self.use_claude_thinking = self.is_claude and self.reasoning_effort != "default"
+
+ self.usage_tracker = TokenUsageTracker()
+ self.litellm_run_model_name = None
+
+ self._partial_messages: List[Dict[str, Any]] = []
+ self._partial_token_usage: Dict[str, int] = {}
+ self._partial_turn_count: int = 0
+
+ logger.debug(
+ "Initialized %s for service '%s' with model '%s'",
+ self.__class__.__name__,
+ self.mcp_service,
+ self.litellm_input_model_name,
+ )
+
+ def __repr__(self) -> str: # pragma: no cover - debug helper
+ return (
+ f"{self.__class__.__name__}(service='{self.mcp_service}', "
+ f"model='{self.litellm_input_model_name}')"
+ )
+
+ @abstractmethod
+ async def execute(
+ self,
+ instruction: str,
+ tool_call_log_file: Optional[str] = None,
+ ) -> Dict[str, Any]:
+ """Execute the agent logic and return execution metadata."""
+
+ def execute_sync(
+ self,
+ instruction: str,
+ tool_call_log_file: Optional[str] = None,
+ ) -> Dict[str, Any]:
+ """Synchronous wrapper for async execution."""
+ return asyncio.run(self.execute(instruction, tool_call_log_file))
+
+ def get_usage_stats(self) -> Dict[str, Any]:
+ """Return aggregated usage statistics."""
+ return self.usage_tracker.get_stats()
+
+ def reset_usage_stats(self):
+ """Clear usage statistics."""
+ self.usage_tracker.reset()
+
+ # ------------------------------------------------------------------
+ # Shared helpers
+ # ------------------------------------------------------------------
+
+ def _is_anthropic_model(self, model_name: str) -> bool:
+ return "claude" in model_name.lower()
+
+ def _get_claude_thinking_budget(self) -> Optional[int]:
+ if not self.use_claude_thinking:
+ return None
+ return self.CLAUDE_THINKING_BUDGETS.get(self.reasoning_effort, 2048)
+
+ def _refresh_service_config(self):
+ if not self._service_config_provider:
+ return
+ try:
+ latest_cfg = self._service_config_provider() or {}
+ self.service_config.update(latest_cfg)
+ except Exception as exc: # pragma: no cover - best effort refresh
+ logger.warning("Failed to refresh service config: %s", exc)
+
+ def _reset_progress(self):
+ self._partial_messages = []
+ self._partial_token_usage = {}
+ self._partial_turn_count = 0
+
+ def _update_progress(
+ self,
+ messages: List[Dict[str, Any]],
+ token_usage: Dict[str, Any],
+ turn_count: int,
+ ):
+ try:
+ self._partial_messages = copy.deepcopy(messages)
+ self._partial_token_usage = dict(token_usage or {})
+ self._partial_turn_count = int(turn_count or 0)
+ except Exception: # pragma: no cover - defensive copy
+ pass
+
+ # ------------------------------------------------------------------
+ # MCP server management
+ # ------------------------------------------------------------------
+
+ async def _create_mcp_server(self) -> Any:
+ if self.mcp_service in self.STDIO_SERVICES:
+ return self._create_stdio_server()
+ if self.mcp_service in self.HTTP_SERVICES:
+ return self._create_http_server()
+ raise ValueError(f"Unsupported MCP service: {self.mcp_service}")
+
+ def _create_stdio_server(self) -> MCPStdioServer:
+ if self.mcp_service == "notion":
+ notion_key = self.service_config.get("notion_key")
+ if not notion_key:
+ raise ValueError("Notion API key required")
+ return MCPStdioServer(
+ command="npx",
+ args=["-y", "@notionhq/notion-mcp-server"],
+ env={
+ "OPENAPI_MCP_HEADERS": (
+ '{"Authorization": "Bearer ' + notion_key + '", '
+ '"Notion-Version": "2022-06-28"}'
+ )
+ },
+ )
+
+ if self.mcp_service == "filesystem":
+ test_directory = self.service_config.get("test_directory")
+ if not test_directory:
+ raise ValueError("Test directory required for filesystem service")
+ return MCPStdioServer(
+ command="npx",
+ args=["-y", "@modelcontextprotocol/server-filesystem", str(test_directory)],
+ )
+
+ if self.mcp_service in ("playwright", "playwright_webarena"):
+ browser = self.service_config.get("browser", "chromium")
+ headless = self.service_config.get("headless", True)
+ viewport_width = self.service_config.get("viewport_width", 1280)
+ viewport_height = self.service_config.get("viewport_height", 720)
+
+ args = ["-y", "@playwright/mcp@latest"]
+ if headless:
+ args.append("--headless")
+ args.extend(
+ [
+ "--isolated",
+ "--no-sandbox",
+ "--browser",
+ browser,
+ "--viewport-size",
+ f"{viewport_width},{viewport_height}",
+ ]
+ )
+ return MCPStdioServer(command="npx", args=args)
+
+ if self.mcp_service == "postgres":
+ host = self.service_config.get("host", "localhost")
+ port = self.service_config.get("port", 5432)
+ username = self.service_config.get("username")
+ password = self.service_config.get("password")
+ database = self.service_config.get("current_database") or self.service_config.get("database")
+ if not all([username, password, database]):
+ raise ValueError("PostgreSQL requires username, password, and database")
+ database_url = f"postgresql://{username}:{password}@{host}:{port}/{database}"
+ return MCPStdioServer(
+ command="pipx",
+ args=["run", "postgres-mcp", "--access-mode=unrestricted"],
+ env={"DATABASE_URI": database_url},
+ )
+
+ raise ValueError(f"Unsupported stdio service: {self.mcp_service}")
+
+ def _create_http_server(self) -> MCPHttpServer:
+ if self.mcp_service == "github":
+ github_token = self.service_config.get("github_token")
+ if not github_token:
+ raise ValueError("GitHub token required")
+ return MCPHttpServer(
+ url="https://api.githubcopilot.com/mcp/",
+ headers={
+ "Authorization": f"Bearer {github_token}",
+ "User-Agent": "MCPMark/1.0",
+ },
+ )
+ raise ValueError(f"Unsupported HTTP service: {self.mcp_service}")
+
+ # ------------------------------------------------------------------
+ # Message/Tool formatting helpers
+ # ------------------------------------------------------------------
+
+ def _convert_to_sdk_format(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ sdk_format: List[Dict[str, Any]] = []
+ function_call_map: Dict[str, str] = {}
+
+ for msg in messages:
+ role = msg.get("role")
+
+ if role == "user":
+ user_content = msg.get("content", "")
+ if isinstance(user_content, list):
+ tool_results = [
+ item
+ for item in user_content
+ if isinstance(item, dict) and item.get("type") == "tool_result"
+ ]
+ if tool_results:
+ for tr in tool_results:
+ content_items = tr.get("content", [])
+ text_content = ""
+ for ci in content_items:
+ if isinstance(ci, dict) and ci.get("type") == "text":
+ text_content = ci.get("text", "")
+ break
+ sdk_format.append(
+ {
+ "call_id": tr.get("tool_use_id", ""),
+ "output": json.dumps(
+ {
+ "type": "text",
+ "text": text_content,
+ "annotations": None,
+ "meta": None,
+ }
+ ),
+ "type": "function_call_output",
+ }
+ )
+ else:
+ text_parts = []
+ for item in user_content:
+ if isinstance(item, dict) and item.get("type") == "text":
+ text_parts.append(item.get("text", ""))
+ sdk_format.append({"content": "\n".join(text_parts), "role": "user"})
+ else:
+ sdk_format.append({"content": user_content, "role": "user"})
+
+ elif role == "assistant":
+ tool_calls = msg.get("tool_calls", [])
+ function_call = msg.get("function_call")
+ content = msg.get("content")
+
+ if isinstance(content, list):
+ text_parts = []
+ claude_tool_uses = []
+ for block in content:
+ if isinstance(block, dict):
+ if block.get("type") == "text":
+ text_parts.append(block.get("text", ""))
+ elif block.get("type") == "thinking":
+ thinking_text = block.get("thinking", "")
+ if thinking_text:
+ text_parts.append(f"\n{thinking_text}\n")
+ elif block.get("type") == "tool_use":
+ claude_tool_uses.append(block)
+ content = "\n".join(text_parts)
+ if claude_tool_uses and not tool_calls:
+ tool_calls = []
+ for tu in claude_tool_uses:
+ tool_calls.append(
+ {
+ "id": tu.get("id"),
+ "function": {
+ "name": tu.get("name"),
+ "arguments": json.dumps(tu.get("input", {})),
+ },
+ }
+ )
+
+ if content:
+ sdk_format.append(
+ {
+ "id": "__fake_id__",
+ "content": [
+ {
+ "annotations": [],
+ "text": content,
+ "type": "output_text",
+ }
+ ],
+ "role": "assistant",
+ "status": "completed",
+ "type": "message",
+ }
+ )
+
+ if tool_calls:
+ for tool_call in tool_calls:
+ call_id = tool_call.get("id", f"call_{uuid.uuid4().hex}")
+ func_name = tool_call.get("function", {}).get("name", "")
+ sdk_format.append(
+ {
+ "arguments": tool_call.get("function", {}).get("arguments", "{}"),
+ "call_id": call_id,
+ "name": func_name,
+ "type": "function_call",
+ "id": "__fake_id__",
+ }
+ )
+
+ if function_call:
+ func_name = function_call.get("name", "")
+ call_id = f"call_{uuid.uuid4().hex}"
+ function_call_map[func_name] = call_id
+ sdk_format.append(
+ {
+ "arguments": function_call.get("arguments", "{}"),
+ "call_id": call_id,
+ "name": func_name,
+ "type": "function_call",
+ "id": "__fake_id__",
+ }
+ )
+
+ elif role == "tool":
+ sdk_format.append(
+ {
+ "call_id": msg.get("tool_call_id", ""),
+ "output": json.dumps(
+ {
+ "type": "text",
+ "text": msg.get("content", ""),
+ "annotations": None,
+ "meta": None,
+ }
+ ),
+ "type": "function_call_output",
+ }
+ )
+
+ elif role == "function":
+ func_name = msg.get("name", "")
+ call_id = function_call_map.get(func_name, f"call_{uuid.uuid4().hex}")
+ sdk_format.append(
+ {
+ "call_id": call_id,
+ "output": json.dumps(
+ {
+ "type": "text",
+ "text": msg.get("content", ""),
+ "annotations": None,
+ "meta": None,
+ }
+ ),
+ "type": "function_call_output",
+ }
+ )
+
+ return sdk_format
+
+ def _convert_to_anthropic_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ anthropic_tools = []
+ for tool in tools:
+ anthropic_tool = {
+ "name": tool.get("name"),
+ "description": tool.get("description", ""),
+ "input_schema": tool.get(
+ "inputSchema",
+ {"type": "object", "properties": {}, "required": []},
+ ),
+ }
+ anthropic_tools.append(anthropic_tool)
+ return anthropic_tools
+
+ def _is_gemini_model(self) -> bool:
+ model_lower = self.litellm_input_model_name.lower()
+ return "gemini" in model_lower or "bison" in model_lower
+
+ def _simplify_schema_for_gemini(self, schema: Optional[Dict[str, Any]]) -> Dict[str, Any]:
+ if not isinstance(schema, dict):
+ return schema or {}
+
+ simplified: Dict[str, Any] = {}
+ for key, value in schema.items():
+ if key == "type" and isinstance(value, list):
+ simplified[key] = value[0] if value else "string"
+ elif key == "items" and isinstance(value, dict):
+ simplified[key] = self._simplify_schema_for_gemini(value)
+ elif key == "properties" and isinstance(value, dict):
+ simplified[key] = {
+ prop_key: self._simplify_schema_for_gemini(prop_val)
+ for prop_key, prop_val in value.items()
+ }
+ elif isinstance(value, dict):
+ simplified[key] = self._simplify_schema_for_gemini(value)
+ elif isinstance(value, list) and key not in ("required", "enum"):
+ simplified[key] = [
+ self._simplify_schema_for_gemini(item) if isinstance(item, dict) else item
+ for item in value
+ ]
+ else:
+ simplified[key] = value
+ return simplified
+
+ def _convert_to_openai_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ functions = []
+ is_gemini = self._is_gemini_model()
+
+ if is_gemini:
+ logger.debug(
+ "Detected Gemini model '%s' – simplifying tool schemas",
+ self.litellm_input_model_name,
+ )
+
+ for tool in tools:
+ input_schema = tool.get(
+ "inputSchema", {"type": "object", "properties": {}, "required": []}
+ )
+ if is_gemini:
+ simplified = self._simplify_schema_for_gemini(input_schema)
+ if simplified != input_schema:
+ input_schema = simplified
+ logger.debug("Simplified schema for tool '%s'", tool.get("name"))
+
+ functions.append(
+ {
+ "name": tool.get("name"),
+ "description": tool.get("description", ""),
+ "parameters": input_schema,
+ }
+ )
+
+ if is_gemini:
+ logger.info("Converted %d tools for Gemini compatibility", len(functions))
+
+ return functions
+
diff --git a/src/agents/mcpmark_agent.py b/src/agents/mcpmark_agent.py
index f143bda1..db0343cd 100644
--- a/src/agents/mcpmark_agent.py
+++ b/src/agents/mcpmark_agent.py
@@ -8,9 +8,6 @@
import asyncio
import json
import time
-import uuid
-import copy
-
from typing import Any, Dict, List, Optional, Callable
import httpx
@@ -18,8 +15,8 @@
import nest_asyncio
from src.logger import get_logger
+from .base_agent import BaseMCPAgent
from .mcp import MCPStdioServer, MCPHttpServer
-from .utils import TokenUsageTracker
# Apply nested asyncio support
nest_asyncio.apply()
@@ -29,34 +26,20 @@
logger = get_logger(__name__)
-class MCPMarkAgent:
+class MCPMarkAgent(BaseMCPAgent):
"""
Unified agent for LLM and MCP server management using LiteLLM.
-
+
- Anthropic models: Native MCP support via extra_body
- Other models: Manual MCP server management with function calling
"""
-
- # Constants
+
MAX_TURNS = 100
- DEFAULT_TIMEOUT = 600
SYSTEM_PROMPT = (
"You are a helpful agent that uses tools iteratively to complete the user's task, "
"and when finished, provides the final answer or simply states \"Task completed\" without further tool calls."
)
-
- # Service categories
- STDIO_SERVICES = ["notion", "filesystem", "playwright", "playwright_webarena", "postgres"]
- HTTP_SERVICES = ["github"]
-
- # Claude thinking budget mapping
- CLAUDE_THINKING_BUDGETS = {
- "low": 1024,
- "medium": 2048,
- "high": 4096
- }
-
- # ==================== Initialization and Configuration ====================
+ DEFAULT_TIMEOUT = BaseMCPAgent.DEFAULT_TIMEOUT
def __init__(
self,
@@ -66,99 +49,28 @@ def __init__(
mcp_service: str,
timeout: int = DEFAULT_TIMEOUT,
service_config: Optional[Dict[str, Any]] = None,
- service_config_provider: Optional[Callable[[], Dict]] = None,
+ service_config_provider: Optional[Callable[[], Dict[str, Any]]] = None,
reasoning_effort: Optional[str] = "default",
):
- """
- Initialize the MCPMark agent.
-
- Args:
- model_name: Name of the LLM model
- api_key: API key for the model provider
- base_url: Base url
- mcp_service: MCP service type
- timeout: Execution timeout in seconds
- service_config: Service-specific configuration
- service_config_provider: Optional provider for dynamic config
- reasoning_effort: Reasoning effort level ("default", "minimal", "low", "medium", "high")
- """
- self.litellm_input_model_name = litellm_input_model_name
- self.api_key = api_key
- self.base_url = base_url
- self.mcp_service = mcp_service
- self.timeout = timeout
- self.service_config = service_config or {}
- self._service_config_provider = service_config_provider
- self.reasoning_effort = reasoning_effort
-
- # Detect if this is a Claude model
- self.is_claude = self._is_anthropic_model(litellm_input_model_name)
-
- # Determine execution path: Claude with thinking or LiteLLM
- self.use_claude_thinking = self.is_claude and reasoning_effort != "default"
-
- # Initialize usage tracker
- self.usage_tracker = TokenUsageTracker()
-
- # Track the actual model name from responses
- self.litellm_run_model_name = None
-
- # Track partial progress for error/timeout handling
- self._partial_messages = []
- self._partial_token_usage = {}
- self._partial_turn_count = 0
-
- logger.debug(
- f"Initialized MCPMarkAgent for '{mcp_service}' with model '{litellm_input_model_name}' "
- f"(Claude: {self.is_claude}, Thinking: {self.use_claude_thinking}, Reasoning: {reasoning_effort})"
+ super().__init__(
+ litellm_input_model_name=litellm_input_model_name,
+ api_key=api_key,
+ base_url=base_url,
+ mcp_service=mcp_service,
+ timeout=timeout,
+ service_config=service_config,
+ service_config_provider=service_config_provider,
+ reasoning_effort=reasoning_effort,
)
-
-
- def __repr__(self):
- return (
- f"MCPMarkAgent(service='{self.mcp_service}', model='{self.litellm_input_model_name}', "
+ logger.debug(
+ "Initialized MCPMarkAgent for '%s' with model '%s' (Claude: %s, Thinking: %s, Reasoning: %s)",
+ mcp_service,
+ litellm_input_model_name,
+ self.is_claude,
+ self.use_claude_thinking,
+ reasoning_effort,
)
- def _is_anthropic_model(self, model_name: str) -> bool:
- """Check if the model is an Anthropic model."""
- return "claude" in model_name.lower()
-
-
- def _get_claude_thinking_budget(self) -> Optional[int]:
- """Get thinking budget for Claude based on reasoning effort."""
- if not self.use_claude_thinking:
- return None
- return self.CLAUDE_THINKING_BUDGETS.get(self.reasoning_effort, 2048)
-
-
- def _refresh_service_config(self):
- """Refresh service config from provider if available."""
- if self._service_config_provider:
- try:
- latest_cfg = self._service_config_provider() or {}
- self.service_config.update(latest_cfg)
- except Exception as e:
- logger.warning(f"| Failed to refresh service config: {e}")
-
- def _reset_progress(self):
- """Reset stored partial progress for a new execution run."""
- self._partial_messages = []
- self._partial_token_usage = {}
- self._partial_turn_count = 0
-
- def _update_progress(self, messages: List[Dict], token_usage: Dict, turn_count: int):
- """Record partial progress so we can return it on timeout/errors."""
- try:
- # Deep copy to avoid mutation by callers
- self._partial_messages = copy.deepcopy(messages)
- self._partial_token_usage = dict(token_usage or {})
- self._partial_turn_count = int(turn_count or 0)
- except Exception:
- # Best-effort; don't let progress recording crash execution
- pass
-
-
-
# ==================== Public Interface Methods ====================
async def execute(
@@ -865,293 +777,6 @@ async def _execute_litellm_tool_loop(
# ==================== Format Conversion Methods ====================
- def _convert_to_sdk_format(self, messages: List[Dict]) -> List[Dict]:
- """Convert OpenAI messages format to old SDK format for backward compatibility."""
- sdk_format = []
- function_call_map = {} # Track function names to call IDs for legacy format
-
- for msg in messages:
- role = msg.get("role")
-
- if role == "user":
- # User messages stay mostly the same
- user_content = msg.get("content", "")
-
- # Handle tool_result messages (content as list)
- if isinstance(user_content, list):
- # Check if this is a tool_result message
- tool_results = [item for item in user_content if isinstance(item, dict) and item.get("type") == "tool_result"]
- if tool_results:
- # Convert tool_results to function_call_output format
- for tr in tool_results:
- content_items = tr.get("content", [])
- text_content = ""
- for ci in content_items:
- if isinstance(ci, dict) and ci.get("type") == "text":
- text_content = ci.get("text", "")
- break
- sdk_format.append({
- "call_id": tr.get("tool_use_id", ""),
- "output": json.dumps({
- "type": "text",
- "text": text_content,
- "annotations": None,
- "meta": None
- }),
- "type": "function_call_output"
- })
- else:
- # Regular user content as list - extract text
- text_parts = []
- for item in user_content:
- if isinstance(item, dict) and item.get("type") == "text":
- text_parts.append(item.get("text", ""))
- sdk_format.append({
- "content": "\n".join(text_parts) if text_parts else "",
- "role": "user"
- })
- else:
- # String content
- sdk_format.append({
- "content": user_content,
- "role": "user"
- })
-
- elif role == "assistant":
- # === CHANGED ORDER START ===
- tool_calls = msg.get("tool_calls", [])
- function_call = msg.get("function_call")
- content = msg.get("content")
-
- # Handle both string content and list content (for Claude thinking)
- if isinstance(content, list):
- # Extract text from content blocks (e.g., Claude responses with thinking)
- text_parts = []
- claude_tool_uses = []
- for block in content:
- if isinstance(block, dict):
- if block.get("type") == "text":
- text_parts.append(block.get("text", ""))
- elif block.get("type") == "thinking":
- # Include thinking in output (marked as such)
- thinking_text = block.get("thinking", "")
- if thinking_text:
- text_parts.append(f"\n{thinking_text}\n")
- elif block.get("type") == "tool_use":
- # Store tool_use blocks for later processing
- claude_tool_uses.append(block)
- content = "\n".join(text_parts) if text_parts else ""
-
- # Add Claude tool_uses to regular tool_calls
- if claude_tool_uses and not tool_calls:
- tool_calls = []
- for tu in claude_tool_uses:
- tool_calls.append({
- "id": tu.get("id"),
- "function": {
- "name": tu.get("name"),
- "arguments": json.dumps(tu.get("input", {}))
- }
- })
-
- # 1) First add assistant's text content (if present)
- if content:
- sdk_format.append({
- "id": "__fake_id__",
- "content": [
- {
- "annotations": [],
- "text": content if content else "",
- "type": "output_text"
- }
- ],
- "role": "assistant",
- "status": "completed",
- "type": "message"
- })
-
- # 2) Then add (new format) tool_calls
- if tool_calls:
- for tool_call in tool_calls:
- call_id = tool_call.get("id", f"call_{uuid.uuid4().hex}")
- func_name = tool_call.get("function", {}).get("name", "")
- sdk_format.append({
- "arguments": tool_call.get("function", {}).get("arguments", "{}"),
- "call_id": call_id,
- "name": func_name,
- "type": "function_call",
- "id": "__fake_id__"
- })
-
- # 3) Finally handle (legacy format) function_call
- if function_call:
- func_name = function_call.get("name", "")
- call_id = f"call_{uuid.uuid4().hex}"
- function_call_map[func_name] = call_id # Store for matching responses
- sdk_format.append({
- "arguments": function_call.get("arguments", "{}"),
- "call_id": call_id,
- "name": func_name,
- "type": "function_call",
- "id": "__fake_id__"
- })
-
- # 4) If neither content nor any calls exist, maintain fallback behavior
- if not content and not tool_calls and not function_call:
- sdk_format.append({
- "id": "__fake_id__",
- "content": [
- {
- "annotations": [],
- "text": "",
- "type": "output_text"
- }
- ],
- "role": "assistant",
- "status": "completed",
- "type": "message"
- })
- # === CHANGED ORDER END ===
-
- elif role == "tool":
- # Tool responses
- sdk_format.append({
- "call_id": msg.get("tool_call_id", ""),
- "output": json.dumps({
- "type": "text",
- "text": msg.get("content", ""),
- "annotations": None,
- "meta": None
- }),
- "type": "function_call_output"
- })
-
- elif role == "function":
- # Legacy function responses - try to match with stored call ID
- func_name = msg.get("name", "")
- call_id = function_call_map.get(func_name, f"call_{uuid.uuid4().hex}")
- sdk_format.append({
- "call_id": call_id,
- "output": json.dumps({
- "type": "text",
- "text": msg.get("content", ""),
- "annotations": None,
- "meta": None
- }),
- "type": "function_call_output"
- })
-
- return sdk_format
-
-
-
- def _convert_to_anthropic_format(self, tools: List[Dict]) -> List[Dict]:
- """Convert MCP tool definitions to Anthropic format."""
- anthropic_tools = []
-
- for tool in tools:
- anthropic_tool = {
- "name": tool.get("name"),
- "description": tool.get("description", ""),
- "input_schema": tool.get("inputSchema", {
- "type": "object",
- "properties": {},
- "required": []
- })
- }
- anthropic_tools.append(anthropic_tool)
-
- return anthropic_tools
-
- def _is_gemini_model(self) -> bool:
- """Check if the model is a Gemini model."""
- model_lower = self.litellm_input_model_name.lower()
- return "gemini" in model_lower or "bison" in model_lower
-
- def _simplify_schema_for_gemini(self, schema: Dict) -> Dict:
- """
- Simplify nested schemas for Gemini compatibility.
- Gemini has issues with deeply nested array type definitions.
-
- Note: This is a compatibility layer for Gemini API via LiteLLM.
- Can be removed once LiteLLM handles this internally.
- """
- if not isinstance(schema, dict):
- return schema
-
- simplified = {}
-
- for key, value in schema.items():
- if key == "type" and isinstance(value, list):
- # Gemini doesn't like type as array, use first type
- simplified[key] = value[0] if value else "string"
- elif key == "items" and isinstance(value, dict):
- # Recursively simplify items
- simplified[key] = self._simplify_schema_for_gemini(value)
- elif key == "properties" and isinstance(value, dict):
- # Recursively simplify each property
- simplified[key] = {
- prop_key: self._simplify_schema_for_gemini(prop_val)
- for prop_key, prop_val in value.items()
- }
- elif isinstance(value, dict):
- # Recursively simplify nested objects
- simplified[key] = self._simplify_schema_for_gemini(value)
- elif isinstance(value, list) and key not in ["required", "enum"]:
- # For non-special arrays, check if they contain schemas
- simplified[key] = [
- self._simplify_schema_for_gemini(item) if isinstance(item, dict) else item
- for item in value
- ]
- else:
- simplified[key] = value
-
- return simplified
-
-
- def _convert_to_openai_format(self, tools: List[Dict]) -> List[Dict]:
- """
- Convert MCP tool definitions to OpenAI function format.
-
- For Gemini models, applies schema simplification to handle
- compatibility issues with deeply nested array type definitions.
- """
- functions = []
- is_gemini = self._is_gemini_model()
-
- if is_gemini:
- logger.debug(f"Detected Gemini model: {self.litellm_input_model_name}")
- logger.debug(f"Processing {len(tools)} tools for Gemini compatibility")
-
- for i, tool in enumerate(tools):
- # Get the input schema
- input_schema = tool.get("inputSchema", {
- "type": "object",
- "properties": {},
- "required": []
- })
-
- # Simplify schema for Gemini if needed
- if is_gemini:
- original_schema = input_schema.copy() # Keep for debugging
- input_schema = self._simplify_schema_for_gemini(input_schema)
-
- # Log significant changes for debugging
- if input_schema != original_schema:
- logger.debug(f"Simplified schema for tool #{i} '{tool.get('name')}'")
-
- function = {
- "name": tool.get("name"),
- "description": tool.get("description", ""),
- "parameters": input_schema
- }
- functions.append(function)
-
- if is_gemini:
- logger.info(f"| Converted {len(functions)} tools for Gemini model with schema simplification")
-
- return functions
-
diff --git a/src/agents/react_agent.py b/src/agents/react_agent.py
new file mode 100644
index 00000000..53312653
--- /dev/null
+++ b/src/agents/react_agent.py
@@ -0,0 +1,424 @@
+"""ReAct agent implementation for the MCPMark pipeline."""
+
+from __future__ import annotations
+
+import asyncio
+import json
+import time
+from typing import Any, Dict, List, Optional, Callable
+
+import litellm
+
+from src.logger import get_logger
+from .base_agent import BaseMCPAgent
+
+logger = get_logger(__name__)
+
+
+class ReActAgent(BaseMCPAgent):
+ """ReAct-style agent that reuses MCPMark infrastructure."""
+
+ DEFAULT_SYSTEM_PROMPT = (
+ "You are a careful ReAct (reasoning and acting) agent. "
+ "At each step you must decide whether to call a tool or provide a final response. "
+ "Only use the tools that are listed for you. When you finish, respond with either the final answer "
+ "or the phrase \"Task completed.\" if no further detail is required. "
+ "Every reply must be valid JSON without code fences."
+ )
+
+ def __init__(
+ self,
+ litellm_input_model_name: str,
+ api_key: str,
+ base_url: str,
+ mcp_service: str,
+ timeout: int = BaseMCPAgent.DEFAULT_TIMEOUT,
+ service_config: Optional[Dict[str, Any]] = None,
+ service_config_provider: Optional[Callable[[], Dict[str, Any]]] = None,
+ reasoning_effort: Optional[str] = "default",
+ max_iterations: int = 100,
+ system_prompt: Optional[str] = None,
+ ):
+ super().__init__(
+ litellm_input_model_name=litellm_input_model_name,
+ api_key=api_key,
+ base_url=base_url,
+ mcp_service=mcp_service,
+ timeout=timeout,
+ service_config=service_config,
+ service_config_provider=service_config_provider,
+ reasoning_effort=reasoning_effort,
+ )
+ self.max_iterations = max_iterations
+ self.react_system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
+
+ async def execute(
+ self,
+ instruction: str,
+ tool_call_log_file: Optional[str] = None,
+ ) -> Dict[str, Any]:
+ start_time = time.time()
+
+ try:
+ self._reset_progress()
+ self._refresh_service_config()
+
+ async def _run_react():
+ return await self._execute_react_loop(instruction, tool_call_log_file)
+
+ result = await asyncio.wait_for(_run_react(), timeout=self.timeout)
+ execution_time = time.time() - start_time
+ self.usage_tracker.update(
+ success=result.get("success", False),
+ token_usage=result.get("token_usage", {}),
+ turn_count=result.get("turn_count", 0),
+ execution_time=execution_time,
+ )
+ result["execution_time"] = execution_time
+ return result
+ except Exception as exc: # noqa: BLE001
+ execution_time = time.time() - start_time
+
+ if isinstance(exc, asyncio.TimeoutError):
+ error_msg = f"Execution timed out after {self.timeout} seconds"
+ logger.error(error_msg)
+ else:
+ error_msg = f"ReAct agent execution failed: {exc}"
+ logger.error(error_msg, exc_info=True)
+
+ self.usage_tracker.update(
+ success=False,
+ token_usage=self._partial_token_usage or {},
+ turn_count=self._partial_turn_count or 0,
+ execution_time=execution_time,
+ )
+
+ if self._partial_messages:
+ final_msg = self._convert_to_sdk_format(self._partial_messages)
+ else:
+ final_msg = []
+
+ return {
+ "success": False,
+ "output": final_msg,
+ "token_usage": self._partial_token_usage or {},
+ "turn_count": self._partial_turn_count or 0,
+ "execution_time": execution_time,
+ "error": error_msg,
+ "litellm_run_model_name": self.litellm_run_model_name,
+ }
+
+ async def _execute_react_loop(
+ self,
+ instruction: str,
+ tool_call_log_file: Optional[str],
+ ) -> Dict[str, Any]:
+ system_message = {"role": "system", "content": self.react_system_prompt}
+ total_tokens = {
+ "input_tokens": 0,
+ "output_tokens": 0,
+ "total_tokens": 0,
+ "reasoning_tokens": 0,
+ }
+ turn_count = 0
+ success = False
+ final_error: Optional[str] = None
+
+ mcp_server = await self._create_mcp_server()
+ async with mcp_server:
+ tools = await mcp_server.list_tools()
+ tool_map = {tool.get("name"): tool for tool in tools}
+ tools_description = self._render_tools_description(tools)
+
+ task_message = {
+ "role": "user",
+ "content": self._build_task_prompt(
+ instruction=instruction,
+ tools_description=tools_description,
+ ),
+ }
+ messages: List[Dict[str, Any]] = [system_message, task_message]
+ self._update_progress(messages, total_tokens, turn_count)
+
+ for step in range(1, self.max_iterations + 1):
+ completion_kwargs = {
+ "model": self.litellm_input_model_name,
+ "messages": messages,
+ "api_key": self.api_key,
+ }
+ if self.base_url:
+ completion_kwargs["base_url"] = self.base_url
+ if self.reasoning_effort != "default":
+ completion_kwargs["reasoning_effort"] = self.reasoning_effort
+
+ try:
+ response = await asyncio.wait_for(
+ litellm.acompletion(**completion_kwargs),
+ timeout=self.timeout / 2,
+ )
+ except asyncio.TimeoutError:
+ final_error = f"LLM call timed out on step {step}"
+ logger.error(final_error)
+ break
+ except Exception as exc: # noqa: BLE001
+ final_error = f"LLM call failed on step {step}: {exc}"
+ logger.error(final_error)
+ break
+
+ if turn_count == 0 and getattr(response, "model", None):
+ self.litellm_run_model_name = response.model.split("/")[-1]
+
+ usage = getattr(response, "usage", None)
+ if usage:
+ prompt_tokens = (
+ getattr(usage, "prompt_tokens", None)
+ or getattr(usage, "input_tokens", None)
+ or 0
+ )
+ completion_tokens = (
+ getattr(usage, "completion_tokens", None)
+ or getattr(usage, "output_tokens", None)
+ or 0
+ )
+ total_tokens_count = getattr(usage, "total_tokens", None)
+ if total_tokens_count is None:
+ total_tokens_count = prompt_tokens + completion_tokens
+
+ total_tokens["input_tokens"] += prompt_tokens
+ total_tokens["output_tokens"] += completion_tokens
+ total_tokens["total_tokens"] += total_tokens_count
+
+ # Extract reasoning tokens if available
+ if hasattr(response.usage, 'completion_tokens_details'):
+ details = response.usage.completion_tokens_details
+ if hasattr(details, 'reasoning_tokens'):
+ total_tokens["reasoning_tokens"] += details.reasoning_tokens or 0
+
+ choice = response.choices[0]
+ message_obj = getattr(choice, "message", None)
+ if message_obj is None and isinstance(choice, dict):
+ message_obj = choice.get("message")
+
+ if message_obj is None:
+ content_raw = getattr(choice, "text", "")
+ else:
+ content_raw = message_obj.get("content", "")
+
+ assistant_text = self._normalize_content(content_raw)
+ assistant_message = {"role": "assistant", "content": assistant_text}
+ messages.append(assistant_message)
+ turn_count += 1
+ self._update_progress(messages, total_tokens, turn_count)
+
+ parsed = self._parse_react_response(assistant_text)
+ if not parsed or "thought" not in parsed:
+ warning = (
+ "The previous response was not valid JSON following the required schema. "
+ "Please respond again using the JSON formats provided."
+ )
+ messages.append({"role": "user", "content": warning})
+ self._update_progress(messages, total_tokens, turn_count)
+ final_error = "Model produced an invalid response format."
+ continue
+
+ thought = parsed.get("thought", "")
+ action = parsed.get("action")
+ answer = parsed.get("answer")
+ result = parsed.get("result")
+
+ logger.info(f"|\n| \033[1;3mThought\033[0m: {str(thought)}")
+ if tool_call_log_file:
+ try:
+ with open(tool_call_log_file, "a", encoding="utf-8") as log_file:
+ log_file.write(f"| {str(thought)}\n")
+ except Exception: # noqa: BLE001
+ pass
+ if action is not None:
+ func_name = action.get("tool")
+ arguments = action.get("arguments", {}) or {}
+ args_str = json.dumps(arguments, separators=(",", ": "))
+ display_arguments = args_str[:140] + "..." if len(args_str) > 140 else args_str
+ logger.info(f"| \033[1;3mAction\033[0m: \033[1m{func_name}\033[0m \033[2;37m{display_arguments}\033[0m")
+
+
+ if answer is not None:
+ success = True
+ break
+
+ if action is not None and isinstance(action, dict):
+ tool_name = action.get("tool")
+ arguments = action.get("arguments", {}) or {}
+
+ if tool_name not in tool_map:
+ observation = (
+ f"Invalid tool '{tool_name}'. Available tools: "
+ f"{', '.join(tool_map)}"
+ )
+ else:
+ try:
+ tool_response = await asyncio.wait_for(
+ mcp_server.call_tool(tool_name, arguments),
+ timeout=60,
+ )
+ observation = self._tool_result_to_text(tool_response)
+ except asyncio.TimeoutError:
+ observation = f"Tool '{tool_name}' timed out"
+ except Exception as tool_exc: # noqa: BLE001
+ observation = f"Tool '{tool_name}' failed: {tool_exc}"
+
+ if tool_call_log_file:
+ try:
+ with open(tool_call_log_file, "a", encoding="utf-8") as log_file:
+ log_file.write(f"| {tool_name} {json.dumps(arguments, ensure_ascii=False)}\n")
+ except Exception: # noqa: BLE001
+ pass
+
+ observation_message = {
+ "role": "user",
+ "content": (
+ f"Observation:\n{observation}\n"
+ "Please continue reasoning and reply using the required JSON format."
+ ),
+ }
+ messages.append(observation_message)
+ self._update_progress(messages, total_tokens, turn_count)
+ continue
+
+ if result is not None:
+ observation_message = {
+ "role": "user",
+ "content": (
+ f"Observation:\n{result}\n"
+ "Please continue reasoning and reply using the required JSON format."
+ ),
+ }
+ messages.append(observation_message)
+ self._update_progress(messages, total_tokens, turn_count)
+ continue
+
+ # Unexpected structure: ask model to restate properly
+ messages.append(
+ {
+ "role": "user",
+ "content": (
+ "The previous reply did not include an action, result, or answer. "
+ "Please respond again using the JSON formats provided."
+ ),
+ }
+ )
+ self._update_progress(messages, total_tokens, turn_count)
+
+ if not success and final_error is None:
+ final_error = (
+ f"Max iterations ({self.max_iterations}) reached without a final answer."
+ )
+
+ if total_tokens["total_tokens"] > 0:
+ log_msg = (
+ f"|\n|\n| Token usage: Total: {total_tokens['total_tokens']:,} | "
+ f"Input: {total_tokens['input_tokens']:,} | "
+ f"Output: {total_tokens['output_tokens']:,}"
+ )
+ if total_tokens.get("reasoning_tokens", 0) > 0:
+ log_msg += f" | Reasoning: {total_tokens['reasoning_tokens']:,}"
+ logger.info(log_msg)
+ logger.info(f"| Turns: {turn_count}")
+
+ sdk_messages = self._convert_to_sdk_format(messages)
+
+ return {
+ "success": success,
+ "output": sdk_messages,
+ "token_usage": total_tokens,
+ "turn_count": turn_count,
+ "error": None if success else final_error,
+ "litellm_run_model_name": self.litellm_run_model_name,
+ }
+
+ def _build_task_prompt(
+ self,
+ instruction: str,
+ tools_description: str,
+ ) -> str:
+ return (
+ f"Task:\n{instruction}\n\n"
+ f"Available MCP tools:\n{tools_description}\n\n"
+ "Respond using the JSON formats below.\n\n"
+ "If you need to use a tool:\n"
+ "{\n"
+ ' "thought": "Reasoning for the next action",\n'
+ ' "action": {\n'
+ ' "tool": "tool-name",\n'
+ ' "arguments": {\n'
+ ' "parameter": value\n'
+ " }\n"
+ " }\n"
+ "}\n\n"
+ "If you can provide the final answer:\n"
+ "{\n"
+ ' "thought": "Reasoning that justifies the answer",\n'
+ ' "answer": "Either the final solution or \'Task completed.\' when no more detail is required"\n'
+ "}\n\n"
+ "Remember: omitting the action object ends the task, so only do this when finished."
+ )
+
+ def _render_tools_description(self, tools: List[Dict[str, Any]]) -> str:
+ descriptions = []
+ for tool in tools:
+ name = tool.get("name", "unknown")
+ description = tool.get("description", "No description provided.")
+ input_schema = tool.get("inputSchema", {}) or {}
+ properties = input_schema.get("properties", {}) or {}
+ required = set(input_schema.get("required", []) or [])
+
+ arg_lines = []
+ for prop_name, prop_details in properties.items():
+ details = json.dumps(prop_details, ensure_ascii=False, indent=2)
+ suffix = " (required)" if prop_name in required else ""
+ arg_lines.append(f"- {prop_name}{suffix}: {details}")
+
+ if arg_lines:
+ arguments_text = "\n".join(arg_lines)
+ else:
+ arguments_text = "(no arguments)"
+
+ descriptions.append(
+ f"Tool: {name}\nDescription: {description}\nArguments:\n{arguments_text}"
+ )
+
+ return "\n\n".join(descriptions) if descriptions else "(no tools available)"
+
+ def _normalize_content(self, content: Any) -> str:
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ parts = []
+ for block in content:
+ if isinstance(block, dict):
+ if block.get("type") == "text":
+ parts.append(block.get("text", ""))
+ elif "text" in block:
+ parts.append(str(block.get("text")))
+ else:
+ parts.append(str(block))
+ return "\n".join(part for part in parts if part)
+ return json.dumps(content, ensure_ascii=False)
+
+ def _parse_react_response(self, payload: str) -> Dict[str, Any]:
+ candidate = payload.strip().strip("`").strip()
+ if candidate.lower().startswith("json"):
+ candidate = candidate[4:].lstrip()
+ try:
+ return json.loads(candidate)
+ except json.JSONDecodeError:
+ return {}
+
+ def _tool_result_to_text(self, result: Any) -> str:
+ if result is None:
+ return ""
+ if isinstance(result, str):
+ return result
+ try:
+ return json.dumps(result, ensure_ascii=False)
+ except TypeError:
+ return str(result)
diff --git a/src/evaluator.py b/src/evaluator.py
index b03abae9..31f1b711 100644
--- a/src/evaluator.py
+++ b/src/evaluator.py
@@ -11,7 +11,7 @@
from src.model_config import ModelConfig
from src.results_reporter import EvaluationReport, ResultsReporter, TaskResult
from src.errors import is_retryable_error
-from src.agents import MCPMarkAgent
+from src.agents import AGENT_REGISTRY
# Initialize logger
logger = get_logger(__name__)
@@ -26,10 +26,14 @@ def __init__(
exp_name: str = "test-run",
output_dir: Path = None,
reasoning_effort: str = "default",
+ agent_name: str = "mcpmark",
):
# Main configuration
self.mcp_service = mcp_service
self.timeout = timeout
+ self.agent_name = (agent_name or "mcpmark").lower()
+ if self.agent_name not in AGENT_REGISTRY:
+ raise ValueError(f"Unsupported agent '{agent_name}'. Available: {sorted(AGENT_REGISTRY)}")
# Initialize model configuration
self.reasoning_effort = reasoning_effort
@@ -54,8 +58,9 @@ def __init__(
# automatically refresh its service configuration from the state
# manager before each execution, so per-task manual updates are no
# longer needed.
- self.agent = MCPMarkAgent(
- litellm_input_model_name=self.litellm_input_model_name, # Use the original model name for detection
+ agent_cls = AGENT_REGISTRY[self.agent_name]
+ self.agent = agent_cls(
+ litellm_input_model_name=self.litellm_input_model_name,
api_key=self.api_key,
base_url=self.base_url,
mcp_service=mcp_service,
@@ -352,6 +357,7 @@ def run_evaluation(self, task_filter: str) -> EvaluationReport:
"litellm_run_model_name": self.litellm_run_model_name,
"reasoning_effort": self.reasoning_effort,
"timeout": self.timeout,
+ "agent_name": self.agent_name,
}
self.results_reporter.save_meta_json(
task_result,
@@ -398,6 +404,7 @@ def _matches_filter(tr: TaskResult, flt: str) -> bool:
"litellm_run_model_name": self.litellm_run_model_name,
"reasoning_effort": self.reasoning_effort,
"timeout": self.timeout,
+ "agent_name": self.agent_name,
},
total_tasks=len(final_results),
successful_tasks=sum(1 for r in final_results if r.success),