diff --git a/src/client/clientGroup.test.ts b/src/client/clientGroup.test.ts new file mode 100644 index 000000000..f53c682a7 --- /dev/null +++ b/src/client/clientGroup.test.ts @@ -0,0 +1,109 @@ +import { ClientGroup } from "./clientGroup.js"; +import { Client } from "./index.js"; +import { Tool, CallToolRequest, CallToolResultSchema, Implementation } from "../types.js"; + +// Mock Client class for testing ClientGroup +export class MockClient extends Client { + mockListTools = jest.fn(); + mockCallTool = jest.fn(); + mockClose = jest.fn(); + + constructor(clientInfo: Implementation) { + super(clientInfo); + } + + listTools = this.mockListTools.mockImplementation(async () => { + return []; + }); + + callTool = this.mockCallTool.mockImplementation(async (params) => { + return { result: `mock result for ${params.name}` }; + }); + + close = this.mockClose.mockImplementation(async () => { + // Do nothing + }); + + // Needed for the base class constructor but not used in these tests + override async connect() {} + override assertCapability() {} + override assertCapabilityForMethod() {} + override assertNotificationCapability() {} + override assertRequestHandlerCapability() {} +} + + +describe("ClientGroup", () => { + let mockClient1: MockClient; + let mockClient2: MockClient; + + beforeEach(() => { + mockClient1 = new MockClient({ name: "client1", version: "1.0" }); + mockClient2 = new MockClient({ name: "client2", version: "1.0" }); + }); + + test("should list tools from all clients", async () => { + const tool1: Tool = { name: "tool1", description: "description1", parameters: {}, inputSchema: { type: 'object' } }; + const tool2: Tool = { name: "tool2", description: "description2", parameters: {}, inputSchema: { type: 'object' } }; + mockClient1.mockListTools.mockResolvedValueOnce({ tools: [tool1] }); + mockClient2.mockListTools.mockResolvedValueOnce({ tools: [tool2] }); + + const clientGroup = await ClientGroup.create([mockClient1, mockClient2]); + + const tools = await clientGroup.listTools(); + expect(mockClient1.mockListTools).toHaveBeenCalled(); + expect(mockClient2.mockListTools).toHaveBeenCalled(); + expect(tools).toHaveLength(2); + expect(tools).toEqual(expect.arrayContaining([tool1, tool2])); + }); + + test("should call the correct tool on the correct client", async () => { + const tool1: Tool = { name: "tool1", description: "description1", parameters: {}, inputSchema: { type: 'object' } }; + const tool2: Tool = { name: "tool2", description: "description2", parameters: {}, inputSchema: { type: 'object' } }; + mockClient1.mockListTools.mockResolvedValueOnce({ tools: [tool1] }); + mockClient2.mockListTools.mockResolvedValueOnce({ tools: [tool2] }); + + const clientGroup = await ClientGroup.create([mockClient1, mockClient2]); + + const params: CallToolRequest["params"] = { + name: "tool1", + parameters: { arg: "value" }, + }; + const result = await clientGroup.callTool(params, CallToolResultSchema); + + expect(mockClient1.mockCallTool).toHaveBeenCalledWith( + params, + CallToolResultSchema, + undefined + ); + expect(mockClient2.mockCallTool).not.toHaveBeenCalled(); + expect(result).toEqual({ result: "mock result for tool1" }); + }); + + test("should throw error if tool is not found", async () => { + mockClient1.mockListTools.mockResolvedValueOnce({ tools: [] }); + mockClient2.mockListTools.mockResolvedValueOnce({ tools: [] }); + + const clientGroup = await ClientGroup.create([mockClient1, mockClient2]); + + const params: CallToolRequest["params"] = { + name: "nonExistentTool", + parameters: {}, + }; + + await expect(clientGroup.callTool(params, CallToolResultSchema)).rejects.toThrow( + "Trying to call too nonExistentTool which is not provided by the client group" + ); + }); + + test("should call close on all clients", async () => { + mockClient1.mockListTools.mockResolvedValueOnce({ tools: [] }); + mockClient2.mockListTools.mockResolvedValueOnce({ tools: [] }); + + const clientGroup = await ClientGroup.create([mockClient1, mockClient2]); + await clientGroup.close(); + + expect(mockClient1.mockClose).toHaveBeenCalled(); + expect(mockClient2.mockClose).toHaveBeenCalled(); + }); +}); diff --git a/src/client/clientGroup.ts b/src/client/clientGroup.ts new file mode 100644 index 000000000..c9821848e --- /dev/null +++ b/src/client/clientGroup.ts @@ -0,0 +1,113 @@ +import { RequestOptions } from "../shared/protocol.js"; +import { Tool, CallToolRequest, CallToolResultSchema, CompatibilityCallToolResultSchema } from "../types.js"; +import { Client } from "./index.js"; + +/** + * A group of MCP clients. + * + * This class makes it easier to manage multiple MCP server connections. + * + * Example: + * + * ```typescript + * + * // Create a client group + * const clientGroup = await ClientGroup.create([client1, client2]); + * + * // List tools from all clients + * const tools = await clientGroup.listTools(); + * + * // Call a tool by name + * const result = await clientGroup.callTool({ name: "myTool", params: {} }); + * + * // Close all clients + * await clientGroup.close(); + * ``` + */ +export class ClientGroup { + private _clients: Client[]; + private _allTools: Tool[]; + private _toolToClient: { [key: string]: Client; } = {}; + + private constructor( + clients: Client[] + ) { + this._clients = clients; + this._allTools = []; + } + + /** + * Creates a new ClientGroup. + * + * @param clients The list of clients to include in the group. + */ + static async create( + clients: Client[], + options?: RequestOptions + ): Promise { + const group = new ClientGroup(clients); + await group.update(options); + return group; + } + + private async update(options?: RequestOptions) { + this._allTools = []; + this._toolToClient = {}; + for (const client of this._clients) { + for (const tool of (await client.listTools(options)).tools) { + if (this._toolToClient[tool.name]) { + // TODO(amirh): we should allow the users to configure tool renames. + console.warn( + `Tool name: ${tool.name} is available on multiple servers, picking an arbitrary one` + ); + } + this._toolToClient[tool.name] = client; + this._allTools.push(tool); + } + } + } + + /** + * Lists all tools available from all clients in the group. + * + * @param options Optional request options. + * @returns A promise that resolves with a list of tools. + */ + async listTools(): Promise { + return structuredClone(this._allTools); + } + + /** + * Calls a tool provided by one of the clients in the group. + * + * @param params The parameters for the tool call. + * @param resultSchema The schema to use for validating the tool result. + * @param options Optional request options. + * @returns A promise that resolves with the tool result. + * @throws An error if no client provides the requested tool. + */ + async callTool( + params: CallToolRequest["params"], + resultSchema: typeof CallToolResultSchema | + typeof CompatibilityCallToolResultSchema = CallToolResultSchema, + options?: RequestOptions + ) { + if (!this._toolToClient[params.name]) { + throw new Error( + `Trying to call too ${params.name} which is not provided by the client group` + ); + } + return this._toolToClient[params.name].callTool(params, resultSchema, options); + } + + /** + * Closes all clients in the group. + * + * @returns A promise that resolves when all clients are closed. + */ + async close() { + for (const client of this._clients) { + await client.close(); + } + } +} diff --git a/src/client/index.ts b/src/client/index.ts index a3edd0beb..69612b775 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -434,4 +434,4 @@ export class Client< async sendRootsListChanged() { return this.notification({ method: "notifications/roots/list_changed" }); } -} +} \ No newline at end of file diff --git a/src/examples/client/clientGroupSample.ts b/src/examples/client/clientGroupSample.ts index a1d2ed02f..86b7c0aa7 100644 --- a/src/examples/client/clientGroupSample.ts +++ b/src/examples/client/clientGroupSample.ts @@ -1,5 +1,6 @@ import { Tool } from "../../types.js"; import { Client } from "../../client/index.js"; +import { ClientGroup } from "../../client/clientGroup.js"; import { InMemoryTransport } from "../../inMemory.js"; import { McpServer, ToolCallback } from "../../server/mcp.js"; import { Transport } from "../../shared/transport.js"; @@ -29,38 +30,23 @@ async function main(): Promise { client3.connect(clientTransports[2]); const allClients = [client1, client2, client3]; - const toolToClient: { [key: string]: Client } = {}; - const allTools = []; - - for (const client of allClients) { - for (const tool of (await client.listTools()).tools) { - if (toolToClient[tool.name]) { - console.warn( - `Tool name: ${tool.name} is available on multiple servers, picking an arbitrary one`, - ); - } - toolToClient[tool.name] = client; - allTools.push(tool); - } - } + const clientGroup = await ClientGroup.create(allClients); const allResources = []; allResources.push(...(await client1.listResources()).resources); allResources.push(...(await client2.listResources()).resources); allResources.push(...(await client3.listResources()).resources); - const toolName = simulatePromptModel(allTools); + const toolName = simulatePromptModel(await clientGroup.listTools()); console.log(`Invoking tool: ${toolName}`); - const toolResult = await toolToClient[toolName].callTool({ + const toolResult = await clientGroup.callTool({ name: toolName, }); console.log(toolResult); - for (const client of allClients) { - await client.close(); - } + clientGroup.close(); } // Start the example