diff --git a/README.md b/README.md index 1c8cd61b..a7c14f0a 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,8 @@ pip install -e . playwright install ``` +MCPMark defaults to the built-in orchestration agent (`MCPMarkAgent`). To experiment with the ReAct-style agent, pass `--agent react` to `pipeline.py` (other settings stay the same). + Docker ```bash ./build-docker.sh diff --git a/pipeline.py b/pipeline.py index b6f442ef..8fe78fc7 100644 --- a/pipeline.py +++ b/pipeline.py @@ -14,6 +14,7 @@ from src.logger import get_logger from src.evaluator import MCPEvaluator +from src.agents import AGENT_REGISTRY from src.factory import MCPServiceFactory from src.model_config import ModelConfig @@ -41,6 +42,13 @@ def main(): required=True, help="Comma-separated list of models to evaluate (e.g., 'o3,k2,gpt-4.1')", ) + + parser.add_argument( + "--agent", + default="mcpmark", + choices=sorted(AGENT_REGISTRY.keys()), + help="Agent implementation to use (default: mcpmark)", + ) parser.add_argument( "--tasks", default="all", @@ -138,6 +146,7 @@ def main(): exp_name=run_exp_name, output_dir=run_output_dir, reasoning_effort=args.reasoning_effort, + agent_name=args.agent, ) pipeline.run_evaluation(args.tasks) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index ccbf7149..ea1e057a 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -2,10 +2,17 @@ MCPMark Agent Module ==================== -Provides a unified agent implementation using LiteLLM for model interactions -and minimal MCP server management. +Provides agent implementations and registry for MCPMark. """ +from .base_agent import BaseMCPAgent from .mcpmark_agent import MCPMarkAgent +from .react_agent import ReActAgent + +AGENT_REGISTRY = { + "mcpmark": MCPMarkAgent, + "react": ReActAgent, +} + +__all__ = ["BaseMCPAgent", "MCPMarkAgent", "ReActAgent", "AGENT_REGISTRY"] -__all__ = ["MCPMarkAgent"] \ No newline at end of file diff --git a/src/agents/base_agent.py b/src/agents/base_agent.py new file mode 100644 index 00000000..81fb87a7 --- /dev/null +++ b/src/agents/base_agent.py @@ -0,0 +1,465 @@ +"""Shared base agent functionality for MCPMark agents.""" + +from __future__ import annotations + +import asyncio +import copy +import json +import uuid +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Callable + +from src.logger import get_logger +from .mcp import MCPStdioServer, MCPHttpServer +from .utils import TokenUsageTracker + +logger = get_logger(__name__) + + +class BaseMCPAgent(ABC): + """Base class with shared functionality for MCPMark agents.""" + + STDIO_SERVICES = ["notion", "filesystem", "playwright", "playwright_webarena", "postgres"] + HTTP_SERVICES = ["github"] + DEFAULT_TIMEOUT = 600 + + CLAUDE_THINKING_BUDGETS = { + "low": 1024, + "medium": 2048, + "high": 4096, + } + + def __init__( + self, + litellm_input_model_name: str, + api_key: str, + base_url: str, + mcp_service: str, + timeout: int = DEFAULT_TIMEOUT, + service_config: Optional[Dict[str, Any]] = None, + service_config_provider: Optional[Callable[[], Dict[str, Any]]] = None, + reasoning_effort: Optional[str] = "default", + ): + self.litellm_input_model_name = litellm_input_model_name + self.api_key = api_key + self.base_url = base_url + self.mcp_service = mcp_service + self.timeout = timeout + self.service_config = service_config or {} + self._service_config_provider = service_config_provider + self.reasoning_effort = reasoning_effort or "default" + + self.is_claude = self._is_anthropic_model(litellm_input_model_name) + self.use_claude_thinking = self.is_claude and self.reasoning_effort != "default" + + self.usage_tracker = TokenUsageTracker() + self.litellm_run_model_name = None + + self._partial_messages: List[Dict[str, Any]] = [] + self._partial_token_usage: Dict[str, int] = {} + self._partial_turn_count: int = 0 + + logger.debug( + "Initialized %s for service '%s' with model '%s'", + self.__class__.__name__, + self.mcp_service, + self.litellm_input_model_name, + ) + + def __repr__(self) -> str: # pragma: no cover - debug helper + return ( + f"{self.__class__.__name__}(service='{self.mcp_service}', " + f"model='{self.litellm_input_model_name}')" + ) + + @abstractmethod + async def execute( + self, + instruction: str, + tool_call_log_file: Optional[str] = None, + ) -> Dict[str, Any]: + """Execute the agent logic and return execution metadata.""" + + def execute_sync( + self, + instruction: str, + tool_call_log_file: Optional[str] = None, + ) -> Dict[str, Any]: + """Synchronous wrapper for async execution.""" + return asyncio.run(self.execute(instruction, tool_call_log_file)) + + def get_usage_stats(self) -> Dict[str, Any]: + """Return aggregated usage statistics.""" + return self.usage_tracker.get_stats() + + def reset_usage_stats(self): + """Clear usage statistics.""" + self.usage_tracker.reset() + + # ------------------------------------------------------------------ + # Shared helpers + # ------------------------------------------------------------------ + + def _is_anthropic_model(self, model_name: str) -> bool: + return "claude" in model_name.lower() + + def _get_claude_thinking_budget(self) -> Optional[int]: + if not self.use_claude_thinking: + return None + return self.CLAUDE_THINKING_BUDGETS.get(self.reasoning_effort, 2048) + + def _refresh_service_config(self): + if not self._service_config_provider: + return + try: + latest_cfg = self._service_config_provider() or {} + self.service_config.update(latest_cfg) + except Exception as exc: # pragma: no cover - best effort refresh + logger.warning("Failed to refresh service config: %s", exc) + + def _reset_progress(self): + self._partial_messages = [] + self._partial_token_usage = {} + self._partial_turn_count = 0 + + def _update_progress( + self, + messages: List[Dict[str, Any]], + token_usage: Dict[str, Any], + turn_count: int, + ): + try: + self._partial_messages = copy.deepcopy(messages) + self._partial_token_usage = dict(token_usage or {}) + self._partial_turn_count = int(turn_count or 0) + except Exception: # pragma: no cover - defensive copy + pass + + # ------------------------------------------------------------------ + # MCP server management + # ------------------------------------------------------------------ + + async def _create_mcp_server(self) -> Any: + if self.mcp_service in self.STDIO_SERVICES: + return self._create_stdio_server() + if self.mcp_service in self.HTTP_SERVICES: + return self._create_http_server() + raise ValueError(f"Unsupported MCP service: {self.mcp_service}") + + def _create_stdio_server(self) -> MCPStdioServer: + if self.mcp_service == "notion": + notion_key = self.service_config.get("notion_key") + if not notion_key: + raise ValueError("Notion API key required") + return MCPStdioServer( + command="npx", + args=["-y", "@notionhq/notion-mcp-server"], + env={ + "OPENAPI_MCP_HEADERS": ( + '{"Authorization": "Bearer ' + notion_key + '", ' + '"Notion-Version": "2022-06-28"}' + ) + }, + ) + + if self.mcp_service == "filesystem": + test_directory = self.service_config.get("test_directory") + if not test_directory: + raise ValueError("Test directory required for filesystem service") + return MCPStdioServer( + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem", str(test_directory)], + ) + + if self.mcp_service in ("playwright", "playwright_webarena"): + browser = self.service_config.get("browser", "chromium") + headless = self.service_config.get("headless", True) + viewport_width = self.service_config.get("viewport_width", 1280) + viewport_height = self.service_config.get("viewport_height", 720) + + args = ["-y", "@playwright/mcp@latest"] + if headless: + args.append("--headless") + args.extend( + [ + "--isolated", + "--no-sandbox", + "--browser", + browser, + "--viewport-size", + f"{viewport_width},{viewport_height}", + ] + ) + return MCPStdioServer(command="npx", args=args) + + if self.mcp_service == "postgres": + host = self.service_config.get("host", "localhost") + port = self.service_config.get("port", 5432) + username = self.service_config.get("username") + password = self.service_config.get("password") + database = self.service_config.get("current_database") or self.service_config.get("database") + if not all([username, password, database]): + raise ValueError("PostgreSQL requires username, password, and database") + database_url = f"postgresql://{username}:{password}@{host}:{port}/{database}" + return MCPStdioServer( + command="pipx", + args=["run", "postgres-mcp", "--access-mode=unrestricted"], + env={"DATABASE_URI": database_url}, + ) + + raise ValueError(f"Unsupported stdio service: {self.mcp_service}") + + def _create_http_server(self) -> MCPHttpServer: + if self.mcp_service == "github": + github_token = self.service_config.get("github_token") + if not github_token: + raise ValueError("GitHub token required") + return MCPHttpServer( + url="https://api.githubcopilot.com/mcp/", + headers={ + "Authorization": f"Bearer {github_token}", + "User-Agent": "MCPMark/1.0", + }, + ) + raise ValueError(f"Unsupported HTTP service: {self.mcp_service}") + + # ------------------------------------------------------------------ + # Message/Tool formatting helpers + # ------------------------------------------------------------------ + + def _convert_to_sdk_format(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + sdk_format: List[Dict[str, Any]] = [] + function_call_map: Dict[str, str] = {} + + for msg in messages: + role = msg.get("role") + + if role == "user": + user_content = msg.get("content", "") + if isinstance(user_content, list): + tool_results = [ + item + for item in user_content + if isinstance(item, dict) and item.get("type") == "tool_result" + ] + if tool_results: + for tr in tool_results: + content_items = tr.get("content", []) + text_content = "" + for ci in content_items: + if isinstance(ci, dict) and ci.get("type") == "text": + text_content = ci.get("text", "") + break + sdk_format.append( + { + "call_id": tr.get("tool_use_id", ""), + "output": json.dumps( + { + "type": "text", + "text": text_content, + "annotations": None, + "meta": None, + } + ), + "type": "function_call_output", + } + ) + else: + text_parts = [] + for item in user_content: + if isinstance(item, dict) and item.get("type") == "text": + text_parts.append(item.get("text", "")) + sdk_format.append({"content": "\n".join(text_parts), "role": "user"}) + else: + sdk_format.append({"content": user_content, "role": "user"}) + + elif role == "assistant": + tool_calls = msg.get("tool_calls", []) + function_call = msg.get("function_call") + content = msg.get("content") + + if isinstance(content, list): + text_parts = [] + claude_tool_uses = [] + for block in content: + if isinstance(block, dict): + if block.get("type") == "text": + text_parts.append(block.get("text", "")) + elif block.get("type") == "thinking": + thinking_text = block.get("thinking", "") + if thinking_text: + text_parts.append(f"\n{thinking_text}\n") + elif block.get("type") == "tool_use": + claude_tool_uses.append(block) + content = "\n".join(text_parts) + if claude_tool_uses and not tool_calls: + tool_calls = [] + for tu in claude_tool_uses: + tool_calls.append( + { + "id": tu.get("id"), + "function": { + "name": tu.get("name"), + "arguments": json.dumps(tu.get("input", {})), + }, + } + ) + + if content: + sdk_format.append( + { + "id": "__fake_id__", + "content": [ + { + "annotations": [], + "text": content, + "type": "output_text", + } + ], + "role": "assistant", + "status": "completed", + "type": "message", + } + ) + + if tool_calls: + for tool_call in tool_calls: + call_id = tool_call.get("id", f"call_{uuid.uuid4().hex}") + func_name = tool_call.get("function", {}).get("name", "") + sdk_format.append( + { + "arguments": tool_call.get("function", {}).get("arguments", "{}"), + "call_id": call_id, + "name": func_name, + "type": "function_call", + "id": "__fake_id__", + } + ) + + if function_call: + func_name = function_call.get("name", "") + call_id = f"call_{uuid.uuid4().hex}" + function_call_map[func_name] = call_id + sdk_format.append( + { + "arguments": function_call.get("arguments", "{}"), + "call_id": call_id, + "name": func_name, + "type": "function_call", + "id": "__fake_id__", + } + ) + + elif role == "tool": + sdk_format.append( + { + "call_id": msg.get("tool_call_id", ""), + "output": json.dumps( + { + "type": "text", + "text": msg.get("content", ""), + "annotations": None, + "meta": None, + } + ), + "type": "function_call_output", + } + ) + + elif role == "function": + func_name = msg.get("name", "") + call_id = function_call_map.get(func_name, f"call_{uuid.uuid4().hex}") + sdk_format.append( + { + "call_id": call_id, + "output": json.dumps( + { + "type": "text", + "text": msg.get("content", ""), + "annotations": None, + "meta": None, + } + ), + "type": "function_call_output", + } + ) + + return sdk_format + + def _convert_to_anthropic_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + anthropic_tools = [] + for tool in tools: + anthropic_tool = { + "name": tool.get("name"), + "description": tool.get("description", ""), + "input_schema": tool.get( + "inputSchema", + {"type": "object", "properties": {}, "required": []}, + ), + } + anthropic_tools.append(anthropic_tool) + return anthropic_tools + + def _is_gemini_model(self) -> bool: + model_lower = self.litellm_input_model_name.lower() + return "gemini" in model_lower or "bison" in model_lower + + def _simplify_schema_for_gemini(self, schema: Optional[Dict[str, Any]]) -> Dict[str, Any]: + if not isinstance(schema, dict): + return schema or {} + + simplified: Dict[str, Any] = {} + for key, value in schema.items(): + if key == "type" and isinstance(value, list): + simplified[key] = value[0] if value else "string" + elif key == "items" and isinstance(value, dict): + simplified[key] = self._simplify_schema_for_gemini(value) + elif key == "properties" and isinstance(value, dict): + simplified[key] = { + prop_key: self._simplify_schema_for_gemini(prop_val) + for prop_key, prop_val in value.items() + } + elif isinstance(value, dict): + simplified[key] = self._simplify_schema_for_gemini(value) + elif isinstance(value, list) and key not in ("required", "enum"): + simplified[key] = [ + self._simplify_schema_for_gemini(item) if isinstance(item, dict) else item + for item in value + ] + else: + simplified[key] = value + return simplified + + def _convert_to_openai_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + functions = [] + is_gemini = self._is_gemini_model() + + if is_gemini: + logger.debug( + "Detected Gemini model '%s' – simplifying tool schemas", + self.litellm_input_model_name, + ) + + for tool in tools: + input_schema = tool.get( + "inputSchema", {"type": "object", "properties": {}, "required": []} + ) + if is_gemini: + simplified = self._simplify_schema_for_gemini(input_schema) + if simplified != input_schema: + input_schema = simplified + logger.debug("Simplified schema for tool '%s'", tool.get("name")) + + functions.append( + { + "name": tool.get("name"), + "description": tool.get("description", ""), + "parameters": input_schema, + } + ) + + if is_gemini: + logger.info("Converted %d tools for Gemini compatibility", len(functions)) + + return functions + diff --git a/src/agents/mcpmark_agent.py b/src/agents/mcpmark_agent.py index f143bda1..db0343cd 100644 --- a/src/agents/mcpmark_agent.py +++ b/src/agents/mcpmark_agent.py @@ -8,9 +8,6 @@ import asyncio import json import time -import uuid -import copy - from typing import Any, Dict, List, Optional, Callable import httpx @@ -18,8 +15,8 @@ import nest_asyncio from src.logger import get_logger +from .base_agent import BaseMCPAgent from .mcp import MCPStdioServer, MCPHttpServer -from .utils import TokenUsageTracker # Apply nested asyncio support nest_asyncio.apply() @@ -29,34 +26,20 @@ logger = get_logger(__name__) -class MCPMarkAgent: +class MCPMarkAgent(BaseMCPAgent): """ Unified agent for LLM and MCP server management using LiteLLM. - + - Anthropic models: Native MCP support via extra_body - Other models: Manual MCP server management with function calling """ - - # Constants + MAX_TURNS = 100 - DEFAULT_TIMEOUT = 600 SYSTEM_PROMPT = ( "You are a helpful agent that uses tools iteratively to complete the user's task, " "and when finished, provides the final answer or simply states \"Task completed\" without further tool calls." ) - - # Service categories - STDIO_SERVICES = ["notion", "filesystem", "playwright", "playwright_webarena", "postgres"] - HTTP_SERVICES = ["github"] - - # Claude thinking budget mapping - CLAUDE_THINKING_BUDGETS = { - "low": 1024, - "medium": 2048, - "high": 4096 - } - - # ==================== Initialization and Configuration ==================== + DEFAULT_TIMEOUT = BaseMCPAgent.DEFAULT_TIMEOUT def __init__( self, @@ -66,99 +49,28 @@ def __init__( mcp_service: str, timeout: int = DEFAULT_TIMEOUT, service_config: Optional[Dict[str, Any]] = None, - service_config_provider: Optional[Callable[[], Dict]] = None, + service_config_provider: Optional[Callable[[], Dict[str, Any]]] = None, reasoning_effort: Optional[str] = "default", ): - """ - Initialize the MCPMark agent. - - Args: - model_name: Name of the LLM model - api_key: API key for the model provider - base_url: Base url - mcp_service: MCP service type - timeout: Execution timeout in seconds - service_config: Service-specific configuration - service_config_provider: Optional provider for dynamic config - reasoning_effort: Reasoning effort level ("default", "minimal", "low", "medium", "high") - """ - self.litellm_input_model_name = litellm_input_model_name - self.api_key = api_key - self.base_url = base_url - self.mcp_service = mcp_service - self.timeout = timeout - self.service_config = service_config or {} - self._service_config_provider = service_config_provider - self.reasoning_effort = reasoning_effort - - # Detect if this is a Claude model - self.is_claude = self._is_anthropic_model(litellm_input_model_name) - - # Determine execution path: Claude with thinking or LiteLLM - self.use_claude_thinking = self.is_claude and reasoning_effort != "default" - - # Initialize usage tracker - self.usage_tracker = TokenUsageTracker() - - # Track the actual model name from responses - self.litellm_run_model_name = None - - # Track partial progress for error/timeout handling - self._partial_messages = [] - self._partial_token_usage = {} - self._partial_turn_count = 0 - - logger.debug( - f"Initialized MCPMarkAgent for '{mcp_service}' with model '{litellm_input_model_name}' " - f"(Claude: {self.is_claude}, Thinking: {self.use_claude_thinking}, Reasoning: {reasoning_effort})" + super().__init__( + litellm_input_model_name=litellm_input_model_name, + api_key=api_key, + base_url=base_url, + mcp_service=mcp_service, + timeout=timeout, + service_config=service_config, + service_config_provider=service_config_provider, + reasoning_effort=reasoning_effort, ) - - - def __repr__(self): - return ( - f"MCPMarkAgent(service='{self.mcp_service}', model='{self.litellm_input_model_name}', " + logger.debug( + "Initialized MCPMarkAgent for '%s' with model '%s' (Claude: %s, Thinking: %s, Reasoning: %s)", + mcp_service, + litellm_input_model_name, + self.is_claude, + self.use_claude_thinking, + reasoning_effort, ) - def _is_anthropic_model(self, model_name: str) -> bool: - """Check if the model is an Anthropic model.""" - return "claude" in model_name.lower() - - - def _get_claude_thinking_budget(self) -> Optional[int]: - """Get thinking budget for Claude based on reasoning effort.""" - if not self.use_claude_thinking: - return None - return self.CLAUDE_THINKING_BUDGETS.get(self.reasoning_effort, 2048) - - - def _refresh_service_config(self): - """Refresh service config from provider if available.""" - if self._service_config_provider: - try: - latest_cfg = self._service_config_provider() or {} - self.service_config.update(latest_cfg) - except Exception as e: - logger.warning(f"| Failed to refresh service config: {e}") - - def _reset_progress(self): - """Reset stored partial progress for a new execution run.""" - self._partial_messages = [] - self._partial_token_usage = {} - self._partial_turn_count = 0 - - def _update_progress(self, messages: List[Dict], token_usage: Dict, turn_count: int): - """Record partial progress so we can return it on timeout/errors.""" - try: - # Deep copy to avoid mutation by callers - self._partial_messages = copy.deepcopy(messages) - self._partial_token_usage = dict(token_usage or {}) - self._partial_turn_count = int(turn_count or 0) - except Exception: - # Best-effort; don't let progress recording crash execution - pass - - - # ==================== Public Interface Methods ==================== async def execute( @@ -865,293 +777,6 @@ async def _execute_litellm_tool_loop( # ==================== Format Conversion Methods ==================== - def _convert_to_sdk_format(self, messages: List[Dict]) -> List[Dict]: - """Convert OpenAI messages format to old SDK format for backward compatibility.""" - sdk_format = [] - function_call_map = {} # Track function names to call IDs for legacy format - - for msg in messages: - role = msg.get("role") - - if role == "user": - # User messages stay mostly the same - user_content = msg.get("content", "") - - # Handle tool_result messages (content as list) - if isinstance(user_content, list): - # Check if this is a tool_result message - tool_results = [item for item in user_content if isinstance(item, dict) and item.get("type") == "tool_result"] - if tool_results: - # Convert tool_results to function_call_output format - for tr in tool_results: - content_items = tr.get("content", []) - text_content = "" - for ci in content_items: - if isinstance(ci, dict) and ci.get("type") == "text": - text_content = ci.get("text", "") - break - sdk_format.append({ - "call_id": tr.get("tool_use_id", ""), - "output": json.dumps({ - "type": "text", - "text": text_content, - "annotations": None, - "meta": None - }), - "type": "function_call_output" - }) - else: - # Regular user content as list - extract text - text_parts = [] - for item in user_content: - if isinstance(item, dict) and item.get("type") == "text": - text_parts.append(item.get("text", "")) - sdk_format.append({ - "content": "\n".join(text_parts) if text_parts else "", - "role": "user" - }) - else: - # String content - sdk_format.append({ - "content": user_content, - "role": "user" - }) - - elif role == "assistant": - # === CHANGED ORDER START === - tool_calls = msg.get("tool_calls", []) - function_call = msg.get("function_call") - content = msg.get("content") - - # Handle both string content and list content (for Claude thinking) - if isinstance(content, list): - # Extract text from content blocks (e.g., Claude responses with thinking) - text_parts = [] - claude_tool_uses = [] - for block in content: - if isinstance(block, dict): - if block.get("type") == "text": - text_parts.append(block.get("text", "")) - elif block.get("type") == "thinking": - # Include thinking in output (marked as such) - thinking_text = block.get("thinking", "") - if thinking_text: - text_parts.append(f"\n{thinking_text}\n") - elif block.get("type") == "tool_use": - # Store tool_use blocks for later processing - claude_tool_uses.append(block) - content = "\n".join(text_parts) if text_parts else "" - - # Add Claude tool_uses to regular tool_calls - if claude_tool_uses and not tool_calls: - tool_calls = [] - for tu in claude_tool_uses: - tool_calls.append({ - "id": tu.get("id"), - "function": { - "name": tu.get("name"), - "arguments": json.dumps(tu.get("input", {})) - } - }) - - # 1) First add assistant's text content (if present) - if content: - sdk_format.append({ - "id": "__fake_id__", - "content": [ - { - "annotations": [], - "text": content if content else "", - "type": "output_text" - } - ], - "role": "assistant", - "status": "completed", - "type": "message" - }) - - # 2) Then add (new format) tool_calls - if tool_calls: - for tool_call in tool_calls: - call_id = tool_call.get("id", f"call_{uuid.uuid4().hex}") - func_name = tool_call.get("function", {}).get("name", "") - sdk_format.append({ - "arguments": tool_call.get("function", {}).get("arguments", "{}"), - "call_id": call_id, - "name": func_name, - "type": "function_call", - "id": "__fake_id__" - }) - - # 3) Finally handle (legacy format) function_call - if function_call: - func_name = function_call.get("name", "") - call_id = f"call_{uuid.uuid4().hex}" - function_call_map[func_name] = call_id # Store for matching responses - sdk_format.append({ - "arguments": function_call.get("arguments", "{}"), - "call_id": call_id, - "name": func_name, - "type": "function_call", - "id": "__fake_id__" - }) - - # 4) If neither content nor any calls exist, maintain fallback behavior - if not content and not tool_calls and not function_call: - sdk_format.append({ - "id": "__fake_id__", - "content": [ - { - "annotations": [], - "text": "", - "type": "output_text" - } - ], - "role": "assistant", - "status": "completed", - "type": "message" - }) - # === CHANGED ORDER END === - - elif role == "tool": - # Tool responses - sdk_format.append({ - "call_id": msg.get("tool_call_id", ""), - "output": json.dumps({ - "type": "text", - "text": msg.get("content", ""), - "annotations": None, - "meta": None - }), - "type": "function_call_output" - }) - - elif role == "function": - # Legacy function responses - try to match with stored call ID - func_name = msg.get("name", "") - call_id = function_call_map.get(func_name, f"call_{uuid.uuid4().hex}") - sdk_format.append({ - "call_id": call_id, - "output": json.dumps({ - "type": "text", - "text": msg.get("content", ""), - "annotations": None, - "meta": None - }), - "type": "function_call_output" - }) - - return sdk_format - - - - def _convert_to_anthropic_format(self, tools: List[Dict]) -> List[Dict]: - """Convert MCP tool definitions to Anthropic format.""" - anthropic_tools = [] - - for tool in tools: - anthropic_tool = { - "name": tool.get("name"), - "description": tool.get("description", ""), - "input_schema": tool.get("inputSchema", { - "type": "object", - "properties": {}, - "required": [] - }) - } - anthropic_tools.append(anthropic_tool) - - return anthropic_tools - - def _is_gemini_model(self) -> bool: - """Check if the model is a Gemini model.""" - model_lower = self.litellm_input_model_name.lower() - return "gemini" in model_lower or "bison" in model_lower - - def _simplify_schema_for_gemini(self, schema: Dict) -> Dict: - """ - Simplify nested schemas for Gemini compatibility. - Gemini has issues with deeply nested array type definitions. - - Note: This is a compatibility layer for Gemini API via LiteLLM. - Can be removed once LiteLLM handles this internally. - """ - if not isinstance(schema, dict): - return schema - - simplified = {} - - for key, value in schema.items(): - if key == "type" and isinstance(value, list): - # Gemini doesn't like type as array, use first type - simplified[key] = value[0] if value else "string" - elif key == "items" and isinstance(value, dict): - # Recursively simplify items - simplified[key] = self._simplify_schema_for_gemini(value) - elif key == "properties" and isinstance(value, dict): - # Recursively simplify each property - simplified[key] = { - prop_key: self._simplify_schema_for_gemini(prop_val) - for prop_key, prop_val in value.items() - } - elif isinstance(value, dict): - # Recursively simplify nested objects - simplified[key] = self._simplify_schema_for_gemini(value) - elif isinstance(value, list) and key not in ["required", "enum"]: - # For non-special arrays, check if they contain schemas - simplified[key] = [ - self._simplify_schema_for_gemini(item) if isinstance(item, dict) else item - for item in value - ] - else: - simplified[key] = value - - return simplified - - - def _convert_to_openai_format(self, tools: List[Dict]) -> List[Dict]: - """ - Convert MCP tool definitions to OpenAI function format. - - For Gemini models, applies schema simplification to handle - compatibility issues with deeply nested array type definitions. - """ - functions = [] - is_gemini = self._is_gemini_model() - - if is_gemini: - logger.debug(f"Detected Gemini model: {self.litellm_input_model_name}") - logger.debug(f"Processing {len(tools)} tools for Gemini compatibility") - - for i, tool in enumerate(tools): - # Get the input schema - input_schema = tool.get("inputSchema", { - "type": "object", - "properties": {}, - "required": [] - }) - - # Simplify schema for Gemini if needed - if is_gemini: - original_schema = input_schema.copy() # Keep for debugging - input_schema = self._simplify_schema_for_gemini(input_schema) - - # Log significant changes for debugging - if input_schema != original_schema: - logger.debug(f"Simplified schema for tool #{i} '{tool.get('name')}'") - - function = { - "name": tool.get("name"), - "description": tool.get("description", ""), - "parameters": input_schema - } - functions.append(function) - - if is_gemini: - logger.info(f"| Converted {len(functions)} tools for Gemini model with schema simplification") - - return functions - diff --git a/src/agents/react_agent.py b/src/agents/react_agent.py new file mode 100644 index 00000000..53312653 --- /dev/null +++ b/src/agents/react_agent.py @@ -0,0 +1,424 @@ +"""ReAct agent implementation for the MCPMark pipeline.""" + +from __future__ import annotations + +import asyncio +import json +import time +from typing import Any, Dict, List, Optional, Callable + +import litellm + +from src.logger import get_logger +from .base_agent import BaseMCPAgent + +logger = get_logger(__name__) + + +class ReActAgent(BaseMCPAgent): + """ReAct-style agent that reuses MCPMark infrastructure.""" + + DEFAULT_SYSTEM_PROMPT = ( + "You are a careful ReAct (reasoning and acting) agent. " + "At each step you must decide whether to call a tool or provide a final response. " + "Only use the tools that are listed for you. When you finish, respond with either the final answer " + "or the phrase \"Task completed.\" if no further detail is required. " + "Every reply must be valid JSON without code fences." + ) + + def __init__( + self, + litellm_input_model_name: str, + api_key: str, + base_url: str, + mcp_service: str, + timeout: int = BaseMCPAgent.DEFAULT_TIMEOUT, + service_config: Optional[Dict[str, Any]] = None, + service_config_provider: Optional[Callable[[], Dict[str, Any]]] = None, + reasoning_effort: Optional[str] = "default", + max_iterations: int = 100, + system_prompt: Optional[str] = None, + ): + super().__init__( + litellm_input_model_name=litellm_input_model_name, + api_key=api_key, + base_url=base_url, + mcp_service=mcp_service, + timeout=timeout, + service_config=service_config, + service_config_provider=service_config_provider, + reasoning_effort=reasoning_effort, + ) + self.max_iterations = max_iterations + self.react_system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + + async def execute( + self, + instruction: str, + tool_call_log_file: Optional[str] = None, + ) -> Dict[str, Any]: + start_time = time.time() + + try: + self._reset_progress() + self._refresh_service_config() + + async def _run_react(): + return await self._execute_react_loop(instruction, tool_call_log_file) + + result = await asyncio.wait_for(_run_react(), timeout=self.timeout) + execution_time = time.time() - start_time + self.usage_tracker.update( + success=result.get("success", False), + token_usage=result.get("token_usage", {}), + turn_count=result.get("turn_count", 0), + execution_time=execution_time, + ) + result["execution_time"] = execution_time + return result + except Exception as exc: # noqa: BLE001 + execution_time = time.time() - start_time + + if isinstance(exc, asyncio.TimeoutError): + error_msg = f"Execution timed out after {self.timeout} seconds" + logger.error(error_msg) + else: + error_msg = f"ReAct agent execution failed: {exc}" + logger.error(error_msg, exc_info=True) + + self.usage_tracker.update( + success=False, + token_usage=self._partial_token_usage or {}, + turn_count=self._partial_turn_count or 0, + execution_time=execution_time, + ) + + if self._partial_messages: + final_msg = self._convert_to_sdk_format(self._partial_messages) + else: + final_msg = [] + + return { + "success": False, + "output": final_msg, + "token_usage": self._partial_token_usage or {}, + "turn_count": self._partial_turn_count or 0, + "execution_time": execution_time, + "error": error_msg, + "litellm_run_model_name": self.litellm_run_model_name, + } + + async def _execute_react_loop( + self, + instruction: str, + tool_call_log_file: Optional[str], + ) -> Dict[str, Any]: + system_message = {"role": "system", "content": self.react_system_prompt} + total_tokens = { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "reasoning_tokens": 0, + } + turn_count = 0 + success = False + final_error: Optional[str] = None + + mcp_server = await self._create_mcp_server() + async with mcp_server: + tools = await mcp_server.list_tools() + tool_map = {tool.get("name"): tool for tool in tools} + tools_description = self._render_tools_description(tools) + + task_message = { + "role": "user", + "content": self._build_task_prompt( + instruction=instruction, + tools_description=tools_description, + ), + } + messages: List[Dict[str, Any]] = [system_message, task_message] + self._update_progress(messages, total_tokens, turn_count) + + for step in range(1, self.max_iterations + 1): + completion_kwargs = { + "model": self.litellm_input_model_name, + "messages": messages, + "api_key": self.api_key, + } + if self.base_url: + completion_kwargs["base_url"] = self.base_url + if self.reasoning_effort != "default": + completion_kwargs["reasoning_effort"] = self.reasoning_effort + + try: + response = await asyncio.wait_for( + litellm.acompletion(**completion_kwargs), + timeout=self.timeout / 2, + ) + except asyncio.TimeoutError: + final_error = f"LLM call timed out on step {step}" + logger.error(final_error) + break + except Exception as exc: # noqa: BLE001 + final_error = f"LLM call failed on step {step}: {exc}" + logger.error(final_error) + break + + if turn_count == 0 and getattr(response, "model", None): + self.litellm_run_model_name = response.model.split("/")[-1] + + usage = getattr(response, "usage", None) + if usage: + prompt_tokens = ( + getattr(usage, "prompt_tokens", None) + or getattr(usage, "input_tokens", None) + or 0 + ) + completion_tokens = ( + getattr(usage, "completion_tokens", None) + or getattr(usage, "output_tokens", None) + or 0 + ) + total_tokens_count = getattr(usage, "total_tokens", None) + if total_tokens_count is None: + total_tokens_count = prompt_tokens + completion_tokens + + total_tokens["input_tokens"] += prompt_tokens + total_tokens["output_tokens"] += completion_tokens + total_tokens["total_tokens"] += total_tokens_count + + # Extract reasoning tokens if available + if hasattr(response.usage, 'completion_tokens_details'): + details = response.usage.completion_tokens_details + if hasattr(details, 'reasoning_tokens'): + total_tokens["reasoning_tokens"] += details.reasoning_tokens or 0 + + choice = response.choices[0] + message_obj = getattr(choice, "message", None) + if message_obj is None and isinstance(choice, dict): + message_obj = choice.get("message") + + if message_obj is None: + content_raw = getattr(choice, "text", "") + else: + content_raw = message_obj.get("content", "") + + assistant_text = self._normalize_content(content_raw) + assistant_message = {"role": "assistant", "content": assistant_text} + messages.append(assistant_message) + turn_count += 1 + self._update_progress(messages, total_tokens, turn_count) + + parsed = self._parse_react_response(assistant_text) + if not parsed or "thought" not in parsed: + warning = ( + "The previous response was not valid JSON following the required schema. " + "Please respond again using the JSON formats provided." + ) + messages.append({"role": "user", "content": warning}) + self._update_progress(messages, total_tokens, turn_count) + final_error = "Model produced an invalid response format." + continue + + thought = parsed.get("thought", "") + action = parsed.get("action") + answer = parsed.get("answer") + result = parsed.get("result") + + logger.info(f"|\n| \033[1;3mThought\033[0m: {str(thought)}") + if tool_call_log_file: + try: + with open(tool_call_log_file, "a", encoding="utf-8") as log_file: + log_file.write(f"| {str(thought)}\n") + except Exception: # noqa: BLE001 + pass + if action is not None: + func_name = action.get("tool") + arguments = action.get("arguments", {}) or {} + args_str = json.dumps(arguments, separators=(",", ": ")) + display_arguments = args_str[:140] + "..." if len(args_str) > 140 else args_str + logger.info(f"| \033[1;3mAction\033[0m: \033[1m{func_name}\033[0m \033[2;37m{display_arguments}\033[0m") + + + if answer is not None: + success = True + break + + if action is not None and isinstance(action, dict): + tool_name = action.get("tool") + arguments = action.get("arguments", {}) or {} + + if tool_name not in tool_map: + observation = ( + f"Invalid tool '{tool_name}'. Available tools: " + f"{', '.join(tool_map)}" + ) + else: + try: + tool_response = await asyncio.wait_for( + mcp_server.call_tool(tool_name, arguments), + timeout=60, + ) + observation = self._tool_result_to_text(tool_response) + except asyncio.TimeoutError: + observation = f"Tool '{tool_name}' timed out" + except Exception as tool_exc: # noqa: BLE001 + observation = f"Tool '{tool_name}' failed: {tool_exc}" + + if tool_call_log_file: + try: + with open(tool_call_log_file, "a", encoding="utf-8") as log_file: + log_file.write(f"| {tool_name} {json.dumps(arguments, ensure_ascii=False)}\n") + except Exception: # noqa: BLE001 + pass + + observation_message = { + "role": "user", + "content": ( + f"Observation:\n{observation}\n" + "Please continue reasoning and reply using the required JSON format." + ), + } + messages.append(observation_message) + self._update_progress(messages, total_tokens, turn_count) + continue + + if result is not None: + observation_message = { + "role": "user", + "content": ( + f"Observation:\n{result}\n" + "Please continue reasoning and reply using the required JSON format." + ), + } + messages.append(observation_message) + self._update_progress(messages, total_tokens, turn_count) + continue + + # Unexpected structure: ask model to restate properly + messages.append( + { + "role": "user", + "content": ( + "The previous reply did not include an action, result, or answer. " + "Please respond again using the JSON formats provided." + ), + } + ) + self._update_progress(messages, total_tokens, turn_count) + + if not success and final_error is None: + final_error = ( + f"Max iterations ({self.max_iterations}) reached without a final answer." + ) + + if total_tokens["total_tokens"] > 0: + log_msg = ( + f"|\n|\n| Token usage: Total: {total_tokens['total_tokens']:,} | " + f"Input: {total_tokens['input_tokens']:,} | " + f"Output: {total_tokens['output_tokens']:,}" + ) + if total_tokens.get("reasoning_tokens", 0) > 0: + log_msg += f" | Reasoning: {total_tokens['reasoning_tokens']:,}" + logger.info(log_msg) + logger.info(f"| Turns: {turn_count}") + + sdk_messages = self._convert_to_sdk_format(messages) + + return { + "success": success, + "output": sdk_messages, + "token_usage": total_tokens, + "turn_count": turn_count, + "error": None if success else final_error, + "litellm_run_model_name": self.litellm_run_model_name, + } + + def _build_task_prompt( + self, + instruction: str, + tools_description: str, + ) -> str: + return ( + f"Task:\n{instruction}\n\n" + f"Available MCP tools:\n{tools_description}\n\n" + "Respond using the JSON formats below.\n\n" + "If you need to use a tool:\n" + "{\n" + ' "thought": "Reasoning for the next action",\n' + ' "action": {\n' + ' "tool": "tool-name",\n' + ' "arguments": {\n' + ' "parameter": value\n' + " }\n" + " }\n" + "}\n\n" + "If you can provide the final answer:\n" + "{\n" + ' "thought": "Reasoning that justifies the answer",\n' + ' "answer": "Either the final solution or \'Task completed.\' when no more detail is required"\n' + "}\n\n" + "Remember: omitting the action object ends the task, so only do this when finished." + ) + + def _render_tools_description(self, tools: List[Dict[str, Any]]) -> str: + descriptions = [] + for tool in tools: + name = tool.get("name", "unknown") + description = tool.get("description", "No description provided.") + input_schema = tool.get("inputSchema", {}) or {} + properties = input_schema.get("properties", {}) or {} + required = set(input_schema.get("required", []) or []) + + arg_lines = [] + for prop_name, prop_details in properties.items(): + details = json.dumps(prop_details, ensure_ascii=False, indent=2) + suffix = " (required)" if prop_name in required else "" + arg_lines.append(f"- {prop_name}{suffix}: {details}") + + if arg_lines: + arguments_text = "\n".join(arg_lines) + else: + arguments_text = "(no arguments)" + + descriptions.append( + f"Tool: {name}\nDescription: {description}\nArguments:\n{arguments_text}" + ) + + return "\n\n".join(descriptions) if descriptions else "(no tools available)" + + def _normalize_content(self, content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for block in content: + if isinstance(block, dict): + if block.get("type") == "text": + parts.append(block.get("text", "")) + elif "text" in block: + parts.append(str(block.get("text"))) + else: + parts.append(str(block)) + return "\n".join(part for part in parts if part) + return json.dumps(content, ensure_ascii=False) + + def _parse_react_response(self, payload: str) -> Dict[str, Any]: + candidate = payload.strip().strip("`").strip() + if candidate.lower().startswith("json"): + candidate = candidate[4:].lstrip() + try: + return json.loads(candidate) + except json.JSONDecodeError: + return {} + + def _tool_result_to_text(self, result: Any) -> str: + if result is None: + return "" + if isinstance(result, str): + return result + try: + return json.dumps(result, ensure_ascii=False) + except TypeError: + return str(result) diff --git a/src/evaluator.py b/src/evaluator.py index b03abae9..31f1b711 100644 --- a/src/evaluator.py +++ b/src/evaluator.py @@ -11,7 +11,7 @@ from src.model_config import ModelConfig from src.results_reporter import EvaluationReport, ResultsReporter, TaskResult from src.errors import is_retryable_error -from src.agents import MCPMarkAgent +from src.agents import AGENT_REGISTRY # Initialize logger logger = get_logger(__name__) @@ -26,10 +26,14 @@ def __init__( exp_name: str = "test-run", output_dir: Path = None, reasoning_effort: str = "default", + agent_name: str = "mcpmark", ): # Main configuration self.mcp_service = mcp_service self.timeout = timeout + self.agent_name = (agent_name or "mcpmark").lower() + if self.agent_name not in AGENT_REGISTRY: + raise ValueError(f"Unsupported agent '{agent_name}'. Available: {sorted(AGENT_REGISTRY)}") # Initialize model configuration self.reasoning_effort = reasoning_effort @@ -54,8 +58,9 @@ def __init__( # automatically refresh its service configuration from the state # manager before each execution, so per-task manual updates are no # longer needed. - self.agent = MCPMarkAgent( - litellm_input_model_name=self.litellm_input_model_name, # Use the original model name for detection + agent_cls = AGENT_REGISTRY[self.agent_name] + self.agent = agent_cls( + litellm_input_model_name=self.litellm_input_model_name, api_key=self.api_key, base_url=self.base_url, mcp_service=mcp_service, @@ -352,6 +357,7 @@ def run_evaluation(self, task_filter: str) -> EvaluationReport: "litellm_run_model_name": self.litellm_run_model_name, "reasoning_effort": self.reasoning_effort, "timeout": self.timeout, + "agent_name": self.agent_name, } self.results_reporter.save_meta_json( task_result, @@ -398,6 +404,7 @@ def _matches_filter(tr: TaskResult, flt: str) -> bool: "litellm_run_model_name": self.litellm_run_model_name, "reasoning_effort": self.reasoning_effort, "timeout": self.timeout, + "agent_name": self.agent_name, }, total_tasks=len(final_results), successful_tasks=sum(1 for r in final_results if r.success),