diff --git a/genkit-tools/common/src/types/model.ts b/genkit-tools/common/src/types/model.ts index c052d9ab46..553f1a67e9 100644 --- a/genkit-tools/common/src/types/model.ts +++ b/genkit-tools/common/src/types/model.ts @@ -399,5 +399,7 @@ export const GenerateActionOptionsSchema = z.object({ maxTurns: z.number().optional(), /** Custom step name for this generate call to display in trace views. Defaults to "generate". */ stepName: z.string().optional(), + /** Registered middleware to be used with this model call. */ + middleware: z.array(z.string()).optional(), }); export type GenerateActionOptions = z.infer; diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 0ee8e76519..43105e1781 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -466,6 +466,12 @@ }, "stepName": { "type": "string" + }, + "middleware": { + "type": "array", + "items": { + "type": "string" + } } }, "required": [ diff --git a/go/ai/gen.go b/go/ai/gen.go index 25616eb973..f59df3f5c1 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -125,6 +125,7 @@ type GenerateActionOptions struct { Docs []*Document `json:"docs,omitempty"` MaxTurns int `json:"maxTurns,omitempty"` Messages []*Message `json:"messages,omitempty"` + Middleware []string `json:"middleware,omitempty"` Model string `json:"model,omitempty"` Output *GenerateActionOutputConfig `json:"output,omitempty"` Resume *GenerateActionResume `json:"resume,omitempty"` diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index a2abfe8a91..9f15bf672e 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -51,7 +51,7 @@ import { type GenerationCommonConfigSchema, type MessageData, type ModelArgument, - type ModelMiddleware, + type ModelMiddlewareArgument, type Part, type ToolRequestPart, type ToolResponsePart, @@ -170,7 +170,7 @@ export interface GenerateOptions< */ streamingCallback?: StreamingCallback; /** Middleware to be used with this model call. */ - use?: ModelMiddleware[]; + use?: ModelMiddlewareArgument[]; /** Additional context (data, like e.g. auth) to be passed down to tools, prompts and other sub actions. */ context?: ActionContext; /** Abort signal for the generate request. */ @@ -538,6 +538,7 @@ export async function toGenerateActionOptions< returnToolRequests: options.returnToolRequests, maxTurns: options.maxTurns, stepName: options.stepName, + middleware: options.use?.filter((m): m is string => typeof m === 'string'), }; // if config is empty and it was not explicitly passed in, we delete it, don't want {} if (Object.keys(params.config).length === 0 && !options.config) { diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index a922404681..c613c455da 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -15,6 +15,7 @@ */ import { + ActionRunOptions, GenkitError, StreamingCallback, defineAction, @@ -42,6 +43,8 @@ import { GenerateResponseChunkSchema, GenerateResponseSchema, MessageData, + ModelMiddlewareArgument, + ModelMiddlewareWithOptions, resolveModel, type GenerateActionOptions, type GenerateActionOutputConfig, @@ -85,7 +88,7 @@ export function defineGenerateAction(registry: Registry): GenerateAction { outputSchema: GenerateResponseSchema, streamSchema: GenerateResponseChunkSchema, }, - async (request, { streamingRequested, sendChunk }) => { + async (request, { streamingRequested, sendChunk, context }) => { const generateFn = ( sendChunk?: StreamingCallback ) => @@ -93,9 +96,8 @@ export function defineGenerateAction(registry: Registry): GenerateAction { rawRequest: request, currentTurn: 0, messageIndex: 0, - // Generate util action does not support middleware. Maybe when we add named/registered middleware.... - middleware: [], streamingCallback: sendChunk, + context, }); return streamingRequested ? generateFn((c: GenerateResponseChunk) => @@ -113,18 +115,18 @@ export async function generateHelper( registry: Registry, options: { rawRequest: GenerateActionOptions; - middleware?: ModelMiddleware[]; + middleware?: ModelMiddlewareArgument[]; currentTurn?: number; messageIndex?: number; abortSignal?: AbortSignal; streamingCallback?: StreamingCallback; + context?: Record; } ): Promise { const currentTurn = options.currentTurn ?? 0; const messageIndex = options.messageIndex ?? 0; // do tracing return await runInNewSpan( - registry, { metadata: { name: options.rawRequest.stepName || 'generate', @@ -143,6 +145,7 @@ export async function generateHelper( messageIndex, abortSignal: options.abortSignal, streamingCallback: options.streamingCallback, + context: options.context, }); metadata.output = JSON.stringify(output); return output; @@ -247,13 +250,15 @@ async function generate( messageIndex, abortSignal, streamingCallback, + context, }: { rawRequest: GenerateActionOptions; - middleware: ModelMiddleware[] | undefined; + middleware?: ModelMiddlewareArgument[] | undefined; currentTurn: number; messageIndex: number; abortSignal?: AbortSignal; streamingCallback?: StreamingCallback; + context?: Record; } ): Promise { const { model, tools, resources, format } = await resolveParameters( @@ -319,30 +324,62 @@ async function generate( streamingCallback(makeChunk('tool', resumedToolMessage)); } + const rawMiddleware = rawRequest.middleware || []; + const argMiddleware = middleware || []; + const effectiveRawMiddleware = rawMiddleware.filter( + (m) => !argMiddleware.includes(m) + ); + const allMiddleware = [...argMiddleware, ...effectiveRawMiddleware]; + var response: GenerateResponse; + const sendChunk = + streamingCallback && + (((chunk: GenerateResponseChunkData) => + streamingCallback && + streamingCallback(makeChunk('model', chunk))) as any); const dispatch = async ( index: number, - req: z.infer + req: z.infer, + actionOpts: ActionRunOptions ) => { - if (!middleware || index === middleware.length) { + if (index === allMiddleware.length) { // end of the chain, call the original model action - return await model(req, { - abortSignal, - onChunk: - streamingCallback && - (((chunk: GenerateResponseChunkData) => - streamingCallback && - streamingCallback(makeChunk('model', chunk))) as any), - }); + return await model(req, actionOpts); } - const currentMiddleware = middleware[index]; - return currentMiddleware(req, async (modifiedReq) => - dispatch(index + 1, modifiedReq || req) - ); + let currentMiddleware = allMiddleware[index]; + if (typeof currentMiddleware === 'string') { + const resolvedMiddleware = await registry.lookupValue< + ModelMiddleware | ModelMiddlewareWithOptions + >('modelMiddleware', currentMiddleware); + if (!resolvedMiddleware) { + throw new GenkitError({ + status: 'NOT_FOUND', + message: `Middleware '${currentMiddleware}' not found.`, + }); + } + currentMiddleware = resolvedMiddleware; + } + + if (currentMiddleware.length === 3) { + return (currentMiddleware as ModelMiddlewareWithOptions)( + req, + actionOpts, + async (modifiedReq, opts) => + dispatch(index + 1, modifiedReq || req, opts || actionOpts) + ); + } else { + return (currentMiddleware as ModelMiddleware)(req, async (modifiedReq) => + dispatch(index + 1, modifiedReq || req, actionOpts) + ); + } }; - const modelResponse = await dispatch(0, request); + const modelResponse = await dispatch(0, request, { + abortSignal, + context, + onChunk: sendChunk, + }); if (model.__action.actionType === 'background-model') { response = new GenerateResponse( @@ -416,7 +453,28 @@ async function generate( // then recursively call for another loop return await generateHelper(registry, { rawRequest: nextRequest, - middleware: middleware, + middleware: allMiddleware, // Pass the combined middleware to the next recursion to avoid re-combining logic issues if any (but we re-evaluate rawRequest here) + // Wait, if we pass 'allMiddleware' here, we are passing functions and strings. + // 'generate' function expects that. + // However, we are also passing 'rawRequest' which is 'nextRequest'. + // 'nextRequest' is derived from 'rawRequest'. Does it keep 'middleware' property? + // Yes, spread operator `{...rawRequest, ...}` copies it. + // So 'nextRequest' has 'middleware' strings. + // 'allMiddleware' has functions + unique strings. + // In recursive call, 'generate' will combine them AGAIN. + // 'allMiddleware' (from arg) will be 'argMiddleware' in next call. + // 'rawRequest.middleware' will be 'rawMiddleware' in next call. + // 'effectiveRaw' will filter out strings present in 'allMiddleware'. + // If 'allMiddleware' contains the strings (which it does, from effectiveRaw), then they are filtered out. + // If 'allMiddleware' contains functions (resolved), they are not filtered. + // So we should be fine? + // Actually, 'allMiddleware' passed to 'generateHelper' becomes 'middleware' arg. + // 'middleware' arg will contain everything. + // 'rawRequest.middleware' will contain original strings. + // 'effectiveRaw' = raw.filter(m => !all.includes(m)). + // If 'all' contains the strings, effectiveRaw is empty. + // So we just use 'all'. + // This seems correct recursion-wise. currentTurn: currentTurn + 1, messageIndex: messageIndex + 1, streamingCallback, diff --git a/js/ai/src/model-types.ts b/js/ai/src/model-types.ts index be2a18fb7c..d2f8e89017 100644 --- a/js/ai/src/model-types.ts +++ b/js/ai/src/model-types.ts @@ -410,5 +410,7 @@ export const GenerateActionOptionsSchema = z.object({ maxTurns: z.number().optional(), /** Custom step name for this generate call to display in trace views. Defaults to "generate". */ stepName: z.string().optional(), + /** Registered middleware to be used with this model call. */ + middleware: z.array(z.string()).optional(), }); export type GenerateActionOptions = z.infer; diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 7877133ba3..ae58694754 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -18,6 +18,7 @@ import { ActionFnArg, BackgroundAction, GenkitError, + MiddlewareWithOptions, Operation, OperationSchema, action, @@ -108,6 +109,17 @@ export type ModelMiddleware = SimpleMiddleware< z.infer >; +export type ModelMiddlewareWithOptions = MiddlewareWithOptions< + z.infer, + z.infer, + z.infer +>; + +export type ModelMiddlewareArgument = + | ModelMiddleware + | ModelMiddlewareWithOptions + | string; + export type DefineModelOptions< CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny, > = { @@ -121,7 +133,7 @@ export type DefineModelOptions< /** Descriptive name for this model e.g. 'Google AI - Gemini Pro'. */ label?: string; /** Middleware to be used with this model. */ - use?: ModelMiddleware[]; + use?: ModelMiddlewareArgument[]; }; export function model( @@ -324,11 +336,11 @@ export function backgroundModel< } function getModelMiddleware(options: { - use?: ModelMiddleware[]; + use?: ModelMiddlewareArgument[]; name: string; supports?: ModelInfo['supports']; }) { - const middleware: ModelMiddleware[] = options.use || []; + const middleware: ModelMiddlewareArgument[] = options.use || []; if (!options?.supports?.context) middleware.push(augmentWithContext()); const constratedSimulator = simulateConstrainedGeneration(); middleware.push((req, next) => { diff --git a/js/ai/src/model/middleware.ts b/js/ai/src/model/middleware.ts index f7da057f78..b92d7a59d6 100644 --- a/js/ai/src/model/middleware.ts +++ b/js/ai/src/model/middleware.ts @@ -308,6 +308,7 @@ const DEFAULT_RETRY_STATUSES: StatusName[] = [ ]; const DEFAULT_FALLBACK_STATUSES: StatusName[] = [ + 'UNKNOWN', 'UNAVAILABLE', 'DEADLINE_EXCEEDED', 'RESOURCE_EXHAUSTED', diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index 052acee642..b91d3d38e9 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -29,6 +29,7 @@ import { defineModel, type ModelAction, type ModelMiddleware, + type ModelMiddlewareWithOptions, } from '../../src/model.js'; import { defineResource } from '../../src/resource.js'; import { defineTool } from '../../src/tool.js'; @@ -788,4 +789,159 @@ describe('generate', () => { }, ]); }); + + it('middleware can intercept streaming callback', async () => { + const registry = new Registry(); + const echoModel = defineModel( + registry, + { + apiVersion: 'v2', + name: 'echoModel', + supports: { tools: true }, + }, + async (_, { sendChunk }) => { + if (sendChunk) { + sendChunk({ content: [{ text: 'chunk1' }] }); + sendChunk({ content: [{ text: 'chunk2' }] }); + } + return { + message: { + role: 'model', + content: [{ text: 'done' }], + }, + finishReason: 'stop', + }; + } + ); + + const interceptMiddleware: ModelMiddlewareWithOptions = async ( + req, + opts, + next + ) => { + const originalOnChunk = opts!.onChunk; + return next(req, { + ...opts, + onChunk: (chunk) => { + if (originalOnChunk) { + const text = chunk.content?.[0]?.text; + originalOnChunk({ + ...chunk, + content: [{ text: `intercepted: ${text}` }], + }); + } + }, + }); + }; + + const { response, stream } = generateStream(registry, { + model: echoModel, + prompt: 'test', + use: [interceptMiddleware], + }); + + const streamed: any[] = []; + for await (const chunk of stream) { + streamed.push(chunk.content[0].text); + } + + assert.deepStrictEqual(streamed, [ + 'intercepted: chunk1', + 'intercepted: chunk2', + ]); + await response; + }); + + it('middleware can modify context', async () => { + const registry = new Registry(); + const checkContextModel = defineModel( + registry, + { + apiVersion: 'v2', + name: 'checkContextModel', + supports: { context: true }, + }, + async (request, { context }) => { + return { + message: { + role: 'model', + content: [{ text: `Context: ${context?.myValue}` }], + }, + finishReason: 'stop', + }; + } + ); + + const contextMiddleware: ModelMiddlewareWithOptions = async ( + req, + opts, + next + ) => { + return next(req, { + ...opts, + context: { + ...opts?.context, + myValue: 'foo', + }, + }); + }; + + const response = await generate(registry, { + model: checkContextModel, + prompt: 'test', + use: [contextMiddleware], + }); + + assert.strictEqual(response.text, 'Context: foo'); + }); + + it('middleware can chain option modifications', async () => { + const registry = new Registry(); + const checkContextModel = defineModel( + registry, + { + apiVersion: 'v2', + name: 'checkContextModel', + supports: { context: true }, + }, + async (request, { context }) => { + return { + message: { + role: 'model', + content: [{ text: `Context: ${JSON.stringify(context)}` }], + }, + finishReason: 'stop', + }; + } + ); + + const middleware1: ModelMiddlewareWithOptions = async (req, opts, next) => { + return next(req, { + ...opts, + context: { + ...opts?.context, + val: [...(opts?.context?.val ?? []), 'A'], + }, + }); + }; + + const middleware2: ModelMiddlewareWithOptions = async (req, opts, next) => { + return next(req, { + ...opts, + context: { + ...opts?.context, + val: [...(opts?.context?.val ?? []), 'B'], + }, + }); + }; + + const response = await generate(registry, { + model: checkContextModel, + prompt: 'test', + use: [middleware1, middleware2], + }); + + const context = JSON.parse(response.text.substring('Context: '.length)); + assert.deepStrictEqual(context.val, ['A', 'B']); + }); }); diff --git a/js/core/src/action.ts b/js/core/src/action.ts index a767dcedf0..1d296b5326 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -209,7 +209,8 @@ export type MiddlewareWithOptions = ( */ export type Middleware = | SimpleMiddleware - | MiddlewareWithOptions; + | MiddlewareWithOptions + | string; /** * Creates an action with provided middleware. @@ -246,7 +247,25 @@ export function actionWithMiddleware< return result.result; } - const currentMiddleware = middleware[index]; + let currentMiddleware = middleware[index]; + if (typeof currentMiddleware === 'string') { + const registry = wrapped.__registry; + if (!registry) { + throw new Error( + `Cannot resolve middleware '${currentMiddleware}' for action '${action.__action.name}' because the action is not registered.` + ); + } + const resolvedMiddleware = await registry.lookupValue< + SimpleMiddleware | MiddlewareWithOptions + >('modelMiddleware', currentMiddleware); + if (!resolvedMiddleware) { + throw new Error( + `Middleware '${currentMiddleware}' not found in registry.` + ); + } + currentMiddleware = resolvedMiddleware; + } + if (currentMiddleware.length === 3) { return (currentMiddleware as MiddlewareWithOptions>)( req, diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 2cd34a4c7a..42c56b9ca3 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -80,6 +80,8 @@ import { type DefineModelOptions, type GenerateResponseChunkData, type ModelAction, + type ModelMiddleware, + type ModelMiddlewareWithOptions, } from '@genkit-ai/ai/model'; import { defineReranker, @@ -341,6 +343,16 @@ export class Genkit implements HasRegistry { return defineBackgroundModel(this.registry, options); } + /** + * Registers a middleware with a name. + */ + defineMiddleware( + name: string, + middleware: ModelMiddleware | ModelMiddlewareWithOptions + ) { + this.registry.registerValue('modelMiddleware', name, middleware); + } + /** * Looks up a prompt by `name` (and optionally `variant`). Can be used to lookup * .prompt files or prompts previously defined with {@link Genkit.definePrompt} diff --git a/js/genkit/src/model.ts b/js/genkit/src/model.ts index 09f8c7bd65..786aaa4b47 100644 --- a/js/genkit/src/model.ts +++ b/js/genkit/src/model.ts @@ -58,6 +58,8 @@ export { type ModelArgument, type ModelInfo, type ModelMiddleware, + type ModelMiddlewareArgument, + type ModelMiddlewareWithOptions, type ModelReference, type ModelRequest, type ModelResponseChunkData, diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts index ecdcb783a8..63b85f67f1 100644 --- a/js/genkit/tests/generate_test.ts +++ b/js/genkit/tests/generate_test.ts @@ -1494,4 +1494,53 @@ describe('generate', () => { }); }); }); + + describe('middleware', () => { + let ai: GenkitBeta; + + beforeEach(() => { + ai = genkit({ + model: 'programmableModel', + }); + defineProgrammableModel(ai); + }); + + it('resolves registered middleware', async () => { + defineEchoModel(ai); + ai.defineMiddleware('simple-wrapper', async (req, next) => { + const response = await next(req); + response.message!.content[0].text = `wrapped: ${response.message!.content[0].text}`; + return response; + }); + + const response = await ai.generate({ + model: 'echoModel', + prompt: 'hi', + use: ['simple-wrapper'], + }); + + assert.strictEqual(response.text, 'wrapped: Echo: hi; config: {}'); + }); + + it('resolves registered middleware via action input', async () => { + defineEchoModel(ai); + ai.defineMiddleware('simple-wrapper', async (req, next) => { + const response = await next(req); + response.message!.content[0].text = `wrapped: ${response.message!.content[0].text}`; + return response; + }); + + const generateAction = await ai.registry.lookupAction('/util/generate'); + const response = await generateAction({ + model: 'echoModel', + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + middleware: ['simple-wrapper'], + }); + + assert.strictEqual( + response.message!.content[0].text, + 'wrapped: Echo: hi' + ); + }); + }); }); diff --git a/js/testapps/basic-gemini/src/index.ts b/js/testapps/basic-gemini/src/index.ts index d9bfc9848d..96817a87fc 100644 --- a/js/testapps/basic-gemini/src/index.ts +++ b/js/testapps/basic-gemini/src/index.ts @@ -50,31 +50,38 @@ ai.defineFlow('basic-hi', async () => { return text; }); +ai.defineMiddleware( + 'basic-retry', + retry({ + maxRetries: 2, + onError: (e, attempt) => console.log('--- oops ', attempt, e), + }) +); + ai.defineFlow('basic-hi-with-retry', async () => { const { text } = await ai.generate({ model: googleAI.model('gemini-2.5-pro'), prompt: 'You are a helpful AI assistant named Walt, say hello', - use: [ - retry({ - maxRetries: 2, - onError: (e, attempt) => console.log('--- oops ', attempt, e), - }), - ], + use: ['basic-retry'], }); return text; }); +ai.defineMiddleware( + 'basic-fallback', + fallback(ai, { + models: [googleAI.model('gemini-2.5-flash')], + statuses: ['UNKNOWN'], + onError: (e) => console.log('--- oops fallback', e), + }) +); + ai.defineFlow('basic-hi-with-fallback', async () => { const { text } = await ai.generate({ model: googleAI.model('gemini-2.5-something-that-does-not-exist'), prompt: 'You are a helpful AI assistant named Walt, say hello', - use: [ - fallback(ai, { - models: [googleAI.model('gemini-2.5-flash')], - statuses: ['UNKNOWN'], - }), - ], + use: ['basic-fallback'], }); return text; diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index 40baffe314..7756598355 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -962,6 +962,7 @@ class GenerateActionOptions(BaseModel): return_tool_requests: bool | None = Field(None, alias='returnToolRequests') max_turns: float | None = Field(None, alias='maxTurns') step_name: str | None = Field(None, alias='stepName') + middleware: list[str] | None = None class GenerateRequest(BaseModel):