diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 18c19275..fbfeae04 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -71,6 +71,7 @@ def __init__(self, self.mcp_server_file = kwargs.get('mcp_server_file', None) self.mcp_config: Dict[str, Any] = self._parse_mcp_servers( kwargs.get('mcp_config', {})) + self.mcp_client = kwargs.get('mcp_client', None) self._task_begin() def register_callback(self, callback: Callback): @@ -182,7 +183,7 @@ async def _parallel_tool_call(self, async def _prepare_tools(self): """Initialize and connect the tool manager.""" - self.tool_manager = ToolManager(self.config, self.mcp_config) + self.tool_manager = ToolManager(self.config, self.mcp_config, self.mcp_client) await self.tool_manager.connect() async def _cleanup_tools(self): diff --git a/ms_agent/tools/mcp_client.py b/ms_agent/tools/mcp_client.py index d8fdddf0..d0210847 100644 --- a/ms_agent/tools/mcp_client.py +++ b/ms_agent/tools/mcp_client.py @@ -3,7 +3,9 @@ import os from contextlib import AsyncExitStack from datetime import timedelta -from typing import Any, Dict, List, Literal, Optional +from os import environb +from types import TracebackType +from typing import Any, Dict, List, Literal, Optional, Union from mcp import ClientSession, ListToolsResult, StdioServerParameters from mcp.client.sse import sse_client @@ -13,7 +15,7 @@ from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger -from omegaconf import DictConfig +from omegaconf import DictConfig, ListConfig logger = get_logger() @@ -24,7 +26,7 @@ DEFAULT_HTTP_TIMEOUT = 5 DEFAULT_SSE_READ_TIMEOUT = 60 * 5 -TOOL_CALL_TIMEOUT = os.getenv('TOOL_CALL_TIMEOUT', 15) +TOOL_CALL_TIMEOUT = os.getenv('TOOL_CALL_TIMEOUT', 30) DEFAULT_STREAMABLE_HTTP_TIMEOUT = timedelta(seconds=30) DEFAULT_STREAMABLE_HTTP_SSE_READ_TIMEOUT = timedelta(seconds=60 * 5) @@ -40,17 +42,23 @@ class MCPClient(ToolBase): mcp_config(`Optional[Dict[str, Any]]`): Extra mcp servers in json format. """ - def __init__(self, - config: DictConfig, - mcp_config: Optional[Dict[str, Any]] = None): + def __init__( + self, + mcp_config: Optional[Dict[str, Any]] = None, + config: Union[DictConfig, ListConfig, None] = None, + ): super().__init__(config) self.sessions: Dict[str, ClientSession] = {} self.exit_stack = AsyncExitStack() - self.mcp_config: Dict[str, Dict[ - str, Any]] = Config.convert_mcp_servers_to_json(config) + self.mcp_config: Dict[str, Dict[str, Any]] = {'mcpServers': {}} + if config is not None: + config_from_file = Config.convert_mcp_servers_to_json(config) + self.mcp_config['mcpServers'].update( + config_from_file.get('mcpServers', {})) self._exclude_functions = {} if mcp_config is not None: - self.mcp_config.update(mcp_config) + self.mcp_config['mcpServers'].update( + mcp_config.get('mcpServers', {})) async def call_tool(self, server_name: str, tool_name: str, tool_args: dict): @@ -79,24 +87,37 @@ async def call_tool(self, server_name: str, tool_name: str, async def get_tools(self) -> Dict: tools = {} + error = dict() for key, session in self.sessions.items(): - tools[key] = [] - response = await session.list_tools() - _session_tools = response.tools - exclude = [] - if key in self._exclude_functions: - exclude = self._exclude_functions[key] - _session_tools = [ - t for t in _session_tools if t.name not in exclude - ] - _session_tools = [ - Tool( - tool_name=t.name, - server_name=key, - description=t.description, - parameters=t.inputSchema) for t in _session_tools - ] - tools[key].extend(_session_tools) + try: + tools[key] = [] + response = await asyncio.wait_for( + session.list_tools(), timeout=TOOL_CALL_TIMEOUT) + _session_tools = response.tools + exclude = [] + if key in self._exclude_functions: + exclude = self._exclude_functions[key] + _session_tools = [ + t for t in _session_tools if t.name not in exclude + ] + _session_tools = [ + Tool( + tool_name=t.name, + server_name=key, + description=t.description, + parameters=t.inputSchema) for t in _session_tools + ] + tools[key].extend(_session_tools) + except asyncio.TimeoutError: + error[key] = 'timeout' + except BaseException as exc: + error[key] = exc + if error: + error_messages = '; '.join(f'`{srv}`: {msg}' + for srv, msg in error.items()) + raise ConnectionError( + f'get MCP tool failed for: {error_messages}. Please check MCP servers and retry.' + ) return tools @staticmethod @@ -226,6 +247,44 @@ async def connect(self): f'MCP connections failed for: {error_messages}. Please check mcp configurations and retry.' ) + async def add_mcp_config(self, mcp_config: Dict[str, Dict[str, Any]]): + if mcp_config is None: + return + new_mcp_config = mcp_config.get('mcpServers', {}) + servers = self.mcp_config.setdefault('mcpServers', {}) + envs = Env.load_env() + for name, server in new_mcp_config.items(): + if name in servers and servers[name] == server: + continue + else: + servers[name] = server + env_dict = server.pop('env', {}) + env_dict = { + key: value if value else envs.get(key, '') + for key, value in env_dict.items() + } + if 'exclude' in server: + self._exclude_functions[name] = server.pop('exclude') + await self.connect_to_server( + server_name=name, env=env_dict, **server) + self.mcp_config['mcpServers'].update(new_mcp_config) + async def cleanup(self): """Clean up resources""" await self.exit_stack.aclose() + + async def __aenter__(self) -> 'MCPClient': + try: + await self.connect() + return self + except Exception: + await self.exit_stack.aclose() + raise + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.exit_stack.aclose() diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index 372c3bfc..e5d82f48 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -2,6 +2,7 @@ import asyncio import os from copy import copy +from types import TracebackType from typing import Any, Dict, List, Optional import json @@ -20,9 +21,12 @@ class ToolManager: TOOL_SPLITER = '---' - def __init__(self, config, mcp_config: Optional[Dict[str, Any]] = None): + def __init__(self, + config, + mcp_config: Optional[Dict[str, Any]] = None, + mcp_client: Optional[MCPClient] = None): self.config = config - self.servers = MCPClient(config, mcp_config) + self.extra_tools: List[ToolBase] = [] self.has_split_task_tool = False if hasattr(config, 'tools') and hasattr(config.tools, 'split_task'): @@ -31,17 +35,30 @@ def __init__(self, config, mcp_config: Optional[Dict[str, Any]] = None): self.extra_tools.append(FileSystemTool(config)) self._tool_index = {} + # Used temporarily during async initialization; the actual client is managed in self.servers + self.mcp_client = mcp_client + self.mcp_config = mcp_config + self._managed_client = mcp_client is None + def register_tool(self, tool: ToolBase): self.extra_tools.append(tool) async def connect(self): - await self.servers.connect() + if self.mcp_client and isinstance(self.mcp_client, MCPClient): + self.servers = self.mcp_client + await self.servers.add_mcp_config(self.mcp_config) + self.mcp_config = self.servers.mcp_config + else: + self.servers = MCPClient(self.mcp_config, self.config) + await self.servers.connect() for tool in self.extra_tools: await tool.connect() await self.reindex_tool() async def cleanup(self): - await self.servers.cleanup() + if self._managed_client and self.servers: + await self.servers.cleanup() + self.servers = None for tool in self.extra_tools: await tool.cleanup() @@ -92,3 +109,15 @@ async def parallel_call_tool(self, tool_list: List[ToolCall]): tasks = [self.single_call_tool(tool) for tool in tool_list] result = await asyncio.gather(*tasks) return result + + async def __aenter__(self) -> 'ToolManager': + + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + pass diff --git a/tests/tools/test_mcp_client.py b/tests/tools/test_mcp_client.py new file mode 100644 index 00000000..ca9a24a8 --- /dev/null +++ b/tests/tools/test_mcp_client.py @@ -0,0 +1,75 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import asyncio +import unittest + +from ms_agent.tools.mcp_client import MCPClient + + +from modelscope.utils.test_utils import test_level + +class TestMCPClient(unittest.TestCase): + mcp_config = { + "mcpServers": { + "fetch": { + "type": "sse", + "url": os.getenv("MCP_SERVER_FETCH_URL"), + } + } + } + mcp_config2 = { + "mcpServers": { + "time": { + "type": "sse", + "url": os.getenv("MCP_SERVER_TIME_URL"), + } + } + } + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_outside_init(self): + async def main(): + async with MCPClient(self.mcp_config) as mcp_client: + mcps = await mcp_client.get_tools() + assert('fetch' in mcps) + + res = await mcp_client.call_tool(server_name='fetch', + tool_name='fetch', + tool_args={'url': 'http://www.baidu.com'}) + assert('baidu' in res) + asyncio.run(main()) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_aenter(self): + async def main(): + mcp_client = MCPClient(self.mcp_config) + await mcp_client.__aenter__() + mcps = await mcp_client.get_tools() + assert('fetch' in mcps) + await mcp_client.__aexit__(None, None, None) + + asyncio.run(main()) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_normal_connect(self): + async def main(): + mcp_client = MCPClient(self.mcp_config) + await mcp_client.connect() + mcps = await mcp_client.get_tools() + assert('fetch' in mcps) + await mcp_client.cleanup() + + asyncio.run(main()) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_add_config(self): + async def main(): + async with MCPClient(self.mcp_config) as mcp_client: + await mcp_client.add_mcp_config(self.mcp_config2) + mcps = await mcp_client.get_tools() + assert ('fetch' in mcps and 'time' in mcps) + + asyncio.run(main()) + +if __name__ == '__main__': + unittest.main()