diff --git a/src/core/MCPServer.ts b/src/core/MCPServer.ts index 3d55492..51dab98 100644 --- a/src/core/MCPServer.ts +++ b/src/core/MCPServer.ts @@ -424,6 +424,7 @@ export class MCPServer { { name: this.serverName, version: this.serverVersion }, { capabilities: this.capabilities } ); + tools.forEach((tool) => tool.injectServer(this.server)); logger.debug( `SDK Server instance created with capabilities: ${JSON.stringify(this.capabilities)}` ); diff --git a/src/tools/BaseTool.ts b/src/tools/BaseTool.ts index 794ecb5..0107737 100644 --- a/src/tools/BaseTool.ts +++ b/src/tools/BaseTool.ts @@ -1,17 +1,8 @@ import { z } from 'zod'; -import { Tool as SDKTool } from '@modelcontextprotocol/sdk/types.js'; +import { CreateMessageRequest, CreateMessageResult, Tool as SDKTool } from '@modelcontextprotocol/sdk/types.js'; import { ImageContent } from '../transports/utils/image-handler.js'; - -// Type to check if a Zod type has a description -type HasDescription = T extends { _def: { description: string } } ? T : never; - -// Type to ensure all properties in a Zod object have descriptions -type AllFieldsHaveDescriptions = { - [K in keyof T]: HasDescription; -}; - -// Strict Zod object type that requires all fields to have descriptions -type StrictZodObject = z.ZodObject>; +import { Server } from '@modelcontextprotocol/sdk/server/index.js'; +import { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js'; export type ToolInputSchema = { [K in keyof T]: { @@ -68,6 +59,7 @@ export interface ToolProtocol extends SDKTool { toolCall(request: { params: { name: string; arguments?: Record }; }): Promise; + injectServer(server: Server): void; } /** @@ -104,6 +96,37 @@ export abstract class MCPTool = any, TSchema protected useStringify: boolean = true; [key: string]: unknown; + private server: Server | undefined; + + /** + * Injects the server into this tool to allow sampling requests. + * Automatically called by the MCP server when registering the tool. + * Calling this method manually will result in an error. + */ + public injectServer(server: Server): void { + if (this.server) { + throw new Error(`Server reference has already been injected into '${this.name}' tool.`); + } + this.server = server; + } + + /** + * Submit a sampling request to the client + * @example + * ```typescript + * const result = await this.samplingRequest({ + * messages: [{ role: "user", content: { type: "text", text: "Hello!" } }], + * maxTokens: 100 + * }); + * ``` + */ + public readonly samplingRequest = async (request: CreateMessageRequest['params'], options?: RequestOptions): Promise => { + if (!this.server) { + throw new Error(`Server reference has not been injected into '${this.name}' tool.`); + } + return await this.server.createMessage(request, options); + }; + /** * Validates the tool schema. This is called automatically when the tool is registered * with an MCP server, but can also be called manually for testing. diff --git a/tests/tools/BaseTool.test.ts b/tests/tools/BaseTool.test.ts index aaa20b5..e9808f2 100644 --- a/tests/tools/BaseTool.test.ts +++ b/tests/tools/BaseTool.test.ts @@ -1,6 +1,16 @@ -import { describe, it, expect, beforeEach } from '@jest/globals'; +import { describe, it, expect, beforeEach, jest } from '@jest/globals'; import { z } from 'zod'; import { MCPTool } from '../../src/tools/BaseTool.js'; +import { Server } from '@modelcontextprotocol/sdk/server/index.js'; +import { CreateMessageRequest, CreateMessageResult } from '@modelcontextprotocol/sdk/types.js'; +import {RequestOptions} from '@modelcontextprotocol/sdk/shared/protocol.js'; + +// Mock the Server class +jest.mock('@modelcontextprotocol/sdk/server/index.js', () => ({ + Server: jest.fn().mockImplementation(() => ({ + createMessage: jest.fn(), + })), +})); describe('BaseTool', () => { describe('Legacy Pattern (Separate Schema Definition)', () => { @@ -488,4 +498,147 @@ describe('BaseTool', () => { console.log(JSON.stringify(definition, null, 2)); }); }); + + describe('Sampling Functionality', () => { + class SamplingTool extends MCPTool { + name = 'sampling_tool'; + description = 'A tool that uses sampling'; + schema = z.object({ + prompt: z.string().describe('The prompt to sample'), + }); + + protected async execute(input: { prompt: string }): Promise { + const result = await this.samplingRequest({ + messages: [ + { + role: 'user', + content: { + type: 'text', + text: input.prompt, + }, + }, + ], + maxTokens: 100, + }); + + return { sampledText: result.content.text }; + } + } + + let samplingTool: SamplingTool; + let mockServer: jest.Mocked; + beforeEach(() => { + samplingTool = new SamplingTool(); + mockServer = new Server( + { name: 'test-server', version: '1.0.0' }, + { capabilities: {} } + ) as jest.Mocked; + mockServer.createMessage = jest.fn(); + }); + + describe('Server Injection', () => { + it('should allow server injection', () => { + expect(() => samplingTool.injectServer(mockServer)).not.toThrow(); + }); + + it('should prevent double injection', () => { + samplingTool.injectServer(mockServer); + expect(() => samplingTool.injectServer(mockServer)).toThrow( + "Server reference has already been injected into 'sampling_tool' tool." + ); + }); + + it('should throw error when sampling without server injection', async () => { + await expect( + samplingTool.samplingRequest({ + messages: [{ role: 'user', content: { type: 'text', text: 'test' } }], + maxTokens: 100, + }) + ).rejects.toThrow("Server reference has not been injected into 'sampling_tool' tool."); + }); + }); + + describe('Sampling Requests', () => { + beforeEach(() => { + samplingTool.injectServer(mockServer); + }); + + it('should make sampling requests with correct parameters', async () => { + const mockResult: CreateMessageResult = { + model: 'test-model', + role: 'assistant', + content: { type: 'text', text: 'Sampled response' }, + }; + mockServer.createMessage.mockResolvedValue(mockResult); + + const request: CreateMessageRequest['params'] = { + messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], + maxTokens: 100, + temperature: 0.7, + systemPrompt: 'Be helpful', + }; + + const result = await samplingTool.samplingRequest(request); + + expect(mockServer.createMessage).toHaveBeenCalledWith(request, undefined); + expect(result).toEqual(mockResult); + }); + + it('should handle sampling errors gracefully', async () => { + mockServer.createMessage.mockRejectedValue(new Error('Sampling failed')); + + await expect( + samplingTool.samplingRequest({ + messages: [{ role: 'user', content: { type: 'text', text: 'test' } }], + maxTokens: 100, + }) + ).rejects.toThrow('Sampling failed'); + }); + + it('should support complex sampling requests with all parameters', async () => { + const mockResult: CreateMessageResult = { + model: 'claude-3-sonnet', + role: 'assistant', + content: { type: 'text', text: 'Complex response' }, + stopReason: 'endTurn', + }; + mockServer.createMessage.mockResolvedValue(mockResult); + + const complexRequest: CreateMessageRequest['params'] = { + messages: [ + { role: 'user', content: { type: 'text', text: 'First message' } }, + { role: 'assistant', content: { type: 'text', text: 'Assistant response' } }, + { role: 'user', content: { type: 'text', text: 'Follow up' } }, + ], + maxTokens: 500, + temperature: 0.8, + systemPrompt: 'You are a helpful assistant', + includeContext: 'thisServer', + modelPreferences: { + hints: [{ name: 'claude-3' }], + costPriority: 0.3, + speedPriority: 0.7, + intelligencePriority: 0.9, + }, + stopSequences: ['END', 'STOP'], + metadata: { taskType: 'analysis' }, + }; + + const options: RequestOptions = { + timeout: 5000, + maxTotalTimeout: 10000, + signal: new AbortController().signal, + resetTimeoutOnProgress: true, + onprogress: (progress) => { + console.log('Progress:', progress); + }, + } + + const result = await samplingTool.samplingRequest(complexRequest, options); + + expect(mockServer.createMessage).toHaveBeenCalledWith(complexRequest, options); + expect(result).toEqual(mockResult); + }); + }); + }); });