Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 22 additions & 51 deletions packages/app/src/server/gradio-endpoint-connector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { createGradioToolName } from './utils/gradio-utils.js';
import { createAudioPlayerUIResource } from './utils/ui/audio-player.js';
import { spaceMetadataCache, CACHE_CONFIG } from './utils/gradio-cache.js';
import { callGradioTool, applyResultPostProcessing, type GradioToolCallOptions } from './utils/gradio-tool-caller.js';
import { parseGradioSchemaResponse } from '@llmindset/hf-mcp';

// Define types for JSON Schema
interface JsonSchemaProperty {
Expand All @@ -31,12 +32,6 @@ interface JsonSchema {
}

// Define type for array format schema
interface ArrayFormatTool {
name: string;
description?: string;
inputSchema: JsonSchema;
}

interface EndpointConnection {
endpointId: string;
originalIndex: number;
Expand Down Expand Up @@ -78,62 +73,38 @@ function createTimeout(ms: number): Promise<never> {
});
}

/**
* Parses schema response and extracts tools based on format (array or object)
*/
// Kept export for callers; now delegates to shared helper and tracks metrics.
export function parseSchemaResponse(
schemaResponse: unknown,
endpointId: string,
subdomain: string
): Array<{ name: string; description?: string; inputSchema: JsonSchema }> {
// Handle both array and object schema formats
let tools: Array<{ name: string; description?: string; inputSchema: JsonSchema }> = [];

if (Array.isArray(schemaResponse)) {
// NEW-- Array format: [{ name: "toolName", description: "...", inputSchema: {...} }, ...]
tools = (schemaResponse as ArrayFormatTool[]).filter(
(tool): tool is ArrayFormatTool =>
typeof tool === 'object' &&
tool !== null &&
'name' in tool &&
typeof tool.name === 'string' &&
'inputSchema' in tool
);
logger.debug(
{
endpointId,
toolCount: tools.length,
tools: tools.map((t) => t.name),
},
'Retrieved schema (array format)'
);
} else if (typeof schemaResponse === 'object' && schemaResponse !== null) {
// Object format: { "toolName": { properties: {...}, required: [...] }, ... }
const schema = schemaResponse as Record<string, JsonSchema>;
tools = Object.entries(schema).map(([name, toolSchema]) => ({
name,
description: typeof toolSchema.description === 'string' ? toolSchema.description : undefined,
inputSchema: toolSchema,
}));
): Array<{ name: string; description?: string; inputSchema: JsonSchema }> {
try {
const parsed = parseGradioSchemaResponse(schemaResponse);
gradioMetrics.recordSchemaFormat(parsed.format);

logger.debug(
{
endpointId,
toolCount: tools.length,
tools: tools.map((t) => t.name),
toolCount: parsed.tools.length,
tools: parsed.tools.map((t) => t.name),
format: parsed.format,
},
'Retrieved schema (object format)'
'Retrieved schema'
);
} else {
logger.error({ endpointId, subdomain, schemaType: typeof schemaResponse }, 'Invalid schema format');
throw new Error('Invalid schema format: expected array or object');
}

if (tools.length === 0) {
logger.error({ endpointId, subdomain }, 'No tools found in schema');
throw new Error('No tools found in schema');
return parsed.tools as Array<{ name: string; description?: string; inputSchema: JsonSchema }>;
} catch (error) {
if (error instanceof Error && error.message.includes('no tools found')) {
// Preserve legacy error wording expected by tests/callers
throw new Error('No tools found in schema');
}
logger.error(
{ endpointId, subdomain, schemaType: typeof schemaResponse, error: error instanceof Error ? error.message : String(error) },
'Invalid schema format'
);
throw error;
}

return tools;
}

/**
Expand Down
27 changes: 17 additions & 10 deletions packages/app/src/server/utils/gradio-discovery.ts
Original file line number Diff line number Diff line change
Expand Up @@ -312,16 +312,23 @@ async function fetchSchema(
// Convert to Tool format
const tools: Tool[] = parsed
.filter((parsedTool) => !parsedTool.name.toLowerCase().includes('<lambda'))
.map((parsedTool) => ({
name: parsedTool.name,
description: parsedTool.description || `${parsedTool.name} tool`,
inputSchema: {
type: 'object',
properties: parsedTool.inputSchema.properties || {},
required: parsedTool.inputSchema.required || [],
description: parsedTool.inputSchema.description,
},
}));
.map((parsedTool) => {
const inputSchema = parsedTool.inputSchema as {
properties?: Record<string, object>;
required?: string[];
description?: string;
};
return {
name: parsedTool.name,
description: parsedTool.description || `${parsedTool.name} tool`,
inputSchema: {
type: 'object',
properties: inputSchema.properties || {},
required: inputSchema.required || [],
description: inputSchema.description,
},
};
});

// Create schema object
const schema: CachedSchema = {
Expand Down
22 changes: 22 additions & 0 deletions packages/app/src/server/utils/gradio-metrics.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ export interface GradioToolMetrics {
failure: number;
/** Breakdown by tool name */
byTool: Record<string, { success: number; failure: number }>;
/** Count of schema formats seen */
schemaFormats: {
array: number;
object: number;
};
}

export class GradioMetricsCollector {
Expand All @@ -24,6 +29,10 @@ export class GradioMetricsCollector {
success: 0,
failure: 0,
byTool: {},
schemaFormats: {
array: 0,
object: 0,
},
};
schemaFetchErrors: Set<string> = new Set();

Expand Down Expand Up @@ -85,6 +94,10 @@ export class GradioMetricsCollector {
success: 0,
failure: 0,
byTool: {},
schemaFormats: {
array: 0,
object: 0,
},
};
}

Expand All @@ -105,6 +118,15 @@ export class GradioMetricsCollector {
this.schemaFetchErrors.add(toolName);
return true;
}

/** track whether schema was array or object */
public recordSchemaFormat(format: 'array' | 'object'): void {
if (format === 'array') {
this.metrics.schemaFormats.array++;
} else if (format === 'object') {
this.metrics.schemaFormats.object++;
}
}
}

// Export singleton instance
Expand Down
130 changes: 27 additions & 103 deletions packages/app/src/server/utils/gradio-tool-caller.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { SSEClientTransport, type SSEClientTransportOptions } from '@modelcontextprotocol/sdk/client/sse.js';
import {
CallToolResultSchema,
type ServerNotification,
type ServerRequest,
} from '@modelcontextprotocol/sdk/types.js';
import type { RequestHandlerExtra, RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js';
import type { CallToolResultSchema} from '@modelcontextprotocol/sdk/types.js';
import { type ServerNotification, type ServerRequest } from '@modelcontextprotocol/sdk/types.js';
import type { RequestHandlerExtra } from '@modelcontextprotocol/sdk/shared/protocol.js';
import { callGradioToolWithHeaders } from '@llmindset/hf-mcp';
import { logger } from './logger.js';
import { stripImageContentFromResult, extractUrlFromContent } from './gradio-result-processor.js';

Expand All @@ -31,58 +27,6 @@ export interface GradioToolCallOptions {
spaceName?: string;
}

/**
* Creates SSE connection to a Gradio endpoint
*/
async function createGradioConnection(sseUrl: string, hfToken?: string): Promise<Client> {
logger.debug({ url: sseUrl }, 'Creating SSE connection to Gradio endpoint');

// Create MCP client
const remoteClient = new Client(
{
name: 'hf-mcp-gradio-client',
version: '1.0.0',
},
{
capabilities: {},
}
);

// Create SSE transport with HF token if available
const transportOptions: SSEClientTransportOptions = {};
if (hfToken) {
const headerName = 'X-HF-Authorization';
const customHeaders = {
[headerName]: `Bearer ${hfToken}`,
};
logger.trace('Creating Gradio connection with authorization header');

// Headers for POST requests
transportOptions.requestInit = {
headers: customHeaders,
};

// Headers for SSE connection
transportOptions.eventSourceInit = {
fetch: (url, init) => {
const headers = new Headers(init.headers);
Object.entries(customHeaders).forEach(([key, value]) => {
headers.set(key, value);
});
return fetch(url.toString(), { ...init, headers });
},
};
}

const transport = new SSEClientTransport(new URL(sseUrl), transportOptions);

// Connect the client to the transport
await remoteClient.connect(transport);
logger.debug('SSE connection established');

return remoteClient;
}

/**
* Unified Gradio tool caller that handles:
* - SSE connection management
Expand All @@ -104,53 +48,33 @@ export async function callGradioTool(
): Promise<typeof CallToolResultSchema._type> {
logger.info({ tool: toolName, params: parameters }, 'Calling Gradio tool via unified caller');

const client = await createGradioConnection(sseUrl, hfToken);

try {
// Check if the client is requesting progress notifications
const progressToken = extra?._meta?.progressToken;
const requestOptions: RequestOptions = {};

if (progressToken !== undefined && extra) {
logger.debug({ tool: toolName, progressToken }, 'Progress notifications requested');

// Set up progress relay from remote tool to our client
requestOptions.onprogress = async (progress) => {
logger.trace({ tool: toolName, progressToken, progress }, 'Relaying progress notification');

// Relay the progress notification to our client
await extra.sendNotification({
method: 'notifications/progress',
params: {
progressToken,
progress: progress.progress,
total: progress.total,
message: progress.message,
},
});
};

// Keep long-running tool calls alive while progress is flowing
requestOptions.resetTimeoutOnProgress = true;
}
// Call the remote tool via shared helper (handles SSE, progress relay, header capture)
const { result, capturedHeaders } = await callGradioToolWithHeaders(
sseUrl,
toolName,
parameters,
hfToken,
extra,
{ logProxiedReplica: true }
);

// Call the remote tool and return raw result
return await client.request(
{
method: 'tools/call',
params: {
name: toolName,
arguments: parameters,
_meta: progressToken !== undefined ? { progressToken } : undefined,
// Attach captured headers (e.g., X-Proxied-Replica) to the result meta so callers can inspect them
const proxiedReplica = capturedHeaders['x-proxied-replica'];
if (proxiedReplica) {
logger.debug({ tool: toolName, proxiedReplica }, 'Captured Gradio response header');
return {
...result,
_meta: {
...(result as { _meta?: Record<string, unknown> })._meta,
responseHeaders: {
...(result as { _meta?: { responseHeaders?: Record<string, unknown> } })._meta?.responseHeaders,
'x-proxied-replica': proxiedReplica,
},
},
CallToolResultSchema,
requestOptions
);
} finally {
// Always clean up the connection
await client.close();
} as typeof CallToolResultSchema._type;
}

return result;
}

/**
Expand Down
2 changes: 2 additions & 0 deletions packages/mcp/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ export * from './readme-utils.js';
export * from './use-space.js';
export * from './jobs/jobs-tool.js';
export * from './space/dynamic-space-tool.js';
export * from './space/utils/gradio-caller.js';
export * from './space/utils/gradio-schema.js';

// Export shared types
export * from './types/tool-result.js';
Expand Down
2 changes: 1 addition & 1 deletion packages/mcp/src/space/commands/discover.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { ToolResult } from '../../types/tool-result.js';
import { escapeMarkdown } from '../../utilities.js';
import { VIEW_PARAMETERS } from '../dynamic-space-tool.js';
import { VIEW_PARAMETERS } from '../types.js';

/**
* Prompt configuration for discover operation (from DYNAMIC_SPACE_DATA)
Expand Down
Loading