From 628a6b12f94e7878cec9f55a270242b03d2625d4 Mon Sep 17 00:00:00 2001 From: Aishwarya Date: Sun, 8 Jun 2025 22:15:15 +0545 Subject: [PATCH 1/4] Added property resource_templates and reas_resource that were present in session but missing in session_group. TESTED=unit tests --- src/mcp/client/session_group.py | 46 +++++++++++++------ tests/client/test_session_group.py | 71 ++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 13 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index a77dc7a1e..0fe15da1c 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -16,7 +16,7 @@ from typing import Any, TypeAlias import anyio -from pydantic import BaseModel +from pydantic import BaseModel, AnyUrl from typing_extensions import Self import mcp @@ -100,6 +100,7 @@ class _ComponentNames(BaseModel): # Client-server connection management. _sessions: dict[mcp.ClientSession, _ComponentNames] _tool_to_session: dict[str, mcp.ClientSession] + _resource_to_session: dict[str, mcp.ClientSession] _exit_stack: contextlib.AsyncExitStack _session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack] @@ -116,20 +117,16 @@ def __init__( ) -> None: """Initializes the MCP client.""" - self._tools = {} - self._resources = {} + self._exit_stack = exit_stack or contextlib.AsyncExitStack() + self._owns_exit_stack = exit_stack is None + self._session_exit_stacks = {} + self._component_name_hook = component_name_hook self._prompts = {} - + self._resources = {} + self._tools = {} self._sessions = {} self._tool_to_session = {} - if exit_stack is None: - self._exit_stack = contextlib.AsyncExitStack() - self._owns_exit_stack = True - else: - self._exit_stack = exit_stack - self._owns_exit_stack = False - self._session_exit_stacks = {} - self._component_name_hook = component_name_hook + self._resource_to_session = {} # New mapping async def __aenter__(self) -> Self: # Enter the exit stack only if we created it ourselves @@ -174,6 +171,16 @@ def tools(self) -> dict[str, types.Tool]: """Returns the tools as a dictionary of names to tools.""" return self._tools + @property + def resource_templates(self) -> list[types.ResourceTemplate]: + """Return all unique resource templates from the resources.""" + templates: list[types.ResourceTemplate] = [] + for r in self._resources.values(): + t = getattr(r, "template", None) + if t is not None and t not in templates: + templates.append(t) + return templates + async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult: """Executes a tool given its name and arguments.""" session = self._tool_to_session[name] @@ -296,8 +303,8 @@ async def _aggregate_components( resources_temp: dict[str, types.Resource] = {} tools_temp: dict[str, types.Tool] = {} tool_to_session_temp: dict[str, mcp.ClientSession] = {} + resource_to_session_temp: dict[str, mcp.ClientSession] = {} - # Query the server for its prompts and aggregate to list. try: prompts = (await session.list_prompts()).prompts for prompt in prompts: @@ -314,6 +321,7 @@ async def _aggregate_components( name = self._component_name(resource.name, server_info) resources_temp[name] = resource component_names.resources.add(name) + resource_to_session_temp[name] = session except McpError as err: logging.warning(f"Could not fetch resources: {err}") @@ -365,8 +373,20 @@ async def _aggregate_components( self._resources.update(resources_temp) self._tools.update(tools_temp) self._tool_to_session.update(tool_to_session_temp) + self._resource_to_session.update(resource_to_session_temp) def _component_name(self, name: str, server_info: types.Implementation) -> str: if self._component_name_hook: return self._component_name_hook(name, server_info) return name + + async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: + """Read a resource from the appropriate session based on the URI.""" + print(self._resources) + print(self._resource_to_session) + for name, resource in self._resources.items(): + if resource.uri == uri: + session = self._resource_to_session.get(name) + if session: + return await session.read_resource(uri) + raise ValueError(f"Resource not found: {uri}") diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 924ef7a06..0ab59f527 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -2,6 +2,7 @@ from unittest import mock import pytest +from pydantic import AnyUrl import mcp from mcp import types @@ -395,3 +396,73 @@ async def test_establish_session_parameterized( # 3. Assert returned values assert returned_server_info is mock_initialize_result.serverInfo assert returned_session is mock_entered_session + + @pytest.mark.anyio + async def test_read_resource_not_found(self): + """Test reading a non-existent resource from a session group.""" + # --- Mock Dependencies --- + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + test_resource = types.Resource( + name="test_resource", + uri=AnyUrl("test://resource/1"), + description="Test resource" + ) + + # Mock all list methods + mock_session.list_resources.return_value = types.ListResourcesResult(resources=[test_resource]) + mock_session.list_prompts.return_value = types.ListPromptsResult(prompts=[]) + mock_session.list_tools.return_value = types.ListToolsResult(tools=[]) + + # --- Test Setup --- + group = ClientSessionGroup() + group._session_exit_stacks[mock_session] = mock.AsyncMock(spec=contextlib.AsyncExitStack) + await group.connect_with_session( + types.Implementation(name="test_server", version="1.0.0"), + mock_session + ) + + # --- Test Execution & Assertions --- + with pytest.raises(ValueError, match="Resource not found: test://nonexistent"): + await group.read_resource(AnyUrl("test://nonexistent")) + + @pytest.mark.anyio + async def test_read_resource_success(self): + """Test successfully reading a resource from a session group.""" + # --- Mock Dependencies --- + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + test_resource = types.Resource( + name="test_resource", + uri=AnyUrl("test://resource/1"), + description="Test resource" + ) + + # Mock all list methods + mock_session.list_resources.return_value = types.ListResourcesResult(resources=[test_resource]) + mock_session.list_prompts.return_value = types.ListPromptsResult(prompts=[]) + mock_session.list_tools.return_value = types.ListToolsResult(tools=[]) + + # Mock the session's read_resource method + mock_read_result = mock.AsyncMock(spec=types.ReadResourceResult) + mock_read_result.content = [types.TextContent(type="text", text="Resource content")] + mock_session.read_resource.return_value = mock_read_result + + # --- Test Setup --- + group = ClientSessionGroup() + group._session_exit_stacks[mock_session] = mock.AsyncMock(spec=contextlib.AsyncExitStack) + await group.connect_with_session( + types.Implementation(name="test_server", version="1.0.0"), + mock_session + ) + + # Verify resource was added + assert "test_resource" in group._resources + assert group._resources["test_resource"] == test_resource + assert "test_resource" in group._resource_to_session + assert group._resource_to_session["test_resource"] == mock_session + + # --- Test Execution --- + result = await group.read_resource(AnyUrl("test://resource/1")) + + # --- Assertions --- + assert result.content == [types.TextContent(type="text", text="Resource content")] + mock_session.read_resource.assert_called_once_with(AnyUrl("test://resource/1")) From a92dea37af998214bef9b49a4e3d05de2f965c4c Mon Sep 17 00:00:00 2001 From: Aishwarya Date: Sun, 8 Jun 2025 22:46:22 +0545 Subject: [PATCH 2/4] ruff fixes --- src/mcp/client/session_group.py | 4 +-- tests/client/test_session_group.py | 52 ++++++++++++++++++------------ 2 files changed, 33 insertions(+), 23 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 0fe15da1c..906b76ab3 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -16,7 +16,7 @@ from typing import Any, TypeAlias import anyio -from pydantic import BaseModel, AnyUrl +from pydantic import AnyUrl, BaseModel from typing_extensions import Self import mcp @@ -303,7 +303,7 @@ async def _aggregate_components( resources_temp: dict[str, types.Resource] = {} tools_temp: dict[str, types.Tool] = {} tool_to_session_temp: dict[str, mcp.ClientSession] = {} - resource_to_session_temp: dict[str, mcp.ClientSession] = {} + resource_to_session_temp: dict[str, mcp.ClientSession] = {} try: prompts = (await session.list_prompts()).prompts diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 0ab59f527..5dcd72953 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -405,22 +405,25 @@ async def test_read_resource_not_found(self): test_resource = types.Resource( name="test_resource", uri=AnyUrl("test://resource/1"), - description="Test resource" + description="Test resource", ) - + # Mock all list methods - mock_session.list_resources.return_value = types.ListResourcesResult(resources=[test_resource]) + mock_session.list_resources.return_value = types.ListResourcesResult( + resources=[test_resource] + ) mock_session.list_prompts.return_value = types.ListPromptsResult(prompts=[]) mock_session.list_tools.return_value = types.ListToolsResult(tools=[]) - + # --- Test Setup --- group = ClientSessionGroup() - group._session_exit_stacks[mock_session] = mock.AsyncMock(spec=contextlib.AsyncExitStack) + group._session_exit_stacks[mock_session] = mock.AsyncMock( + spec=contextlib.AsyncExitStack + ) await group.connect_with_session( - types.Implementation(name="test_server", version="1.0.0"), - mock_session + types.Implementation(name="test_server", version="1.0.0"), mock_session ) - + # --- Test Execution & Assertions --- with pytest.raises(ValueError, match="Resource not found: test://nonexistent"): await group.read_resource(AnyUrl("test://nonexistent")) @@ -433,36 +436,43 @@ async def test_read_resource_success(self): test_resource = types.Resource( name="test_resource", uri=AnyUrl("test://resource/1"), - description="Test resource" + description="Test resource", ) - + # Mock all list methods - mock_session.list_resources.return_value = types.ListResourcesResult(resources=[test_resource]) + mock_session.list_resources.return_value = types.ListResourcesResult( + resources=[test_resource] + ) mock_session.list_prompts.return_value = types.ListPromptsResult(prompts=[]) mock_session.list_tools.return_value = types.ListToolsResult(tools=[]) - + # Mock the session's read_resource method mock_read_result = mock.AsyncMock(spec=types.ReadResourceResult) - mock_read_result.content = [types.TextContent(type="text", text="Resource content")] + mock_read_result.content = [ + types.TextContent(type="text", text="Resource content") + ] mock_session.read_resource.return_value = mock_read_result - + # --- Test Setup --- group = ClientSessionGroup() - group._session_exit_stacks[mock_session] = mock.AsyncMock(spec=contextlib.AsyncExitStack) + group._session_exit_stacks[mock_session] = mock.AsyncMock( + spec=contextlib.AsyncExitStack + ) await group.connect_with_session( - types.Implementation(name="test_server", version="1.0.0"), - mock_session + types.Implementation(name="test_server", version="1.0.0"), mock_session ) - + # Verify resource was added assert "test_resource" in group._resources assert group._resources["test_resource"] == test_resource assert "test_resource" in group._resource_to_session assert group._resource_to_session["test_resource"] == mock_session - + # --- Test Execution --- result = await group.read_resource(AnyUrl("test://resource/1")) - + # --- Assertions --- - assert result.content == [types.TextContent(type="text", text="Resource content")] + assert result.content == [ + types.TextContent(type="text", text="Resource content") + ] mock_session.read_resource.assert_called_once_with(AnyUrl("test://resource/1")) From a65292570574295526221bd7cdd99189458b45f0 Mon Sep 17 00:00:00 2001 From: Aishwarya Date: Mon, 9 Jun 2025 10:57:11 +0545 Subject: [PATCH 3/4] pyright test fix --- tests/client/test_session_group.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 5dcd72953..4761c9dd4 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -448,7 +448,7 @@ async def test_read_resource_success(self): # Mock the session's read_resource method mock_read_result = mock.AsyncMock(spec=types.ReadResourceResult) - mock_read_result.content = [ + mock_read_result.contents = [ types.TextContent(type="text", text="Resource content") ] mock_session.read_resource.return_value = mock_read_result @@ -472,7 +472,7 @@ async def test_read_resource_success(self): result = await group.read_resource(AnyUrl("test://resource/1")) # --- Assertions --- - assert result.content == [ + assert result.contents == [ types.TextContent(type="text", text="Resource content") ] mock_session.read_resource.assert_called_once_with(AnyUrl("test://resource/1")) From 0dc3067f82c749a3b64e8b309a6a675f5660af34 Mon Sep 17 00:00:00 2001 From: Aishwarya Date: Thu, 12 Jun 2025 19:56:14 +0545 Subject: [PATCH 4/4] added few more functions, need to take a look again --- src/mcp/client/session_group.py | 57 +++++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 906b76ab3..a822cc98c 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -187,6 +187,7 @@ async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResu session_tool_name = self.tools[name].name return await session.call_tool(session_tool_name, args) + async def disconnect_from_server(self, session: mcp.ClientSession) -> None: """Disconnects from a single MCP server.""" @@ -382,11 +383,61 @@ def _component_name(self, name: str, server_info: types.Implementation) -> str: async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: """Read a resource from the appropriate session based on the URI.""" - print(self._resources) - print(self._resource_to_session) for name, resource in self._resources.items(): if resource.uri == uri: session = self._resource_to_session.get(name) if session: return await session.read_resource(uri) - raise ValueError(f"Resource not found: {uri}") + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"No session found for resource with URI '{uri}'", + ) + ) + + async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + """Send a resources/subscribe request.""" + for name, resource in self._resources.items(): + if resource.uri == uri: + session = self._resource_to_session[name] + if session: + return await session.subscribe_resource(uri) + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"No session found for resource with URI '{uri}'", + ) + ) + + async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + """Send a resources/unsubscribe request.""" + # Find the session that owns this resource + for name, resource in self._resources.items(): + if resource.uri == uri: + session = self._resource_to_session.get(name) + if session: + return await session.unsubscribe_resource(uri) + + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"No resource found with URI '{uri}'", + ) + ) + + async def get_prompt( + self, name: str, arguments: dict[str, str] | None = None + ) -> types.GetPromptResult: + """Send a prompts/get request.""" + if name in self._prompts: + prompt = self._prompts[name] + session = self._tool_to_session.get(name) + if session: + return await session.get_prompt(prompt.name, arguments) + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"No prompt found with name '{name}'", + ) + ) +