1+ import asyncio
12import json
23import os
3- from typing import Any , List , Optional
4+ from typing import Any , List , Optional , Union
45
56from mcp .types import CallToolResult
7+ from openai import NOT_GIVEN , NotGiven
68from openai .types .chat import ChatCompletionMessage , ChatCompletionToolParam
79from openai .types .chat .chat_completion_message_param import ChatCompletionMessageParam
810
@@ -22,46 +24,73 @@ def __init__(self, model: str, initial_messages: list[Message], config_path: str
2224 self .messages : list [Message ] = initial_messages
2325 self ._policy = LiteLLMPolicy (model_id = model )
2426 self .mcp_client = MCPMultiClient (config_path = config_path ) if config_path else None
27+ self .tools : Union [List [ChatCompletionToolParam ], NotGiven ] = NOT_GIVEN
2528
2629 async def setup (self ):
2730 if self .mcp_client :
2831 await self .mcp_client .connect_to_servers ()
2932
33+ async def _get_tools (self ) -> Optional [List [ChatCompletionToolParam ]]:
34+ if self .tools is NOT_GIVEN :
35+ self .tools = await self .mcp_client .get_available_tools () if self .mcp_client else None
36+ return self .tools
37+
3038 async def call_agent (self ) -> str :
3139 """
3240 Call the assistant with the user query.
3341 """
34- tools = await self .mcp_client . get_available_tools () if self .mcp_client else None
42+ tools = await self ._get_tools () if self .mcp_client else None
3543
3644 message = await self ._call_model (self .messages , tools )
3745 self .messages .append (message )
3846 if message ["tool_calls" ]:
47+ # Create tasks for all tool calls to run them in parallel
48+ tool_tasks = []
3949 for tool_call in message ["tool_calls" ]:
4050 tool_call_id = tool_call ["id" ]
4151 tool_name = tool_call ["function" ]["name" ]
4252 tool_args = tool_call ["function" ]["arguments" ]
4353 tool_args_dict = json .loads (tool_args )
44- tool_result = await self .mcp_client .call_tool (tool_name , tool_args_dict )
45- content = self ._get_content_from_tool_result (tool_result )
54+
55+ # Create a task for each tool call
56+ task = self ._execute_tool_call (tool_call_id , tool_name , tool_args_dict )
57+ tool_tasks .append (task )
58+
59+ # Execute all tool calls in parallel
60+ tool_results = await asyncio .gather (* tool_tasks )
61+
62+ # Add all tool results to messages (they will be in the same order as tool_calls)
63+ for tool_call , (tool_call_id , content ) in zip (message ["tool_calls" ], tool_results ):
4664 self .messages .append (
4765 {
4866 "role" : "tool" ,
4967 "content" : content ,
5068 "tool_call_id" : tool_call_id ,
5169 }
5270 )
71+ return await self .call_agent ()
5372 return message ["content" ]
5473
5574 async def _call_model (
5675 self , messages : list [Message ], tools : Optional [list [ChatCompletionToolParam ]]
5776 ) -> ChatCompletionMessage :
5877 messages = [message .model_dump () if hasattr (message , "model_dump" ) else message for message in messages ]
78+ tools = [{"function" : tool ["function" ].model_dump (), "type" : "function" } for tool in tools ]
5979 response = await self ._policy ._make_llm_call (
6080 messages = messages ,
6181 tools = tools ,
6282 )
6383 return response ["choices" ][0 ]["message" ]
6484
85+ async def _execute_tool_call (self , tool_call_id : str , tool_name : str , tool_args_dict : dict ) -> tuple [str , str ]:
86+ """
87+ Execute a single tool call and return the tool_call_id and content.
88+ This method is designed to be used with asyncio.gather() for parallel execution.
89+ """
90+ tool_result = await self .mcp_client .call_tool (tool_name , tool_args_dict )
91+ content = self ._get_content_from_tool_result (tool_result )
92+ return tool_call_id , content
93+
6594 def _get_content_from_tool_result (self , tool_result : CallToolResult ) -> str :
6695 if tool_result .structuredContent :
6796 return json .dumps (tool_result .structuredContent )
0 commit comments