Skip to content

feat: ClientGroup for managing multiple MCP server connections #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: session_group_1_sample
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions src/client/clientGroup.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import { ClientGroup } from "./clientGroup.js";
import { Client } from "./index.js";
import { Tool, CallToolRequest, CallToolResultSchema, Implementation } from "../types.js";

// Mock Client class for testing ClientGroup
export class MockClient extends Client {
mockListTools = jest.fn();
mockCallTool = jest.fn();
mockClose = jest.fn();

constructor(clientInfo: Implementation) {
super(clientInfo);
}

listTools = this.mockListTools.mockImplementation(async () => {
return [];
});

callTool = this.mockCallTool.mockImplementation(async (params) => {
return { result: `mock result for ${params.name}` };
});

close = this.mockClose.mockImplementation(async () => {
// Do nothing
});

// Needed for the base class constructor but not used in these tests
override async connect() {}
override assertCapability() {}
override assertCapabilityForMethod() {}
override assertNotificationCapability() {}
override assertRequestHandlerCapability() {}
}


describe("ClientGroup", () => {
let mockClient1: MockClient;
let mockClient2: MockClient;

beforeEach(() => {
mockClient1 = new MockClient({ name: "client1", version: "1.0" });
mockClient2 = new MockClient({ name: "client2", version: "1.0" });
});

test("should list tools from all clients", async () => {
const tool1: Tool = { name: "tool1", description: "description1", parameters: {}, inputSchema: { type: 'object' } };
const tool2: Tool = { name: "tool2", description: "description2", parameters: {}, inputSchema: { type: 'object' } };
mockClient1.mockListTools.mockResolvedValueOnce({ tools: [tool1] });
mockClient2.mockListTools.mockResolvedValueOnce({ tools: [tool2] });

const clientGroup = await ClientGroup.create([mockClient1, mockClient2]);

const tools = await clientGroup.listTools();
expect(mockClient1.mockListTools).toHaveBeenCalled();
expect(mockClient2.mockListTools).toHaveBeenCalled();
expect(tools).toHaveLength(2);
expect(tools).toEqual(expect.arrayContaining([tool1, tool2]));
});

test("should call the correct tool on the correct client", async () => {
const tool1: Tool = { name: "tool1", description: "description1", parameters: {}, inputSchema: { type: 'object' } };
const tool2: Tool = { name: "tool2", description: "description2", parameters: {}, inputSchema: { type: 'object' } };
mockClient1.mockListTools.mockResolvedValueOnce({ tools: [tool1] });
mockClient2.mockListTools.mockResolvedValueOnce({ tools: [tool2] });

const clientGroup = await ClientGroup.create([mockClient1, mockClient2]);

const params: CallToolRequest["params"] = {
name: "tool1",
parameters: { arg: "value" },
};
const result = await clientGroup.callTool(params, CallToolResultSchema);

expect(mockClient1.mockCallTool).toHaveBeenCalledWith(
params,
CallToolResultSchema,
undefined
);
expect(mockClient2.mockCallTool).not.toHaveBeenCalled();
expect(result).toEqual({ result: "mock result for tool1" });
});

test("should throw error if tool is not found", async () => {
mockClient1.mockListTools.mockResolvedValueOnce({ tools: [] });
mockClient2.mockListTools.mockResolvedValueOnce({ tools: [] });

const clientGroup = await ClientGroup.create([mockClient1, mockClient2]);

const params: CallToolRequest["params"] = {
name: "nonExistentTool",
parameters: {},
};

await expect(clientGroup.callTool(params, CallToolResultSchema)).rejects.toThrow(
"Trying to call too nonExistentTool which is not provided by the client group"
);
});

test("should call close on all clients", async () => {
mockClient1.mockListTools.mockResolvedValueOnce({ tools: [] });
mockClient2.mockListTools.mockResolvedValueOnce({ tools: [] });

const clientGroup = await ClientGroup.create([mockClient1, mockClient2]);
await clientGroup.close();

expect(mockClient1.mockClose).toHaveBeenCalled();
expect(mockClient2.mockClose).toHaveBeenCalled();
});
});
113 changes: 113 additions & 0 deletions src/client/clientGroup.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import { RequestOptions } from "../shared/protocol.js";
import { Tool, CallToolRequest, CallToolResultSchema, CompatibilityCallToolResultSchema } from "../types.js";
import { Client } from "./index.js";

/**
* A group of MCP clients.
*
* This class makes it easier to manage multiple MCP server connections.
*
* Example:
*
* ```typescript
*
* // Create a client group
* const clientGroup = await ClientGroup.create([client1, client2]);
*
* // List tools from all clients
* const tools = await clientGroup.listTools();
*
* // Call a tool by name
* const result = await clientGroup.callTool({ name: "myTool", params: {} });
*
* // Close all clients
* await clientGroup.close();
* ```
*/
export class ClientGroup {
private _clients: Client[];
private _allTools: Tool[];
private _toolToClient: { [key: string]: Client; } = {};

private constructor(
clients: Client[]
) {
this._clients = clients;
this._allTools = [];
}

/**
* Creates a new ClientGroup.
*
* @param clients The list of clients to include in the group.
*/
static async create(
clients: Client[],
options?: RequestOptions
): Promise<ClientGroup> {
const group = new ClientGroup(clients);
await group.update(options);
return group;
}

private async update(options?: RequestOptions) {
this._allTools = [];
this._toolToClient = {};
for (const client of this._clients) {
for (const tool of (await client.listTools(options)).tools) {
if (this._toolToClient[tool.name]) {
// TODO(amirh): we should allow the users to configure tool renames.
console.warn(
`Tool name: ${tool.name} is available on multiple servers, picking an arbitrary one`
);
}
this._toolToClient[tool.name] = client;
this._allTools.push(tool);
}
}
}

/**
* Lists all tools available from all clients in the group.
*
* @param options Optional request options.
* @returns A promise that resolves with a list of tools.
*/
async listTools(): Promise<Tool[]> {
return structuredClone(this._allTools);
}

/**
* Calls a tool provided by one of the clients in the group.
*
* @param params The parameters for the tool call.
* @param resultSchema The schema to use for validating the tool result.
* @param options Optional request options.
* @returns A promise that resolves with the tool result.
* @throws An error if no client provides the requested tool.
*/
async callTool(
params: CallToolRequest["params"],
resultSchema: typeof CallToolResultSchema |
typeof CompatibilityCallToolResultSchema = CallToolResultSchema,
options?: RequestOptions
) {
if (!this._toolToClient[params.name]) {
throw new Error(
`Trying to call too ${params.name} which is not provided by the client group`
);
}
return this._toolToClient[params.name].callTool(params, resultSchema, options);
}

/**
* Closes all clients in the group.
*
* @returns A promise that resolves when all clients are closed.
*/
async close() {
for (const client of this._clients) {
await client.close();
}
}
}
2 changes: 1 addition & 1 deletion src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -434,4 +434,4 @@ export class Client<
async sendRootsListChanged() {
return this.notification({ method: "notifications/roots/list_changed" });
}
}
}
24 changes: 5 additions & 19 deletions src/examples/client/clientGroupSample.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Tool } from "../../types.js";
import { Client } from "../../client/index.js";
import { ClientGroup } from "../../client/clientGroup.js";
import { InMemoryTransport } from "../../inMemory.js";
import { McpServer, ToolCallback } from "../../server/mcp.js";
import { Transport } from "../../shared/transport.js";
Expand Down Expand Up @@ -29,38 +30,23 @@ async function main(): Promise<void> {
client3.connect(clientTransports[2]);

const allClients = [client1, client2, client3];
const toolToClient: { [key: string]: Client } = {};
const allTools = [];

for (const client of allClients) {
for (const tool of (await client.listTools()).tools) {
if (toolToClient[tool.name]) {
console.warn(
`Tool name: ${tool.name} is available on multiple servers, picking an arbitrary one`,
);
}
toolToClient[tool.name] = client;
allTools.push(tool);
}
}
const clientGroup = await ClientGroup.create(allClients);

const allResources = [];
allResources.push(...(await client1.listResources()).resources);
allResources.push(...(await client2.listResources()).resources);
allResources.push(...(await client3.listResources()).resources);

const toolName = simulatePromptModel(allTools);
const toolName = simulatePromptModel(await clientGroup.listTools());

console.log(`Invoking tool: ${toolName}`);
const toolResult = await toolToClient[toolName].callTool({
const toolResult = await clientGroup.callTool({
name: toolName,
});

console.log(toolResult);

for (const client of allClients) {
await client.close();
}
clientGroup.close();
}

// Start the example
Expand Down