diff --git a/docs/reflection-v2-protocol.md b/docs/reflection-v2-protocol.md new file mode 100644 index 0000000000..32b131fce3 --- /dev/null +++ b/docs/reflection-v2-protocol.md @@ -0,0 +1,290 @@ +# Genkit Reflection Protocol V2 (WebSocket) + +This document outlines the design for the V2 Reflection API, which uses WebSockets for bidirectional communication between the Genkit CLI (Runtime Manager) and Genkit Runtimes (User Applications). + +## Overview + +In V2, the connection direction is reversed compared to V1: +- **Server**: The Genkit CLI (`RuntimeManagerV2`) starts a WebSocket server. +- **Client**: The Genkit Runtime connects to the CLI's WebSocket server. + +This architecture allows the CLI to easily manage multiple runtimes (e.g., for multi-service projects) and eliminates the need for runtimes to manage their own HTTP servers and ports for reflection. + +## Transport + +- **Protocol**: WebSocket +- **Data Format**: JSON +- **Message Structure**: JSON-RPC 2.0 (modified for streaming) + +## Message Format + +All messages follow the JSON-RPC 2.0 specification. + +### Request +```json +{ + "jsonrpc": "2.0", + "method": "methodName", + "params": { ... }, + "id": 1 +} +``` +*Note: The `id` is generated by the sender (Manager). It can be a number (auto-incrementing) or a string (UUID). It must be unique for the pending request within the WebSocket session.* + +### Response (Success) +```json +{ + "jsonrpc": "2.0", + "result": { ... }, + "id": 1 +} +``` + +### Response (Error) +```json +{ + "jsonrpc": "2.0", + "error": { + "code": -32000, + "message": "Error message", + "data": { ... } // Optional details (stack trace, etc.) + }, + "id": 1 +} +``` + +### Notification +A request without an `id`. +```json +{ + "jsonrpc": "2.0", + "method": "methodName", + "params": { ... } +} +``` + +## Streaming Extension + +JSON-RPC 2.0 does not natively support streaming. We extend it by using Notifications from the Runtime to the Manager associated with a specific Request ID. + +### Stream Chunk Notification +Sent by the Runtime during a streaming `runAction` request. + +```json +{ + "jsonrpc": "2.0", + "method": "streamChunk", + "params": { + "requestId": 1, // Matches the ID of the runAction request + "chunk": { ... } // The chunk data + } +} +``` + +### Run Action State Notification +Sent by the Runtime to provide status updates or metadata (like trace ID) while the action is running, before the result is ready. + +```json +{ + "jsonrpc": "2.0", + "method": "runActionState", + "params": { + "requestId": 1, // Matches the ID of the runAction request + "state": { + "traceId": "..." + } + } +} +``` + +## Protocol Flow + +### 1. Registration (Runtime -> Manager) + +Upon connection, the Runtime must register itself. + +**Request (Runtime -> Manager):** +- **Method**: `register` +- **Params**: + ```typescript + interface RegisterParams { + id: string; // Unique Runtime ID + pid: number; // Process ID + name?: string; // App name + genkitVersion: string; // e.g., "0.9.0" + reflectionApiSpecVersion: number; + envs?: string[]; // Configured environments + } + ``` + +**Response (Manager -> Runtime):** +- **Result**: `void` (null) + +### 2. Configuration (Manager -> Runtime) + +The Manager may push configuration updates to the Runtime, such as the Telemetry Server URL. + +**Notification (Manager -> Runtime):** +- **Method**: `configure` +- **Params**: + ```typescript + interface ConfigureParams { + telemetryServerUrl?: string; + } + ``` + +### 3. List Actions (Manager -> Runtime) + +The Manager requests the list of available actions/flows. + +**Request (Manager -> Runtime):** +- **Method**: `listActions` +- **Params**: `void` (empty object or null) + +**Response (Runtime -> Manager):** +- **Result**: `Record` (Same schema as V1 `/api/actions`) + +### 4. Run Action (Manager -> Runtime) + +The Manager requests the execution of an action. + +**Request (Manager -> Runtime):** +- **Method**: `runAction` +- **Params**: + ```typescript + interface RunActionParams { + key: string; // Action key (e.g., "flowName") + input: any; // Input payload + context?: any; // Context data + telemetryLabels?: Record; + stream?: boolean; // Whether to stream results + } + ``` + +**Scenario A: Non-Streaming Response** + +1. **Notification (Runtime -> Manager)**: `runActionState` (optional, repeated) + - Used to send early trace info or status updates. + - `params.requestId`: Matches request ID. + - `params.state`: The state update (e.g., traceId). + +2. **Response (Runtime -> Manager)**: + - **Result**: + ```typescript + interface RunActionResult { + result: any; // The return value + telemetry?: { + traceId?: string; + }; + } + ``` + +**Scenario B: Streaming Response** + +1. **Notification (Runtime -> Manager)**: `runActionState` (optional, repeated) + - Used to send early trace info or status updates. + +2. **Notification (Runtime -> Manager)**: `streamChunk` (repeated) + - `params.requestId`: Matches request ID. + - `params.chunk`: The partial result. + +3. **Response (Runtime -> Manager)**: Final result. + - **Result**: Same as Non-Streaming (`RunActionResult`). Signals the end of the stream. + +### 5. Health Checks + +The WebSocket connection state itself serves as a basic health check. +- **Heartbeats**: Standard WebSocket Ping/Pong frames should be used to maintain the connection and detect timeouts. + +## Compatibility + +- **V1**: HTTP Server on Runtime, Polling/Request from CLI. +- **V2**: WebSocket Server on CLI, Persistent Connection from Runtime. + +The CLI will determine which mode to use based on the `--experimental-reflection-v2` flag. + +## Example: Streaming Flow Execution + +Below is an example sequence of messages for running a flow named `myFlow` with streaming enabled. + +**1. Manager Requests Execution** +```json +// Request (Manager -> Runtime) +{ + "jsonrpc": "2.0", + "method": "runAction", + "params": { + "key": "/flow/myFlow", + "input": "Describe a cat", + "stream": true + }, + "id": 100 +} +``` + +**2. Runtime Sends Early Trace ID** +```json +// Notification (Runtime -> Manager) +{ + "jsonrpc": "2.0", + "method": "runActionState", + "params": { + "requestId": 100, + "state": { + "traceId": "abc-123-trace-id" + } + } +} +``` + +**3. Runtime Sends Stream Chunks** +```json +// Notification (Runtime -> Manager) +{ + "jsonrpc": "2.0", + "method": "streamChunk", + "params": { + "requestId": 100, + "chunk": { "content": [{ "text": "A cat is "}] } + } +} +``` + +```json +// Notification (Runtime -> Manager) +{ + "jsonrpc": "2.0", + "method": "streamChunk", + "params": { + "requestId": 100, + "chunk": { "content": [{ "text": "a small "}] } + } +} +``` + +```json +// Notification (Runtime -> Manager) +{ + "jsonrpc": "2.0", + "method": "streamChunk", + "params": { + "requestId": 100, + "chunk": { "content": [{ "text": "feline."}] } + } +} +``` + +**4. Runtime Sends Final Result** +```json +// Response (Runtime -> Manager) +{ + "jsonrpc": "2.0", + "result": { + "result": "A cat is a small feline.", + "telemetry": { + "traceId": "abc-123-trace-id" + } + }, + "id": 100 +} +``` diff --git a/genkit-tools/cli/src/commands/start.ts b/genkit-tools/cli/src/commands/start.ts index ca8bda48b2..57868c1961 100644 --- a/genkit-tools/cli/src/commands/start.ts +++ b/genkit-tools/cli/src/commands/start.ts @@ -26,6 +26,8 @@ interface RunOptions { noui?: boolean; port?: string; open?: boolean; + experimentalReflectionV2?: boolean; + allowedTelemetryCorsHostnames?: string[]; } /** Command to run code in dev mode and/or the Dev UI. */ @@ -34,6 +36,15 @@ export const start = new Command('start') .option('-n, --noui', 'do not start the Dev UI', false) .option('-p, --port ', 'port for the Dev UI') .option('-o, --open', 'Open the browser on UI start up') + .option( + '--experimental-reflection-v2', + 'start the experimental reflection server (WebSocket)' + ) + .option( + '--allowed-telemetry-cors-hostnames ', + 'comma separated list of allowed telemetry CORS hostnames', + (value) => value.split(',') + ) .action(async (options: RunOptions) => { const projectRoot = await findProjectRoot(); if (projectRoot.includes('/.Trash/')) { @@ -49,14 +60,22 @@ export const start = new Command('start') const result = await startDevProcessManager( projectRoot, start.args[0], - start.args.slice(1) + start.args.slice(1), + options.experimentalReflectionV2, + options.allowedTelemetryCorsHostnames ); manager = result.manager; processPromise = result.processPromise; } else { - manager = await startManager(projectRoot, true); + manager = await startManager( + projectRoot, + true, + options.experimentalReflectionV2, + options.allowedTelemetryCorsHostnames + ); processPromise = new Promise(() => {}); } + if (!options.noui) { let port: number; if (options.port) { diff --git a/genkit-tools/cli/src/utils/manager-utils.ts b/genkit-tools/cli/src/utils/manager-utils.ts index 7badf39a11..949c4298c2 100644 --- a/genkit-tools/cli/src/utils/manager-utils.ts +++ b/genkit-tools/cli/src/utils/manager-utils.ts @@ -33,7 +33,8 @@ import getPort, { makeRange } from 'get-port'; * This function is not idempotent. Typically you want to make sure it's called only once per cli instance. */ export async function resolveTelemetryServer( - projectRoot: string + projectRoot: string, + allowedTelemetryCorsHostnames?: string[] ): Promise { let telemetryServerUrl = process.env.GENKIT_TELEMETRY_SERVER; if (!telemetryServerUrl) { @@ -45,6 +46,7 @@ export async function resolveTelemetryServer( storeRoot: projectRoot, indexRoot: projectRoot, }), + allowedCorsHostnames: allowedTelemetryCorsHostnames, }); } return telemetryServerUrl; @@ -55,13 +57,19 @@ export async function resolveTelemetryServer( */ export async function startManager( projectRoot: string, - manageHealth?: boolean -): Promise { - const telemetryServerUrl = await resolveTelemetryServer(projectRoot); + manageHealth?: boolean, + experimentalReflectionV2?: boolean, + allowedTelemetryCorsHostnames?: string[] +): Promise { + const telemetryServerUrl = await resolveTelemetryServer( + projectRoot, + allowedTelemetryCorsHostnames + ); const manager = RuntimeManager.create({ telemetryServerUrl, manageHealth, projectRoot, + experimentalReflectionV2, }); return manager; } @@ -69,18 +77,36 @@ export async function startManager( export async function startDevProcessManager( projectRoot: string, command: string, - args: string[] -): Promise<{ manager: RuntimeManager; processPromise: Promise }> { - const telemetryServerUrl = await resolveTelemetryServer(projectRoot); - const processManager = new ProcessManager(command, args, { + args: string[], + experimentalReflectionV2?: boolean, + allowedTelemetryCorsHostnames?: string[] +): Promise<{ + manager: RuntimeManager | any; + processPromise: Promise; +}> { + const telemetryServerUrl = await resolveTelemetryServer( + projectRoot, + allowedTelemetryCorsHostnames + ); + const env: Record = { GENKIT_TELEMETRY_SERVER: telemetryServerUrl, GENKIT_ENV: 'dev', - }); + }; + + let reflectionV2Port: number | undefined; + if (experimentalReflectionV2) { + reflectionV2Port = await getPort({ port: makeRange(3200, 3400) }); + env['GENKIT_REFLECTION_V2_SERVER'] = `ws://localhost:${reflectionV2Port}`; + } + + const processManager = new ProcessManager(command, args, env); const manager = await RuntimeManager.create({ telemetryServerUrl, manageHealth: true, projectRoot, processManager, + experimentalReflectionV2, + reflectionV2Port, }); const processPromise = processManager.start(); return { manager, processPromise }; diff --git a/genkit-tools/common/package.json b/genkit-tools/common/package.json index 57a25d1b16..c22c590ecd 100644 --- a/genkit-tools/common/package.json +++ b/genkit-tools/common/package.json @@ -10,6 +10,7 @@ }, "dependencies": { "@asteasolutions/zod-to-openapi": "^7.0.0", + "@inquirer/prompts": "^7.8.0", "@trpc/server": "^10.45.2", "adm-zip": "^0.5.12", "ajv": "^8.12.0", @@ -23,7 +24,6 @@ "express": "^4.21.0", "get-port": "5.1.1", "glob": "^10.3.12", - "@inquirer/prompts": "^7.8.0", "js-yaml": "^4.1.0", "json-2-csv": "^5.5.1", "json-schema": "^0.4.0", @@ -31,6 +31,7 @@ "tsx": "^4.19.2", "uuid": "^9.0.1", "winston": "^3.11.0", + "ws": "^8.18.3", "yaml": "^2.4.1", "zod": "^3.22.4", "zod-to-json-schema": "^3.22.4" @@ -48,6 +49,7 @@ "@types/json-schema": "^7.0.15", "@types/node": "^20.11.19", "@types/uuid": "^9.0.8", + "@types/ws": "^8.18.1", "bun-types": "^1.2.16", "genversion": "^3.2.0", "jest": "^29.7.0", diff --git a/genkit-tools/common/src/manager/manager-v2.ts b/genkit-tools/common/src/manager/manager-v2.ts new file mode 100644 index 0000000000..66b9afebe7 --- /dev/null +++ b/genkit-tools/common/src/manager/manager-v2.ts @@ -0,0 +1,375 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import EventEmitter from 'events'; +import getPort, { makeRange } from 'get-port'; +import { WebSocket, WebSocketServer } from 'ws'; +import { + Action, + RunActionResponse, + RunActionResponseSchema, +} from '../types/action'; +import * as apis from '../types/apis'; +import { DevToolsInfo } from '../utils/utils'; +import { BaseRuntimeManager } from './manager'; +import { ProcessManager } from './process-manager'; +import { + GenkitToolsError, + RuntimeEvent, + RuntimeInfo, + StreamingCallback, +} from './types'; + +interface JsonRpcRequest { + jsonrpc: '2.0'; + method: string; + params?: any; + id?: number | string; +} + +interface JsonRpcResponse { + jsonrpc: '2.0'; + result?: any; + error?: { + code: number; + message: string; + data?: any; + }; + id: number | string; +} + +type JsonRpcMessage = JsonRpcRequest | JsonRpcResponse; + +interface ConnectedRuntime { + ws: WebSocket; + info: RuntimeInfo; +} + +export class RuntimeManagerV2 extends BaseRuntimeManager { + private _port?: number; + private wss?: WebSocketServer; + private runtimes: Map = new Map(); + + get port(): number | undefined { + return this._port; + } + private pendingRequests: Map< + number | string, + { resolve: (value: any) => void; reject: (reason?: any) => void } + > = new Map(); + private streamCallbacks: Map> = + new Map(); + private eventEmitter = new EventEmitter(); + private requestIdCounter = 0; + + constructor( + telemetryServerUrl: string | undefined, + readonly manageHealth: boolean, + readonly projectRoot: string, + override readonly processManager?: ProcessManager + ) { + super(telemetryServerUrl, processManager); + } + + static async create(options: { + telemetryServerUrl?: string; + manageHealth?: boolean; + projectRoot: string; + processManager?: ProcessManager; + reflectionV2Port?: number; + }): Promise { + const manager = new RuntimeManagerV2( + options.telemetryServerUrl, + options.manageHealth ?? true, + options.projectRoot, + options.processManager + ); + await manager.startWebSocketServer(options.reflectionV2Port); + return manager; + } + + /** + * Starts a WebSocket server. + */ + private async startWebSocketServer(port?: number): Promise<{ port: number }> { + if (!port) { + port = await getPort({ port: makeRange(3200, 3400) }); + } + this.wss = new WebSocketServer({ port }); + + this._port = port; + console.error(`Starting reflection server: ws://localhost:${port}`); + + this.wss.on('connection', (ws) => { + ws.on('error', console.error); + + ws.on('message', (data) => { + try { + const message = JSON.parse(data.toString()) as JsonRpcMessage; + this.handleMessage(ws, message); + } catch (error) { + console.error('Failed to parse WebSocket message:', error); + } + }); + + ws.on('close', () => { + this.handleDisconnect(ws); + }); + }); + return { port }; + } + + private handleMessage(ws: WebSocket, message: JsonRpcMessage) { + if ('method' in message) { + this.handleRequest(ws, message as JsonRpcRequest); + } else { + this.handleResponse(message as JsonRpcResponse); + } + } + + private handleRequest(ws: WebSocket, request: JsonRpcRequest) { + switch (request.method) { + case 'register': + this.handleRegister(ws, request); + break; + case 'streamChunk': + this.handleStreamChunk(request); + break; + case 'runActionState': + // TODO: Handle runActionState for early trace info + break; + default: + console.warn(`Unknown method: ${request.method}`); + } + } + + private handleRegister(ws: WebSocket, request: JsonRpcRequest) { + const params = request.params; + const runtimeInfo: RuntimeInfo = { + id: params.id, + pid: params.pid, + name: params.name, + genkitVersion: params.genkitVersion, + reflectionApiSpecVersion: params.reflectionApiSpecVersion, + reflectionServerUrl: `ws://localhost:${this.port}`, // Virtual URL for compatibility + timestamp: new Date().toISOString(), + projectName: params.name || 'Unknown', // Or derive from other means if needed + }; + + this.runtimes.set(runtimeInfo.id, { ws, info: runtimeInfo }); + this.eventEmitter.emit(RuntimeEvent.ADD, runtimeInfo); + + // Send success response + if (request.id) { + ws.send( + JSON.stringify({ + jsonrpc: '2.0', + result: null, + id: request.id, + }) + ); + } + + // Configure the runtime immediately + this.notifyRuntime(runtimeInfo.id); + } + + private handleStreamChunk(notification: JsonRpcRequest) { + const { requestId, chunk } = notification.params; + const callback = this.streamCallbacks.get(requestId); + if (callback) { + callback(chunk); + } + } + + private handleResponse(response: JsonRpcResponse) { + const pending = this.pendingRequests.get(response.id); + if (pending) { + if (response.error) { + const errorData = response.error.data || {}; + const massagedData = { + ...errorData, + stack: errorData.details?.stack, + data: { + genkitErrorMessage: errorData.message, + genkitErrorDetails: errorData.details, + }, + }; + const error = new GenkitToolsError(response.error.message); + error.data = massagedData; + pending.reject(error); + } else { + pending.resolve(response.result); + } + this.pendingRequests.delete(response.id); + } + } + + private handleDisconnect(ws: WebSocket) { + for (const [id, runtime] of this.runtimes.entries()) { + if (runtime.ws === ws) { + this.runtimes.delete(id); + this.eventEmitter.emit(RuntimeEvent.REMOVE, runtime.info); + break; + } + } + } + + private async sendRequest( + runtimeId: string, + method: string, + params?: any + ): Promise { + const runtime = this.runtimes.get(runtimeId); + if (!runtime) { + throw new Error(`Runtime ${runtimeId} not found`); + } + + const id = ++this.requestIdCounter; + const message: JsonRpcRequest = { + jsonrpc: '2.0', + method, + params, + id, + }; + + return new Promise((resolve, reject) => { + this.pendingRequests.set(id, { resolve, reject }); + runtime.ws.send(JSON.stringify(message)); + + // Timeout cleanup + setTimeout(() => { + if (this.pendingRequests.has(id)) { + this.pendingRequests.delete(id); + reject(new Error(`Request ${id} timed out`)); + } + }, 30000); + }); + } + + private sendNotification(runtimeId: string, method: string, params?: any) { + const runtime = this.runtimes.get(runtimeId); + if (!runtime) { + console.warn(`Runtime ${runtimeId} not found, cannot send notification`); + return; + } + const message: JsonRpcRequest = { + jsonrpc: '2.0', + method, + params, + }; + runtime.ws.send(JSON.stringify(message)); + } + + private notifyRuntime(runtimeId: string) { + this.sendNotification(runtimeId, 'configure', { + telemetryServerUrl: this.telemetryServerUrl, + }); + } + + listRuntimes(): RuntimeInfo[] { + return Array.from(this.runtimes.values()).map((r) => r.info); + } + + getRuntimeById(id: string): RuntimeInfo | undefined { + return this.runtimes.get(id)?.info; + } + + getMostRecentRuntime(): RuntimeInfo | undefined { + const runtimes = this.listRuntimes(); + if (runtimes.length === 0) return undefined; + // Sort by timestamp descending? Or simply last added? + // Map iteration order is insertion order, so last one is likely most recent if we just added them. + // But let's trust the array. + return runtimes[runtimes.length - 1]; + } + + getMostRecentDevUI(): DevToolsInfo | undefined { + // Not applicable for V2 yet, or maybe handled differently + return undefined; + } + + onRuntimeEvent( + listener: (eventType: RuntimeEvent, runtime: RuntimeInfo) => void + ) { + Object.values(RuntimeEvent).forEach((event) => + this.eventEmitter.on(event, (rt) => listener(event, rt)) + ); + } + + async listActions( + input?: apis.ListActionsRequest + ): Promise> { + const runtimeId = input?.runtimeId || this.getMostRecentRuntime()?.id; + if (!runtimeId) { + // No runtimes connected + return {}; + } + return this.sendRequest(runtimeId, 'listActions'); + } + + async close() { + if (this.wss) { + this.wss.close(); + } + } + + async runAction( + input: apis.RunActionRequest, + streamingCallback?: StreamingCallback + ): Promise { + const runtimeId = input.runtimeId || this.getMostRecentRuntime()?.id; + if (!runtimeId) { + throw new Error('No runtime found'); + } + + const runtime = this.runtimes.get(runtimeId); + if (!runtime) { + throw new Error(`Runtime ${runtimeId} not found`); + } + + const id = ++this.requestIdCounter; + + if (streamingCallback) { + this.streamCallbacks.set(id, streamingCallback); + } + + const message: JsonRpcRequest = { + jsonrpc: '2.0', + method: 'runAction', + params: { + ...input, + stream: !!streamingCallback, + }, + id, + }; + + return new Promise((resolve, reject) => { + this.pendingRequests.set(id, { resolve, reject }); + runtime.ws.send(JSON.stringify(message)); + + // Timeout cleanup? Maybe longer for actions. + }) + .then((result) => { + return RunActionResponseSchema.parse(result); + }) + .finally(() => { + if (streamingCallback) { + this.streamCallbacks.delete(id); + } + }); + } +} diff --git a/genkit-tools/common/src/manager/manager.ts b/genkit-tools/common/src/manager/manager.ts index b166095939..6f82b81692 100644 --- a/genkit-tools/common/src/manager/manager.ts +++ b/genkit-tools/common/src/manager/manager.ts @@ -60,28 +60,141 @@ interface RuntimeManagerOptions { projectRoot: string; /** An optional process manager for the main application process. */ processManager?: ProcessManager; + /** Whether to use the experimental reflection V2 (WebSocket). */ + experimentalReflectionV2?: boolean; + /** Port for the reflection V2 server. */ + reflectionV2Port?: number; } -export class RuntimeManager { - readonly processManager?: ProcessManager; +export abstract class BaseRuntimeManager { + constructor( + readonly telemetryServerUrl: string | undefined, + readonly processManager?: ProcessManager + ) {} + + abstract listRuntimes(): RuntimeInfo[]; + abstract getRuntimeById(id: string): RuntimeInfo | undefined; + abstract getMostRecentRuntime(): RuntimeInfo | undefined; + abstract getMostRecentDevUI(): DevToolsInfo | undefined; + abstract onRuntimeEvent( + listener: (eventType: RuntimeEvent, runtime: RuntimeInfo) => void + ): void; + abstract listActions( + input?: apis.ListActionsRequest + ): Promise>; + abstract runAction( + input: apis.RunActionRequest, + streamingCallback?: StreamingCallback + ): Promise; + + /** + * Retrieves all traces + */ + async listTraces( + input: apis.ListTracesRequest + ): Promise { + const { limit, continuationToken, filter } = input; + let query = ''; + if (limit) { + query += `limit=${limit}`; + } + if (continuationToken) { + if (query !== '') { + query += '&'; + } + query += `continuationToken=${continuationToken}`; + } + if (filter) { + if (query !== '') { + query += '&'; + } + query += `filter=${encodeURI(JSON.stringify(filter))}`; + } + + const response = await axios + .get(`${this.telemetryServerUrl}/api/traces?${query}`) + .catch((err) => + this.httpErrorHandler(err, `Error listing traces for query='${query}'.`) + ); + + return apis.ListTracesResponseSchema.parse(response.data); + } + + /** + * Retrieves a trace for a given ID. + */ + async getTrace(input: apis.GetTraceRequest): Promise { + const { traceId } = input; + const response = await axios + .get(`${this.telemetryServerUrl}/api/traces/${traceId}`) + .catch((err) => + this.httpErrorHandler( + err, + `Error getting trace for traceId='${traceId}'` + ) + ); + + return response.data as TraceData; + } + + /** + * Adds a trace to the trace store + */ + async addTrace(input: TraceData): Promise { + await axios + .post(`${this.telemetryServerUrl}/api/traces/`, input) + .catch((err) => + this.httpErrorHandler(err, 'Error writing trace to store.') + ); + } + + /** + * Handles an HTTP error. + */ + protected httpErrorHandler(error: AxiosError, message?: string): any { + const newError = new GenkitToolsError(message || 'Internal Error'); + + if (error.response) { + if ((error.response?.data as any).message) { + newError.message = (error.response?.data as any).message; + } + // we got a non-200 response; copy the payload and rethrow + newError.data = error.response.data as GenkitError; + throw newError; + } + + // We actually have an exception; wrap it and re-throw. + throw new GenkitToolsError(message || 'Internal Error', { + cause: error.cause, + }); + } +} + +export class RuntimeManager extends BaseRuntimeManager { private filenameToRuntimeMap: Record = {}; private filenameToDevUiMap: Record = {}; private idToFileMap: Record = {}; private eventEmitter = new EventEmitter(); private constructor( - readonly telemetryServerUrl: string | undefined, + telemetryServerUrl: string | undefined, private manageHealth: boolean, readonly projectRoot: string, processManager?: ProcessManager ) { - this.processManager = processManager; + super(telemetryServerUrl, processManager); } /** * Creates a new runtime manager. */ - static async create(options: RuntimeManagerOptions) { + static async create( + options: RuntimeManagerOptions + ): Promise { + if (options.experimentalReflectionV2) { + const { RuntimeManagerV2 } = await import('./manager-v2'); + return RuntimeManagerV2.create(options); + } const manager = new RuntimeManager( options.telemetryServerUrl, options.manageHealth ?? true, @@ -282,67 +395,6 @@ export class RuntimeManager { } } - /** - * Retrieves all traces - */ - async listTraces( - input: apis.ListTracesRequest - ): Promise { - const { limit, continuationToken, filter } = input; - let query = ''; - if (limit) { - query += `limit=${limit}`; - } - if (continuationToken) { - if (query !== '') { - query += '&'; - } - query += `continuationToken=${continuationToken}`; - } - if (filter) { - if (query !== '') { - query += '&'; - } - query += `filter=${encodeURI(JSON.stringify(filter))}`; - } - - const response = await axios - .get(`${this.telemetryServerUrl}/api/traces?${query}`) - .catch((err) => - this.httpErrorHandler(err, `Error listing traces for query='${query}'.`) - ); - - return apis.ListTracesResponseSchema.parse(response.data); - } - - /** - * Retrieves a trace for a given ID. - */ - async getTrace(input: apis.GetTraceRequest): Promise { - const { traceId } = input; - const response = await axios - .get(`${this.telemetryServerUrl}/api/traces/${traceId}`) - .catch((err) => - this.httpErrorHandler( - err, - `Error getting trace for traceId='${traceId}'` - ) - ); - - return response.data as TraceData; - } - - /** - * Adds a trace to the trace store - */ - async addTrace(input: TraceData): Promise { - await axios - .post(`${this.telemetryServerUrl}/api/traces/`, input) - .catch((err) => - this.httpErrorHandler(err, 'Error writing trace to store.') - ); - } - /** * Notifies the runtime of dependencies it may need (e.g. telemetry server URL). */ @@ -535,27 +587,6 @@ export class RuntimeManager { } } - /** - * Handles an HTTP error. - */ - private httpErrorHandler(error: AxiosError, message?: string): any { - const newError = new GenkitToolsError(message || 'Internal Error'); - - if (error.response) { - if ((error.response?.data as any).message) { - newError.message = (error.response?.data as any).message; - } - // we got a non-200 response; copy the payload and rethrow - newError.data = error.response.data as GenkitError; - throw newError; - } - - // We actually have an exception; wrap it and re-throw. - throw new GenkitToolsError(message || 'Internal Error', { - cause: error.cause, - }); - } - /** * Performs health checks on all runtimes. */ diff --git a/genkit-tools/common/tests/manager-v2_test.ts b/genkit-tools/common/tests/manager-v2_test.ts new file mode 100644 index 0000000000..66963d4c98 --- /dev/null +++ b/genkit-tools/common/tests/manager-v2_test.ts @@ -0,0 +1,259 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { afterEach, beforeEach, describe, expect, it } from '@jest/globals'; +import WebSocket from 'ws'; +import { RuntimeManagerV2 } from '../src/manager/manager-v2'; +import { RuntimeEvent } from '../src/manager/types'; + +describe('RuntimeManagerV2', () => { + let manager: RuntimeManagerV2; + let wsClient: WebSocket; + let port: number; + + beforeEach(async () => { + manager = await RuntimeManagerV2.create({ + projectRoot: './', + }); + port = manager.port!; + }); + + afterEach(async () => { + if (wsClient) { + wsClient.close(); + } + // Clean up server + await manager.close(); + }); + + it('should accept connections and handle registration', (done) => { + wsClient = new WebSocket(`ws://localhost:${port}`); + + wsClient.on('open', () => { + const registerMessage = { + jsonrpc: '2.0', + method: 'register', + params: { + id: 'test-runtime-1', + pid: 1234, + name: 'Test Runtime', + genkitVersion: '0.0.1', + reflectionApiSpecVersion: 1, + }, + id: 1, + }; + wsClient.send(JSON.stringify(registerMessage)); + }); + + manager.onRuntimeEvent((event, runtime) => { + if (event === RuntimeEvent.ADD) { + expect(runtime.id).toBe('test-runtime-1'); + expect(runtime.pid).toBe(1234); + expect(manager.listRuntimes().length).toBe(1); + done(); + } + }); + }); + + it('should send requests and handle responses', async () => { + wsClient = new WebSocket(`ws://localhost:${port}`); + + await new Promise((resolve) => { + wsClient.on('open', () => { + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'register', + params: { id: 'test-runtime-2', pid: 1234 }, + id: 1, + }) + ); + // Wait for server to acknowledge or just wait a bit + setTimeout(resolve, 100); + }); + }); + + // Mock runtime response to runAction + wsClient.on('message', (data) => { + const message = JSON.parse(data.toString()); + if (message.method === 'runAction') { + const response = { + jsonrpc: '2.0', + result: { + result: 'Hello World', + telemetry: { + traceId: '1234', + }, + }, + id: message.id, + }; + wsClient.send(JSON.stringify(response)); + } + }); + + const response = await manager.runAction({ + key: 'testAction', + input: {}, + }); + + expect(response.result).toBe('Hello World'); + expect(response.telemetry).toStrictEqual({ + traceId: '1234', + }); + }); + + it('should handle streaming', async () => { + wsClient = new WebSocket(`ws://localhost:${port}`); + + await new Promise((resolve) => { + wsClient.on('open', () => { + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'register', + params: { id: 'test-runtime-3', pid: 1234 }, + id: 1, + }) + ); + setTimeout(resolve, 100); + }); + }); + + wsClient.on('message', (data) => { + const message = JSON.parse(data.toString()); + if (message.method === 'runAction' && message.params.stream) { + // Send chunk 1 + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'streamChunk', + params: { requestId: message.id, chunk: { content: 'Hello' } }, + }) + ); + // Send chunk 2 + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'streamChunk', + params: { requestId: message.id, chunk: { content: ' World' } }, + }) + ); + // Send final result + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + result: { result: 'Hello World', telemetry: {} }, + id: message.id, + }) + ); + } + }); + + const chunks: any[] = []; + const response = await manager.runAction( + { + key: 'testAction', + input: {}, + }, + (chunk) => { + chunks.push(chunk); + } + ); + + expect(chunks).toHaveLength(2); + expect(chunks[0]).toEqual({ content: 'Hello' }); + expect(chunks[1]).toEqual({ content: ' World' }); + expect(response.result).toBe('Hello World'); + expect(response.telemetry).toBeDefined(); + }); + + it('should handle streaming errors and massage the error object', async () => { + wsClient = new WebSocket(`ws://localhost:${port}`); + + await new Promise((resolve) => { + wsClient.on('open', () => { + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'register', + params: { id: 'test-runtime-error', pid: 1234 }, + id: 1, + }) + ); + setTimeout(resolve, 100); + }); + }); + + wsClient.on('message', (data) => { + const message = JSON.parse(data.toString()); + if (message.method === 'runAction' && message.params.stream) { + // Send chunk 1 + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'streamChunk', + params: { requestId: message.id, chunk: { content: 'Hello' } }, + }) + ); + // Send error + const errorResponse = { + code: -32000, + message: 'Test Error', + data: { + code: 13, + message: 'Test Error', + details: { + stack: 'Error stack...', + traceId: 'trace-123', + }, + }, + }; + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + error: errorResponse, + id: message.id, + }) + ); + } + }); + + const chunks: any[] = []; + try { + await manager.runAction( + { + key: 'testAction', + input: {}, + }, + (chunk) => { + chunks.push(chunk); + } + ); + throw new Error('Should have thrown'); + } catch (err: any) { + expect(chunks).toHaveLength(1); + expect(chunks[0]).toEqual({ content: 'Hello' }); + expect(err.message).toBe('Test Error'); + expect(err.data).toBeDefined(); + expect(err.data.data.genkitErrorMessage).toBe('Test Error'); + expect(err.data.stack).toBe('Error stack...'); + expect(err.data.data.genkitErrorDetails).toEqual({ + stack: 'Error stack...', + traceId: 'trace-123', + }); + } + }); +}); diff --git a/genkit-tools/pnpm-lock.yaml b/genkit-tools/pnpm-lock.yaml index 2e7f0b05d6..480988fd52 100644 --- a/genkit-tools/pnpm-lock.yaml +++ b/genkit-tools/pnpm-lock.yaml @@ -168,6 +168,9 @@ importers: winston: specifier: ^3.11.0 version: 3.17.0 + ws: + specifier: ^8.18.3 + version: 8.18.3 yaml: specifier: ^2.4.1 version: 2.8.0 @@ -214,6 +217,9 @@ importers: '@types/uuid': specifier: ^9.0.8 version: 9.0.8 + '@types/ws': + specifier: ^8.18.1 + version: 8.18.1 bun-types: specifier: ^1.2.16 version: 1.2.16 @@ -271,6 +277,9 @@ importers: async-mutex: specifier: ^0.5.0 version: 0.5.0 + cors: + specifier: ^2.8.5 + version: 2.8.5 express: specifier: ^4.21.0 version: 4.21.2 @@ -281,6 +290,9 @@ importers: specifier: ^3.22.4 version: 3.25.67 devDependencies: + '@types/cors': + specifier: ^2.8.19 + version: 2.8.19 '@types/express': specifier: ~4.17.21 version: 4.17.23 @@ -1108,6 +1120,9 @@ packages: '@types/connect@3.4.38': resolution: {integrity: sha512-K6uROf1LD88uDQqJCktA4yzL1YYAK6NgfsI0v/mTgyPKWsX1CnJ0XPSDhViejru1GcRkLWb8RlzFYJRqGUbaug==} + '@types/cors@2.8.19': + resolution: {integrity: sha512-mFNylyeyqN93lfe/9CSxOGREz8cpzAhH+E93xJ4xWQf62V8sQ/24reV2nyzUWM6H6Xji+GGHpkbLe7pVoUEskg==} + '@types/express-serve-static-core@4.19.0': resolution: {integrity: sha512-bGyep3JqPCRry1wq+O5n7oiBgGWmeIJXPjXXCo8EK0u8duZGSYar7cGqd3ML2JUsLGeB7fmc06KYo9fLGWqPvQ==} @@ -1150,9 +1165,6 @@ packages: '@types/mime@1.3.5': resolution: {integrity: sha512-/pyBZWSLD2n0dcHE3hq8s8ZvcETHtEuF+3E7XVt0Ig2nvsVQXdghHVcEkIWjy9A0wKfTn97a/PSDYohKIlnP/w==} - '@types/node@20.19.0': - resolution: {integrity: sha512-hfrc+1tud1xcdVTABC2JiomZJEklMcXYNTVtZLAeqTVWD+qL5jkHKT+1lOtqDdGxt+mB53DTtiz673vfjU8D1Q==} - '@types/node@20.19.1': resolution: {integrity: sha512-jJD50LtlD2dodAEO653i3YF04NWak6jN3ky+Ri3Em3mGR39/glWiboM/IePaRbgwSfqM1TpGXfAg8ohn/4dTgA==} @@ -1192,6 +1204,9 @@ packages: '@types/uuid@9.0.8': resolution: {integrity: sha512-jg+97EGIcY9AGHJJRaaPVgetKDsrTgbRjQ5Msgjh/DQKEFl0DtyRr/VCOyD1T2R1MNeWPK/u7JoGhlDZnKBAfA==} + '@types/ws@8.18.1': + resolution: {integrity: sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==} + '@types/yargs-parser@21.0.3': resolution: {integrity: sha512-I4q9QU9MQv4oEOz4tAHJtNz1cwuLxn2F3xcc2iV5WdqLPpUnj30aUuxt1mAxYTG+oe8CZMV/+6rU4S4gRDzqtQ==} @@ -3464,6 +3479,18 @@ packages: resolution: {integrity: sha512-7KxauUdBmSdWnmpaGFg+ppNjKF8uNLry8LyzjauQDOVONfFLNKrKvQOxZ/VuTIcS/gge/YNahf5RIIQWTSarlg==} engines: {node: ^12.13.0 || ^14.15.0 || >=16.0.0} + ws@8.18.3: + resolution: {integrity: sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==} + engines: {node: '>=10.0.0'} + peerDependencies: + bufferutil: ^4.0.1 + utf-8-validate: '>=5.0.2' + peerDependenciesMeta: + bufferutil: + optional: true + utf-8-validate: + optional: true + xdg-basedir@4.0.0: resolution: {integrity: sha512-PSNhEJDejZYV7h50BohL09Er9VaIefr2LMAf3OEmpCkjOi34eYyQYAXUTjEQtZJTKcF0E2UKTh+osDLsgNim9Q==} engines: {node: '>=8'} @@ -3990,14 +4017,14 @@ snapshots: '@jest/test-result': 29.7.0 '@jest/transform': 29.7.0 '@jest/types': 29.6.3 - '@types/node': 20.19.0 + '@types/node': 20.19.1 ansi-escapes: 4.3.2 chalk: 4.1.2 ci-info: 3.9.0 exit: 0.1.2 graceful-fs: 4.2.11 jest-changed-files: 29.7.0 - jest-config: 29.7.0(@types/node@20.19.0)(ts-node@10.9.2(@types/node@20.19.1)(typescript@5.8.3)) + jest-config: 29.7.0(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@5.8.3)) jest-haste-map: 29.7.0 jest-message-util: 29.7.0 jest-regex-util: 29.6.3 @@ -4420,6 +4447,10 @@ snapshots: dependencies: '@types/node': 20.19.1 + '@types/cors@2.8.19': + dependencies: + '@types/node': 20.19.1 + '@types/express-serve-static-core@4.19.0': dependencies: '@types/node': 20.19.1 @@ -4470,10 +4501,6 @@ snapshots: '@types/mime@1.3.5': {} - '@types/node@20.19.0': - dependencies: - undici-types: 6.21.0 - '@types/node@20.19.1': dependencies: undici-types: 6.21.0 @@ -4516,6 +4543,10 @@ snapshots: '@types/uuid@9.0.8': {} + '@types/ws@8.18.1': + dependencies: + '@types/node': 20.19.1 + '@types/yargs-parser@21.0.3': {} '@types/yargs@17.0.32': @@ -5910,37 +5941,6 @@ snapshots: - supports-color - ts-node - jest-config@29.7.0(@types/node@20.19.0)(ts-node@10.9.2(@types/node@20.19.1)(typescript@5.8.3)): - dependencies: - '@babel/core': 7.24.5 - '@jest/test-sequencer': 29.7.0 - '@jest/types': 29.6.3 - babel-jest: 29.7.0(@babel/core@7.24.5) - chalk: 4.1.2 - ci-info: 3.9.0 - deepmerge: 4.3.1 - glob: 7.2.3 - graceful-fs: 4.2.11 - jest-circus: 29.7.0 - jest-environment-node: 29.7.0 - jest-get-type: 29.6.3 - jest-regex-util: 29.6.3 - jest-resolve: 29.7.0 - jest-runner: 29.7.0 - jest-util: 29.7.0 - jest-validate: 29.7.0 - micromatch: 4.0.5 - parse-json: 5.2.0 - pretty-format: 29.7.0 - slash: 3.0.0 - strip-json-comments: 3.1.1 - optionalDependencies: - '@types/node': 20.19.0 - ts-node: 10.9.2(@types/node@20.19.1)(typescript@5.8.3) - transitivePeerDependencies: - - babel-plugin-macros - - supports-color - jest-config@29.7.0(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@5.8.3)): dependencies: '@babel/core': 7.24.5 @@ -6182,7 +6182,7 @@ snapshots: jest-worker@29.7.0: dependencies: - '@types/node': 20.19.0 + '@types/node': 20.19.1 jest-util: 29.7.0 merge-stream: 2.0.0 supports-color: 8.1.1 @@ -7273,6 +7273,8 @@ snapshots: imurmurhash: 0.1.4 signal-exit: 3.0.7 + ws@8.18.3: {} + xdg-basedir@4.0.0: {} y18n@5.0.8: {} diff --git a/genkit-tools/telemetry-server/package.json b/genkit-tools/telemetry-server/package.json index d99da89c93..d192d7f6e1 100644 --- a/genkit-tools/telemetry-server/package.json +++ b/genkit-tools/telemetry-server/package.json @@ -25,8 +25,8 @@ "author": "genkit", "license": "Apache-2.0", "dependencies": { - "@genkit-ai/tools-common": "workspace:*", "@asteasolutions/zod-to-openapi": "^7.0.0", + "@genkit-ai/tools-common": "workspace:*", "@google-cloud/firestore": "^7.6.0", "@opentelemetry/api": "~1.9.0", "@opentelemetry/context-async-hooks": "~1.25.0", @@ -35,18 +35,20 @@ "@opentelemetry/sdk-node": "^0.52.0", "@opentelemetry/sdk-trace-base": "~1.25.0", "async-mutex": "^0.5.0", + "cors": "^2.8.5", "express": "^4.21.0", "lockfile": "^1.0.4", "zod": "^3.22.4" }, "devDependencies": { + "@types/cors": "^2.8.19", "@types/express": "~4.17.21", "@types/lockfile": "^1.0.4", "@types/node": "^20.11.30", + "genversion": "^3.2.0", "get-port": "^7.1.0", "npm-run-all": "^4.1.5", "rimraf": "^6.0.1", - "genversion": "^3.2.0", "tsx": "^4.19.2", "typescript": "^4.9.0" }, diff --git a/genkit-tools/telemetry-server/src/index.ts b/genkit-tools/telemetry-server/src/index.ts index 6c038495fc..903ed4959b 100644 --- a/genkit-tools/telemetry-server/src/index.ts +++ b/genkit-tools/telemetry-server/src/index.ts @@ -19,6 +19,7 @@ import { TraceQueryFilterSchema, } from '@genkit-ai/tools-common'; import { logger } from '@genkit-ai/tools-common/utils'; +import cors from 'cors'; import express from 'express'; import type * as http from 'http'; import type { TraceStore } from './types'; @@ -43,11 +44,24 @@ export async function startTelemetryServer(params: { * Defaults to '5mb'. */ maxRequestBodySize?: string | number; + allowedCorsHostnames?: string[]; }) { await params.traceStore.init(); const api = express(); api.use(express.json({ limit: params.maxRequestBodySize ?? '100mb' })); + api.use( + cors((req, callback) => { + if ( + req.hostname === 'localhost' || + params.allowedCorsHostnames?.includes(req.hostname) + ) { + callback(null, { origin: true }); // Allow the request + } else { + callback(null, { origin: false }); // Deny the request + } + }) + ); api.get('/api/__health', async (_, response) => { response.status(200).send('OK'); diff --git a/js/core/package.json b/js/core/package.json index dd028eda07..5226e2b3c6 100644 --- a/js/core/package.json +++ b/js/core/package.json @@ -29,26 +29,28 @@ "@opentelemetry/api": "^1.9.0", "@opentelemetry/context-async-hooks": "~1.25.0", "@opentelemetry/core": "~1.25.0", + "@opentelemetry/exporter-jaeger": "^1.25.0", "@opentelemetry/sdk-metrics": "~1.25.0", "@opentelemetry/sdk-node": "^0.52.0", "@opentelemetry/sdk-trace-base": "~1.25.0", - "@opentelemetry/exporter-jaeger": "^1.25.0", "@types/json-schema": "^7.0.15", "ajv": "^8.12.0", "ajv-formats": "^3.0.1", "async-mutex": "^0.5.0", "body-parser": "^1.20.3", "cors": "^2.8.5", + "dotprompt": "^1.1.1", "express": "^4.21.0", "get-port": "^5.1.0", "json-schema": "^0.4.0", + "ws": "^8.18.3", "zod": "^3.23.8", - "zod-to-json-schema": "^3.22.4", - "dotprompt": "^1.1.1" + "zod-to-json-schema": "^3.22.4" }, "devDependencies": { "@types/express": "^4.17.21", "@types/node": "^20.11.30", + "@types/ws": "^8.18.1", "genversion": "^3.2.0", "npm-run-all": "^4.1.5", "rimraf": "^6.0.1", diff --git a/js/core/src/index.ts b/js/core/src/index.ts index 632b9b30fc..c5f4758d4b 100644 --- a/js/core/src/index.ts +++ b/js/core/src/index.ts @@ -76,6 +76,7 @@ export { type FlowSideChannel, } from './flow.js'; export * from './plugin.js'; +export * from './reflection-v2.js'; export * from './reflection.js'; export { defineJsonSchema, defineSchema, type JSONSchema } from './schema.js'; export * from './telemetryTypes.js'; diff --git a/js/core/src/reflection-v2.ts b/js/core/src/reflection-v2.ts new file mode 100644 index 0000000000..36069414be --- /dev/null +++ b/js/core/src/reflection-v2.ts @@ -0,0 +1,286 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import WebSocket from 'ws'; +import { StatusCodes, type Status } from './action.js'; +import { GENKIT_REFLECTION_API_SPEC_VERSION, GENKIT_VERSION } from './index.js'; +import { logger } from './logging.js'; +import type { Registry } from './registry.js'; +import { toJsonSchema } from './schema.js'; +import { flushTracing, setTelemetryServerUrl } from './tracing.js'; + +let apiIndex = 0; + +interface JsonRpcRequest { + jsonrpc: '2.0'; + method: string; + params?: any; + id?: number | string; +} + +interface JsonRpcResponse { + jsonrpc: '2.0'; + result?: any; + error?: { + code: number; + message: string; + data?: any; + }; + id: number | string; +} + +type JsonRpcMessage = JsonRpcRequest | JsonRpcResponse; + +export interface ReflectionServerV2Options { + configuredEnvs?: string[]; + name?: string; + url: string; +} + +export class ReflectionServerV2 { + private registry: Registry; + private options: ReflectionServerV2Options; + private ws: WebSocket | null = null; + private url: string; + private index = apiIndex++; + + constructor(registry: Registry, options: ReflectionServerV2Options) { + this.registry = registry; + this.options = { + configuredEnvs: ['dev'], + ...options, + }; + // The URL should be provided via environment variable by the CLI manager + this.url = this.options.url; + } + + async start() { + logger.debug(`Connecting to Reflection V2 server at ${this.url}`); + this.ws = new WebSocket(this.url); + + this.ws.on('open', () => { + logger.debug('Connected to Reflection V2 server.'); + this.register(); + }); + + this.ws.on('message', async (data) => { + try { + const message = JSON.parse(data.toString()) as JsonRpcMessage; + if ('method' in message) { + await this.handleRequest(message as JsonRpcRequest); + } + } catch (error) { + logger.error(`Failed to parse message: ${error}`); + } + }); + + this.ws.on('error', (error) => { + logger.error(`Reflection V2 WebSocket error: ${error}`); + }); + + this.ws.on('close', () => { + logger.debug('Reflection V2 WebSocket closed.'); + }); + } + + async stop() { + if (this.ws) { + this.ws.close(); + this.ws = null; + } + } + + private send(message: JsonRpcMessage) { + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.send(JSON.stringify(message)); + } + } + + private sendResponse(id: number | string, result: any) { + this.send({ + jsonrpc: '2.0', + result, + id, + }); + } + + private sendError( + id: number | string, + code: number, + message: string, + data?: any + ) { + this.send({ + jsonrpc: '2.0', + error: { code, message, data }, + id, + }); + } + + private sendNotification(method: string, params: any) { + this.send({ + jsonrpc: '2.0', + method, + params, + }); + } + + private register() { + const params = { + id: process.env.GENKIT_RUNTIME_ID || this.runtimeId, + pid: process.pid, + name: this.options.name || this.runtimeId, + genkitVersion: GENKIT_VERSION, + reflectionApiSpecVersion: GENKIT_REFLECTION_API_SPEC_VERSION, + envs: this.options.configuredEnvs, + }; + this.sendNotification('register', params); + } + + get runtimeId() { + return `${process.pid}${this.index ? `-${this.index}` : ''}`; + } + + private async handleRequest(request: JsonRpcRequest) { + try { + switch (request.method) { + case 'listActions': + await this.handleListActions(request); + break; + case 'runAction': + await this.handleRunAction(request); + break; + case 'configure': + this.handleConfigure(request); + break; + default: + if (request.id) { + this.sendError( + request.id, + -32601, + `Method not found: ${request.method}` + ); + } + } + } catch (error: any) { + if (request.id) { + this.sendError(request.id, -32000, error.message, { + stack: error.stack, + }); + } + } + } + + private async handleListActions(request: JsonRpcRequest) { + if (!request.id) return; // Should be a request + const actions = await this.registry.listResolvableActions(); + const convertedActions: Record = {}; + + Object.keys(actions).forEach((key) => { + const action = actions[key]; + convertedActions[key] = { + key, + name: action.name, + description: action.description, + metadata: action.metadata, + }; + if (action.inputSchema || action.inputJsonSchema) { + convertedActions[key].inputSchema = toJsonSchema({ + schema: action.inputSchema, + jsonSchema: action.inputJsonSchema, + }); + } + if (action.outputSchema || action.outputJsonSchema) { + convertedActions[key].outputSchema = toJsonSchema({ + schema: action.outputSchema, + jsonSchema: action.outputJsonSchema, + }); + } + }); + + this.sendResponse(request.id, convertedActions); + } + + private async handleRunAction(request: JsonRpcRequest) { + if (!request.id) return; + + const { key, input, context, telemetryLabels, stream } = request.params; + const action = await this.registry.lookupAction(key); + + if (!action) { + this.sendError(request.id, 404, `action ${key} not found`); + return; + } + + try { + if (stream) { + const callback = (chunk: any) => { + this.sendNotification('streamChunk', { + requestId: request.id, + chunk, + }); + }; + + const result = await action.run(input, { + context, + onChunk: callback, + telemetryLabels, + }); + + await flushTracing(); + + // Send final result + this.sendResponse(request.id, { + result: result.result, + telemetry: { + traceId: result.telemetry.traceId, + }, + }); + } else { + const result = await action.run(input, { context, telemetryLabels }); + await flushTracing(); + + this.sendResponse(request.id, { + result: result.result, + telemetry: { + traceId: result.telemetry.traceId, + }, + }); + } + } catch (err: any) { + const errorResponse: Status = { + code: StatusCodes.INTERNAL, + message: err.message, + details: { + stack: err.stack, + }, + }; + if (err.traceId) { + errorResponse.details.traceId = err.traceId; + } + + this.sendError(request.id, -32000, err.message, errorResponse); + } + } + + private handleConfigure(request: JsonRpcRequest) { + const { telemetryServerUrl } = request.params; + if (telemetryServerUrl && !process.env.GENKIT_TELEMETRY_SERVER) { + setTelemetryServerUrl(telemetryServerUrl); + logger.debug(`Connected to telemetry server on ${telemetryServerUrl}`); + } + } +} diff --git a/js/core/tests/reflection-v2_test.ts b/js/core/tests/reflection-v2_test.ts new file mode 100644 index 0000000000..4897cbd917 --- /dev/null +++ b/js/core/tests/reflection-v2_test.ts @@ -0,0 +1,254 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import * as assert from 'assert'; +import { afterEach, beforeEach, describe, it } from 'node:test'; +import { WebSocketServer } from 'ws'; +import { z } from 'zod'; +import { action } from '../src/action.js'; +import { initNodeFeatures } from '../src/node.js'; +import { ReflectionServerV2 } from '../src/reflection-v2.js'; +import { Registry } from '../src/registry.js'; + +initNodeFeatures(); + +describe('ReflectionServerV2', () => { + let wss: WebSocketServer; + let server: ReflectionServerV2; + let registry: Registry; + let port: number; + let serverWs: any; + + beforeEach(() => { + return new Promise((resolve) => { + wss = new WebSocketServer({ port: 0 }); + wss.on('listening', () => { + port = (wss.address() as any).port; + resolve(); + }); + wss.on('connection', (ws) => { + serverWs = ws; + }); + registry = new Registry(); + }); + }); + + afterEach(async () => { + if (server) { + await server.stop(); + } + if (serverWs) { + serverWs.terminate(); + } + await new Promise((resolve) => { + wss.close(() => resolve()); + }); + }); + + it('should connect to the server and register', async () => { + const connected = new Promise((resolve) => { + wss.on('connection', (ws) => { + ws.on('message', (data) => { + const msg = JSON.parse(data.toString()); + if (msg.method === 'register') { + assert.strictEqual(msg.params.name, 'test-app'); + resolve(); + } + }); + }); + }); + + server = new ReflectionServerV2(registry, { + url: `ws://localhost:${port}`, + name: 'test-app', + }); + await server.start(); + await connected; + }); + + it('should handle listActions', async () => { + // Register a dummy action + const testAction = action( + { + name: 'testAction', + description: 'A test action', + inputSchema: z.object({ foo: z.string() }), + outputSchema: z.object({ bar: z.string() }), + actionType: 'custom', + }, + async (input) => ({ bar: input.foo }) + ); + registry.registerAction('custom', testAction); + + const gotListActions = new Promise((resolve) => { + wss.on('connection', (ws) => { + ws.on('message', (data) => { + const msg = JSON.parse(data.toString()); + if (msg.method === 'register') { + // After registration, request listActions + ws.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'listActions', + id: '123', + }) + ); + } else if (msg.id === '123') { + assert.ok(msg.result['/custom/testAction']); + assert.strictEqual( + msg.result['/custom/testAction'].name, + 'testAction' + ); + resolve(); + } + }); + }); + }); + + server = new ReflectionServerV2(registry, { + url: `ws://localhost:${port}`, + }); + await server.start(); + await gotListActions; + }); + + it('should handle runAction', async () => { + const testAction = action( + { + name: 'testAction', + inputSchema: z.object({ foo: z.string() }), + outputSchema: z.object({ bar: z.string() }), + actionType: 'custom', + }, + async (input) => ({ bar: input.foo }) + ); + registry.registerAction('custom', testAction); + + const actionRun = new Promise((resolve, reject) => { + const timeout = setTimeout( + () => reject(new Error('runAction timeout')), + 2000 + ); + wss.on('connection', (ws) => { + ws.on('message', (data) => { + try { + const msg = JSON.parse(data.toString()); + if (msg.method === 'register') { + ws.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'runAction', + params: { + key: '/custom/testAction', + input: { foo: 'baz' }, + }, + id: '456', + }) + ); + } else if (msg.id === '456') { + if (msg.error) { + reject( + new Error(`runAction error: ${JSON.stringify(msg.error)}`) + ); + return; + } + assert.strictEqual(msg.result.result.bar, 'baz'); + clearTimeout(timeout); + resolve(); + } + } catch (e) { + clearTimeout(timeout); + reject(e); + } + }); + }); + }); + + server = new ReflectionServerV2(registry, { + url: `ws://localhost:${port}`, + }); + await server.start(); + await actionRun; + }); + + it('should handle streaming runAction', async () => { + const streamAction = action( + { + name: 'streamAction', + inputSchema: z.object({ foo: z.string() }), + outputSchema: z.string(), + actionType: 'custom', + }, + async (input, { sendChunk }) => { + sendChunk('chunk1'); + sendChunk('chunk2'); + return 'done'; + } + ); + registry.registerAction('custom', streamAction); + + const chunks: any[] = []; + const actionRun = new Promise((resolve, reject) => { + const timeout = setTimeout( + () => reject(new Error('streamAction timeout')), + 2000 + ); + wss.on('connection', (ws) => { + ws.on('message', (data) => { + try { + const msg = JSON.parse(data.toString()); + if (msg.method === 'register') { + ws.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'runAction', + params: { + key: '/custom/streamAction', + input: { foo: 'baz' }, + stream: true, + }, + id: '789', + }) + ); + } else if (msg.method === 'streamChunk') { + chunks.push(msg.params.chunk); + } else if (msg.id === '789') { + if (msg.error) { + reject( + new Error(`streamAction error: ${JSON.stringify(msg.error)}`) + ); + return; + } + assert.strictEqual(msg.result.result, 'done'); + assert.deepStrictEqual(chunks, ['chunk1', 'chunk2']); + clearTimeout(timeout); + resolve(); + } + } catch (e) { + clearTimeout(timeout); + reject(e); + } + }); + }); + }); + + server = new ReflectionServerV2(registry, { + url: `ws://localhost:${port}`, + }); + await server.start(); + await actionRun; + }); +}); diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 040d640113..dc6612ec34 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -103,6 +103,7 @@ import { GenkitError, Operation, ReflectionServer, + ReflectionServerV2, defineDynamicActionProvider, defineFlow, defineJsonSchema, @@ -179,7 +180,7 @@ export class Genkit implements HasRegistry { /** Registry instance that is exclusively modified by this Genkit instance. */ readonly registry: Registry; /** Reflection server for this registry. May be null if not started. */ - private reflectionServer: ReflectionServer | null = null; + private reflectionServer: ReflectionServer | ReflectionServerV2 | null = null; /** List of flows that have been registered in this instance. */ readonly flows: Action[] = []; @@ -195,10 +196,18 @@ export class Genkit implements HasRegistry { } this.configure(); if (isDevEnv() && !disableReflectionApi) { - this.reflectionServer = new ReflectionServer(this.registry, { - configuredEnvs: ['dev'], - name: this.options.name, - }); + if (process.env.GENKIT_REFLECTION_V2_SERVER) { + this.reflectionServer = new ReflectionServerV2(this.registry, { + configuredEnvs: ['dev'], + name: this.options.name, + url: process.env.GENKIT_REFLECTION_V2_SERVER, + }); + } else { + this.reflectionServer = new ReflectionServer(this.registry, { + configuredEnvs: ['dev'], + name: this.options.name, + }); + } this.reflectionServer.start().catch((e) => logger.error); } if (options?.clientHeader) { diff --git a/js/genkit/tests/genkit_test.ts b/js/genkit/tests/genkit_test.ts index faff9243cc..92343673f1 100644 --- a/js/genkit/tests/genkit_test.ts +++ b/js/genkit/tests/genkit_test.ts @@ -30,4 +30,26 @@ describe('genkit', () => { assert.ok(getClientHeader().includes('genkit-node/')); assert.ok(getClientHeader().includes('foo')); }); + + it('initializes ReflectionServerV2 when GENKIT_REFLECTION_V2_SERVER is set', async () => { + process.env.GENKIT_REFLECTION_V2_SERVER = 'ws://localhost:1234'; + // Ensure we are in dev env for reflection server to start + const originalEnv = process.env.GENKIT_ENV; + process.env.GENKIT_ENV = 'dev'; + + try { + const instance = genkit({}); + // reflectionServer is private, cast to any to inspect + const reflectionServer = (instance as any).reflectionServer; + + assert.ok(reflectionServer, 'Reflection server should be initialized'); + assert.strictEqual(reflectionServer.url, 'ws://localhost:1234'); + + // Clean up + await instance.stopServers(); + } finally { + delete process.env.GENKIT_REFLECTION_V2_SERVER; + process.env.GENKIT_ENV = originalEnv; + } + }); }); diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index c1918c9103..ce499d6baa 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -147,6 +147,9 @@ importers: json-schema: specifier: ^0.4.0 version: 0.4.0 + ws: + specifier: ^8.18.3 + version: 8.18.3 zod: specifier: ^3.23.8 version: 3.25.67 @@ -160,6 +163,9 @@ importers: '@types/node': specifier: ^20.11.30 version: 20.19.1 + '@types/ws': + specifier: ^8.18.1 + version: 8.18.1 genversion: specifier: ^3.2.0 version: 3.2.0 @@ -4299,6 +4305,9 @@ packages: '@types/whatwg-url@11.0.5': resolution: {integrity: sha512-coYR071JRaHa+xoEvvYqvnIHaVqaYrLPbsufM9BF63HkwI5Lgmy2QR8Q5K/lYDYo5AK82wOvSOS0UsLTpTG7uQ==} + '@types/ws@8.18.1': + resolution: {integrity: sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==} + '@types/yargs-parser@21.0.3': resolution: {integrity: sha512-I4q9QU9MQv4oEOz4tAHJtNz1cwuLxn2F3xcc2iV5WdqLPpUnj30aUuxt1mAxYTG+oe8CZMV/+6rU4S4gRDzqtQ==} @@ -7188,6 +7197,7 @@ packages: source-map@0.8.0-beta.0: resolution: {integrity: sha512-2ymg6oRBpebeZi9UUNsgQ89bhx01TcTkmNTGnNO88imTmbSgy4nfujrgVEFKWpMTEGA11EDkTt7mqObTPdigIA==} engines: {node: '>= 8'} + deprecated: The work that was done in this beta branch won't be included in future versions spdx-correct@3.2.0: resolution: {integrity: sha512-kN9dJbvnySHULIluDHy32WHRUu3Og7B9sbY7tsFLctQkIqnMh3hErYgdMjTYuqmcXX+lK5T1lnUt3G7zNswmZA==} @@ -10512,6 +10522,10 @@ snapshots: dependencies: '@types/webidl-conversions': 7.0.3 + '@types/ws@8.18.1': + dependencies: + '@types/node': 20.19.1 + '@types/yargs-parser@21.0.3': {} '@types/yargs@17.0.33':