diff --git a/genkit-tools/common/src/types/index.ts b/genkit-tools/common/src/types/index.ts index acc8b6a11b..c8a040dd0b 100644 --- a/genkit-tools/common/src/types/index.ts +++ b/genkit-tools/common/src/types/index.ts @@ -23,6 +23,7 @@ export * from './document'; export * from './env'; export * from './eval'; export * from './evaluator'; +export * from './middleware'; export * from './model'; export * from './prompt'; export * from './reflection'; diff --git a/genkit-tools/common/src/types/middleware.ts b/genkit-tools/common/src/types/middleware.ts new file mode 100644 index 0000000000..6fd2dd9810 --- /dev/null +++ b/genkit-tools/common/src/types/middleware.ts @@ -0,0 +1,38 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { z } from 'zod'; + +/** Descriptor for a registered middleware, returned by reflection API. */ +export const MiddlewareDescSchema = z.object({ + /** Unique name of the middleware. */ + name: z.string(), + /** Human-readable description of what the middleware does. */ + description: z.string().optional(), + /** JSON Schema for the middleware's configuration. */ + configSchema: z.record(z.any()).nullish(), + /** User defined metadata for the middleware. */ + metadata: z.record(z.any()).optional(), +}); +export type MiddlewareDesc = z.infer; + +/** Reference to a registered middleware with optional configuration. */ +export const MiddlewareRefSchema = z.object({ + /** Name of the registered middleware. */ + name: z.string(), + /** Configuration for the middleware (schema defined by the middleware). */ + config: z.any().optional(), +}); +export type MiddlewareRef = z.infer; diff --git a/genkit-tools/common/src/types/model.ts b/genkit-tools/common/src/types/model.ts index df6067bce2..31c666a775 100644 --- a/genkit-tools/common/src/types/model.ts +++ b/genkit-tools/common/src/types/model.ts @@ -15,6 +15,7 @@ */ import { z } from 'zod'; import { DocumentDataSchema } from './document'; +import { MiddlewareRefSchema } from './middleware'; import { CustomPartSchema, DataPartSchema, @@ -64,32 +65,6 @@ export { // IMPORTANT: Keep this file in sync with genkit/ai/src/model-types.ts! // -/** Descriptor for a registered middleware, returned by reflection API. */ -export const MiddlewareDescSchema = z.object({ - /** Unique name of the middleware. */ - name: z.string(), - /** Human-readable description of what the middleware does. */ - description: z.string().optional(), - /** JSON Schema for the middleware's configuration. */ - configSchema: z.record(z.any()).nullish(), - /** User defined metadata for the middleware. */ - metadata: z.record(z.any()).nullish(), -}); -export type MiddlewareDesc = z.infer; - -/** - * Zod schema of middleware reference. - */ -export const MiddlewareRefSchema = z.object({ - name: z.string(), - config: z.any().optional(), -}); - -/** - * Middleware reference. - */ -export type MiddlewareRef = z.infer; - /** * Zod schema of an opration representing a model reference. */ diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index eb07c00232..4506c45796 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -270,6 +270,49 @@ ], "additionalProperties": false }, + "MiddlewareDesc": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "configSchema": { + "anyOf": [ + { + "type": "object", + "additionalProperties": {} + }, + { + "type": "null" + } + ] + }, + "metadata": { + "type": "object", + "additionalProperties": {} + } + }, + "required": [ + "name" + ], + "additionalProperties": false + }, + "MiddlewareRef": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "config": {} + }, + "required": [ + "name" + ], + "additionalProperties": false + }, "CandidateError": { "type": "object", "properties": { @@ -756,56 +799,6 @@ ], "additionalProperties": false }, - "MiddlewareDesc": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "description": { - "type": "string" - }, - "configSchema": { - "anyOf": [ - { - "type": "object", - "additionalProperties": {} - }, - { - "type": "null" - } - ] - }, - "metadata": { - "anyOf": [ - { - "type": "object", - "additionalProperties": {} - }, - { - "type": "null" - } - ] - } - }, - "required": [ - "name" - ], - "additionalProperties": false - }, - "MiddlewareRef": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "config": {} - }, - "required": [ - "name" - ], - "additionalProperties": false - }, "ModelInfo": { "type": "object", "properties": { diff --git a/genkit-tools/scripts/schema-exporter.ts b/genkit-tools/scripts/schema-exporter.ts index 1d7bedf119..c923d6c21f 100644 --- a/genkit-tools/scripts/schema-exporter.ts +++ b/genkit-tools/scripts/schema-exporter.ts @@ -26,6 +26,7 @@ const EXPORTED_TYPE_MODULES = [ '../common/src/types/embedder.ts', '../common/src/types/evaluator.ts', '../common/src/types/error.ts', + '../common/src/types/middleware.ts', '../common/src/types/model.ts', '../common/src/types/parts.ts', '../common/src/types/reranker.ts', diff --git a/go/ai/gen.go b/go/ai/gen.go index 89dbe1dc32..c9dac20735 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -111,8 +111,9 @@ type GenerateActionOptions struct { // the model to choose a tool, and none forces the model not to use any tools. Defaults to auto. ToolChoice ToolChoice `json:"toolChoice,omitempty"` // Tools is a list of registered tool names for this generation if supported. - Tools []string `json:"tools,omitempty"` - Use []*MiddlewareRef `json:"use,omitempty"` + Tools []string `json:"tools,omitempty"` + // Use is middleware to apply to this generation, referenced by name with optional config. + Use []*MiddlewareRef `json:"use,omitempty"` } // GenerateActionResume holds options for resuming an interrupted generation. @@ -225,16 +226,25 @@ type Message struct { Role Role `json:"role,omitempty"` } +// MiddlewareDesc is the registered descriptor for a middleware. type MiddlewareDesc struct { - ConfigSchema any `json:"configSchema,omitempty"` - Description string `json:"description,omitempty"` - Metadata any `json:"metadata,omitempty"` - Name string `json:"name,omitempty"` + // ConfigSchema is a JSON Schema describing the middleware's configuration. + ConfigSchema map[string]any `json:"configSchema,omitempty"` + // Description explains what the middleware does. + Description string `json:"description,omitempty"` + // Metadata contains additional context for the middleware. + Metadata map[string]any `json:"metadata,omitempty"` + // Name is the middleware's unique identifier. + Name string `json:"name,omitempty"` + buildFromJSON middlewareFactoryFunc } +// MiddlewareRef is a serializable reference to a registered middleware with config. type MiddlewareRef struct { - Config any `json:"config,omitempty"` - Name string `json:"name,omitempty"` + // Config contains the middleware configuration. + Config any `json:"config,omitempty"` + // Name is the name of the registered middleware. + Name string `json:"name,omitempty"` } // ModelInfo contains metadata about a model's capabilities and characteristics. diff --git a/go/ai/generate.go b/go/ai/generate.go index 7cc240d6ea..307f82f002 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -67,6 +67,8 @@ type ModelFunc = core.StreamingFunc[*ModelRequest, *ModelResponse, *ModelRespons type ModelStreamCallback = func(context.Context, *ModelResponseChunk) error // ModelMiddleware is middleware for model generate requests that takes in a ModelFunc, does something, then returns another ModelFunc. +// +// Deprecated: Use [Middleware] interface with [WithUse] instead, which supports Generate, Model, and Tool hooks. type ModelMiddleware = core.Middleware[*ModelRequest, *ModelResponse, *ModelResponseChunk] // model is an action with functions specific to model generation such as Generate(). @@ -201,7 +203,7 @@ func LookupModel(r api.Registry, name string) Model { } // GenerateWithRequest is the central generation implementation for ai.Generate(), prompt.Execute(), and the GenerateAction direct call. -func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActionOptions, mw []ModelMiddleware, cb ModelStreamCallback) (*ModelResponse, error) { +func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActionOptions, mmws []ModelMiddleware, cb ModelStreamCallback) (*ModelResponse, error) { if opts.Model == "" { if defaultModel, ok := r.LookupValue(api.DefaultModelKey).(string); ok && defaultModel != "" { opts.Model = defaultModel @@ -217,29 +219,14 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: model %q not found", opts.Model) } - resumeOutput, err := handleResumeOption(ctx, r, opts) + mws, err := resolveRefs(ctx, r, opts.Use) if err != nil { return nil, err } - if resumeOutput.interruptedResponse != nil { - return nil, core.NewError(core.FAILED_PRECONDITION, - "One or more tools triggered an interrupt during a restarted execution.") - } - - opts = resumeOutput.revisedRequest - - if resumeOutput.toolMessage != nil && cb != nil { - err := cb(ctx, &ModelResponseChunk{ - Content: resumeOutput.toolMessage.Content, - Role: RoleTool, - Index: 0, - }) - if err != nil { - return nil, fmt.Errorf("streaming callback failed for resumed tool message: %w", err) - } - } - + // Tools contributed by middleware bundles are registered on a child + // registry so this Generate() call sees them while outer callers do not. + // Duplicate names across multiple middleware are rejected explicitly. toolDefMap := make(map[string]*ToolDefinition) for _, t := range opts.Tools { if _, ok := toolDefMap[t]; ok { @@ -253,6 +240,27 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi toolDefMap[t] = tool.Definition() } + var middlewareTools []Tool + for _, mw := range mws { + if mw == nil { + continue + } + for _, t := range mw.Tools { + if _, ok := toolDefMap[t.Name()]; ok { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: tool %q is contributed by middleware but already declared elsewhere", t.Name()) + } + toolDefMap[t.Name()] = t.Definition() + middlewareTools = append(middlewareTools, t) + } + } + if len(middlewareTools) > 0 { + if !r.IsChild() { + r = r.NewChild() + } + for _, t := range middlewareTools { + t.Register(r) + } + } toolDefs := make([]*ToolDefinition, 0, len(toolDefMap)) for _, t := range toolDefMap { toolDefs = append(toolDefs, t) @@ -322,12 +330,28 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi } else { fn = m.Generate } - fn = core.ChainMiddleware(mw...)(fn) - // Inline recursive helper function that captures variables from parent scope. + // Build the full hook chains once: wrapping the model function with + // WrapModel hooks from middleware, and wrapping the generate iteration + // with WrapGenerate hooks. These chains are reused across every tool-loop + // iteration rather than rebuilt each turn. + fn = buildModelChain(mws, fn) + fn = core.ChainMiddleware(mmws...)(fn) + + var streamingHandler StreamingFormatHandler + if sfh, ok := formatHandler.(StreamingFormatHandler); ok { + streamingHandler = sfh + } + + runTool := buildToolRunner(mws) + var generate func(context.Context, *ModelRequest, int, int) (*ModelResponse, error) + var runGenerate func(context.Context, *GenerateParams) (*ModelResponse, error) - generate = func(ctx context.Context, req *ModelRequest, currentTurn int, messageIndex int) (*ModelResponse, error) { + runGenerate = func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) { + req := params.Request + currentTurn := params.Iteration + messageIndex := params.MessageIndex spanMetadata := &tracing.SpanMetadata{ Name: "generate", Type: "util", @@ -339,11 +363,6 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi currentRole := RoleModel currentIndex := messageIndex - var streamingHandler StreamingFormatHandler - if sfh, ok := formatHandler.(StreamingFormatHandler); ok { - streamingHandler = sfh - } - if cb != nil { wrappedCb = func(ctx context.Context, chunk *ModelResponseChunk) error { if chunk.Role != currentRole && chunk.Role != "" { @@ -359,6 +378,44 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi } } + // Handle resume on the first iteration inside this span so that + // restarted tool execution is both wrapped by WrapGenerate (from + // the outer middleware chain) and recorded as a child of the + // current generate span (lifecycle: generate > tool > generate > + // model > tool). + if currentTurn == 0 && opts.Resume != nil && (len(opts.Resume.Respond) > 0 || len(opts.Resume.Restart) > 0) { + resumeOutput, err := handleResumeOption(ctx, r, opts, runTool) + if err != nil { + return nil, err + } + + if resumeOutput.interruptedResponse != nil { + return nil, core.NewError(core.FAILED_PRECONDITION, + "One or more tools triggered an interrupt during a restarted execution.") + } + + opts = resumeOutput.revisedRequest + + if resumeOutput.toolMessage != nil && wrappedCb != nil { + if err := wrappedCb(ctx, &ModelResponseChunk{ + Content: resumeOutput.toolMessage.Content, + Role: RoleTool, + }); err != nil { + return nil, fmt.Errorf("streaming callback failed for resumed tool message: %w", err) + } + } + + resumeReq := &ModelRequest{ + Messages: opts.Messages, + Config: req.Config, + Docs: req.Docs, + ToolChoice: req.ToolChoice, + Tools: req.Tools, + Output: req.Output, + } + return generate(ctx, resumeReq, currentTurn+1, currentIndex) + } + resp, err := fn(ctx, req, wrappedCb) if err != nil { return nil, err @@ -390,7 +447,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi return nil, core.NewError(core.ABORTED, "exceeded maximum tool call iterations (%d)", maxTurns) } - newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex) + newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex, runTool) if err != nil { return nil, err } @@ -408,9 +465,99 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi }) } + // Compose WrapGenerate hooks once; this chain is invoked for every + // tool-loop iteration. + hookedGenerate := buildGenerateChain(mws, runGenerate) + + generate = func(ctx context.Context, req *ModelRequest, currentTurn int, messageIndex int) (*ModelResponse, error) { + return hookedGenerate(ctx, &GenerateParams{ + Options: opts, + Request: req, + Iteration: currentTurn, + MessageIndex: messageIndex, + Callback: cb, + }) + } + return generate(ctx, req, 0, 0) } +// buildGenerateChain composes the WrapGenerate hooks from mws (outer-to-inner) +// around run. Middleware with a nil WrapGenerate hook is skipped. +func buildGenerateChain(mws []*Hooks, run func(ctx context.Context, params *GenerateParams) (*ModelResponse, error)) func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) { + chain := run + for i := len(mws) - 1; i >= 0; i-- { + mw := mws[i] + if mw == nil || mw.WrapGenerate == nil { + continue + } + hook := mw.WrapGenerate + next := chain + chain = func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) { + return hook(ctx, params, next) + } + } + return chain +} + +// buildModelChain composes the WrapModel hooks from mws (outer-to-inner) +// around fn. Middleware with a nil WrapModel hook is skipped. +func buildModelChain(mws []*Hooks, fn ModelFunc) ModelFunc { + chain := fn + for i := len(mws) - 1; i >= 0; i-- { + mw := mws[i] + if mw == nil || mw.WrapModel == nil { + continue + } + hook := mw.WrapModel + next := chain + chain = func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return hook(ctx, &ModelParams{Request: req, Callback: cb}, + func(ctx context.Context, params *ModelParams) (*ModelResponse, error) { + return next(ctx, params.Request, params.Callback) + }) + } + } + return chain +} + +// buildToolRunner composes the WrapTool hooks from mws (outer-to-inner) into +// a single function that executes a tool. The returned function is safe to +// invoke from concurrent goroutines; each invocation threads its own params +// through the shared hook chain. When no WrapTool hooks are configured, the +// tool is invoked directly without allocating a ToolParams wrapper. +func buildToolRunner(mws []*Hooks) func(ctx context.Context, tool Tool, req *ToolRequest) (*MultipartToolResponse, error) { + hasHook := false + for _, mw := range mws { + if mw != nil && mw.WrapTool != nil { + hasHook = true + break + } + } + if !hasHook { + return func(ctx context.Context, tool Tool, req *ToolRequest) (*MultipartToolResponse, error) { + return tool.RunRawMultipart(ctx, req.Input) + } + } + chain := func(ctx context.Context, params *ToolParams) (*MultipartToolResponse, error) { + return params.Tool.RunRawMultipart(ctx, params.Request.Input) + } + for i := len(mws) - 1; i >= 0; i-- { + mw := mws[i] + if mw == nil || mw.WrapTool == nil { + continue + } + hook := mw.WrapTool + next := chain + chain = func(ctx context.Context, params *ToolParams) (*MultipartToolResponse, error) { + return hook(ctx, params, next) + } + } + return func(ctx context.Context, tool Tool, req *ToolRequest) (*MultipartToolResponse, error) { + return chain(ctx, &ToolParams{Request: req, Tool: tool}) + } +} + // Generate generates a model response based on the provided options. func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*ModelResponse, error) { genOpts := &generateOptions{} @@ -539,7 +686,14 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod } } - // Process resources in messages + refs, err := configsToRefs(genOpts.Use) + if err != nil { + return nil, err + } + if len(refs) > 0 { + actionOpts.Use = refs + } + processedMessages, err := processResources(ctx, r, messages) if err != nil { return nil, core.NewError(core.INTERNAL, "ai.Generate: error processing resources: %v", err) @@ -774,10 +928,14 @@ func clone[T any](obj *T) *T { return &newObj } +// toolRunnerFunc runs a tool through the WrapTool hook chain and returns the +// raw [MultipartToolResponse]. Returned by [buildToolRunner]. +type toolRunnerFunc = func(ctx context.Context, tool Tool, req *ToolRequest) (*MultipartToolResponse, error) + // handleToolRequests processes any tool requests in the response, returning // either a new request to continue the conversation or nil if no tool requests // need handling. -func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int) (*ModelRequest, *Message, error) { +func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int, runTool toolRunnerFunc) (*ModelRequest, *Message, error) { toolCount := len(resp.ToolRequests()) if toolCount == 0 { return nil, nil, nil @@ -800,7 +958,7 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return } - multipartResp, err := tool.RunRawMultipart(ctx, toolReq.Input) + multipartResp, err := runTool(ctx, tool, toolReq) if err != nil { var tie *toolInterruptError if errors.As(err, &tie) { @@ -1122,7 +1280,7 @@ func (m ModelRef) Config() any { // handleResumedToolRequest resolves a tool request from a previous, interrupted model turn, // when generation is being resumed. It determines the outcome of the tool request based on // pending output, or explicit 'respond' or 'restart' directives in the resume options. -func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions, p *Part) (*resumedToolRequestOutput, error) { +func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions, p *Part, runTool toolRunnerFunc) (*resumedToolRequestOutput, error) { if p == nil || !p.IsToolRequest() { return nil, core.NewError(core.INVALID_ARGUMENT, "handleResumedToolRequest: part is not a tool request") } @@ -1210,7 +1368,12 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene resumedCtx = origInputCtxKey.NewContext(resumedCtx, originalInputVal) } - output, err := tool.RunRaw(resumedCtx, restartPart.ToolRequest.Input) + restartToolReq := &ToolRequest{ + Name: restartPart.ToolRequest.Name, + Ref: restartPart.ToolRequest.Ref, + Input: restartPart.ToolRequest.Input, + } + multipartResp, err := runTool(resumedCtx, tool, restartToolReq) if err != nil { var tie *toolInterruptError if errors.As(err, &tie) { @@ -1237,9 +1400,10 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene } newToolResp := NewToolResponsePart(&ToolResponse{ - Name: restartPart.ToolRequest.Name, - Ref: restartPart.ToolRequest.Ref, - Output: output, + Name: restartPart.ToolRequest.Name, + Ref: restartPart.ToolRequest.Ref, + Output: multipartResp.Output, + Content: multipartResp.Content, }) return &resumedToolRequestOutput{ @@ -1259,7 +1423,7 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene // handleResumeOption amends message history to handle `resume` arguments. // It returns the amended history. -func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions) (*resumeOptionOutput, error) { +func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions, runTool toolRunnerFunc) (*resumeOptionOutput, error) { if genOpts.Resume == nil || (len(genOpts.Resume.Respond) == 0 && len(genOpts.Resume.Restart) == 0) { return &resumeOptionOutput{revisedRequest: genOpts}, nil } @@ -1295,7 +1459,7 @@ func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateAc toolReqCount++ go func(idx int, p *Part) { - output, err := handleResumedToolRequest(ctx, r, genOpts, p) + output, err := handleResumedToolRequest(ctx, r, genOpts, p, runTool) resultChan <- result[*resumedToolRequestOutput]{ index: idx, value: output, @@ -1361,6 +1525,7 @@ func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateAc Docs: genOpts.Docs, ReturnToolRequests: genOpts.ReturnToolRequests, Output: genOpts.Output, + Use: genOpts.Use, }, toolMessage: toolMessage, }, nil diff --git a/go/ai/middleware.go b/go/ai/middleware.go new file mode 100644 index 0000000000..1b28012b65 --- /dev/null +++ b/go/ai/middleware.go @@ -0,0 +1,263 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package ai + +import ( + "context" + "encoding/json" + + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" +) + +// Hooks is the per-call bundle of hook functions produced by a [Middleware]'s +// New method. Each field is optional; a nil hook is treated as a pass-through. +type Hooks struct { + // Tools are additional tools to register during the generation this + // middleware is attached to. They are available to the model alongside + // any user-supplied tools. + Tools []Tool + // WrapGenerate wraps each iteration of the tool loop. It sees the + // accumulated request, the iteration index, and the streaming callback. + // A single Generate() with N tool-call turns invokes this hook N+1 times. + WrapGenerate func(ctx context.Context, params *GenerateParams, next GenerateNext) (*ModelResponse, error) + // WrapModel wraps each model API call. Retry, fallback, and caching + // middleware typically hook here. + WrapModel func(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) + // WrapTool wraps each tool execution. It may be called concurrently when + // multiple tools execute in parallel for the same Generate() call; any + // state closed over from the enclosing scope that this hook mutates must + // be guarded with sync primitives. + WrapTool func(ctx context.Context, params *ToolParams, next ToolNext) (*MultipartToolResponse, error) +} + +// GenerateParams holds params for the WrapGenerate hook. +type GenerateParams struct { + // Options is the original options passed to [Generate]. + Options *GenerateActionOptions + // Request is the current model request for this iteration, with accumulated messages. + Request *ModelRequest + // Iteration is the current tool-loop iteration (0-indexed). + Iteration int + // MessageIndex is the index of the next message in the streamed response sequence. + // Middleware that streams extra messages (e.g. injected user content) should emit + // chunks at this index and advance it so downstream middleware and the model + // receive the shifted value. + MessageIndex int + // Callback is the streaming callback supplied to [Generate], or nil if not streaming. + // Middleware may invoke it to emit chunks, setting [ModelResponseChunk.Role] and + // [ModelResponseChunk.Index] explicitly. + Callback ModelStreamCallback +} + +// ModelParams holds params for the WrapModel hook. +type ModelParams struct { + // Request is the model request about to be sent. + Request *ModelRequest + // Callback is the streaming callback, or nil if not streaming. + Callback ModelStreamCallback +} + +// ToolParams holds params for the WrapTool hook. +type ToolParams struct { + // Request is the tool request about to be executed. + Request *ToolRequest + // Tool is the resolved tool being called. + Tool Tool +} + +// GenerateNext is the next function in the WrapGenerate hook chain. +type GenerateNext = func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) + +// ModelNext is the next function in the WrapModel hook chain. +type ModelNext = func(ctx context.Context, params *ModelParams) (*ModelResponse, error) + +// ToolNext is the next function in the WrapTool hook chain. +type ToolNext = func(ctx context.Context, params *ToolParams) (*MultipartToolResponse, error) + +// Middleware is the contract every value passed to [WithUse] satisfies. The +// config struct both identifies the middleware (via [Name]) and produces a +// per-call [Hooks] bundle (via [New]). +// +// Plugin-level state belongs on unexported fields of the config type. A +// plugin's [MiddlewarePlugin.Middlewares] sets those fields on a prototype +// that is preserved across JSON dispatch by value-copy inside the descriptor. +type Middleware interface { + // Name returns the registered middleware's unique identifier. Must be a + // stable constant, since it is read from a zero value of the config type + // during descriptor creation. + Name() string + // New produces a fresh [Hooks] bundle for one Generate() call. It is + // invoked per-Generate, so any state the bundle's hooks need to share + // (counters, caches) may be allocated in this method and closed over by + // the returned hooks. + New(ctx context.Context) (*Hooks, error) +} + +// middlewareFactoryFunc is the closure stored on [MiddlewareDesc] that +// materializes a [Hooks] bundle from JSON config. It is produced by +// [NewMiddleware] and captures the prototype so value-copy preserves any +// unexported plugin-level state across JSON-dispatched calls. +type middlewareFactoryFunc = func(ctx context.Context, configJSON []byte) (*Hooks, error) + +// middlewareRegistryPrefix is the registry-key prefix under which middleware +// descriptors are stored. The reflection API lists values under this prefix. +const middlewareRegistryPrefix = "/middleware/" + +func middlewareRegistryKey(name string) string { + return middlewareRegistryPrefix + name +} + +// Register records this descriptor in the registry under its name so it can +// be resolved during JSON dispatch and surfaced to the Dev UI. +func (d *MiddlewareDesc) Register(r api.Registry) { + r.RegisterValue(middlewareRegistryKey(d.Name), d) +} + +// NewMiddleware constructs a descriptor without registering it. Useful for +// [MiddlewarePlugin.Middlewares] implementations that defer registration +// to [genkit.Init]. The prototype argument supplies both the registered name +// (via its [Middleware.Name] method) and any plugin-level state that should +// flow into JSON-dispatched invocations via unexported fields preserved by +// value-copy. +func NewMiddleware[M Middleware](description string, prototype M) *MiddlewareDesc { + name := prototype.Name() + return &MiddlewareDesc{ + Name: name, + Description: description, + ConfigSchema: core.InferSchemaMap(prototype), + buildFromJSON: func(ctx context.Context, configJSON []byte) (*Hooks, error) { + cfg := prototype // value copy preserves unexported fields, shares pointers + if len(configJSON) > 0 { + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return nil, core.NewError(core.INVALID_ARGUMENT, "middleware %q: %w", name, err) + } + } + return cfg.New(ctx) + }, + } +} + +// DefineMiddleware creates and registers a middleware descriptor in one step. +func DefineMiddleware[M Middleware](r api.Registry, description string, prototype M) *MiddlewareDesc { + d := NewMiddleware(description, prototype) + d.Register(r) + return d +} + +// MiddlewareFunc adapts a per-call factory closure to the [Middleware] +// interface for ad-hoc inline use, without a registered descriptor or plugin +// wiring. The adapted middleware does not appear in the Dev UI. +// +// Example: +// +// ai.WithUse(ai.MiddlewareFunc(func(ctx context.Context) (*ai.Hooks, error) { +// return &ai.Hooks{WrapModel: ...}, nil +// })) +type MiddlewareFunc func(ctx context.Context) (*Hooks, error) + +// Name returns the placeholder name shared by all [MiddlewareFunc] values. +// Uniqueness is unnecessary: inline middleware is resolved via the fast path +// in [resolveRefs] and never goes through a name-keyed registry lookup. +func (MiddlewareFunc) Name() string { return "inline" } + +func (f MiddlewareFunc) New(ctx context.Context) (*Hooks, error) { return f(ctx) } + +// LookupMiddleware returns the registered middleware descriptor with the +// given name, or nil if no such descriptor exists in the registry or any +// ancestor. Primarily useful for inspection and for the reflection API; +// callers dispatching middleware should do so through [WithUse]. +func LookupMiddleware(r api.Registry, name string) *MiddlewareDesc { + v := r.LookupValue(middlewareRegistryKey(name)) + if v == nil { + return nil + } + d, _ := v.(*MiddlewareDesc) + return d +} + +// MiddlewarePlugin is implemented by plugins that provide middleware. The +// returned descriptors are registered in the registry during [genkit.Init], +// with any plugin-level state captured by the descriptor's build closure via +// the prototype passed to [NewMiddleware]. +type MiddlewarePlugin interface { + Middlewares(ctx context.Context) ([]*MiddlewareDesc, error) +} + +// configsToRefs converts a user-supplied slice of [Middleware] values into +// the [MiddlewareRef] entries carried on [GenerateActionOptions.Use]. The Go +// value is stored on each ref so [resolveRefs] can build the hooks bundle +// directly without a registry round trip for local calls. +func configsToRefs(configs []Middleware) ([]*MiddlewareRef, error) { + if len(configs) == 0 { + return nil, nil + } + refs := make([]*MiddlewareRef, 0, len(configs)) + for _, c := range configs { + if c == nil { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai: nil middleware") + } + refs = append(refs, &MiddlewareRef{Name: c.Name(), Config: c}) + } + return refs, nil +} + +// resolveRefs resolves [MiddlewareRef] entries to [Hooks] bundles. If +// ref.Config is a [Middleware] value, its New method is invoked directly +// (local fast path). Otherwise the descriptor is looked up in the registry +// and its build closure is invoked with the marshaled config (JSON dispatch, +// used for cross-runtime / Dev UI calls). +func resolveRefs(ctx context.Context, r api.Registry, refs []*MiddlewareRef) ([]*Hooks, error) { + if len(refs) == 0 { + return nil, nil + } + bundles := make([]*Hooks, 0, len(refs)) + for _, ref := range refs { + if mw, ok := ref.Config.(Middleware); ok { + h, err := mw.New(ctx) + if err != nil { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai: failed to build middleware %q: %v", ref.Name, err) + } + if h == nil { + return nil, core.NewError(core.INTERNAL, "ai: middleware %q returned nil hooks", ref.Name) + } + bundles = append(bundles, h) + continue + } + d := LookupMiddleware(r, ref.Name) + if d == nil { + return nil, core.NewError(core.NOT_FOUND, "ai: middleware %q not registered (is the providing plugin installed?)", ref.Name) + } + var configJSON []byte + if ref.Config != nil { + b, err := json.Marshal(ref.Config) + if err != nil { + return nil, core.NewError(core.INTERNAL, "ai: failed to marshal config for middleware %q: %v", ref.Name, err) + } + configJSON = b + } + h, err := d.buildFromJSON(ctx, configJSON) + if err != nil { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai: failed to build middleware %q: %v", ref.Name, err) + } + if h == nil { + return nil, core.NewError(core.INTERNAL, "ai: middleware %q factory returned nil", ref.Name) + } + bundles = append(bundles, h) + } + return bundles, nil +} diff --git a/go/ai/middleware_test.go b/go/ai/middleware_test.go new file mode 100644 index 0000000000..faf0bd814e --- /dev/null +++ b/go/ai/middleware_test.go @@ -0,0 +1,576 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package ai + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" +) + +// --- counter: a config whose BuildMiddleware tracks hook invocations --- + +type counterConfig struct { + Label string `json:"label,omitempty"` + + // Plugin-level state lives on unexported fields and is preserved by + // the descriptor's value-copy across JSON-dispatch calls. + sharedGenerateCalls *int32 + sharedModelCalls *int32 + sharedToolCalls *int32 +} + +func (counterConfig) Name() string { return "test/counter" } + +func (c counterConfig) New(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapGenerate: func(ctx context.Context, p *GenerateParams, next GenerateNext) (*ModelResponse, error) { + if c.sharedGenerateCalls != nil { + atomic.AddInt32(c.sharedGenerateCalls, 1) + } + return next(ctx, p) + }, + WrapModel: func(ctx context.Context, p *ModelParams, next ModelNext) (*ModelResponse, error) { + if c.sharedModelCalls != nil { + atomic.AddInt32(c.sharedModelCalls, 1) + } + return next(ctx, p) + }, + WrapTool: func(ctx context.Context, p *ToolParams, next ToolNext) (*MultipartToolResponse, error) { + if c.sharedToolCalls != nil { + atomic.AddInt32(c.sharedToolCalls, 1) + } + return next(ctx, p) + }, + }, nil +} + +// --- core descriptor tests --- + +func TestNewMiddleware_NameFromPrototype(t *testing.T) { + desc := NewMiddleware("tracks calls", counterConfig{}) + if desc.Name != "test/counter" { + t.Errorf("got name %q, want %q", desc.Name, "test/counter") + } + if desc.Description != "tracks calls" { + t.Errorf("got description %q, want %q", desc.Description, "tracks calls") + } +} + +func TestBuildFromJSON(t *testing.T) { + desc := NewMiddleware("desc", counterConfig{}) + mw, err := desc.buildFromJSON(testCtx, []byte(`{"label": "custom"}`)) + if err != nil { + t.Fatalf("buildFromJSON failed: %v", err) + } + if mw == nil || mw.WrapModel == nil { + t.Fatal("expected middleware with WrapModel hook") + } +} + +func TestBuildFromJSON_InvalidJSON(t *testing.T) { + desc := NewMiddleware("desc", counterConfig{}) + _, err := desc.buildFromJSON(testCtx, []byte(`not-json`)) + if err == nil { + t.Fatal("expected error from invalid JSON") + } +} + +// --- plugin-level state: prototype unexported fields preserved across calls --- + +func TestPluginStateCarriedThroughPrototype(t *testing.T) { + // Simulate the plugin's Middlewares() building a prototype that holds + // an "expensive client" (here, a shared counter). JSON dispatch must + // preserve this plugin-level state across invocations. + var shared int32 + desc := NewMiddleware("desc", counterConfig{sharedModelCalls: &shared}) + + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + desc.Register(r) + + // Simulate Dev UI JSON dispatch: ref.Config is a map, not a typed config. + refs := []*MiddlewareRef{{Name: "test/counter", Config: map[string]any{"label": "dev-ui"}}} + for i := 0; i < 3; i++ { + _, err := GenerateWithRequest(testCtx, r, &GenerateActionOptions{ + Model: m.Name(), + Messages: []*Message{NewUserTextMessage("go")}, + Use: refs, + }, nil, nil) + assertNoError(t, err) + } + if got := atomic.LoadInt32(&shared); got != 3 { + t.Errorf("shared counter = %d, want 3 (plugin state should persist across JSON dispatches)", got) + } +} + +// --- call-level state: each Generate gets fresh BuildMiddleware scope --- + +type perCallConfig struct { + checker func(n int32) +} + +func (perCallConfig) Name() string { return "test/per-call" } + +func (c perCallConfig) New(ctx context.Context) (*Hooks, error) { + var counter int32 + return &Hooks{ + WrapModel: func(ctx context.Context, p *ModelParams, next ModelNext) (*ModelResponse, error) { + n := atomic.AddInt32(&counter, 1) + if c.checker != nil { + c.checker(n) + } + return next(ctx, p) + }, + }, nil +} + +func TestCallLevelStateIsolation(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + cfg := perCallConfig{checker: func(n int32) { + if n != 1 { + t.Errorf("call-level counter leaked: got %d, want 1", n) + } + }} + for i := 0; i < 3; i++ { + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("go"), WithUse(cfg)) + assertNoError(t, err) + } +} + +// --- pure Go usage: no registration required for local calls --- + +func TestWithUseNoRegistrationNeeded(t *testing.T) { + // The whole point: user creates a Genkit with no middleware plugins and + // still calls WithUse(middleware.Retry{...}). The config's BuildMiddleware + // method runs directly; the registry is never consulted. + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + var called int32 + cfg := counterConfig{sharedModelCalls: &called} + // Note: no Register call anywhere. + + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("hi"), WithUse(cfg)) + assertNoError(t, err) + if atomic.LoadInt32(&called) != 1 { + t.Errorf("expected 1 model-hook call, got %d", called) + } +} + +// --- hook invocation: model, tool, generate --- + +func TestMiddlewareModelHook(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + var called int32 + tracker := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapModel: func(ctx context.Context, p *ModelParams, next ModelNext) (*ModelResponse, error) { + atomic.AddInt32(&called, 1) + return next(ctx, p) + }, + }, nil + }) + + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("hello"), WithUse(tracker)) + assertNoError(t, err) + if atomic.LoadInt32(&called) == 0 { + t.Error("expected model hook to be called") + } +} + +func TestMiddlewareToolHook(t *testing.T) { + r := newTestRegistry(t) + defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolModel", + handler: toolCallingModelHandler("myTool", map[string]any{"value": "test"}, "done"), + }) + defineFakeTool(t, r, "myTool", "A test tool") + + var called int32 + tracker := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapTool: func(ctx context.Context, p *ToolParams, next ToolNext) (*MultipartToolResponse, error) { + atomic.AddInt32(&called, 1) + return next(ctx, p) + }, + }, nil + }) + + _, err := Generate(testCtx, r, + WithModelName("test/toolModel"), + WithPrompt("use the tool"), + WithTools(ToolName("myTool")), + WithUse(tracker), + ) + assertNoError(t, err) + if atomic.LoadInt32(&called) == 0 { + t.Error("expected tool hook to be called at least once") + } +} + +func TestMiddlewareGenerateHook(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + var called int32 + tracker := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapGenerate: func(ctx context.Context, p *GenerateParams, next GenerateNext) (*ModelResponse, error) { + atomic.AddInt32(&called, 1) + return next(ctx, p) + }, + }, nil + }) + + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("hello"), WithUse(tracker)) + assertNoError(t, err) + if atomic.LoadInt32(&called) == 0 { + t.Error("expected generate hook to be called") + } +} + +// --- ordering: first middleware wraps outermost --- + +func TestMiddlewareOrdering(t *testing.T) { + var mu sync.Mutex + var order []string + appendOrder := func(s string) { + mu.Lock() + defer mu.Unlock() + order = append(order, s) + } + tracker := func(label string) Middleware { + return MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapModel: func(ctx context.Context, p *ModelParams, next ModelNext) (*ModelResponse, error) { + appendOrder(label + "-before") + resp, err := next(ctx, p) + appendOrder(label + "-after") + return resp, err + }, + }, nil + }) + } + + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + _, err := Generate(testCtx, r, + WithModel(m), + WithPrompt("hello"), + WithUse(tracker("A"), tracker("B")), + ) + assertNoError(t, err) + + want := []string{"A-before", "B-before", "B-after", "A-after"} + if len(order) != len(want) { + t.Fatalf("got %v, want %v", order, want) + } + for i := range want { + if order[i] != want[i] { + t.Errorf("order[%d] = %q, want %q", i, order[i], want[i]) + } + } +} + +// --- MiddlewareFunc adapter basics --- + +func TestMiddlewareFunc(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + var called bool + mw := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapGenerate: func(ctx context.Context, p *GenerateParams, next GenerateNext) (*ModelResponse, error) { + called = true + return next(ctx, p) + }, + }, nil + }) + + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("hello"), WithUse(mw)) + assertNoError(t, err) + if !called { + t.Error("inline middleware hook not called") + } +} + +func TestMiddlewareFuncCoexist(t *testing.T) { + // Two MiddlewareFunc adapter instances should be able to coexist in a + // single WithUse call. + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + var a, b int32 + useA := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{WrapModel: func(ctx context.Context, p *ModelParams, next ModelNext) (*ModelResponse, error) { + atomic.AddInt32(&a, 1) + return next(ctx, p) + }}, nil + }) + useB := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{WrapModel: func(ctx context.Context, p *ModelParams, next ModelNext) (*ModelResponse, error) { + atomic.AddInt32(&b, 1) + return next(ctx, p) + }}, nil + }) + + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("hi"), WithUse(useA, useB)) + assertNoError(t, err) + if atomic.LoadInt32(&a) != 1 || atomic.LoadInt32(&b) != 1 { + t.Errorf("expected both hooks called once, got a=%d b=%d", a, b) + } +} + +// --- optional hooks: nil hook fields must pass through --- + +func TestNilHookFieldsPassThrough(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + passthrough := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{}, nil + }) + + resp, err := Generate(testCtx, r, WithModel(m), WithPrompt("hi"), WithUse(passthrough)) + assertNoError(t, err) + if resp == nil { + t.Fatal("expected response") + } +} + +// --- streaming: middleware-emitted chunks accumulate with model chunks --- + +func TestMiddlewareStreamsAccumulateWithModel(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{ + name: "test/streamModel", + handler: streamingModelHandler([]string{"model chunk"}, "done"), + }) + + midChunk := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapGenerate: func(ctx context.Context, p *GenerateParams, next GenerateNext) (*ModelResponse, error) { + if p.Callback != nil { + if err := p.Callback(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart("middleware chunk ")}, + }); err != nil { + return nil, err + } + } + return next(ctx, p) + }, + }, nil + }) + + var chunks []*ModelResponseChunk + _, err := Generate(testCtx, r, + WithModel(m), + WithPrompt("go"), + WithUse(midChunk), + WithStreaming(func(_ context.Context, c *ModelResponseChunk) error { + chunks = append(chunks, c) + return nil + }), + ) + assertNoError(t, err) + + if len(chunks) != 2 { + t.Fatalf("got %d chunks, want 2", len(chunks)) + } +} + +// --- tool contribution: Tools on *Middleware --- + +func TestMiddlewareContributesTool(t *testing.T) { + r := newTestRegistry(t) + defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolModel", + handler: toolCallingModelHandler("mw/tool", map[string]any{"value": "x"}, "done"), + }) + + injectTool := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + Tools: []Tool{NewTool("mw/tool", "injected", + func(tc *ToolContext, in struct { + Value string `json:"value"` + }) (string, error) { + return "ok", nil + })}, + }, nil + }) + + _, err := Generate(testCtx, r, + WithModelName("test/toolModel"), + WithPrompt("use it"), + WithUse(injectTool), + ) + assertNoError(t, err) +} + +// --- duplicate tool collision: two middleware with same tool name --- + +func TestDuplicateMiddlewareToolRejected(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + makeInjector := func() Middleware { + return MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + Tools: []Tool{NewTool("dup/tool", "d", + func(tc *ToolContext, in struct{}) (string, error) { return "x", nil })}, + }, nil + }) + } + + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("hi"), + WithUse(makeInjector(), makeInjector())) + if err == nil { + t.Fatal("expected duplicate tool error, got nil") + } +} + +// --- error propagation from BuildMiddleware --- + +func TestBuildMiddlewareErrorPropagates(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + bad := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return nil, errors.New("boom") + }) + + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("hi"), WithUse(bad)) + if err == nil { + t.Fatal("expected BuildMiddleware error, got nil") + } +} + +// --- tool interrupt from WrapTool --- + +func TestWrapToolInterrupts(t *testing.T) { + r := newTestRegistry(t) + defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolModel", + handler: toolCallingModelHandler("myTool", map[string]any{"value": "x"}, "done"), + }) + defineFakeTool(t, r, "myTool", "A test tool") + + interrupter := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapTool: func(ctx context.Context, p *ToolParams, next ToolNext) (*MultipartToolResponse, error) { + return nil, NewToolInterruptError(map[string]any{"reason": "blocked"}) + }, + }, nil + }) + + resp, err := Generate(testCtx, r, + WithModelName("test/toolModel"), + WithPrompt("use it"), + WithTools(ToolName("myTool")), + WithUse(interrupter), + ) + assertNoError(t, err) + if resp.FinishReason != "interrupted" { + t.Errorf("expected FinishReason=interrupted, got %q", resp.FinishReason) + } + if len(resp.Interrupts()) == 0 { + t.Error("expected at least one interrupt part in response") + } +} + +// --- WrapGenerate fires per tool-loop iteration --- + +func TestGenerateHookFiresEachIteration(t *testing.T) { + r := newTestRegistry(t) + defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolLoop", + handler: toolCallingModelHandler("myTool", map[string]any{"value": "x"}, "done"), + }) + defineFakeTool(t, r, "myTool", "A test tool") + + var iters int32 + tracker := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapGenerate: func(ctx context.Context, p *GenerateParams, next GenerateNext) (*ModelResponse, error) { + atomic.AddInt32(&iters, 1) + return next(ctx, p) + }, + }, nil + }) + + _, err := Generate(testCtx, r, + WithModelName("test/toolLoop"), + WithPrompt("use it"), + WithTools(ToolName("myTool")), + WithUse(tracker), + ) + assertNoError(t, err) + if got := atomic.LoadInt32(&iters); got < 2 { + t.Errorf("expected WrapGenerate to fire >=2 times, got %d", got) + } +} + +// --- WrapTool metadata survives round trip --- + +func TestWrapToolPreservesMetadata(t *testing.T) { + r := newTestRegistry(t) + defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolModel", + handler: toolCallingModelHandler("myTool", map[string]any{"value": "x"}, "done"), + }) + defineFakeTool(t, r, "myTool", "A test tool") + + var sawMetadata map[string]any + reader := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapTool: func(ctx context.Context, p *ToolParams, next ToolNext) (*MultipartToolResponse, error) { + resp, err := next(ctx, p) + if err != nil { + return nil, err + } + if resp.Metadata == nil { + resp.Metadata = map[string]any{} + } + resp.Metadata["traced"] = true + sawMetadata = resp.Metadata + return resp, nil + }, + }, nil + }) + + _, err := Generate(testCtx, r, + WithModelName("test/toolModel"), + WithPrompt("use it"), + WithTools(ToolName("myTool")), + WithUse(reader), + ) + assertNoError(t, err) + if sawMetadata == nil || sawMetadata["traced"] != true { + t.Errorf("metadata not threaded, got %v", sawMetadata) + } +} + +var testCtx = context.Background() diff --git a/go/ai/option.go b/go/ai/option.go index d28c68e3e9..5b2bcb905a 100644 --- a/go/ai/option.go +++ b/go/ai/option.go @@ -109,7 +109,8 @@ type commonGenOptions struct { ToolChoice ToolChoice // Whether tool calls are required, disabled, or optional. MaxTurns int // Maximum number of tool call iterations. ReturnToolRequests *bool // Whether to return tool requests instead of making the tool calls and continuing the generation. - Middleware []ModelMiddleware // Middleware to apply to the model request and model response. + Middleware []ModelMiddleware // Deprecated: Use WithUse instead. Middleware to apply to the model request and model response. + Use []Middleware // Middleware to apply to generation (Generate, Model, and Tool hooks). } type CommonGenOption interface { @@ -181,6 +182,13 @@ func (o *commonGenOptions) applyCommonGen(opts *commonGenOptions) error { opts.Middleware = o.Middleware } + if o.Use != nil { + if opts.Use != nil { + return errors.New("cannot set middleware more than once (WithUse)") + } + opts.Use = o.Use + } + return nil } @@ -233,10 +241,22 @@ func WithModelName(name string) CommonGenOption { } // WithMiddleware sets middleware to apply to the model request. +// +// Deprecated: Use [WithUse] instead, which supports Generate, Model, and Tool hooks. func WithMiddleware(middleware ...ModelMiddleware) CommonGenOption { return &commonGenOptions{Middleware: middleware} } +// WithUse sets middleware to apply to generation. Middleware hooks wrap +// the generate loop, model calls, and tool executions. +// +// Accepts either a middleware config struct (produced by a plugin) or an +// inline adapter via [MiddlewareFunc]. The chain applies outer-to-inner, so +// WithUse(A, B) expands to A { B { ... } }. +func WithUse(middleware ...Middleware) CommonGenOption { + return &commonGenOptions{Use: middleware} +} + // WithMaxTurns sets the maximum number of tool call iterations before erroring. // A tool call happens when tools are provided in the request and a model decides to call one or more as a response. // Each round trip, including multiple tools in parallel, counts as one turn. diff --git a/go/ai/prompt.go b/go/ai/prompt.go index 4d0151c4c8..de62751ea9 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -249,6 +249,14 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod } } + refs, err := configsToRefs(execOpts.Use) + if err != nil { + return nil, fmt.Errorf("Prompt.Execute: %w", err) + } + if len(refs) > 0 { + actionOpts.Use = refs + } + return GenerateWithRequest(ctx, r, actionOpts, execOpts.Middleware, execOpts.Stream) } diff --git a/go/ai/tools.go b/go/ai/tools.go index 453ad779a3..b88712e091 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -106,6 +106,13 @@ func IsToolInterruptError(err error) (bool, map[string]any) { return false, nil } +// NewToolInterruptError creates a tool interrupt error with the given metadata. +// This is intended for use in middleware that needs to interrupt tool execution +// without calling the tool itself. +func NewToolInterruptError(metadata map[string]any) error { + return &toolInterruptError{Metadata: metadata} +} + // InterruptOptions provides configuration for tool interruption. type InterruptOptions struct { Metadata map[string]any diff --git a/go/core/schemas.config b/go/core/schemas.config index 1beb6f139e..4a27007821 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -732,6 +732,10 @@ StepName is a custom step name for this generate call to display in trace views. Defaults to "generate". . +GenerateActionOptions.use doc +Use is middleware to apply to this generation, referenced by name with optional config. +. + GenerateActionOptionsResume doc GenerateActionResume holds options for resuming an interrupted generation. . @@ -840,6 +844,42 @@ PathMetadata.error doc Error contains error information if the path failed. . +# ---------------------------------------------------------------------------- +# Middleware Types +# ---------------------------------------------------------------------------- + +MiddlewareDesc doc +MiddlewareDesc is the registered descriptor for a middleware. +. + +MiddlewareDesc.name doc +Name is the middleware's unique identifier. +. + +MiddlewareDesc.description doc +Description explains what the middleware does. +. + +MiddlewareDesc.configSchema doc +ConfigSchema is a JSON Schema describing the middleware's configuration. +. + +MiddlewareDesc.metadata doc +Metadata contains additional context for the middleware. +. + +MiddlewareRef doc +MiddlewareRef is a serializable reference to a registered middleware with config. +. + +MiddlewareRef.name doc +Name is the name of the registered middleware. +. + +MiddlewareRef.config doc +Config contains the middleware configuration. +. + # ---------------------------------------------------------------------------- # Multipart Tool Response # ---------------------------------------------------------------------------- @@ -1061,6 +1101,7 @@ GenerateActionOptions.config type any GenerateActionOptions.output type *GenerateActionOutputConfig GenerateActionOptions.returnToolRequests type bool GenerateActionOptions.maxTurns type int +GenerateActionOptions.use type []*MiddlewareRef GenerateActionOptionsResume name GenerateActionResume # GenerateActionOutputConfig @@ -1104,6 +1145,12 @@ ModelResponseChunk.index type int ModelResponseChunk.role type Role ModelResponseChunk field formatHandler StreamingFormatHandler +# Middleware +MiddlewareDesc pkg ai +MiddlewareDesc.configSchema type map[string]any +MiddlewareDesc field buildFromJSON middlewareFactoryFunc +MiddlewareRef pkg ai + Score omit Embedding.embedding type []float32 diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 476908a287..c0055de7a7 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -221,6 +221,16 @@ func Init(ctx context.Context, opts ...GenkitOption) *Genkit { action.Register(r) } r.RegisterPlugin(plugin.Name(), plugin) + + if mp, ok := plugin.(ai.MiddlewarePlugin); ok { + descs, err := mp.Middlewares(ctx) + if err != nil { + panic(fmt.Errorf("genkit.Init: plugin %q Middlewares failed: %w", plugin.Name(), err)) + } + for _, d := range descs { + d.Register(r) + } + } } ai.ConfigureFormats(r) @@ -680,6 +690,68 @@ func LookupTool(g *Genkit, name string) ai.Tool { return ai.LookupTool(g.reg, name) } +// DefineMiddleware registers a middleware descriptor with the Genkit instance +// and returns the resulting [*ai.MiddlewareDesc]. Registered middleware is +// surfaced to the Dev UI and addressable by name for cross-runtime dispatch. +// +// This is the path for application code that declares its own middleware +// directly. Plugins should instead construct descriptors with [ai.NewMiddleware] +// (no registration) and return them from [ai.MiddlewarePlugin.Middlewares]; +// [Init] registers those descriptors during plugin setup. +// +// The `description` is a human-readable explanation shown in the Dev UI. The +// `prototype` is a value of a type that implements [ai.Middleware]. Its +// [ai.Middleware.Name] method supplies the registered name, and its fields +// (both exported JSON config and unexported plugin-level state) are captured +// by a value-copy inside the descriptor so JSON-dispatched invocations +// preserve prototype state across calls. +// +// For pure Go use, registration is not strictly required: passing a middleware +// config directly to [ai.WithUse] invokes its [ai.Middleware.New] method on +// the local fast path without a registry lookup. Registration is what makes +// the middleware visible to the Dev UI and callable from other runtimes. For +// ad-hoc one-off middleware that doesn't need Dev UI visibility, use +// [ai.MiddlewareFunc] instead of defining a type. +// +// Example: +// +// type Trace struct { +// Label string `json:"label,omitempty"` +// } +// +// func (Trace) Name() string { return "mine/trace" } +// +// func (t Trace) New(ctx context.Context) (*ai.Hooks, error) { +// return &ai.Hooks{ +// WrapModel: func(ctx context.Context, p *ai.ModelParams, next ai.ModelNext) (*ai.ModelResponse, error) { +// start := time.Now() +// resp, err := next(ctx, p) +// log.Printf("[%s] model call took %s", t.Label, time.Since(start)) +// return resp, err +// }, +// }, nil +// } +// +// // Register so it appears in the Dev UI and can be called by name: +// genkit.DefineMiddleware(g, "logs model call latency", Trace{}) +// +// // Use it per-call: +// resp, err := genkit.Generate(ctx, g, +// ai.WithPrompt("hello"), +// ai.WithUse(Trace{Label: "debug"}), +// ) +func DefineMiddleware[M ai.Middleware](g *Genkit, description string, prototype M) *ai.MiddlewareDesc { + return ai.DefineMiddleware(g.reg, description, prototype) +} + +// LookupMiddleware retrieves a registered middleware descriptor by its name. +// It returns the descriptor if found, or `nil` if no middleware with the +// given name is registered (e.g., via [DefineMiddleware] or through a +// plugin's [ai.MiddlewarePlugin.Middlewares] method). +func LookupMiddleware(g *Genkit, name string) *ai.MiddlewareDesc { + return ai.LookupMiddleware(g.reg, name) +} + // DefinePrompt defines a prompt programmatically, registers it as a [core.Action] // of type Prompt, and returns an executable [ai.Prompt]. // diff --git a/go/genkit/reflection.go b/go/genkit/reflection.go index 1bd675f75a..0f6fcca890 100644 --- a/go/genkit/reflection.go +++ b/go/genkit/reflection.go @@ -303,6 +303,7 @@ func serveMux(g *Genkit, s *reflectionServer) *http.ServeMux { mux.HandleFunc("POST /api/runAction", wrapReflectionHandler(handleRunAction(g, s.activeActions))) mux.HandleFunc("POST /api/notify", wrapReflectionHandler(handleNotify())) mux.HandleFunc("POST /api/cancelAction", wrapReflectionHandler(handleCancelAction(s.activeActions))) + mux.HandleFunc("GET /api/values", wrapReflectionHandler(handleListValues(g))) return mux } @@ -598,6 +599,26 @@ func handleListActions(g *Genkit) func(w http.ResponseWriter, r *http.Request) e } } +// handleListValues returns registered values filtered by type query parameter. +// Matches JS: GET /api/values?type=middleware +func handleListValues(g *Genkit) func(w http.ResponseWriter, r *http.Request) error { + return func(w http.ResponseWriter, r *http.Request) error { + valueType := r.URL.Query().Get("type") + if valueType == "" { + return core.NewError(core.INVALID_ARGUMENT, `query parameter "type" is required`) + } + prefix := "/" + valueType + "/" + result := map[string]any{} + for key, val := range g.reg.ListValues() { + if strings.HasPrefix(key, prefix) { + name := strings.TrimPrefix(key, prefix) + result[name] = val + } + } + return writeJSON(r.Context(), w, result) + } +} + // listActions lists all the registered actions. func listActions(g *Genkit) []api.ActionDesc { ads := []api.ActionDesc{} diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index 7b64dae4a4..16cafd092e 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -175,6 +175,24 @@ class GenkitError(GenkitModel): data: Data | None = None +class MiddlewareDesc(GenkitModel): + """Model for middlewaredesc data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + name: str = Field(...) + description: str | None = None + config_schema: Any | ConfigSchema | None = Field(default=None) + metadata: Metadata | None = None + + +class MiddlewareRef(GenkitModel): + """Model for middlewareref data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + name: str = Field(...) + config: Any | None = Field(default=None) + + class CandidateError(GenkitModel): """Model for candidateerror data.""" @@ -325,24 +343,6 @@ class MessageData(GenkitModel): metadata: Metadata | None = None -class MiddlewareDesc(GenkitModel): - """Model for middlewaredesc data.""" - - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - name: str = Field(...) - description: str | None = None - config_schema: Any | ConfigSchema | None = Field(default=None) - metadata: Any | Metadata | None = Field(default=None) - - -class MiddlewareRef(GenkitModel): - """Model for middlewareref data.""" - - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - name: str = Field(...) - config: Any | None = Field(default=None) - - class ModelInfo(GenkitModel): """Model for modelinfo data."""