Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ms_agent/agent/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self,
self.mcp_server_file = kwargs.get('mcp_server_file', None)
self.mcp_config: Dict[str, Any] = self._parse_mcp_servers(
kwargs.get('mcp_config', {}))
self.mcp_client = kwargs.get('mcp_client', None)
self._task_begin()

def register_callback(self, callback: Callback):
Expand Down Expand Up @@ -182,7 +183,7 @@ async def _parallel_tool_call(self,

async def _prepare_tools(self):
"""Initialize and connect the tool manager."""
self.tool_manager = ToolManager(self.config, self.mcp_config)
self.tool_manager = ToolManager(self.config, self.mcp_config, self.mcp_client)
await self.tool_manager.connect()

async def _cleanup_tools(self):
Expand Down
111 changes: 85 additions & 26 deletions ms_agent/tools/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import os
from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Any, Dict, List, Literal, Optional
from os import environb
from types import TracebackType
from typing import Any, Dict, List, Literal, Optional, Union

from mcp import ClientSession, ListToolsResult, StdioServerParameters
from mcp.client.sse import sse_client
Expand All @@ -13,7 +15,7 @@
from ms_agent.llm.utils import Tool
from ms_agent.tools.base import ToolBase
from ms_agent.utils import get_logger
from omegaconf import DictConfig
from omegaconf import DictConfig, ListConfig

logger = get_logger()

Expand All @@ -24,7 +26,7 @@

DEFAULT_HTTP_TIMEOUT = 5
DEFAULT_SSE_READ_TIMEOUT = 60 * 5
TOOL_CALL_TIMEOUT = os.getenv('TOOL_CALL_TIMEOUT', 15)
TOOL_CALL_TIMEOUT = os.getenv('TOOL_CALL_TIMEOUT', 30)

DEFAULT_STREAMABLE_HTTP_TIMEOUT = timedelta(seconds=30)
DEFAULT_STREAMABLE_HTTP_SSE_READ_TIMEOUT = timedelta(seconds=60 * 5)
Expand All @@ -40,17 +42,23 @@ class MCPClient(ToolBase):
mcp_config(`Optional[Dict[str, Any]]`): Extra mcp servers in json format.
"""

def __init__(self,
config: DictConfig,
mcp_config: Optional[Dict[str, Any]] = None):
def __init__(
self,
mcp_config: Optional[Dict[str, Any]] = None,
config: Union[DictConfig, ListConfig, None] = None,
):
super().__init__(config)
self.sessions: Dict[str, ClientSession] = {}
self.exit_stack = AsyncExitStack()
self.mcp_config: Dict[str, Dict[
str, Any]] = Config.convert_mcp_servers_to_json(config)
self.mcp_config: Dict[str, Dict[str, Any]] = {'mcpServers': {}}
if config is not None:
config_from_file = Config.convert_mcp_servers_to_json(config)
self.mcp_config['mcpServers'].update(
config_from_file.get('mcpServers', {}))
self._exclude_functions = {}
if mcp_config is not None:
self.mcp_config.update(mcp_config)
self.mcp_config['mcpServers'].update(
mcp_config.get('mcpServers', {}))

async def call_tool(self, server_name: str, tool_name: str,
tool_args: dict):
Expand Down Expand Up @@ -79,24 +87,37 @@ async def call_tool(self, server_name: str, tool_name: str,

async def get_tools(self) -> Dict:
tools = {}
error = dict()
for key, session in self.sessions.items():
tools[key] = []
response = await session.list_tools()
_session_tools = response.tools
exclude = []
if key in self._exclude_functions:
exclude = self._exclude_functions[key]
_session_tools = [
t for t in _session_tools if t.name not in exclude
]
_session_tools = [
Tool(
tool_name=t.name,
server_name=key,
description=t.description,
parameters=t.inputSchema) for t in _session_tools
]
tools[key].extend(_session_tools)
try:
tools[key] = []
response = await asyncio.wait_for(
session.list_tools(), timeout=TOOL_CALL_TIMEOUT)
_session_tools = response.tools
exclude = []
if key in self._exclude_functions:
exclude = self._exclude_functions[key]
_session_tools = [
t for t in _session_tools if t.name not in exclude
]
_session_tools = [
Tool(
tool_name=t.name,
server_name=key,
description=t.description,
parameters=t.inputSchema) for t in _session_tools
]
tools[key].extend(_session_tools)
except asyncio.TimeoutError:
error[key] = 'timeout'
except BaseException as exc:
error[key] = exc
if error:
error_messages = '; '.join(f'`{srv}`: {msg}'
for srv, msg in error.items())
raise ConnectionError(
f'get MCP tool failed for: {error_messages}. Please check MCP servers and retry.'
)
return tools

@staticmethod
Expand Down Expand Up @@ -226,6 +247,44 @@ async def connect(self):
f'MCP connections failed for: {error_messages}. Please check mcp configurations and retry.'
)

async def add_mcp_config(self, mcp_config: Dict[str, Dict[str, Any]]):
if mcp_config is None:
return
new_mcp_config = mcp_config.get('mcpServers', {})
servers = self.mcp_config.setdefault('mcpServers', {})
envs = Env.load_env()
for name, server in new_mcp_config.items():
if name in servers and servers[name] == server:
continue
else:
servers[name] = server
env_dict = server.pop('env', {})
env_dict = {
key: value if value else envs.get(key, '')
for key, value in env_dict.items()
}
if 'exclude' in server:
self._exclude_functions[name] = server.pop('exclude')
await self.connect_to_server(
server_name=name, env=env_dict, **server)
self.mcp_config['mcpServers'].update(new_mcp_config)

async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()

async def __aenter__(self) -> 'MCPClient':
try:
await self.connect()
return self
except Exception:
await self.exit_stack.aclose()
raise

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.exit_stack.aclose()
37 changes: 33 additions & 4 deletions ms_agent/tools/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import asyncio
import os
from copy import copy
from types import TracebackType
from typing import Any, Dict, List, Optional

import json
Expand All @@ -20,9 +21,12 @@ class ToolManager:

TOOL_SPLITER = '---'

def __init__(self, config, mcp_config: Optional[Dict[str, Any]] = None):
def __init__(self,
config,
mcp_config: Optional[Dict[str, Any]] = None,
mcp_client: Optional[MCPClient] = None):
self.config = config
self.servers = MCPClient(config, mcp_config)

self.extra_tools: List[ToolBase] = []
self.has_split_task_tool = False
if hasattr(config, 'tools') and hasattr(config.tools, 'split_task'):
Expand All @@ -31,17 +35,30 @@ def __init__(self, config, mcp_config: Optional[Dict[str, Any]] = None):
self.extra_tools.append(FileSystemTool(config))
self._tool_index = {}

# Used temporarily during async initialization; the actual client is managed in self.servers
self.mcp_client = mcp_client
self.mcp_config = mcp_config
self._managed_client = mcp_client is None

def register_tool(self, tool: ToolBase):
self.extra_tools.append(tool)

async def connect(self):
await self.servers.connect()
if self.mcp_client and isinstance(self.mcp_client, MCPClient):
self.servers = self.mcp_client
await self.servers.add_mcp_config(self.mcp_config)
self.mcp_config = self.servers.mcp_config
else:
self.servers = MCPClient(self.mcp_config, self.config)
await self.servers.connect()
for tool in self.extra_tools:
await tool.connect()
await self.reindex_tool()

async def cleanup(self):
await self.servers.cleanup()
if self._managed_client and self.servers:
await self.servers.cleanup()
self.servers = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Replacing await self.servers.cleanup() with self.servers = None introduces a resource leak. If the ToolManager creates its own MCPClient instance (in the else branch of the connect method), it is responsible for cleaning it up. The current implementation fails to do so.

The cleanup should be conditional. You could track if the client is managed internally. For example:

In __init__:

self._managed_client = mcp_client is None

In cleanup:

if self._managed_client and self.servers:
    await self.servers.cleanup()
self.servers = None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

for tool in self.extra_tools:
await tool.cleanup()

Expand Down Expand Up @@ -92,3 +109,15 @@ async def parallel_call_tool(self, tool_list: List[ToolCall]):
tasks = [self.single_call_tool(tool) for tool in tool_list]
result = await asyncio.gather(*tasks)
return result

async def __aenter__(self) -> 'ToolManager':

return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
pass
75 changes: 75 additions & 0 deletions tests/tools/test_mcp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import asyncio
import unittest

from ms_agent.tools.mcp_client import MCPClient


from modelscope.utils.test_utils import test_level

class TestMCPClient(unittest.TestCase):
mcp_config = {
"mcpServers": {
"fetch": {
"type": "sse",
"url": os.getenv("MCP_SERVER_FETCH_URL"),
}
}
}
mcp_config2 = {
"mcpServers": {
"time": {
"type": "sse",
"url": os.getenv("MCP_SERVER_TIME_URL"),
}
}
}

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_outside_init(self):
async def main():
async with MCPClient(self.mcp_config) as mcp_client:
mcps = await mcp_client.get_tools()
assert('fetch' in mcps)

res = await mcp_client.call_tool(server_name='fetch',
tool_name='fetch',
tool_args={'url': 'http://www.baidu.com'})
assert('baidu' in res)
asyncio.run(main())

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_aenter(self):
async def main():
mcp_client = MCPClient(self.mcp_config)
await mcp_client.__aenter__()
mcps = await mcp_client.get_tools()
assert('fetch' in mcps)
await mcp_client.__aexit__(None, None, None)

asyncio.run(main())

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_normal_connect(self):
async def main():
mcp_client = MCPClient(self.mcp_config)
await mcp_client.connect()
mcps = await mcp_client.get_tools()
assert('fetch' in mcps)
await mcp_client.cleanup()

asyncio.run(main())

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_add_config(self):
async def main():
async with MCPClient(self.mcp_config) as mcp_client:
await mcp_client.add_mcp_config(self.mcp_config2)
mcps = await mcp_client.get_tools()
assert ('fetch' in mcps and 'time' in mcps)

asyncio.run(main())

if __name__ == '__main__':
unittest.main()
Loading