Skip to content

Commit 7291feb

Browse files
committed
implements mcp_config_path. Added "test_pytest_mcp_config"
1 parent 6a1c136 commit 7291feb

File tree

8 files changed

+464
-89
lines changed

8 files changed

+464
-89
lines changed

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import asyncio
12
import json
23
import os
3-
from typing import Any, List, Optional
4+
from typing import Any, List, Optional, Union
45

56
from mcp.types import CallToolResult
7+
from openai import NOT_GIVEN, NotGiven
68
from openai.types.chat import ChatCompletionMessage, ChatCompletionToolParam
79
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
810

@@ -22,46 +24,73 @@ def __init__(self, model: str, initial_messages: list[Message], config_path: str
2224
self.messages: list[Message] = initial_messages
2325
self._policy = LiteLLMPolicy(model_id=model)
2426
self.mcp_client = MCPMultiClient(config_path=config_path) if config_path else None
27+
self.tools: Union[List[ChatCompletionToolParam], NotGiven] = NOT_GIVEN
2528

2629
async def setup(self):
2730
if self.mcp_client:
2831
await self.mcp_client.connect_to_servers()
2932

33+
async def _get_tools(self) -> Optional[List[ChatCompletionToolParam]]:
34+
if self.tools is NOT_GIVEN:
35+
self.tools = await self.mcp_client.get_available_tools() if self.mcp_client else None
36+
return self.tools
37+
3038
async def call_agent(self) -> str:
3139
"""
3240
Call the assistant with the user query.
3341
"""
34-
tools = await self.mcp_client.get_available_tools() if self.mcp_client else None
42+
tools = await self._get_tools() if self.mcp_client else None
3543

3644
message = await self._call_model(self.messages, tools)
3745
self.messages.append(message)
3846
if message["tool_calls"]:
47+
# Create tasks for all tool calls to run them in parallel
48+
tool_tasks = []
3949
for tool_call in message["tool_calls"]:
4050
tool_call_id = tool_call["id"]
4151
tool_name = tool_call["function"]["name"]
4252
tool_args = tool_call["function"]["arguments"]
4353
tool_args_dict = json.loads(tool_args)
44-
tool_result = await self.mcp_client.call_tool(tool_name, tool_args_dict)
45-
content = self._get_content_from_tool_result(tool_result)
54+
55+
# Create a task for each tool call
56+
task = self._execute_tool_call(tool_call_id, tool_name, tool_args_dict)
57+
tool_tasks.append(task)
58+
59+
# Execute all tool calls in parallel
60+
tool_results = await asyncio.gather(*tool_tasks)
61+
62+
# Add all tool results to messages (they will be in the same order as tool_calls)
63+
for tool_call, (tool_call_id, content) in zip(message["tool_calls"], tool_results):
4664
self.messages.append(
4765
{
4866
"role": "tool",
4967
"content": content,
5068
"tool_call_id": tool_call_id,
5169
}
5270
)
71+
return await self.call_agent()
5372
return message["content"]
5473

5574
async def _call_model(
5675
self, messages: list[Message], tools: Optional[list[ChatCompletionToolParam]]
5776
) -> ChatCompletionMessage:
5877
messages = [message.model_dump() if hasattr(message, "model_dump") else message for message in messages]
78+
tools = [{"function": tool["function"].model_dump(), "type": "function"} for tool in tools]
5979
response = await self._policy._make_llm_call(
6080
messages=messages,
6181
tools=tools,
6282
)
6383
return response["choices"][0]["message"]
6484

85+
async def _execute_tool_call(self, tool_call_id: str, tool_name: str, tool_args_dict: dict) -> tuple[str, str]:
86+
"""
87+
Execute a single tool call and return the tool_call_id and content.
88+
This method is designed to be used with asyncio.gather() for parallel execution.
89+
"""
90+
tool_result = await self.mcp_client.call_tool(tool_name, tool_args_dict)
91+
content = self._get_content_from_tool_result(tool_result)
92+
return tool_call_id, content
93+
6594
def _get_content_from_tool_result(self, tool_result: CallToolResult) -> str:
6695
if tool_result.structuredContent:
6796
return json.dumps(tool_result.structuredContent)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ tau2 = { git = "https://github.com/sierra-research/tau2-bench.git" }
140140

141141
[dependency-groups]
142142
dev = [
143+
"fastmcp>=2.10.6",
143144
"haikus==0.3.8",
144145
"pytest>=8.4.1",
145146
]

tests/pytest/data/mcp_config.jsonl

Whitespace-only changes.

0 commit comments

Comments
 (0)