From ded7a20f7da4e639804e5002359ce8e654e11fb3 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 07:10:36 -0800 Subject: [PATCH 01/22] added new `Middleware` --- genkit-tools/common/src/types/index.ts | 1 + genkit-tools/common/src/types/middleware.ts | 37 +++ genkit-tools/common/src/types/model.ts | 1 + go/ai/gen.go | 5 +- go/ai/generate.go | 125 ++++++++++- go/ai/middleware.go | 155 +++++++++++++ go/ai/middleware_test.go | 237 ++++++++++++++++++++ go/ai/option.go | 18 +- go/ai/prompt.go | 22 ++ go/genkit/genkit.go | 10 + go/genkit/reflection.go | 22 ++ 11 files changed, 627 insertions(+), 6 deletions(-) create mode 100644 genkit-tools/common/src/types/middleware.ts create mode 100644 go/ai/middleware.go create mode 100644 go/ai/middleware_test.go diff --git a/genkit-tools/common/src/types/index.ts b/genkit-tools/common/src/types/index.ts index acc8b6a11b..c8a040dd0b 100644 --- a/genkit-tools/common/src/types/index.ts +++ b/genkit-tools/common/src/types/index.ts @@ -23,6 +23,7 @@ export * from './document'; export * from './env'; export * from './eval'; export * from './evaluator'; +export * from './middleware'; export * from './model'; export * from './prompt'; export * from './reflection'; diff --git a/genkit-tools/common/src/types/middleware.ts b/genkit-tools/common/src/types/middleware.ts new file mode 100644 index 0000000000..7f41af991e --- /dev/null +++ b/genkit-tools/common/src/types/middleware.ts @@ -0,0 +1,37 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { z } from 'zod'; +import { JSONSchema7Schema } from './action'; + +/** Descriptor for a registered middleware, returned by reflection API. */ +export const MiddlewareDescSchema = z.object({ + /** Unique name of the middleware. */ + name: z.string(), + /** Human-readable description of what the middleware does. */ + description: z.string().optional(), + /** JSON Schema for the middleware's configuration. */ + configSchema: JSONSchema7Schema.optional(), +}); +export type MiddlewareDesc = z.infer; + +/** Reference to a registered middleware with optional configuration. */ +export const MiddlewareRefSchema = z.object({ + /** Name of the registered middleware. */ + name: z.string(), + /** Configuration for the middleware (schema defined by the middleware). */ + config: z.any().optional(), +}); +export type MiddlewareRef = z.infer; diff --git a/genkit-tools/common/src/types/model.ts b/genkit-tools/common/src/types/model.ts index fa4f07311a..072d446b6d 100644 --- a/genkit-tools/common/src/types/model.ts +++ b/genkit-tools/common/src/types/model.ts @@ -15,6 +15,7 @@ */ import { z } from 'zod'; import { DocumentDataSchema } from './document'; +import { MiddlewareRefSchema } from './middleware'; import { CustomPartSchema, DataPartSchema, diff --git a/go/ai/gen.go b/go/ai/gen.go index db6ced67b4..31135e5f03 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -111,8 +111,9 @@ type GenerateActionOptions struct { // the model to choose a tool, and none forces the model not to use any tools. Defaults to auto. ToolChoice ToolChoice `json:"toolChoice,omitempty"` // Tools is a list of registered tool names for this generation if supported. - Tools []string `json:"tools,omitempty"` - Use []*MiddlewareRef `json:"use,omitempty"` + Tools []string `json:"tools,omitempty"` + // Use is middleware to apply to this generation, referenced by name with optional config. + Use []*MiddlewareRef `json:"use,omitempty"` } // GenerateActionResume holds options for resuming an interrupted generation. diff --git a/go/ai/generate.go b/go/ai/generate.go index 7cc240d6ea..5ff23c22d4 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -67,6 +67,8 @@ type ModelFunc = core.StreamingFunc[*ModelRequest, *ModelResponse, *ModelRespons type ModelStreamCallback = func(context.Context, *ModelResponseChunk) error // ModelMiddleware is middleware for model generate requests that takes in a ModelFunc, does something, then returns another ModelFunc. +// +// Deprecated: Use [Middleware] interface with [WithUse] instead, which supports Generate, Model, and Tool hooks. type ModelMiddleware = core.Middleware[*ModelRequest, *ModelResponse, *ModelResponseChunk] // model is an action with functions specific to model generation such as Generate(). @@ -313,6 +315,27 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi Output: &outputCfg, } + // Resolve middleware from Use refs. + var middlewareHandlers []Middleware + if len(opts.Use) > 0 { + middlewareHandlers = make([]Middleware, 0, len(opts.Use)) + for _, ref := range opts.Use { + desc := LookupMiddleware(r, ref.Name) + if desc == nil { + return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: middleware %q not found", ref.Name) + } + configJSON, err := json.Marshal(ref.Config) + if err != nil { + return nil, core.NewError(core.INTERNAL, "ai.GenerateWithRequest: failed to marshal config for middleware %q: %v", ref.Name, err) + } + handler, err := desc.configFromJSON(configJSON) + if err != nil { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: failed to create middleware %q: %v", ref.Name, err) + } + middlewareHandlers = append(middlewareHandlers, handler) + } + } + var fn ModelFunc if bm != nil { if cb != nil { @@ -322,6 +345,24 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi } else { fn = m.Generate } + + // Apply Model hooks from new middleware as a ModelMiddleware, then chain with legacy mw. + if len(middlewareHandlers) > 0 { + modelHook := func(next ModelFunc) ModelFunc { + wrapped := next + for i := len(middlewareHandlers) - 1; i >= 0; i-- { + h := middlewareHandlers[i] + inner := wrapped + wrapped = func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return h.Model(ctx, &ModelState{Request: req, Callback: cb}, func(ctx context.Context, state *ModelState) (*ModelResponse, error) { + return inner(ctx, state.Request, state.Callback) + }) + } + } + return wrapped + } + mw = append([]ModelMiddleware{modelHook}, mw...) + } fn = core.ChainMiddleware(mw...)(fn) // Inline recursive helper function that captures variables from parent scope. @@ -390,7 +431,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi return nil, core.NewError(core.ABORTED, "exceeded maximum tool call iterations (%d)", maxTurns) } - newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex) + newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex, middlewareHandlers) if err != nil { return nil, err } @@ -408,6 +449,28 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi }) } + // Wrap generate with the Generate hook chain from middleware. + if len(middlewareHandlers) > 0 { + innerGenerate := generate + generate = func(ctx context.Context, req *ModelRequest, currentTurn int, messageIndex int) (*ModelResponse, error) { + innerFn := func(ctx context.Context, state *GenerateState) (*ModelResponse, error) { + return innerGenerate(ctx, state.Request, currentTurn, messageIndex) + } + for i := len(middlewareHandlers) - 1; i >= 0; i-- { + h := middlewareHandlers[i] + next := innerFn + innerFn = func(ctx context.Context, state *GenerateState) (*ModelResponse, error) { + return h.Generate(ctx, state, next) + } + } + return innerFn(ctx, &GenerateState{ + Options: opts, + Request: req, + Iteration: currentTurn, + }) + } + } + return generate(ctx, req, 0, 0) } @@ -539,6 +602,28 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod } } + // Register dynamic middleware (like dynamic tools) and build MiddlewareRefs. + if len(genOpts.Use) > 0 { + for _, mw := range genOpts.Use { + name := mw.Name() + if LookupMiddleware(r, name) == nil { + if !r.IsChild() { + r = r.NewChild() + } + NewMiddleware("", mw).Register(r) + } + configJSON, err := json.Marshal(mw) + if err != nil { + return nil, core.NewError(core.INTERNAL, "ai.Generate: failed to marshal middleware %q config: %v", name, err) + } + var config any + if err := json.Unmarshal(configJSON, &config); err != nil { + return nil, core.NewError(core.INTERNAL, "ai.Generate: failed to unmarshal middleware %q config: %v", name, err) + } + actionOpts.Use = append(actionOpts.Use, &MiddlewareRef{Name: name, Config: config}) + } + } + // Process resources in messages processedMessages, err := processResources(ctx, r, messages) if err != nil { @@ -777,7 +862,7 @@ func clone[T any](obj *T) *T { // handleToolRequests processes any tool requests in the response, returning // either a new request to continue the conversation or nil if no tool requests // need handling. -func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int) (*ModelRequest, *Message, error) { +func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int, middlewareHandlers []Middleware) (*ModelRequest, *Message, error) { toolCount := len(resp.ToolRequests()) if toolCount == 0 { return nil, nil, nil @@ -800,7 +885,7 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return } - multipartResp, err := tool.RunRawMultipart(ctx, toolReq.Input) + multipartResp, err := runToolWithMiddleware(ctx, tool, toolReq, middlewareHandlers) if err != nil { var tie *toolInterruptError if errors.As(err, &tie) { @@ -883,6 +968,39 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return newReq, nil, nil } +// runToolWithMiddleware runs a tool, wrapping the execution with Tool hooks from middleware. +func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, handlers []Middleware) (*MultipartToolResponse, error) { + if len(handlers) == 0 { + return tool.RunRawMultipart(ctx, toolReq.Input) + } + + inner := func(ctx context.Context, state *ToolState) (*ToolResponse, error) { + resp, err := state.Tool.RunRawMultipart(ctx, state.Request.Input) + if err != nil { + return nil, err + } + return &ToolResponse{ + Name: state.Request.Name, + Output: resp.Output, + }, nil + } + + for i := len(handlers) - 1; i >= 0; i-- { + h := handlers[i] + next := inner + inner = func(ctx context.Context, state *ToolState) (*ToolResponse, error) { + return h.Tool(ctx, state, next) + } + } + + toolResp, err := inner(ctx, &ToolState{Request: toolReq, Tool: tool}) + if err != nil { + return nil, err + } + + return &MultipartToolResponse{Output: toolResp.Output}, nil +} + // Text returns the contents of the first candidate in a // [ModelResponse] as a string. It returns an empty string if there // are no candidates or if the candidate has no message. @@ -1361,6 +1479,7 @@ func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateAc Docs: genOpts.Docs, ReturnToolRequests: genOpts.ReturnToolRequests, Output: genOpts.Output, + Use: genOpts.Use, }, toolMessage: toolMessage, }, nil diff --git a/go/ai/middleware.go b/go/ai/middleware.go new file mode 100644 index 0000000000..71d5b93d1a --- /dev/null +++ b/go/ai/middleware.go @@ -0,0 +1,155 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package ai + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" +) + +// Middleware provides hooks for different stages of generation. +type Middleware interface { + // Name returns the middleware's unique identifier. + Name() string + // New returns a fresh instance for each ai.Generate() call, enabling per-invocation state. + New() Middleware + // Generate wraps each iteration of the tool loop. + Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) + // Model wraps each model API call. + Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) + // Tool wraps each tool execution. + Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) +} + +// GenerateState holds state for the Generate hook. +type GenerateState struct { + // Options is the original options passed to [Generate]. + Options *GenerateActionOptions + // Request is the current model request for this iteration, with accumulated messages. + Request *ModelRequest + // Iteration is the current tool-loop iteration (0-indexed). + Iteration int +} + +// ModelState holds state for the Model hook. +type ModelState struct { + // Request is the model request about to be sent. + Request *ModelRequest + // Callback is the streaming callback, or nil if not streaming. + Callback ModelStreamCallback +} + +// ToolState holds state for the Tool hook. +type ToolState struct { + // Request is the tool request about to be executed. + Request *ToolRequest + // Tool is the resolved tool being called. + Tool Tool +} + +// GenerateNext is the next function in the Generate hook chain. +type GenerateNext = func(ctx context.Context, state *GenerateState) (*ModelResponse, error) + +// ModelNext is the next function in the Model hook chain. +type ModelNext = func(ctx context.Context, state *ModelState) (*ModelResponse, error) + +// ToolNext is the next function in the Tool hook chain. +type ToolNext = func(ctx context.Context, state *ToolState) (*ToolResponse, error) + +// BaseMiddleware provides default pass-through for the three hooks. +// Embed this so you only need to implement Name() and New(). +type BaseMiddleware struct{} + +func (b *BaseMiddleware) Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) { + return next(ctx, state) +} + +func (b *BaseMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { + return next(ctx, state) +} + +func (b *BaseMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) { + return next(ctx, state) +} + +// MiddlewareDesc is the registered descriptor for a middleware. +type MiddlewareDesc struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + ConfigSchema map[string]any `json:"configSchema,omitempty"` + configFromJSON func([]byte) (Middleware, error) +} + +// Register registers the descriptor with the registry. +func (d *MiddlewareDesc) Register(r api.Registry) { + r.RegisterValue("/middleware/"+d.Name, d) +} + +// NewMiddleware creates a middleware descriptor without registering it. +// The prototype carries stable state; configFromJSON calls prototype.New() +// then unmarshals user config on top. +func NewMiddleware[T Middleware](description string, prototype T) *MiddlewareDesc { + return &MiddlewareDesc{ + Name: prototype.Name(), + Description: description, + ConfigSchema: core.InferSchemaMap(*new(T)), + configFromJSON: func(configJSON []byte) (Middleware, error) { + inst := prototype.New() + if len(configJSON) > 0 { + if err := json.Unmarshal(configJSON, inst); err != nil { + return nil, fmt.Errorf("middleware %q: %w", prototype.Name(), err) + } + } + return inst, nil + }, + } +} + +// DefineMiddleware creates and registers a middleware descriptor. +func DefineMiddleware[T Middleware](r api.Registry, description string, prototype T) *MiddlewareDesc { + d := NewMiddleware(description, prototype) + d.Register(r) + return d +} + +// LookupMiddleware looks up a registered middleware descriptor by name. +func LookupMiddleware(r api.Registry, name string) *MiddlewareDesc { + v := r.LookupValue("/middleware/" + name) + if v == nil { + return nil + } + d, ok := v.(*MiddlewareDesc) + if !ok { + return nil + } + return d +} + +// MiddlewareRef is a serializable reference to a registered middleware with config. +type MiddlewareRef struct { + Name string `json:"name"` + Config any `json:"config,omitempty"` +} + +// MiddlewarePlugin is implemented by plugins that provide middleware. +type MiddlewarePlugin interface { + ListMiddleware(ctx context.Context) ([]*MiddlewareDesc, error) +} diff --git a/go/ai/middleware_test.go b/go/ai/middleware_test.go new file mode 100644 index 0000000000..0613e3e63e --- /dev/null +++ b/go/ai/middleware_test.go @@ -0,0 +1,237 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package ai + +import ( + "context" + "sync/atomic" + "testing" +) + +// testMiddleware is a simple middleware for testing that tracks hook invocations. +type testMiddleware struct { + BaseMiddleware + Label string `json:"label"` + generateCalls int + modelCalls int + toolCalls int32 // atomic since tool hooks run in parallel +} + +func (m *testMiddleware) Name() string { return "test" } + +func (m *testMiddleware) New() Middleware { + return &testMiddleware{Label: m.Label} +} + +func (m *testMiddleware) Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) { + m.generateCalls++ + return next(ctx, state) +} + +func (m *testMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { + m.modelCalls++ + return next(ctx, state) +} + +func (m *testMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) { + atomic.AddInt32(&m.toolCalls, 1) + return next(ctx, state) +} + +func TestNewMiddleware(t *testing.T) { + proto := &testMiddleware{Label: "original"} + desc := NewMiddleware("test middleware", proto) + + if desc.Name != "test" { + t.Errorf("got name %q, want %q", desc.Name, "test") + } + if desc.Description != "test middleware" { + t.Errorf("got description %q, want %q", desc.Description, "test middleware") + } +} + +func TestDefineAndLookupMiddleware(t *testing.T) { + r := newTestRegistry(t) + proto := &testMiddleware{Label: "original"} + DefineMiddleware(r, "test middleware", proto) + + found := LookupMiddleware(r, "test") + if found == nil { + t.Fatal("expected to find middleware, got nil") + } + if found.Name != "test" { + t.Errorf("got name %q, want %q", found.Name, "test") + } +} + +func TestLookupMiddlewareNotFound(t *testing.T) { + r := newTestRegistry(t) + found := LookupMiddleware(r, "nonexistent") + if found != nil { + t.Errorf("expected nil, got %v", found) + } +} + +func TestConfigFromJSON(t *testing.T) { + proto := &testMiddleware{Label: "stable"} + desc := NewMiddleware("test middleware", proto) + + handler, err := desc.configFromJSON([]byte(`{"label": "custom"}`)) + if err != nil { + t.Fatalf("configFromJSON failed: %v", err) + } + + tm, ok := handler.(*testMiddleware) + if !ok { + t.Fatalf("expected *testMiddleware, got %T", handler) + } + if tm.Label != "custom" { + t.Errorf("got label %q, want %q", tm.Label, "custom") + } + // Per-request state should be zeroed by New() + if tm.generateCalls != 0 { + t.Errorf("got generateCalls %d, want 0", tm.generateCalls) + } +} + +func TestConfigFromJSONPreservesStableState(t *testing.T) { + // Simulate a plugin middleware with unexported stable state + proto := &stableStateMiddleware{apiKey: "secret123"} + desc := NewMiddleware("middleware with stable state", proto) + + handler, err := desc.configFromJSON([]byte(`{"sampleRate": 0.5}`)) + if err != nil { + t.Fatalf("configFromJSON failed: %v", err) + } + + sm, ok := handler.(*stableStateMiddleware) + if !ok { + t.Fatalf("expected *stableStateMiddleware, got %T", handler) + } + if sm.apiKey != "secret123" { + t.Errorf("got apiKey %q, want %q", sm.apiKey, "secret123") + } + if sm.SampleRate != 0.5 { + t.Errorf("got SampleRate %f, want 0.5", sm.SampleRate) + } +} + +func TestMiddlewareModelHook(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + DefineMiddleware(r, "tracks calls", &testMiddleware{}) + + resp, err := Generate(ctx, r, + WithModel(m), + WithPrompt("hello"), + WithUse(&testMiddleware{}), + ) + assertNoError(t, err) + if resp == nil { + t.Fatal("expected response, got nil") + } +} + +func TestMiddlewareToolHook(t *testing.T) { + r := newTestRegistry(t) + defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolModel", + handler: toolCallingModelHandler("myTool", map[string]any{"value": "test"}, "done"), + }) + defineFakeTool(t, r, "myTool", "A test tool") + + mw := &testMiddleware{} + DefineMiddleware(r, "tracks calls", mw) + + _, err := Generate(ctx, r, + WithModelName("test/toolModel"), + WithPrompt("use the tool"), + WithTools(ToolName("myTool")), + WithUse(&testMiddleware{}), + ) + assertNoError(t, err) +} + +func TestMiddlewareOrdering(t *testing.T) { + // First middleware is outermost + var order []string + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + mwA := &orderMiddleware{label: "A", order: &order} + mwB := &orderMiddleware{label: "B", order: &order} + DefineMiddleware(r, "middleware A", mwA) + DefineMiddleware(r, "middleware B", mwB) + + _, err := Generate(ctx, r, + WithModel(m), + WithPrompt("hello"), + WithUse( + &orderMiddleware{label: "A", order: &order}, + &orderMiddleware{label: "B", order: &order}, + ), + ) + assertNoError(t, err) + + // Expect: A-before, B-before, B-after, A-after (first is outermost) + want := []string{"A-model-before", "B-model-before", "B-model-after", "A-model-after"} + if len(order) != len(want) { + t.Fatalf("got order %v, want %v", order, want) + } + for i := range want { + if order[i] != want[i] { + t.Errorf("order[%d] = %q, want %q", i, order[i], want[i]) + } + } +} + +// --- helper middleware types for tests --- + +// stableStateMiddleware has unexported stable state preserved by New(). +type stableStateMiddleware struct { + BaseMiddleware + SampleRate float64 `json:"sampleRate"` + apiKey string +} + +func (m *stableStateMiddleware) Name() string { return "stableState" } + +func (m *stableStateMiddleware) New() Middleware { + return &stableStateMiddleware{apiKey: m.apiKey} +} + +// orderMiddleware tracks the order of Model hook invocations. +type orderMiddleware struct { + BaseMiddleware + label string + order *[]string +} + +func (m *orderMiddleware) Name() string { return "order-" + m.label } + +func (m *orderMiddleware) New() Middleware { + return &orderMiddleware{label: m.label, order: m.order} +} + +func (m *orderMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { + *m.order = append(*m.order, m.label+"-model-before") + resp, err := next(ctx, state) + *m.order = append(*m.order, m.label+"-model-after") + return resp, err +} + +var ctx = context.Background() diff --git a/go/ai/option.go b/go/ai/option.go index d28c68e3e9..84019b11d7 100644 --- a/go/ai/option.go +++ b/go/ai/option.go @@ -109,7 +109,8 @@ type commonGenOptions struct { ToolChoice ToolChoice // Whether tool calls are required, disabled, or optional. MaxTurns int // Maximum number of tool call iterations. ReturnToolRequests *bool // Whether to return tool requests instead of making the tool calls and continuing the generation. - Middleware []ModelMiddleware // Middleware to apply to the model request and model response. + Middleware []ModelMiddleware // Deprecated: Use WithUse instead. Middleware to apply to the model request and model response. + Use []Middleware // Middleware to apply to generation (Generate, Model, and Tool hooks). } type CommonGenOption interface { @@ -181,6 +182,13 @@ func (o *commonGenOptions) applyCommonGen(opts *commonGenOptions) error { opts.Middleware = o.Middleware } + if o.Use != nil { + if opts.Use != nil { + return errors.New("cannot set middleware more than once (WithUse)") + } + opts.Use = o.Use + } + return nil } @@ -233,10 +241,18 @@ func WithModelName(name string) CommonGenOption { } // WithMiddleware sets middleware to apply to the model request. +// +// Deprecated: Use [WithUse] instead, which supports Generate, Model, and Tool hooks. func WithMiddleware(middleware ...ModelMiddleware) CommonGenOption { return &commonGenOptions{Middleware: middleware} } +// WithUse sets middleware to apply to generation. Middleware hooks wrap +// the generate loop, model calls, and tool executions. +func WithUse(middleware ...Middleware) CommonGenOption { + return &commonGenOptions{Use: middleware} +} + // WithMaxTurns sets the maximum number of tool call iterations before erroring. // A tool call happens when tools are provided in the request and a model decides to call one or more as a response. // Each round trip, including multiple tools in parallel, counts as one turn. diff --git a/go/ai/prompt.go b/go/ai/prompt.go index 4d0151c4c8..88c36e0cd7 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -249,6 +249,28 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod } } + // Register dynamic middleware and build MiddlewareRefs. + if len(execOpts.Use) > 0 { + for _, mw := range execOpts.Use { + name := mw.Name() + if LookupMiddleware(r, name) == nil { + if !r.IsChild() { + r = r.NewChild() + } + NewMiddleware("", mw).Register(r) + } + configJSON, err := json.Marshal(mw) + if err != nil { + return nil, fmt.Errorf("Prompt.Execute: failed to marshal middleware %q config: %w", name, err) + } + var config any + if err := json.Unmarshal(configJSON, &config); err != nil { + return nil, fmt.Errorf("Prompt.Execute: failed to unmarshal middleware %q config: %w", name, err) + } + actionOpts.Use = append(actionOpts.Use, &MiddlewareRef{Name: name, Config: config}) + } + } + return GenerateWithRequest(ctx, r, actionOpts, execOpts.Middleware, execOpts.Stream) } diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 476908a287..d2850f2278 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -221,6 +221,16 @@ func Init(ctx context.Context, opts ...GenkitOption) *Genkit { action.Register(r) } r.RegisterPlugin(plugin.Name(), plugin) + + if mp, ok := plugin.(ai.MiddlewarePlugin); ok { + descs, err := mp.ListMiddleware(ctx) + if err != nil { + panic(fmt.Errorf("genkit.Init: plugin %q ListMiddleware failed: %w", plugin.Name(), err)) + } + for _, d := range descs { + d.Register(r) + } + } } ai.ConfigureFormats(r) diff --git a/go/genkit/reflection.go b/go/genkit/reflection.go index 1bd675f75a..9936936e61 100644 --- a/go/genkit/reflection.go +++ b/go/genkit/reflection.go @@ -303,6 +303,7 @@ func serveMux(g *Genkit, s *reflectionServer) *http.ServeMux { mux.HandleFunc("POST /api/runAction", wrapReflectionHandler(handleRunAction(g, s.activeActions))) mux.HandleFunc("POST /api/notify", wrapReflectionHandler(handleNotify())) mux.HandleFunc("POST /api/cancelAction", wrapReflectionHandler(handleCancelAction(s.activeActions))) + mux.HandleFunc("GET /api/values", wrapReflectionHandler(handleListValues(g))) return mux } @@ -598,6 +599,27 @@ func handleListActions(g *Genkit) func(w http.ResponseWriter, r *http.Request) e } } +// handleListValues returns registered values filtered by type query parameter. +// Matches JS: GET /api/values?type=middleware +func handleListValues(g *Genkit) func(w http.ResponseWriter, r *http.Request) error { + return func(w http.ResponseWriter, r *http.Request) error { + valueType := r.URL.Query().Get("type") + if valueType == "" { + http.Error(w, `query parameter "type" is required`, http.StatusBadRequest) + return nil + } + prefix := "/" + valueType + "/" + result := map[string]any{} + for key, val := range g.reg.ListValues() { + if strings.HasPrefix(key, prefix) { + name := strings.TrimPrefix(key, prefix) + result[name] = val + } + } + return writeJSON(r.Context(), w, result) + } +} + // listActions lists all the registered actions. func listActions(g *Genkit) []api.ActionDesc { ads := []api.ActionDesc{} From 713d8b374ae98f3f075049b9f6cb503009ef71f4 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 07:49:54 -0800 Subject: [PATCH 02/22] updated Genkit schema --- genkit-tools/genkit-schema.json | 42 +++++++++++++++++++++++++ genkit-tools/scripts/schema-exporter.ts | 1 + 2 files changed, 43 insertions(+) diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index aecc932b1a..1426bba73f 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -270,6 +270,48 @@ ], "additionalProperties": false }, + "MiddlewareDesc": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "configSchema": { + "anyOf": [ + { + "type": "object", + "properties": {}, + "additionalProperties": false, + "description": "A JSON Schema Draft 7 (http://json-schema.org/draft-07/schema) object." + }, + { + "type": "null" + } + ], + "description": "A JSON Schema Draft 7 (http://json-schema.org/draft-07/schema) object." + } + }, + "required": [ + "name" + ], + "additionalProperties": false + }, + "MiddlewareRef": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "config": {} + }, + "required": [ + "name" + ], + "additionalProperties": false + }, "CandidateError": { "type": "object", "properties": { diff --git a/genkit-tools/scripts/schema-exporter.ts b/genkit-tools/scripts/schema-exporter.ts index 1d7bedf119..c923d6c21f 100644 --- a/genkit-tools/scripts/schema-exporter.ts +++ b/genkit-tools/scripts/schema-exporter.ts @@ -26,6 +26,7 @@ const EXPORTED_TYPE_MODULES = [ '../common/src/types/embedder.ts', '../common/src/types/evaluator.ts', '../common/src/types/error.ts', + '../common/src/types/middleware.ts', '../common/src/types/model.ts', '../common/src/types/parts.ts', '../common/src/types/reranker.ts', From f60ebdf330fe1dcc18962be5ac37b9e89a788a49 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 08:02:14 -0800 Subject: [PATCH 03/22] updated common schema --- genkit-tools/common/src/types/middleware.ts | 3 +- genkit-tools/genkit-schema.json | 7 +--- go/ai/gen.go | 19 ++++++--- go/ai/middleware.go | 17 ++------ go/core/schemas.config | 43 +++++++++++++++++++++ 5 files changed, 62 insertions(+), 27 deletions(-) diff --git a/genkit-tools/common/src/types/middleware.ts b/genkit-tools/common/src/types/middleware.ts index 7f41af991e..4bb1297ede 100644 --- a/genkit-tools/common/src/types/middleware.ts +++ b/genkit-tools/common/src/types/middleware.ts @@ -14,7 +14,6 @@ * limitations under the License. */ import { z } from 'zod'; -import { JSONSchema7Schema } from './action'; /** Descriptor for a registered middleware, returned by reflection API. */ export const MiddlewareDescSchema = z.object({ @@ -23,7 +22,7 @@ export const MiddlewareDescSchema = z.object({ /** Human-readable description of what the middleware does. */ description: z.string().optional(), /** JSON Schema for the middleware's configuration. */ - configSchema: JSONSchema7Schema.optional(), + configSchema: z.record(z.any()).nullish(), }); export type MiddlewareDesc = z.infer; diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 1426bba73f..3e305b253c 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -283,15 +283,12 @@ "anyOf": [ { "type": "object", - "properties": {}, - "additionalProperties": false, - "description": "A JSON Schema Draft 7 (http://json-schema.org/draft-07/schema) object." + "additionalProperties": {} }, { "type": "null" } - ], - "description": "A JSON Schema Draft 7 (http://json-schema.org/draft-07/schema) object." + ] } }, "required": [ diff --git a/go/ai/gen.go b/go/ai/gen.go index 31135e5f03..7051cb98b1 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -226,16 +226,23 @@ type Message struct { Role Role `json:"role,omitempty"` } +// MiddlewareDesc is the registered descriptor for a middleware. type MiddlewareDesc struct { - ConfigSchema any `json:"configSchema,omitempty"` - Description string `json:"description,omitempty"` - Metadata any `json:"metadata,omitempty"` - Name string `json:"name,omitempty"` + // ConfigSchema is a JSON Schema describing the middleware's configuration. + ConfigSchema map[string]any `json:"configSchema,omitempty"` + // Description explains what the middleware does. + Description string `json:"description,omitempty"` + // Name is the middleware's unique identifier. + Name string `json:"name,omitempty"` + configFromJSON middlewareConfigFunc } +// MiddlewareRef is a serializable reference to a registered middleware with config. type MiddlewareRef struct { - Config any `json:"config,omitempty"` - Name string `json:"name,omitempty"` + // Config contains the middleware configuration. + Config any `json:"config,omitempty"` + // Name is the name of the registered middleware. + Name string `json:"name,omitempty"` } // ModelInfo contains metadata about a model's capabilities and characteristics. diff --git a/go/ai/middleware.go b/go/ai/middleware.go index 71d5b93d1a..35b2faf37f 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -25,6 +25,9 @@ import ( "github.com/firebase/genkit/go/core/api" ) +// middlewareConfigFunc creates a Middleware instance from JSON config. +type middlewareConfigFunc = func([]byte) (Middleware, error) + // Middleware provides hooks for different stages of generation. type Middleware interface { // Name returns the middleware's unique identifier. @@ -90,14 +93,6 @@ func (b *BaseMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNe return next(ctx, state) } -// MiddlewareDesc is the registered descriptor for a middleware. -type MiddlewareDesc struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - ConfigSchema map[string]any `json:"configSchema,omitempty"` - configFromJSON func([]byte) (Middleware, error) -} - // Register registers the descriptor with the registry. func (d *MiddlewareDesc) Register(r api.Registry) { r.RegisterValue("/middleware/"+d.Name, d) @@ -143,12 +138,6 @@ func LookupMiddleware(r api.Registry, name string) *MiddlewareDesc { return d } -// MiddlewareRef is a serializable reference to a registered middleware with config. -type MiddlewareRef struct { - Name string `json:"name"` - Config any `json:"config,omitempty"` -} - // MiddlewarePlugin is implemented by plugins that provide middleware. type MiddlewarePlugin interface { ListMiddleware(ctx context.Context) ([]*MiddlewareDesc, error) diff --git a/go/core/schemas.config b/go/core/schemas.config index 1beb6f139e..e622eb6e20 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -732,6 +732,10 @@ StepName is a custom step name for this generate call to display in trace views. Defaults to "generate". . +GenerateActionOptions.use doc +Use is middleware to apply to this generation, referenced by name with optional config. +. + GenerateActionOptionsResume doc GenerateActionResume holds options for resuming an interrupted generation. . @@ -840,6 +844,38 @@ PathMetadata.error doc Error contains error information if the path failed. . +# ---------------------------------------------------------------------------- +# Middleware Types +# ---------------------------------------------------------------------------- + +MiddlewareDesc doc +MiddlewareDesc is the registered descriptor for a middleware. +. + +MiddlewareDesc.name doc +Name is the middleware's unique identifier. +. + +MiddlewareDesc.description doc +Description explains what the middleware does. +. + +MiddlewareDesc.configSchema doc +ConfigSchema is a JSON Schema describing the middleware's configuration. +. + +MiddlewareRef doc +MiddlewareRef is a serializable reference to a registered middleware with config. +. + +MiddlewareRef.name doc +Name is the name of the registered middleware. +. + +MiddlewareRef.config doc +Config contains the middleware configuration. +. + # ---------------------------------------------------------------------------- # Multipart Tool Response # ---------------------------------------------------------------------------- @@ -1061,6 +1097,7 @@ GenerateActionOptions.config type any GenerateActionOptions.output type *GenerateActionOutputConfig GenerateActionOptions.returnToolRequests type bool GenerateActionOptions.maxTurns type int +GenerateActionOptions.use type []*MiddlewareRef GenerateActionOptionsResume name GenerateActionResume # GenerateActionOutputConfig @@ -1104,6 +1141,12 @@ ModelResponseChunk.index type int ModelResponseChunk.role type Role ModelResponseChunk field formatHandler StreamingFormatHandler +# Middleware +MiddlewareDesc pkg ai +MiddlewareDesc.configSchema type map[string]any +MiddlewareDesc field configFromJSON middlewareConfigFunc +MiddlewareRef pkg ai + Score omit Embedding.embedding type []float32 From efc34d68486072ef6b6fb4a569b40825405b5cfa Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 08:03:47 -0800 Subject: [PATCH 04/22] Update generate.go --- go/ai/generate.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index 5ff23c22d4..e4663deead 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -980,8 +980,9 @@ func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, return nil, err } return &ToolResponse{ - Name: state.Request.Name, - Output: resp.Output, + Name: state.Request.Name, + Output: resp.Output, + Content: resp.Content, }, nil } @@ -998,7 +999,7 @@ func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, return nil, err } - return &MultipartToolResponse{Output: toolResp.Output}, nil + return &MultipartToolResponse{Output: toolResp.Output, Content: toolResp.Content}, nil } // Text returns the contents of the first candidate in a From a6076ddb3eb19fbd4a0e30a14c55618a0ec42050 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 08:04:21 -0800 Subject: [PATCH 05/22] Update middleware_test.go --- go/ai/middleware_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go/ai/middleware_test.go b/go/ai/middleware_test.go index 0613e3e63e..a0f9f935ea 100644 --- a/go/ai/middleware_test.go +++ b/go/ai/middleware_test.go @@ -25,7 +25,7 @@ import ( // testMiddleware is a simple middleware for testing that tracks hook invocations. type testMiddleware struct { BaseMiddleware - Label string `json:"label"` + Label string `json:"label"` generateCalls int modelCalls int toolCalls int32 // atomic since tool hooks run in parallel @@ -149,7 +149,7 @@ func TestMiddlewareModelHook(t *testing.T) { func TestMiddlewareToolHook(t *testing.T) { r := newTestRegistry(t) defineFakeModel(t, r, fakeModelConfig{ - name: "test/toolModel", + name: "test/toolModel", handler: toolCallingModelHandler("myTool", map[string]any{"value": "test"}, "done"), }) defineFakeTool(t, r, "myTool", "A test tool") From bfcf61e3d978aedb1fbd372d9809e871b463b1ec Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 08:26:52 -0800 Subject: [PATCH 06/22] fixes --- go/ai/generate.go | 6 +----- go/ai/prompt.go | 3 +-- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index e4663deead..be06c2b89b 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -315,7 +315,6 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi Output: &outputCfg, } - // Resolve middleware from Use refs. var middlewareHandlers []Middleware if len(opts.Use) > 0 { middlewareHandlers = make([]Middleware, 0, len(opts.Use)) @@ -346,7 +345,6 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi fn = m.Generate } - // Apply Model hooks from new middleware as a ModelMiddleware, then chain with legacy mw. if len(middlewareHandlers) > 0 { modelHook := func(next ModelFunc) ModelFunc { wrapped := next @@ -602,7 +600,6 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod } } - // Register dynamic middleware (like dynamic tools) and build MiddlewareRefs. if len(genOpts.Use) > 0 { for _, mw := range genOpts.Use { name := mw.Name() @@ -610,7 +607,7 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod if !r.IsChild() { r = r.NewChild() } - NewMiddleware("", mw).Register(r) + DefineMiddleware(r, "", mw) } configJSON, err := json.Marshal(mw) if err != nil { @@ -624,7 +621,6 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod } } - // Process resources in messages processedMessages, err := processResources(ctx, r, messages) if err != nil { return nil, core.NewError(core.INTERNAL, "ai.Generate: error processing resources: %v", err) diff --git a/go/ai/prompt.go b/go/ai/prompt.go index 88c36e0cd7..9e4dff9f14 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -249,7 +249,6 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod } } - // Register dynamic middleware and build MiddlewareRefs. if len(execOpts.Use) > 0 { for _, mw := range execOpts.Use { name := mw.Name() @@ -257,7 +256,7 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod if !r.IsChild() { r = r.NewChild() } - NewMiddleware("", mw).Register(r) + DefineMiddleware(r, "", mw) } configJSON, err := json.Marshal(mw) if err != nil { From d679f886f40ca0983fde04b91982938e4de747e3 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 9 Feb 2026 09:53:04 -0800 Subject: [PATCH 07/22] Update genkit-tools/common/src/types/middleware.ts Co-authored-by: Pavel Jbanov --- genkit-tools/common/src/types/middleware.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/genkit-tools/common/src/types/middleware.ts b/genkit-tools/common/src/types/middleware.ts index 4bb1297ede..6fd2dd9810 100644 --- a/genkit-tools/common/src/types/middleware.ts +++ b/genkit-tools/common/src/types/middleware.ts @@ -23,6 +23,8 @@ export const MiddlewareDescSchema = z.object({ description: z.string().optional(), /** JSON Schema for the middleware's configuration. */ configSchema: z.record(z.any()).nullish(), + /** User defined metadata for the middleware. */ + metadata: z.record(z.any()).optional(), }); export type MiddlewareDesc = z.infer; From b158b2f4d9552ae00a22faee6eb17feb18534fe4 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 9 Feb 2026 10:00:40 -0800 Subject: [PATCH 08/22] added new fields --- genkit-tools/genkit-schema.json | 4 ++++ go/ai/gen.go | 2 ++ go/core/schemas.config | 4 ++++ 3 files changed, 10 insertions(+) diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 3e305b253c..d4da8a3704 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -289,6 +289,10 @@ "type": "null" } ] + }, + "metadata": { + "type": "object", + "additionalProperties": {} } }, "required": [ diff --git a/go/ai/gen.go b/go/ai/gen.go index 7051cb98b1..e3b1fb884b 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -232,6 +232,8 @@ type MiddlewareDesc struct { ConfigSchema map[string]any `json:"configSchema,omitempty"` // Description explains what the middleware does. Description string `json:"description,omitempty"` + // Metadata contains additional context for the middleware. + Metadata map[string]any `json:"metadata,omitempty"` // Name is the middleware's unique identifier. Name string `json:"name,omitempty"` configFromJSON middlewareConfigFunc diff --git a/go/core/schemas.config b/go/core/schemas.config index e622eb6e20..7f0488ffad 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -864,6 +864,10 @@ MiddlewareDesc.configSchema doc ConfigSchema is a JSON Schema describing the middleware's configuration. . +MiddlewareDesc.metadata doc +Metadata contains additional context for the middleware. +. + MiddlewareRef doc MiddlewareRef is a serializable reference to a registered middleware with config. . From 92d3d94c0ae48d88ae334bae6ca20097a251c27e Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 10 Feb 2026 09:17:28 -0800 Subject: [PATCH 09/22] added tools to middleware interface --- go/ai/generate.go | 8 ++++++++ go/ai/middleware.go | 5 +++++ 2 files changed, 13 insertions(+) diff --git a/go/ai/generate.go b/go/ai/generate.go index be06c2b89b..24a47127f2 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -502,6 +502,14 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod return nil, err } + // Collect tools provided by middleware. + for _, mw := range genOpts.Use { + for _, t := range mw.Tools() { + dynamicTools = append(dynamicTools, t) + toolNames = append(toolNames, t.Name()) + } + } + if len(dynamicTools) > 0 { if !r.IsChild() { r = r.NewChild() diff --git a/go/ai/middleware.go b/go/ai/middleware.go index 35b2faf37f..d5bd63f792 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -40,6 +40,9 @@ type Middleware interface { Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) // Tool wraps each tool execution. Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) + // Tools returns additional tools to make available during generation. + // These tools are dynamically registered when the middleware is used via [WithUse]. + Tools() []Tool } // GenerateState holds state for the Generate hook. @@ -93,6 +96,8 @@ func (b *BaseMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNe return next(ctx, state) } +func (b *BaseMiddleware) Tools() []Tool { return nil } + // Register registers the descriptor with the registry. func (d *MiddlewareDesc) Register(r api.Registry) { r.RegisterValue("/middleware/"+d.Name, d) From ff1798cbd93fb6d3359f599ec5c19ad2df45ad3b Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 13 Feb 2026 14:31:21 -0800 Subject: [PATCH 10/22] renames --- go/ai/generate.go | 28 +++++++++++------------ go/ai/middleware.go | 48 ++++++++++++++++++++-------------------- go/ai/middleware_test.go | 18 +++++++-------- 3 files changed, 47 insertions(+), 47 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index 24a47127f2..ebe0aa68fb 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -352,8 +352,8 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi h := middlewareHandlers[i] inner := wrapped wrapped = func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { - return h.Model(ctx, &ModelState{Request: req, Callback: cb}, func(ctx context.Context, state *ModelState) (*ModelResponse, error) { - return inner(ctx, state.Request, state.Callback) + return h.WrapModel(ctx, &ModelParams{Request: req, Callback: cb}, func(ctx context.Context, params *ModelParams) (*ModelResponse, error) { + return inner(ctx, params.Request, params.Callback) }) } } @@ -451,17 +451,17 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi if len(middlewareHandlers) > 0 { innerGenerate := generate generate = func(ctx context.Context, req *ModelRequest, currentTurn int, messageIndex int) (*ModelResponse, error) { - innerFn := func(ctx context.Context, state *GenerateState) (*ModelResponse, error) { - return innerGenerate(ctx, state.Request, currentTurn, messageIndex) + innerFn := func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) { + return innerGenerate(ctx, params.Request, currentTurn, messageIndex) } for i := len(middlewareHandlers) - 1; i >= 0; i-- { h := middlewareHandlers[i] next := innerFn - innerFn = func(ctx context.Context, state *GenerateState) (*ModelResponse, error) { - return h.Generate(ctx, state, next) + innerFn = func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) { + return h.WrapGenerate(ctx, params, next) } } - return innerFn(ctx, &GenerateState{ + return innerFn(ctx, &GenerateParams{ Options: opts, Request: req, Iteration: currentTurn, @@ -972,19 +972,19 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return newReq, nil, nil } -// runToolWithMiddleware runs a tool, wrapping the execution with Tool hooks from middleware. +// runToolWithMiddleware runs a tool, wrapping the execution with WrapTool hooks from middleware. func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, handlers []Middleware) (*MultipartToolResponse, error) { if len(handlers) == 0 { return tool.RunRawMultipart(ctx, toolReq.Input) } - inner := func(ctx context.Context, state *ToolState) (*ToolResponse, error) { - resp, err := state.Tool.RunRawMultipart(ctx, state.Request.Input) + inner := func(ctx context.Context, params *ToolParams) (*ToolResponse, error) { + resp, err := params.Tool.RunRawMultipart(ctx, params.Request.Input) if err != nil { return nil, err } return &ToolResponse{ - Name: state.Request.Name, + Name: params.Request.Name, Output: resp.Output, Content: resp.Content, }, nil @@ -993,12 +993,12 @@ func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, for i := len(handlers) - 1; i >= 0; i-- { h := handlers[i] next := inner - inner = func(ctx context.Context, state *ToolState) (*ToolResponse, error) { - return h.Tool(ctx, state, next) + inner = func(ctx context.Context, params *ToolParams) (*ToolResponse, error) { + return h.WrapTool(ctx, params, next) } } - toolResp, err := inner(ctx, &ToolState{Request: toolReq, Tool: tool}) + toolResp, err := inner(ctx, &ToolParams{Request: toolReq, Tool: tool}) if err != nil { return nil, err } diff --git a/go/ai/middleware.go b/go/ai/middleware.go index d5bd63f792..aff3b063fc 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -34,19 +34,19 @@ type Middleware interface { Name() string // New returns a fresh instance for each ai.Generate() call, enabling per-invocation state. New() Middleware - // Generate wraps each iteration of the tool loop. - Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) - // Model wraps each model API call. - Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) - // Tool wraps each tool execution. - Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) + // WrapGenerate wraps each iteration of the tool loop. + WrapGenerate(ctx context.Context, params *GenerateParams, next GenerateNext) (*ModelResponse, error) + // WrapModel wraps each model API call. + WrapModel(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) + // WrapTool wraps each tool execution. + WrapTool(ctx context.Context, params *ToolParams, next ToolNext) (*ToolResponse, error) // Tools returns additional tools to make available during generation. // These tools are dynamically registered when the middleware is used via [WithUse]. Tools() []Tool } -// GenerateState holds state for the Generate hook. -type GenerateState struct { +// GenerateParams holds params for the WrapGenerate hook. +type GenerateParams struct { // Options is the original options passed to [Generate]. Options *GenerateActionOptions // Request is the current model request for this iteration, with accumulated messages. @@ -55,45 +55,45 @@ type GenerateState struct { Iteration int } -// ModelState holds state for the Model hook. -type ModelState struct { +// ModelParams holds params for the WrapModel hook. +type ModelParams struct { // Request is the model request about to be sent. Request *ModelRequest // Callback is the streaming callback, or nil if not streaming. Callback ModelStreamCallback } -// ToolState holds state for the Tool hook. -type ToolState struct { +// ToolParams holds params for the WrapTool hook. +type ToolParams struct { // Request is the tool request about to be executed. Request *ToolRequest // Tool is the resolved tool being called. Tool Tool } -// GenerateNext is the next function in the Generate hook chain. -type GenerateNext = func(ctx context.Context, state *GenerateState) (*ModelResponse, error) +// GenerateNext is the next function in the WrapGenerate hook chain. +type GenerateNext = func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) -// ModelNext is the next function in the Model hook chain. -type ModelNext = func(ctx context.Context, state *ModelState) (*ModelResponse, error) +// ModelNext is the next function in the WrapModel hook chain. +type ModelNext = func(ctx context.Context, params *ModelParams) (*ModelResponse, error) -// ToolNext is the next function in the Tool hook chain. -type ToolNext = func(ctx context.Context, state *ToolState) (*ToolResponse, error) +// ToolNext is the next function in the WrapTool hook chain. +type ToolNext = func(ctx context.Context, params *ToolParams) (*ToolResponse, error) // BaseMiddleware provides default pass-through for the three hooks. // Embed this so you only need to implement Name() and New(). type BaseMiddleware struct{} -func (b *BaseMiddleware) Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) { - return next(ctx, state) +func (b *BaseMiddleware) WrapGenerate(ctx context.Context, params *GenerateParams, next GenerateNext) (*ModelResponse, error) { + return next(ctx, params) } -func (b *BaseMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { - return next(ctx, state) +func (b *BaseMiddleware) WrapModel(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) { + return next(ctx, params) } -func (b *BaseMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) { - return next(ctx, state) +func (b *BaseMiddleware) WrapTool(ctx context.Context, params *ToolParams, next ToolNext) (*ToolResponse, error) { + return next(ctx, params) } func (b *BaseMiddleware) Tools() []Tool { return nil } diff --git a/go/ai/middleware_test.go b/go/ai/middleware_test.go index a0f9f935ea..4361d00b58 100644 --- a/go/ai/middleware_test.go +++ b/go/ai/middleware_test.go @@ -37,19 +37,19 @@ func (m *testMiddleware) New() Middleware { return &testMiddleware{Label: m.Label} } -func (m *testMiddleware) Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) { +func (m *testMiddleware) WrapGenerate(ctx context.Context, params *GenerateParams, next GenerateNext) (*ModelResponse, error) { m.generateCalls++ - return next(ctx, state) + return next(ctx, params) } -func (m *testMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { +func (m *testMiddleware) WrapModel(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) { m.modelCalls++ - return next(ctx, state) + return next(ctx, params) } -func (m *testMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) { +func (m *testMiddleware) WrapTool(ctx context.Context, params *ToolParams, next ToolNext) (*ToolResponse, error) { atomic.AddInt32(&m.toolCalls, 1) - return next(ctx, state) + return next(ctx, params) } func TestNewMiddleware(t *testing.T) { @@ -214,7 +214,7 @@ func (m *stableStateMiddleware) New() Middleware { return &stableStateMiddleware{apiKey: m.apiKey} } -// orderMiddleware tracks the order of Model hook invocations. +// orderMiddleware tracks the order of WrapModel hook invocations. type orderMiddleware struct { BaseMiddleware label string @@ -227,9 +227,9 @@ func (m *orderMiddleware) New() Middleware { return &orderMiddleware{label: m.label, order: m.order} } -func (m *orderMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { +func (m *orderMiddleware) WrapModel(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) { *m.order = append(*m.order, m.label+"-model-before") - resp, err := next(ctx, state) + resp, err := next(ctx, params) *m.order = append(*m.order, m.label+"-model-after") return resp, err } From 0c7562dced0bb0350144eba945800c01a19d7eec Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 23 Mar 2026 14:28:15 -0700 Subject: [PATCH 11/22] minor refactor --- go/ai/generate.go | 18 ++++-------------- go/ai/middleware.go | 22 ++++++++++++++++++++++ go/ai/prompt.go | 18 ++++-------------- go/genkit/reflection.go | 3 +-- 4 files changed, 31 insertions(+), 30 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index ebe0aa68fb..3818742dd8 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -610,22 +610,12 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod if len(genOpts.Use) > 0 { for _, mw := range genOpts.Use { - name := mw.Name() - if LookupMiddleware(r, name) == nil { - if !r.IsChild() { - r = r.NewChild() - } - DefineMiddleware(r, "", mw) - } - configJSON, err := json.Marshal(mw) + ref, newR, err := middlewareToRef(r, mw) if err != nil { - return nil, core.NewError(core.INTERNAL, "ai.Generate: failed to marshal middleware %q config: %v", name, err) - } - var config any - if err := json.Unmarshal(configJSON, &config); err != nil { - return nil, core.NewError(core.INTERNAL, "ai.Generate: failed to unmarshal middleware %q config: %v", name, err) + return nil, core.NewError(core.INTERNAL, "ai.Generate: %v", err) } - actionOpts.Use = append(actionOpts.Use, &MiddlewareRef{Name: name, Config: config}) + r = newR + actionOpts.Use = append(actionOpts.Use, ref) } } diff --git a/go/ai/middleware.go b/go/ai/middleware.go index aff3b063fc..d6016d8762 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -143,6 +143,28 @@ func LookupMiddleware(r api.Registry, name string) *MiddlewareDesc { return d } +// middlewareToRef registers a Middleware instance (if not already registered) and +// returns a MiddlewareRef for the action layer. If registration requires a child +// registry, the returned registry may differ from the input. +func middlewareToRef(r api.Registry, mw Middleware) (*MiddlewareRef, api.Registry, error) { + name := mw.Name() + if LookupMiddleware(r, name) == nil { + if !r.IsChild() { + r = r.NewChild() + } + DefineMiddleware(r, "", mw) + } + configJSON, err := json.Marshal(mw) + if err != nil { + return nil, r, fmt.Errorf("failed to marshal middleware %q config: %w", name, err) + } + var config any + if err := json.Unmarshal(configJSON, &config); err != nil { + return nil, r, fmt.Errorf("failed to unmarshal middleware %q config: %w", name, err) + } + return &MiddlewareRef{Name: name, Config: config}, r, nil +} + // MiddlewarePlugin is implemented by plugins that provide middleware. type MiddlewarePlugin interface { ListMiddleware(ctx context.Context) ([]*MiddlewareDesc, error) diff --git a/go/ai/prompt.go b/go/ai/prompt.go index 9e4dff9f14..1e42be4d2f 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -251,22 +251,12 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod if len(execOpts.Use) > 0 { for _, mw := range execOpts.Use { - name := mw.Name() - if LookupMiddleware(r, name) == nil { - if !r.IsChild() { - r = r.NewChild() - } - DefineMiddleware(r, "", mw) - } - configJSON, err := json.Marshal(mw) + ref, newR, err := middlewareToRef(r, mw) if err != nil { - return nil, fmt.Errorf("Prompt.Execute: failed to marshal middleware %q config: %w", name, err) - } - var config any - if err := json.Unmarshal(configJSON, &config); err != nil { - return nil, fmt.Errorf("Prompt.Execute: failed to unmarshal middleware %q config: %w", name, err) + return nil, fmt.Errorf("Prompt.Execute: %w", err) } - actionOpts.Use = append(actionOpts.Use, &MiddlewareRef{Name: name, Config: config}) + r = newR + actionOpts.Use = append(actionOpts.Use, ref) } } diff --git a/go/genkit/reflection.go b/go/genkit/reflection.go index 9936936e61..0f6fcca890 100644 --- a/go/genkit/reflection.go +++ b/go/genkit/reflection.go @@ -605,8 +605,7 @@ func handleListValues(g *Genkit) func(w http.ResponseWriter, r *http.Request) er return func(w http.ResponseWriter, r *http.Request) error { valueType := r.URL.Query().Get("type") if valueType == "" { - http.Error(w, `query parameter "type" is required`, http.StatusBadRequest) - return nil + return core.NewError(core.INVALID_ARGUMENT, `query parameter "type" is required`) } prefix := "/" + valueType + "/" result := map[string]any{} From 7e35d4e36c50dcedb97e9d97610262d2cb0679c6 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 24 Mar 2026 10:49:21 -0700 Subject: [PATCH 12/22] cleaned up duplicate types --- genkit-tools/common/src/types/model.ts | 26 -------------- genkit-tools/genkit-schema.json | 50 -------------------------- 2 files changed, 76 deletions(-) diff --git a/genkit-tools/common/src/types/model.ts b/genkit-tools/common/src/types/model.ts index 072d446b6d..6f5cb9b89f 100644 --- a/genkit-tools/common/src/types/model.ts +++ b/genkit-tools/common/src/types/model.ts @@ -65,32 +65,6 @@ export { // IMPORTANT: Keep this file in sync with genkit/ai/src/model-types.ts! // -/** Descriptor for a registered middleware, returned by reflection API. */ -export const MiddlewareDescSchema = z.object({ - /** Unique name of the middleware. */ - name: z.string(), - /** Human-readable description of what the middleware does. */ - description: z.string().optional(), - /** JSON Schema for the middleware's configuration. */ - configSchema: z.record(z.any()).nullish(), - /** User defined metadata for the middleware. */ - metadata: z.record(z.any()).nullish(), -}); -export type MiddlewareDesc = z.infer; - -/** - * Zod schema of middleware reference. - */ -export const MiddlewareRefSchema = z.object({ - name: z.string(), - config: z.any().optional(), -}); - -/** - * Middleware reference. - */ -export type MiddlewareRef = z.infer; - /** * Zod schema of an opration representing a model reference. */ diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index d4da8a3704..74db4c8fd1 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -799,56 +799,6 @@ ], "additionalProperties": false }, - "MiddlewareDesc": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "description": { - "type": "string" - }, - "configSchema": { - "anyOf": [ - { - "type": "object", - "additionalProperties": {} - }, - { - "type": "null" - } - ] - }, - "metadata": { - "anyOf": [ - { - "type": "object", - "additionalProperties": {} - }, - { - "type": "null" - } - ] - } - }, - "required": [ - "name" - ], - "additionalProperties": false - }, - "MiddlewareRef": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "config": {} - }, - "required": [ - "name" - ], - "additionalProperties": false - }, "ModelInfo": { "type": "object", "properties": { From a87149cce92d30085dc968b98566c905346c65e5 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 24 Mar 2026 11:19:55 -0700 Subject: [PATCH 13/22] Update _typing.py --- .../genkit/src/genkit/_core/_typing.py | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index 34f0736427..d19eea1913 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -175,6 +175,24 @@ class GenkitError(GenkitModel): data: Data | None = None +class MiddlewareDesc(GenkitModel): + """Model for middlewaredesc data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + name: str = Field(...) + description: str | None = None + config_schema: Any | ConfigSchema | None = Field(default=None) + metadata: Metadata | None = None + + +class MiddlewareRef(GenkitModel): + """Model for middlewareref data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + name: str = Field(...) + config: Any | None = Field(default=None) + + class CandidateError(GenkitModel): """Model for candidateerror data.""" @@ -326,24 +344,6 @@ class MessageData(GenkitModel): metadata: Metadata | None = None -class MiddlewareDesc(GenkitModel): - """Model for middlewaredesc data.""" - - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - name: str = Field(...) - description: str | None = None - config_schema: Any | ConfigSchema | None = Field(default=None) - metadata: Any | Metadata | None = Field(default=None) - - -class MiddlewareRef(GenkitModel): - """Model for middlewareref data.""" - - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - name: str = Field(...) - config: Any | None = Field(default=None) - - class ModelInfo(GenkitModel): """Model for modelinfo data.""" From 175c74726b6152bd5dca0aa1c64fb092c2f134ea Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 24 Mar 2026 13:24:01 -0700 Subject: [PATCH 14/22] fixes --- genkit-tools/common/src/types/middleware.ts | 2 +- genkit-tools/genkit-schema.json | 11 +++- go/ai/generate.go | 8 ++- go/ai/middleware.go | 19 +++++- go/ai/middleware_test.go | 63 ++++++++++++++++--- .../genkit/src/genkit/_core/_typing.py | 2 +- 6 files changed, 92 insertions(+), 13 deletions(-) diff --git a/genkit-tools/common/src/types/middleware.ts b/genkit-tools/common/src/types/middleware.ts index 6fd2dd9810..861227627e 100644 --- a/genkit-tools/common/src/types/middleware.ts +++ b/genkit-tools/common/src/types/middleware.ts @@ -24,7 +24,7 @@ export const MiddlewareDescSchema = z.object({ /** JSON Schema for the middleware's configuration. */ configSchema: z.record(z.any()).nullish(), /** User defined metadata for the middleware. */ - metadata: z.record(z.any()).optional(), + metadata: z.record(z.any()).nullish(), }); export type MiddlewareDesc = z.infer; diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 74db4c8fd1..e41280bfb5 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -291,8 +291,15 @@ ] }, "metadata": { - "type": "object", - "additionalProperties": {} + "anyOf": [ + { + "type": "object", + "additionalProperties": {} + }, + { + "type": "null" + } + ] } }, "required": [ diff --git a/go/ai/generate.go b/go/ai/generate.go index 3818742dd8..4a80583d16 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -968,13 +968,19 @@ func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, return tool.RunRawMultipart(ctx, toolReq.Input) } + // Capture metadata from the raw tool response so it isn't lost through + // the ToolResponse conversion (ToolResponse has no Metadata field). + var toolMetadata map[string]any + inner := func(ctx context.Context, params *ToolParams) (*ToolResponse, error) { resp, err := params.Tool.RunRawMultipart(ctx, params.Request.Input) if err != nil { return nil, err } + toolMetadata = resp.Metadata return &ToolResponse{ Name: params.Request.Name, + Ref: params.Request.Ref, Output: resp.Output, Content: resp.Content, }, nil @@ -993,7 +999,7 @@ func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, return nil, err } - return &MultipartToolResponse{Output: toolResp.Output, Content: toolResp.Content}, nil + return &MultipartToolResponse{Output: toolResp.Output, Content: toolResp.Content, Metadata: toolMetadata}, nil } // Text returns the contents of the first candidate in a diff --git a/go/ai/middleware.go b/go/ai/middleware.go index d6016d8762..cbeb5c5354 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -39,6 +39,8 @@ type Middleware interface { // WrapModel wraps each model API call. WrapModel(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) // WrapTool wraps each tool execution. + // WrapTool may be called concurrently when multiple tools execute in parallel. + // Implementations must be safe for concurrent use. WrapTool(ctx context.Context, params *ToolParams, next ToolNext) (*ToolResponse, error) // Tools returns additional tools to make available during generation. // These tools are dynamically registered when the middleware is used via [WithUse]. @@ -152,7 +154,22 @@ func middlewareToRef(r api.Registry, mw Middleware) (*MiddlewareRef, api.Registr if !r.IsChild() { r = r.NewChild() } - DefineMiddleware(r, "", mw) + // Register directly instead of via DefineMiddleware to avoid generic + // type inference losing the concrete type (mw is typed as Middleware + // interface here, so InferSchemaMap would receive a nil interface). + desc := &MiddlewareDesc{ + Name: name, + configFromJSON: func(configJSON []byte) (Middleware, error) { + inst := mw.New() + if len(configJSON) > 0 { + if err := json.Unmarshal(configJSON, inst); err != nil { + return nil, fmt.Errorf("middleware %q: %w", name, err) + } + } + return inst, nil + }, + } + desc.Register(r) } configJSON, err := json.Marshal(mw) if err != nil { diff --git a/go/ai/middleware_test.go b/go/ai/middleware_test.go index 4361d00b58..cfcb075617 100644 --- a/go/ai/middleware_test.go +++ b/go/ai/middleware_test.go @@ -135,15 +135,23 @@ func TestMiddlewareModelHook(t *testing.T) { m := defineFakeModel(t, r, fakeModelConfig{}) DefineMiddleware(r, "tracks calls", &testMiddleware{}) - resp, err := Generate(ctx, r, + var modelHookCalled bool + mw := &hookTrackingMiddleware{ + onModel: func() { modelHookCalled = true }, + } + + resp, err := Generate(testCtx, r, WithModel(m), WithPrompt("hello"), - WithUse(&testMiddleware{}), + WithUse(mw), ) assertNoError(t, err) if resp == nil { t.Fatal("expected response, got nil") } + if !modelHookCalled { + t.Error("expected model hook to be called") + } } func TestMiddlewareToolHook(t *testing.T) { @@ -154,16 +162,22 @@ func TestMiddlewareToolHook(t *testing.T) { }) defineFakeTool(t, r, "myTool", "A test tool") - mw := &testMiddleware{} + var toolHookCalled int32 + mw := &hookTrackingMiddleware{ + onTool: func() { atomic.AddInt32(&toolHookCalled, 1) }, + } DefineMiddleware(r, "tracks calls", mw) - _, err := Generate(ctx, r, + _, err := Generate(testCtx, r, WithModelName("test/toolModel"), WithPrompt("use the tool"), WithTools(ToolName("myTool")), - WithUse(&testMiddleware{}), + WithUse(mw), ) assertNoError(t, err) + if atomic.LoadInt32(&toolHookCalled) == 0 { + t.Error("expected tool hook to be called at least once") + } } func TestMiddlewareOrdering(t *testing.T) { @@ -177,7 +191,7 @@ func TestMiddlewareOrdering(t *testing.T) { DefineMiddleware(r, "middleware A", mwA) DefineMiddleware(r, "middleware B", mwB) - _, err := Generate(ctx, r, + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("hello"), WithUse( @@ -201,6 +215,41 @@ func TestMiddlewareOrdering(t *testing.T) { // --- helper middleware types for tests --- +// hookTrackingMiddleware uses callbacks to verify hooks are actually invoked. +type hookTrackingMiddleware struct { + BaseMiddleware + onGenerate func() + onModel func() + onTool func() +} + +func (m *hookTrackingMiddleware) Name() string { return "hookTracking" } + +func (m *hookTrackingMiddleware) New() Middleware { + return &hookTrackingMiddleware{onGenerate: m.onGenerate, onModel: m.onModel, onTool: m.onTool} +} + +func (m *hookTrackingMiddleware) WrapGenerate(ctx context.Context, params *GenerateParams, next GenerateNext) (*ModelResponse, error) { + if m.onGenerate != nil { + m.onGenerate() + } + return next(ctx, params) +} + +func (m *hookTrackingMiddleware) WrapModel(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) { + if m.onModel != nil { + m.onModel() + } + return next(ctx, params) +} + +func (m *hookTrackingMiddleware) WrapTool(ctx context.Context, params *ToolParams, next ToolNext) (*ToolResponse, error) { + if m.onTool != nil { + m.onTool() + } + return next(ctx, params) +} + // stableStateMiddleware has unexported stable state preserved by New(). type stableStateMiddleware struct { BaseMiddleware @@ -234,4 +283,4 @@ func (m *orderMiddleware) WrapModel(ctx context.Context, params *ModelParams, ne return resp, err } -var ctx = context.Background() +var testCtx = context.Background() diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index d19eea1913..56f6c8d535 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -182,7 +182,7 @@ class MiddlewareDesc(GenkitModel): name: str = Field(...) description: str | None = None config_schema: Any | ConfigSchema | None = Field(default=None) - metadata: Metadata | None = None + metadata: Any | Metadata | None = Field(default=None) class MiddlewareRef(GenkitModel): From 8a84f7585400bec82ee539fec3ad5123c73b97d3 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 24 Mar 2026 13:31:32 -0700 Subject: [PATCH 15/22] reverted nullish --- genkit-tools/common/src/types/middleware.ts | 2 +- genkit-tools/genkit-schema.json | 11 ++--------- py/packages/genkit/src/genkit/_core/_typing.py | 2 +- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/genkit-tools/common/src/types/middleware.ts b/genkit-tools/common/src/types/middleware.ts index 861227627e..6fd2dd9810 100644 --- a/genkit-tools/common/src/types/middleware.ts +++ b/genkit-tools/common/src/types/middleware.ts @@ -24,7 +24,7 @@ export const MiddlewareDescSchema = z.object({ /** JSON Schema for the middleware's configuration. */ configSchema: z.record(z.any()).nullish(), /** User defined metadata for the middleware. */ - metadata: z.record(z.any()).nullish(), + metadata: z.record(z.any()).optional(), }); export type MiddlewareDesc = z.infer; diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index e41280bfb5..74db4c8fd1 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -291,15 +291,8 @@ ] }, "metadata": { - "anyOf": [ - { - "type": "object", - "additionalProperties": {} - }, - { - "type": "null" - } - ] + "type": "object", + "additionalProperties": {} } }, "required": [ diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index 56f6c8d535..d19eea1913 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -182,7 +182,7 @@ class MiddlewareDesc(GenkitModel): name: str = Field(...) description: str | None = None config_schema: Any | ConfigSchema | None = Field(default=None) - metadata: Any | Metadata | None = Field(default=None) + metadata: Metadata | None = None class MiddlewareRef(GenkitModel): From 6e0f6497e2ba85c858ce59a1c6500d6031f62879 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 24 Mar 2026 17:09:33 -0700 Subject: [PATCH 16/22] added correct handling of generate > model > tool order even for restarted tools --- go/ai/gen.go | 4 +- go/ai/generate.go | 129 ++++++++++++++++++++++----------------- go/ai/middleware.go | 10 +-- go/ai/middleware_test.go | 10 +-- go/core/schemas.config | 2 +- 5 files changed, 87 insertions(+), 68 deletions(-) diff --git a/go/ai/gen.go b/go/ai/gen.go index e3b1fb884b..f8ee4f9a7d 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -235,8 +235,8 @@ type MiddlewareDesc struct { // Metadata contains additional context for the middleware. Metadata map[string]any `json:"metadata,omitempty"` // Name is the middleware's unique identifier. - Name string `json:"name,omitempty"` - configFromJSON middlewareConfigFunc + Name string `json:"name,omitempty"` + newFromJSON middlewareFactory } // MiddlewareRef is a serializable reference to a registered middleware with config. diff --git a/go/ai/generate.go b/go/ai/generate.go index 4a80583d16..8a09da3fde 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -203,7 +203,7 @@ func LookupModel(r api.Registry, name string) Model { } // GenerateWithRequest is the central generation implementation for ai.Generate(), prompt.Execute(), and the GenerateAction direct call. -func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActionOptions, mw []ModelMiddleware, cb ModelStreamCallback) (*ModelResponse, error) { +func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActionOptions, mmws []ModelMiddleware, cb ModelStreamCallback) (*ModelResponse, error) { if opts.Model == "" { if defaultModel, ok := r.LookupValue(api.DefaultModelKey).(string); ok && defaultModel != "" { opts.Model = defaultModel @@ -219,26 +219,23 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: model %q not found", opts.Model) } - resumeOutput, err := handleResumeOption(ctx, r, opts) - if err != nil { - return nil, err - } - - if resumeOutput.interruptedResponse != nil { - return nil, core.NewError(core.FAILED_PRECONDITION, - "One or more tools triggered an interrupt during a restarted execution.") - } - - opts = resumeOutput.revisedRequest - - if resumeOutput.toolMessage != nil && cb != nil { - err := cb(ctx, &ModelResponseChunk{ - Content: resumeOutput.toolMessage.Content, - Role: RoleTool, - Index: 0, - }) - if err != nil { - return nil, fmt.Errorf("streaming callback failed for resumed tool message: %w", err) + var mws []Middleware + if len(opts.Use) > 0 { + mws = make([]Middleware, 0, len(opts.Use)) + for _, ref := range opts.Use { + desc := LookupMiddleware(r, ref.Name) + if desc == nil { + return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: middleware %q not found", ref.Name) + } + configJSON, err := json.Marshal(ref.Config) + if err != nil { + return nil, core.NewError(core.INTERNAL, "ai.GenerateWithRequest: failed to marshal config for middleware %q: %v", ref.Name, err) + } + mw, err := desc.newFromJSON(configJSON) + if err != nil { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: failed to create middleware %q: %v", ref.Name, err) + } + mws = append(mws, mw) } } @@ -315,26 +312,6 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi Output: &outputCfg, } - var middlewareHandlers []Middleware - if len(opts.Use) > 0 { - middlewareHandlers = make([]Middleware, 0, len(opts.Use)) - for _, ref := range opts.Use { - desc := LookupMiddleware(r, ref.Name) - if desc == nil { - return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: middleware %q not found", ref.Name) - } - configJSON, err := json.Marshal(ref.Config) - if err != nil { - return nil, core.NewError(core.INTERNAL, "ai.GenerateWithRequest: failed to marshal config for middleware %q: %v", ref.Name, err) - } - handler, err := desc.configFromJSON(configJSON) - if err != nil { - return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: failed to create middleware %q: %v", ref.Name, err) - } - middlewareHandlers = append(middlewareHandlers, handler) - } - } - var fn ModelFunc if bm != nil { if cb != nil { @@ -345,11 +322,11 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi fn = m.Generate } - if len(middlewareHandlers) > 0 { + if len(mws) > 0 { modelHook := func(next ModelFunc) ModelFunc { wrapped := next - for i := len(middlewareHandlers) - 1; i >= 0; i-- { - h := middlewareHandlers[i] + for i := len(mws) - 1; i >= 0; i-- { + h := mws[i] inner := wrapped wrapped = func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { return h.WrapModel(ctx, &ModelParams{Request: req, Callback: cb}, func(ctx context.Context, params *ModelParams) (*ModelResponse, error) { @@ -359,14 +336,51 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi } return wrapped } - mw = append([]ModelMiddleware{modelHook}, mw...) + mmws = append([]ModelMiddleware{modelHook}, mmws...) } - fn = core.ChainMiddleware(mw...)(fn) + fn = core.ChainMiddleware(mmws...)(fn) // Inline recursive helper function that captures variables from parent scope. var generate func(context.Context, *ModelRequest, int, int) (*ModelResponse, error) generate = func(ctx context.Context, req *ModelRequest, currentTurn int, messageIndex int) (*ModelResponse, error) { + // Handle resume on first iteration so restarted tool execution is + // wrapped by WrapGenerate (lifecycle: generate > tool > generate > model > tool). + if currentTurn == 0 && opts.Resume != nil && (len(opts.Resume.Respond) > 0 || len(opts.Resume.Restart) > 0) { + resumeOutput, err := handleResumeOption(ctx, r, opts, mws) + if err != nil { + return nil, err + } + + if resumeOutput.interruptedResponse != nil { + return nil, core.NewError(core.FAILED_PRECONDITION, + "One or more tools triggered an interrupt during a restarted execution.") + } + + opts = resumeOutput.revisedRequest + + if resumeOutput.toolMessage != nil && cb != nil { + err := cb(ctx, &ModelResponseChunk{ + Content: resumeOutput.toolMessage.Content, + Role: RoleTool, + Index: messageIndex, + }) + if err != nil { + return nil, fmt.Errorf("streaming callback failed for resumed tool message: %w", err) + } + } + + resumeReq := &ModelRequest{ + Messages: opts.Messages, + Config: req.Config, + Docs: req.Docs, + ToolChoice: req.ToolChoice, + Tools: req.Tools, + Output: req.Output, + } + return generate(ctx, resumeReq, currentTurn+1, messageIndex+1) + } + spanMetadata := &tracing.SpanMetadata{ Name: "generate", Type: "util", @@ -429,7 +443,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi return nil, core.NewError(core.ABORTED, "exceeded maximum tool call iterations (%d)", maxTurns) } - newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex, middlewareHandlers) + newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex, mws) if err != nil { return nil, err } @@ -448,14 +462,14 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi } // Wrap generate with the Generate hook chain from middleware. - if len(middlewareHandlers) > 0 { + if len(mws) > 0 { innerGenerate := generate generate = func(ctx context.Context, req *ModelRequest, currentTurn int, messageIndex int) (*ModelResponse, error) { innerFn := func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) { return innerGenerate(ctx, params.Request, currentTurn, messageIndex) } - for i := len(middlewareHandlers) - 1; i >= 0; i-- { - h := middlewareHandlers[i] + for i := len(mws) - 1; i >= 0; i-- { + h := mws[i] next := innerFn innerFn = func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) { return h.WrapGenerate(ctx, params, next) @@ -1241,7 +1255,7 @@ func (m ModelRef) Config() any { // handleResumedToolRequest resolves a tool request from a previous, interrupted model turn, // when generation is being resumed. It determines the outcome of the tool request based on // pending output, or explicit 'respond' or 'restart' directives in the resume options. -func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions, p *Part) (*resumedToolRequestOutput, error) { +func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions, p *Part, middlewareHandlers []Middleware) (*resumedToolRequestOutput, error) { if p == nil || !p.IsToolRequest() { return nil, core.NewError(core.INVALID_ARGUMENT, "handleResumedToolRequest: part is not a tool request") } @@ -1329,7 +1343,12 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene resumedCtx = origInputCtxKey.NewContext(resumedCtx, originalInputVal) } - output, err := tool.RunRaw(resumedCtx, restartPart.ToolRequest.Input) + restartToolReq := &ToolRequest{ + Name: restartPart.ToolRequest.Name, + Ref: restartPart.ToolRequest.Ref, + Input: restartPart.ToolRequest.Input, + } + multipartResp, err := runToolWithMiddleware(resumedCtx, tool, restartToolReq, middlewareHandlers) if err != nil { var tie *toolInterruptError if errors.As(err, &tie) { @@ -1358,7 +1377,7 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene newToolResp := NewToolResponsePart(&ToolResponse{ Name: restartPart.ToolRequest.Name, Ref: restartPart.ToolRequest.Ref, - Output: output, + Output: multipartResp.Output, }) return &resumedToolRequestOutput{ @@ -1378,7 +1397,7 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene // handleResumeOption amends message history to handle `resume` arguments. // It returns the amended history. -func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions) (*resumeOptionOutput, error) { +func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions, middlewareHandlers []Middleware) (*resumeOptionOutput, error) { if genOpts.Resume == nil || (len(genOpts.Resume.Respond) == 0 && len(genOpts.Resume.Restart) == 0) { return &resumeOptionOutput{revisedRequest: genOpts}, nil } @@ -1414,7 +1433,7 @@ func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateAc toolReqCount++ go func(idx int, p *Part) { - output, err := handleResumedToolRequest(ctx, r, genOpts, p) + output, err := handleResumedToolRequest(ctx, r, genOpts, p, middlewareHandlers) resultChan <- result[*resumedToolRequestOutput]{ index: idx, value: output, diff --git a/go/ai/middleware.go b/go/ai/middleware.go index cbeb5c5354..a90b36077a 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -25,8 +25,8 @@ import ( "github.com/firebase/genkit/go/core/api" ) -// middlewareConfigFunc creates a Middleware instance from JSON config. -type middlewareConfigFunc = func([]byte) (Middleware, error) +// middlewareFactory creates a Middleware instance from JSON config. +type middlewareFactory = func(configJSON []byte) (Middleware, error) // Middleware provides hooks for different stages of generation. type Middleware interface { @@ -106,14 +106,14 @@ func (d *MiddlewareDesc) Register(r api.Registry) { } // NewMiddleware creates a middleware descriptor without registering it. -// The prototype carries stable state; configFromJSON calls prototype.New() +// The prototype carries stable state; newFromJSON calls prototype.New() // then unmarshals user config on top. func NewMiddleware[T Middleware](description string, prototype T) *MiddlewareDesc { return &MiddlewareDesc{ Name: prototype.Name(), Description: description, ConfigSchema: core.InferSchemaMap(*new(T)), - configFromJSON: func(configJSON []byte) (Middleware, error) { + newFromJSON: func(configJSON []byte) (Middleware, error) { inst := prototype.New() if len(configJSON) > 0 { if err := json.Unmarshal(configJSON, inst); err != nil { @@ -159,7 +159,7 @@ func middlewareToRef(r api.Registry, mw Middleware) (*MiddlewareRef, api.Registr // interface here, so InferSchemaMap would receive a nil interface). desc := &MiddlewareDesc{ Name: name, - configFromJSON: func(configJSON []byte) (Middleware, error) { + newFromJSON: func(configJSON []byte) (Middleware, error) { inst := mw.New() if len(configJSON) > 0 { if err := json.Unmarshal(configJSON, inst); err != nil { diff --git a/go/ai/middleware_test.go b/go/ai/middleware_test.go index cfcb075617..f08d8a9e7e 100644 --- a/go/ai/middleware_test.go +++ b/go/ai/middleware_test.go @@ -52,7 +52,7 @@ func (m *testMiddleware) WrapTool(ctx context.Context, params *ToolParams, next return next(ctx, params) } -func TestNewMiddleware(t *testing.T) { +func TestNewMiddlewareDesc(t *testing.T) { proto := &testMiddleware{Label: "original"} desc := NewMiddleware("test middleware", proto) @@ -90,9 +90,9 @@ func TestConfigFromJSON(t *testing.T) { proto := &testMiddleware{Label: "stable"} desc := NewMiddleware("test middleware", proto) - handler, err := desc.configFromJSON([]byte(`{"label": "custom"}`)) + handler, err := desc.newFromJSON([]byte(`{"label": "custom"}`)) if err != nil { - t.Fatalf("configFromJSON failed: %v", err) + t.Fatalf("newFromJSON failed: %v", err) } tm, ok := handler.(*testMiddleware) @@ -113,9 +113,9 @@ func TestConfigFromJSONPreservesStableState(t *testing.T) { proto := &stableStateMiddleware{apiKey: "secret123"} desc := NewMiddleware("middleware with stable state", proto) - handler, err := desc.configFromJSON([]byte(`{"sampleRate": 0.5}`)) + handler, err := desc.newFromJSON([]byte(`{"sampleRate": 0.5}`)) if err != nil { - t.Fatalf("configFromJSON failed: %v", err) + t.Fatalf("newFromJSON failed: %v", err) } sm, ok := handler.(*stableStateMiddleware) diff --git a/go/core/schemas.config b/go/core/schemas.config index 7f0488ffad..369db20dba 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1148,7 +1148,7 @@ ModelResponseChunk field formatHandler StreamingFormatHandler # Middleware MiddlewareDesc pkg ai MiddlewareDesc.configSchema type map[string]any -MiddlewareDesc field configFromJSON middlewareConfigFunc +MiddlewareDesc field newFromJSON middlewareFactory MiddlewareRef pkg ai Score omit From 521974d08bb4687ee3a7618b62c76c1e323197a8 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 31 Mar 2026 10:19:17 -0700 Subject: [PATCH 17/22] Update generate.go --- go/ai/generate.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index 8a09da3fde..e80393a9ea 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -870,7 +870,7 @@ func clone[T any](obj *T) *T { // handleToolRequests processes any tool requests in the response, returning // either a new request to continue the conversation or nil if no tool requests // need handling. -func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int, middlewareHandlers []Middleware) (*ModelRequest, *Message, error) { +func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int, mws []Middleware) (*ModelRequest, *Message, error) { toolCount := len(resp.ToolRequests()) if toolCount == 0 { return nil, nil, nil @@ -893,7 +893,7 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return } - multipartResp, err := runToolWithMiddleware(ctx, tool, toolReq, middlewareHandlers) + multipartResp, err := runToolWithMiddleware(ctx, tool, toolReq, mws) if err != nil { var tie *toolInterruptError if errors.As(err, &tie) { @@ -977,8 +977,8 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, } // runToolWithMiddleware runs a tool, wrapping the execution with WrapTool hooks from middleware. -func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, handlers []Middleware) (*MultipartToolResponse, error) { - if len(handlers) == 0 { +func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, mws []Middleware) (*MultipartToolResponse, error) { + if len(mws) == 0 { return tool.RunRawMultipart(ctx, toolReq.Input) } @@ -1000,8 +1000,8 @@ func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, }, nil } - for i := len(handlers) - 1; i >= 0; i-- { - h := handlers[i] + for i := len(mws) - 1; i >= 0; i-- { + h := mws[i] next := inner inner = func(ctx context.Context, params *ToolParams) (*ToolResponse, error) { return h.WrapTool(ctx, params, next) @@ -1255,7 +1255,7 @@ func (m ModelRef) Config() any { // handleResumedToolRequest resolves a tool request from a previous, interrupted model turn, // when generation is being resumed. It determines the outcome of the tool request based on // pending output, or explicit 'respond' or 'restart' directives in the resume options. -func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions, p *Part, middlewareHandlers []Middleware) (*resumedToolRequestOutput, error) { +func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions, p *Part, mws []Middleware) (*resumedToolRequestOutput, error) { if p == nil || !p.IsToolRequest() { return nil, core.NewError(core.INVALID_ARGUMENT, "handleResumedToolRequest: part is not a tool request") } @@ -1348,7 +1348,7 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene Ref: restartPart.ToolRequest.Ref, Input: restartPart.ToolRequest.Input, } - multipartResp, err := runToolWithMiddleware(resumedCtx, tool, restartToolReq, middlewareHandlers) + multipartResp, err := runToolWithMiddleware(resumedCtx, tool, restartToolReq, mws) if err != nil { var tie *toolInterruptError if errors.As(err, &tie) { From 064185e4ec2193fb5f249850fe97dec4fff4077b Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 15 Apr 2026 09:13:00 -0700 Subject: [PATCH 18/22] Update generate.go --- go/ai/generate.go | 75 ++++++++++++++++++++++++----------------------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index e80393a9ea..6f3f88467d 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -344,43 +344,6 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi var generate func(context.Context, *ModelRequest, int, int) (*ModelResponse, error) generate = func(ctx context.Context, req *ModelRequest, currentTurn int, messageIndex int) (*ModelResponse, error) { - // Handle resume on first iteration so restarted tool execution is - // wrapped by WrapGenerate (lifecycle: generate > tool > generate > model > tool). - if currentTurn == 0 && opts.Resume != nil && (len(opts.Resume.Respond) > 0 || len(opts.Resume.Restart) > 0) { - resumeOutput, err := handleResumeOption(ctx, r, opts, mws) - if err != nil { - return nil, err - } - - if resumeOutput.interruptedResponse != nil { - return nil, core.NewError(core.FAILED_PRECONDITION, - "One or more tools triggered an interrupt during a restarted execution.") - } - - opts = resumeOutput.revisedRequest - - if resumeOutput.toolMessage != nil && cb != nil { - err := cb(ctx, &ModelResponseChunk{ - Content: resumeOutput.toolMessage.Content, - Role: RoleTool, - Index: messageIndex, - }) - if err != nil { - return nil, fmt.Errorf("streaming callback failed for resumed tool message: %w", err) - } - } - - resumeReq := &ModelRequest{ - Messages: opts.Messages, - Config: req.Config, - Docs: req.Docs, - ToolChoice: req.ToolChoice, - Tools: req.Tools, - Output: req.Output, - } - return generate(ctx, resumeReq, currentTurn+1, messageIndex+1) - } - spanMetadata := &tracing.SpanMetadata{ Name: "generate", Type: "util", @@ -412,6 +375,44 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi } } + // Handle resume on the first iteration inside this span so that + // restarted tool execution is both wrapped by WrapGenerate (from + // the outer middleware chain) and recorded as a child of the + // current generate span (lifecycle: generate > tool > generate > + // model > tool). + if currentTurn == 0 && opts.Resume != nil && (len(opts.Resume.Respond) > 0 || len(opts.Resume.Restart) > 0) { + resumeOutput, err := handleResumeOption(ctx, r, opts, mws) + if err != nil { + return nil, err + } + + if resumeOutput.interruptedResponse != nil { + return nil, core.NewError(core.FAILED_PRECONDITION, + "One or more tools triggered an interrupt during a restarted execution.") + } + + opts = resumeOutput.revisedRequest + + if resumeOutput.toolMessage != nil && wrappedCb != nil { + if err := wrappedCb(ctx, &ModelResponseChunk{ + Content: resumeOutput.toolMessage.Content, + Role: RoleTool, + }); err != nil { + return nil, fmt.Errorf("streaming callback failed for resumed tool message: %w", err) + } + } + + resumeReq := &ModelRequest{ + Messages: opts.Messages, + Config: req.Config, + Docs: req.Docs, + ToolChoice: req.ToolChoice, + Tools: req.Tools, + Output: req.Output, + } + return generate(ctx, resumeReq, currentTurn+1, currentIndex) + } + resp, err := fn(ctx, req, wrappedCb) if err != nil { return nil, err From 590dfebbe72aff6e8b651e5d8e732025eb5e41bd Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 16 Apr 2026 18:02:39 -0700 Subject: [PATCH 19/22] refactor(go/ai): simplify middleware architecture Rework the Go middleware primitives introduced in PR #4464 to collapse configuration and behavior into a single "config struct is the middleware" model and remove the descriptor/factory/prototype scaffolding. - Drop the Middleware interface (Name/New/WrapGenerate/WrapModel/WrapTool/Tools) and the BaseMiddleware embedding helper. Introduce Hooks as a plain struct of optional hook func fields (WrapGenerate, WrapModel, WrapTool, Tools); nil hooks pass through. - Repurpose Middleware as an interface with just Name() + New(ctx), which a user-facing config struct implements directly. Passing a config value to WithUse runs its New on the local fast path with no registry lookup, so pure-Go code works without plugin registration. - NewMiddleware[M](description, prototype) captures the typed prototype in a closure stored on MiddlewareDesc.buildFromJSON, preserving unexported plugin-level state across JSON-dispatched calls via value-copy. - MiddlewareDesc returns to being the shared schemas.config-generated type with the private factory added via the existing `field` directive. - Rename MiddlewarePlugin.ListMiddleware to Middlewares to align with the upcoming V2 naming conventions. - Replace Inline with MiddlewareFunc, a canonical Go adapter type that satisfies Middleware for ad-hoc closure-based middleware. - Add genkit.DefineMiddleware and genkit.LookupMiddleware wrappers with complete godoc matching the DefineTool/LookupTool style. Fixes carried over from the initial review: - Preserve MultipartToolResponse.Content through the resume path in handleResumedToolRequest (previously dropped). - Change WrapTool return type to *MultipartToolResponse so metadata and content flow through without an out-of-band capture hack. - Reject duplicate middleware-contributed tool names explicitly in GenerateWithRequest instead of panicking at registry registration. - Build the WrapGenerate, WrapModel, and WrapTool hook chains once per GenerateWithRequest rather than rebuilding them on every tool-loop turn. - Export NewToolInterruptError so WrapTool hooks can interrupt tools without constructing a ToolContext. Tests rewritten against the new shape and expanded to cover: plugin-state value-copy, call-level state isolation, MiddlewareFunc adapter, nil hooks, stream chunk accumulation, tool contribution, duplicate-tool rejection, factory error propagation, WrapTool interrupts, per-iteration WrapGenerate, and metadata round-trip through WrapTool. All green under -race. --- go/ai/gen.go | 4 +- go/ai/generate.go | 270 +++++++++--------- go/ai/middleware.go | 261 +++++++++++------ go/ai/middleware_test.go | 602 +++++++++++++++++++++++++++++---------- go/ai/option.go | 8 +- go/ai/prompt.go | 15 +- go/ai/testutil_test.go | 34 +++ go/ai/tools.go | 7 + go/core/schemas.config | 2 +- go/genkit/genkit.go | 66 ++++- 10 files changed, 878 insertions(+), 391 deletions(-) diff --git a/go/ai/gen.go b/go/ai/gen.go index 2c7a72bf71..c9dac20735 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -235,8 +235,8 @@ type MiddlewareDesc struct { // Metadata contains additional context for the middleware. Metadata map[string]any `json:"metadata,omitempty"` // Name is the middleware's unique identifier. - Name string `json:"name,omitempty"` - newFromJSON middlewareFactory + Name string `json:"name,omitempty"` + buildFromJSON middlewareFactoryFunc } // MiddlewareRef is a serializable reference to a registered middleware with config. diff --git a/go/ai/generate.go b/go/ai/generate.go index 6f3f88467d..7fc74ddd29 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -219,26 +219,14 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: model %q not found", opts.Model) } - var mws []Middleware - if len(opts.Use) > 0 { - mws = make([]Middleware, 0, len(opts.Use)) - for _, ref := range opts.Use { - desc := LookupMiddleware(r, ref.Name) - if desc == nil { - return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: middleware %q not found", ref.Name) - } - configJSON, err := json.Marshal(ref.Config) - if err != nil { - return nil, core.NewError(core.INTERNAL, "ai.GenerateWithRequest: failed to marshal config for middleware %q: %v", ref.Name, err) - } - mw, err := desc.newFromJSON(configJSON) - if err != nil { - return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: failed to create middleware %q: %v", ref.Name, err) - } - mws = append(mws, mw) - } + mws, err := resolveRefs(ctx, r, opts.Use) + if err != nil { + return nil, err } + // Tools contributed by middleware bundles are registered on a child + // registry so this Generate() call sees them while outer callers do not. + // Duplicate names across multiple middleware are rejected explicitly. toolDefMap := make(map[string]*ToolDefinition) for _, t := range opts.Tools { if _, ok := toolDefMap[t]; ok { @@ -252,6 +240,27 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi toolDefMap[t] = tool.Definition() } + var middlewareTools []Tool + for _, mw := range mws { + if mw == nil { + continue + } + for _, t := range mw.Tools { + if _, ok := toolDefMap[t.Name()]; ok { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: tool %q is contributed by middleware but already declared elsewhere", t.Name()) + } + toolDefMap[t.Name()] = t.Definition() + middlewareTools = append(middlewareTools, t) + } + } + if len(middlewareTools) > 0 { + if !r.IsChild() { + r = r.NewChild() + } + for _, t := range middlewareTools { + t.Register(r) + } + } toolDefs := make([]*ToolDefinition, 0, len(toolDefMap)) for _, t := range toolDefMap { toolDefs = append(toolDefs, t) @@ -322,28 +331,28 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi fn = m.Generate } - if len(mws) > 0 { - modelHook := func(next ModelFunc) ModelFunc { - wrapped := next - for i := len(mws) - 1; i >= 0; i-- { - h := mws[i] - inner := wrapped - wrapped = func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { - return h.WrapModel(ctx, &ModelParams{Request: req, Callback: cb}, func(ctx context.Context, params *ModelParams) (*ModelResponse, error) { - return inner(ctx, params.Request, params.Callback) - }) - } - } - return wrapped - } - mmws = append([]ModelMiddleware{modelHook}, mmws...) - } + // Build the full hook chains once: wrapping the model function with + // WrapModel hooks from middleware, and wrapping the generate iteration + // with WrapGenerate hooks. These chains are reused across every tool-loop + // iteration rather than rebuilt each turn. + fn = buildModelChain(mws, fn) fn = core.ChainMiddleware(mmws...)(fn) + var streamingHandler StreamingFormatHandler + if sfh, ok := formatHandler.(StreamingFormatHandler); ok { + streamingHandler = sfh + } + + runTool := buildToolRunner(mws) + // Inline recursive helper function that captures variables from parent scope. var generate func(context.Context, *ModelRequest, int, int) (*ModelResponse, error) + var runGenerate func(context.Context, *GenerateParams) (*ModelResponse, error) - generate = func(ctx context.Context, req *ModelRequest, currentTurn int, messageIndex int) (*ModelResponse, error) { + runGenerate = func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) { + req := params.Request + currentTurn := params.Iteration + messageIndex := params.MessageIndex spanMetadata := &tracing.SpanMetadata{ Name: "generate", Type: "util", @@ -355,11 +364,6 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi currentRole := RoleModel currentIndex := messageIndex - var streamingHandler StreamingFormatHandler - if sfh, ok := formatHandler.(StreamingFormatHandler); ok { - streamingHandler = sfh - } - if cb != nil { wrappedCb = func(ctx context.Context, chunk *ModelResponseChunk) error { if chunk.Role != currentRole && chunk.Role != "" { @@ -381,7 +385,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi // current generate span (lifecycle: generate > tool > generate > // model > tool). if currentTurn == 0 && opts.Resume != nil && (len(opts.Resume.Respond) > 0 || len(opts.Resume.Restart) > 0) { - resumeOutput, err := handleResumeOption(ctx, r, opts, mws) + resumeOutput, err := handleResumeOption(ctx, r, opts, runTool) if err != nil { return nil, err } @@ -444,7 +448,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi return nil, core.NewError(core.ABORTED, "exceeded maximum tool call iterations (%d)", maxTurns) } - newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex, mws) + newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex, runTool) if err != nil { return nil, err } @@ -462,31 +466,87 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi }) } - // Wrap generate with the Generate hook chain from middleware. - if len(mws) > 0 { - innerGenerate := generate - generate = func(ctx context.Context, req *ModelRequest, currentTurn int, messageIndex int) (*ModelResponse, error) { - innerFn := func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) { - return innerGenerate(ctx, params.Request, currentTurn, messageIndex) - } - for i := len(mws) - 1; i >= 0; i-- { - h := mws[i] - next := innerFn - innerFn = func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) { - return h.WrapGenerate(ctx, params, next) - } - } - return innerFn(ctx, &GenerateParams{ - Options: opts, - Request: req, - Iteration: currentTurn, - }) - } + // Compose WrapGenerate hooks once; this chain is invoked for every + // tool-loop iteration. + hookedGenerate := buildGenerateChain(mws, runGenerate) + + generate = func(ctx context.Context, req *ModelRequest, currentTurn int, messageIndex int) (*ModelResponse, error) { + return hookedGenerate(ctx, &GenerateParams{ + Options: opts, + Request: req, + Iteration: currentTurn, + MessageIndex: messageIndex, + Callback: cb, + }) } return generate(ctx, req, 0, 0) } +// buildGenerateChain composes the WrapGenerate hooks from mws (outer-to-inner) +// around run. Middleware with a nil WrapGenerate hook is skipped. +func buildGenerateChain(mws []*Hooks, run func(ctx context.Context, params *GenerateParams) (*ModelResponse, error)) func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) { + chain := run + for i := len(mws) - 1; i >= 0; i-- { + mw := mws[i] + if mw == nil || mw.WrapGenerate == nil { + continue + } + hook := mw.WrapGenerate + next := chain + chain = func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) { + return hook(ctx, params, next) + } + } + return chain +} + +// buildModelChain composes the WrapModel hooks from mws (outer-to-inner) +// around fn. Middleware with a nil WrapModel hook is skipped. +func buildModelChain(mws []*Hooks, fn ModelFunc) ModelFunc { + chain := fn + for i := len(mws) - 1; i >= 0; i-- { + mw := mws[i] + if mw == nil || mw.WrapModel == nil { + continue + } + hook := mw.WrapModel + next := chain + chain = func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return hook(ctx, &ModelParams{Request: req, Callback: cb}, + func(ctx context.Context, params *ModelParams) (*ModelResponse, error) { + return next(ctx, params.Request, params.Callback) + }) + } + } + return chain +} + +// buildToolRunner composes the WrapTool hooks from mws (outer-to-inner) into +// a single function that executes a tool. The returned function is safe to +// invoke from concurrent goroutines; each invocation threads its own params +// through the shared hook chain. +func buildToolRunner(mws []*Hooks) func(ctx context.Context, tool Tool, req *ToolRequest) (*MultipartToolResponse, error) { + base := func(ctx context.Context, params *ToolParams) (*MultipartToolResponse, error) { + return params.Tool.RunRawMultipart(ctx, params.Request.Input) + } + chain := base + for i := len(mws) - 1; i >= 0; i-- { + mw := mws[i] + if mw == nil || mw.WrapTool == nil { + continue + } + hook := mw.WrapTool + next := chain + chain = func(ctx context.Context, params *ToolParams) (*MultipartToolResponse, error) { + return hook(ctx, params, next) + } + } + return func(ctx context.Context, tool Tool, req *ToolRequest) (*MultipartToolResponse, error) { + return chain(ctx, &ToolParams{Request: req, Tool: tool}) + } +} + // Generate generates a model response based on the provided options. func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*ModelResponse, error) { genOpts := &generateOptions{} @@ -517,14 +577,6 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod return nil, err } - // Collect tools provided by middleware. - for _, mw := range genOpts.Use { - for _, t := range mw.Tools() { - dynamicTools = append(dynamicTools, t) - toolNames = append(toolNames, t.Name()) - } - } - if len(dynamicTools) > 0 { if !r.IsChild() { r = r.NewChild() @@ -623,15 +675,12 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod } } - if len(genOpts.Use) > 0 { - for _, mw := range genOpts.Use { - ref, newR, err := middlewareToRef(r, mw) - if err != nil { - return nil, core.NewError(core.INTERNAL, "ai.Generate: %v", err) - } - r = newR - actionOpts.Use = append(actionOpts.Use, ref) - } + refs, err := configsToRefs(genOpts.Use) + if err != nil { + return nil, err + } + if len(refs) > 0 { + actionOpts.Use = refs } processedMessages, err := processResources(ctx, r, messages) @@ -868,10 +917,14 @@ func clone[T any](obj *T) *T { return &newObj } +// toolRunnerFunc runs a tool through the WrapTool hook chain and returns the +// raw [MultipartToolResponse]. Returned by [buildToolRunner]. +type toolRunnerFunc = func(ctx context.Context, tool Tool, req *ToolRequest) (*MultipartToolResponse, error) + // handleToolRequests processes any tool requests in the response, returning // either a new request to continue the conversation or nil if no tool requests // need handling. -func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int, mws []Middleware) (*ModelRequest, *Message, error) { +func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int, runTool toolRunnerFunc) (*ModelRequest, *Message, error) { toolCount := len(resp.ToolRequests()) if toolCount == 0 { return nil, nil, nil @@ -894,7 +947,7 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return } - multipartResp, err := runToolWithMiddleware(ctx, tool, toolReq, mws) + multipartResp, err := runTool(ctx, tool, toolReq) if err != nil { var tie *toolInterruptError if errors.As(err, &tie) { @@ -977,46 +1030,6 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return newReq, nil, nil } -// runToolWithMiddleware runs a tool, wrapping the execution with WrapTool hooks from middleware. -func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, mws []Middleware) (*MultipartToolResponse, error) { - if len(mws) == 0 { - return tool.RunRawMultipart(ctx, toolReq.Input) - } - - // Capture metadata from the raw tool response so it isn't lost through - // the ToolResponse conversion (ToolResponse has no Metadata field). - var toolMetadata map[string]any - - inner := func(ctx context.Context, params *ToolParams) (*ToolResponse, error) { - resp, err := params.Tool.RunRawMultipart(ctx, params.Request.Input) - if err != nil { - return nil, err - } - toolMetadata = resp.Metadata - return &ToolResponse{ - Name: params.Request.Name, - Ref: params.Request.Ref, - Output: resp.Output, - Content: resp.Content, - }, nil - } - - for i := len(mws) - 1; i >= 0; i-- { - h := mws[i] - next := inner - inner = func(ctx context.Context, params *ToolParams) (*ToolResponse, error) { - return h.WrapTool(ctx, params, next) - } - } - - toolResp, err := inner(ctx, &ToolParams{Request: toolReq, Tool: tool}) - if err != nil { - return nil, err - } - - return &MultipartToolResponse{Output: toolResp.Output, Content: toolResp.Content, Metadata: toolMetadata}, nil -} - // Text returns the contents of the first candidate in a // [ModelResponse] as a string. It returns an empty string if there // are no candidates or if the candidate has no message. @@ -1256,7 +1269,7 @@ func (m ModelRef) Config() any { // handleResumedToolRequest resolves a tool request from a previous, interrupted model turn, // when generation is being resumed. It determines the outcome of the tool request based on // pending output, or explicit 'respond' or 'restart' directives in the resume options. -func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions, p *Part, mws []Middleware) (*resumedToolRequestOutput, error) { +func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions, p *Part, runTool toolRunnerFunc) (*resumedToolRequestOutput, error) { if p == nil || !p.IsToolRequest() { return nil, core.NewError(core.INVALID_ARGUMENT, "handleResumedToolRequest: part is not a tool request") } @@ -1349,7 +1362,7 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene Ref: restartPart.ToolRequest.Ref, Input: restartPart.ToolRequest.Input, } - multipartResp, err := runToolWithMiddleware(resumedCtx, tool, restartToolReq, mws) + multipartResp, err := runTool(resumedCtx, tool, restartToolReq) if err != nil { var tie *toolInterruptError if errors.As(err, &tie) { @@ -1376,9 +1389,10 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene } newToolResp := NewToolResponsePart(&ToolResponse{ - Name: restartPart.ToolRequest.Name, - Ref: restartPart.ToolRequest.Ref, - Output: multipartResp.Output, + Name: restartPart.ToolRequest.Name, + Ref: restartPart.ToolRequest.Ref, + Output: multipartResp.Output, + Content: multipartResp.Content, }) return &resumedToolRequestOutput{ @@ -1398,7 +1412,7 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene // handleResumeOption amends message history to handle `resume` arguments. // It returns the amended history. -func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions, middlewareHandlers []Middleware) (*resumeOptionOutput, error) { +func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions, runTool toolRunnerFunc) (*resumeOptionOutput, error) { if genOpts.Resume == nil || (len(genOpts.Resume.Respond) == 0 && len(genOpts.Resume.Restart) == 0) { return &resumeOptionOutput{revisedRequest: genOpts}, nil } @@ -1434,7 +1448,7 @@ func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateAc toolReqCount++ go func(idx int, p *Part) { - output, err := handleResumedToolRequest(ctx, r, genOpts, p, middlewareHandlers) + output, err := handleResumedToolRequest(ctx, r, genOpts, p, runTool) resultChan <- result[*resumedToolRequestOutput]{ index: idx, value: output, diff --git a/go/ai/middleware.go b/go/ai/middleware.go index a90b36077a..6ebeea885d 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -25,26 +25,25 @@ import ( "github.com/firebase/genkit/go/core/api" ) -// middlewareFactory creates a Middleware instance from JSON config. -type middlewareFactory = func(configJSON []byte) (Middleware, error) - -// Middleware provides hooks for different stages of generation. -type Middleware interface { - // Name returns the middleware's unique identifier. - Name() string - // New returns a fresh instance for each ai.Generate() call, enabling per-invocation state. - New() Middleware - // WrapGenerate wraps each iteration of the tool loop. - WrapGenerate(ctx context.Context, params *GenerateParams, next GenerateNext) (*ModelResponse, error) - // WrapModel wraps each model API call. - WrapModel(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) - // WrapTool wraps each tool execution. - // WrapTool may be called concurrently when multiple tools execute in parallel. - // Implementations must be safe for concurrent use. - WrapTool(ctx context.Context, params *ToolParams, next ToolNext) (*ToolResponse, error) - // Tools returns additional tools to make available during generation. - // These tools are dynamically registered when the middleware is used via [WithUse]. - Tools() []Tool +// Hooks is the per-call bundle of hook functions produced by a [Middleware]'s +// New method. Each field is optional; a nil hook is treated as a pass-through. +type Hooks struct { + // Tools are additional tools to register during the generation this + // middleware is attached to. They are available to the model alongside + // any user-supplied tools. + Tools []Tool + // WrapGenerate wraps each iteration of the tool loop. It sees the + // accumulated request, the iteration index, and the streaming callback. + // A single Generate() with N tool-call turns invokes this hook N+1 times. + WrapGenerate func(ctx context.Context, params *GenerateParams, next GenerateNext) (*ModelResponse, error) + // WrapModel wraps each model API call. Retry, fallback, and caching + // middleware typically hook here. + WrapModel func(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) + // WrapTool wraps each tool execution. It may be called concurrently when + // multiple tools execute in parallel for the same Generate() call; any + // state closed over from the enclosing scope that this hook mutates must + // be guarded with sync primitives. + WrapTool func(ctx context.Context, params *ToolParams, next ToolNext) (*MultipartToolResponse, error) } // GenerateParams holds params for the WrapGenerate hook. @@ -55,6 +54,15 @@ type GenerateParams struct { Request *ModelRequest // Iteration is the current tool-loop iteration (0-indexed). Iteration int + // MessageIndex is the index of the next message in the streamed response sequence. + // Middleware that streams extra messages (e.g. injected user content) should emit + // chunks at this index and advance it so downstream middleware and the model + // receive the shifted value. + MessageIndex int + // Callback is the streaming callback supplied to [Generate], or nil if not streaming. + // Middleware may invoke it to emit chunks, setting [ModelResponseChunk.Role] and + // [ModelResponseChunk.Index] explicitly. + Callback ModelStreamCallback } // ModelParams holds params for the WrapModel hook. @@ -80,109 +88,180 @@ type GenerateNext = func(ctx context.Context, params *GenerateParams) (*ModelRes type ModelNext = func(ctx context.Context, params *ModelParams) (*ModelResponse, error) // ToolNext is the next function in the WrapTool hook chain. -type ToolNext = func(ctx context.Context, params *ToolParams) (*ToolResponse, error) - -// BaseMiddleware provides default pass-through for the three hooks. -// Embed this so you only need to implement Name() and New(). -type BaseMiddleware struct{} +type ToolNext = func(ctx context.Context, params *ToolParams) (*MultipartToolResponse, error) -func (b *BaseMiddleware) WrapGenerate(ctx context.Context, params *GenerateParams, next GenerateNext) (*ModelResponse, error) { - return next(ctx, params) +// Middleware is the contract every value passed to [WithUse] satisfies. The +// config struct both identifies the middleware (via [Name]) and produces a +// per-call [Hooks] bundle (via [New]). +// +// Plugin-level state belongs on unexported fields of the config type. A +// plugin's [MiddlewarePlugin.Middlewares] sets those fields on a prototype +// that is preserved across JSON dispatch by value-copy inside the descriptor. +type Middleware interface { + // Name returns the registered middleware's unique identifier. Must be a + // stable constant, since it is read from a zero value of the config type + // during descriptor creation. + Name() string + // New produces a fresh [Hooks] bundle for one Generate() call. It is + // invoked per-Generate, so any state the bundle's hooks need to share + // (counters, caches) may be allocated in this method and closed over by + // the returned hooks. + New(ctx context.Context) (*Hooks, error) } -func (b *BaseMiddleware) WrapModel(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) { - return next(ctx, params) -} +// middlewareFactoryFunc is the closure stored on [MiddlewareDesc] that +// materializes a [Hooks] bundle from JSON config. It is produced by +// [NewMiddleware] and captures the prototype so value-copy preserves any +// unexported plugin-level state across JSON-dispatched calls. +type middlewareFactoryFunc = func(ctx context.Context, configJSON []byte) (*Hooks, error) -func (b *BaseMiddleware) WrapTool(ctx context.Context, params *ToolParams, next ToolNext) (*ToolResponse, error) { - return next(ctx, params) -} +// middlewareRegistryPrefix is the registry-key prefix under which middleware +// descriptors are stored. The reflection API lists values under this prefix. +const middlewareRegistryPrefix = "/middleware/" -func (b *BaseMiddleware) Tools() []Tool { return nil } +// middlewareRegistryKey returns the registry key for a middleware with the +// given name. +func middlewareRegistryKey(name string) string { + return middlewareRegistryPrefix + name +} -// Register registers the descriptor with the registry. +// Register records this descriptor in the registry under its name so it can +// be resolved during JSON dispatch and surfaced to the Dev UI. func (d *MiddlewareDesc) Register(r api.Registry) { - r.RegisterValue("/middleware/"+d.Name, d) + r.RegisterValue(middlewareRegistryKey(d.Name), d) } -// NewMiddleware creates a middleware descriptor without registering it. -// The prototype carries stable state; newFromJSON calls prototype.New() -// then unmarshals user config on top. -func NewMiddleware[T Middleware](description string, prototype T) *MiddlewareDesc { +// NewMiddleware constructs a descriptor without registering it. Useful for +// [MiddlewarePlugin.Middlewares] implementations that defer registration +// to [genkit.Init]. The prototype argument supplies both the registered name +// (via its [Middleware.Name] method) and any plugin-level state that should +// flow into JSON-dispatched invocations via unexported fields preserved by +// value-copy. +func NewMiddleware[M Middleware](description string, prototype M) *MiddlewareDesc { + name := prototype.Name() return &MiddlewareDesc{ - Name: prototype.Name(), + Name: name, Description: description, - ConfigSchema: core.InferSchemaMap(*new(T)), - newFromJSON: func(configJSON []byte) (Middleware, error) { - inst := prototype.New() + ConfigSchema: core.InferSchemaMap(prototype), + buildFromJSON: func(ctx context.Context, configJSON []byte) (*Hooks, error) { + cfg := prototype // value copy preserves unexported fields, shares pointers if len(configJSON) > 0 { - if err := json.Unmarshal(configJSON, inst); err != nil { - return nil, fmt.Errorf("middleware %q: %w", prototype.Name(), err) + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return nil, fmt.Errorf("middleware %q: %w", name, err) } } - return inst, nil + return cfg.New(ctx) }, } } -// DefineMiddleware creates and registers a middleware descriptor. -func DefineMiddleware[T Middleware](r api.Registry, description string, prototype T) *MiddlewareDesc { +// DefineMiddleware creates and registers a middleware descriptor in one step. +func DefineMiddleware[M Middleware](r api.Registry, description string, prototype M) *MiddlewareDesc { d := NewMiddleware(description, prototype) d.Register(r) return d } -// LookupMiddleware looks up a registered middleware descriptor by name. +// MiddlewareFunc adapts a per-call factory closure to the [Middleware] +// interface for ad-hoc inline use, without a registered descriptor or plugin +// wiring. The adapted middleware does not appear in the Dev UI. +// +// Example: +// +// ai.WithUse(ai.MiddlewareFunc(func(ctx context.Context) (*ai.Hooks, error) { +// return &ai.Hooks{WrapModel: ...}, nil +// })) +type MiddlewareFunc func(ctx context.Context) (*Hooks, error) + +// Name returns the placeholder name shared by all [MiddlewareFunc] values. +// Uniqueness is unnecessary: inline middleware is resolved via the fast path +// in [resolveRefs] and never goes through a name-keyed registry lookup. +func (MiddlewareFunc) Name() string { return "inline" } + +// New invokes the adapted factory to produce a fresh [Hooks] bundle. +func (f MiddlewareFunc) New(ctx context.Context) (*Hooks, error) { return f(ctx) } + +// LookupMiddleware returns the registered middleware descriptor with the +// given name, or nil if no such descriptor exists in the registry or any +// ancestor. Primarily useful for inspection and for the reflection API; +// callers dispatching middleware should do so through [WithUse]. func LookupMiddleware(r api.Registry, name string) *MiddlewareDesc { - v := r.LookupValue("/middleware/" + name) + v := r.LookupValue(middlewareRegistryKey(name)) if v == nil { return nil } - d, ok := v.(*MiddlewareDesc) - if !ok { - return nil - } + d, _ := v.(*MiddlewareDesc) return d } -// middlewareToRef registers a Middleware instance (if not already registered) and -// returns a MiddlewareRef for the action layer. If registration requires a child -// registry, the returned registry may differ from the input. -func middlewareToRef(r api.Registry, mw Middleware) (*MiddlewareRef, api.Registry, error) { - name := mw.Name() - if LookupMiddleware(r, name) == nil { - if !r.IsChild() { - r = r.NewChild() - } - // Register directly instead of via DefineMiddleware to avoid generic - // type inference losing the concrete type (mw is typed as Middleware - // interface here, so InferSchemaMap would receive a nil interface). - desc := &MiddlewareDesc{ - Name: name, - newFromJSON: func(configJSON []byte) (Middleware, error) { - inst := mw.New() - if len(configJSON) > 0 { - if err := json.Unmarshal(configJSON, inst); err != nil { - return nil, fmt.Errorf("middleware %q: %w", name, err) - } - } - return inst, nil - }, - } - desc.Register(r) - } - configJSON, err := json.Marshal(mw) - if err != nil { - return nil, r, fmt.Errorf("failed to marshal middleware %q config: %w", name, err) +// MiddlewarePlugin is implemented by plugins that provide middleware. The +// returned descriptors are registered in the registry during [genkit.Init], +// with any plugin-level state captured by the descriptor's build closure via +// the prototype passed to [NewMiddleware]. +type MiddlewarePlugin interface { + Middlewares(ctx context.Context) ([]*MiddlewareDesc, error) +} + +// configsToRefs converts a user-supplied slice of [Middleware] values into +// the [MiddlewareRef] entries carried on [GenerateActionOptions.Use]. The Go +// value is stored on each ref so [resolveRefs] can build the hooks bundle +// directly without a registry round trip for local calls. +func configsToRefs(configs []Middleware) ([]*MiddlewareRef, error) { + if len(configs) == 0 { + return nil, nil } - var config any - if err := json.Unmarshal(configJSON, &config); err != nil { - return nil, r, fmt.Errorf("failed to unmarshal middleware %q config: %w", name, err) + refs := make([]*MiddlewareRef, 0, len(configs)) + for _, c := range configs { + if c == nil { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai: nil middleware") + } + refs = append(refs, &MiddlewareRef{Name: c.Name(), Config: c}) } - return &MiddlewareRef{Name: name, Config: config}, r, nil + return refs, nil } -// MiddlewarePlugin is implemented by plugins that provide middleware. -type MiddlewarePlugin interface { - ListMiddleware(ctx context.Context) ([]*MiddlewareDesc, error) +// resolveRefs resolves [MiddlewareRef] entries to [Hooks] bundles. If +// ref.Config is a [Middleware] value, its New method is invoked directly +// (local fast path). Otherwise the descriptor is looked up in the registry +// and its build closure is invoked with the marshaled config (JSON dispatch, +// used for cross-runtime / Dev UI calls). +func resolveRefs(ctx context.Context, r api.Registry, refs []*MiddlewareRef) ([]*Hooks, error) { + if len(refs) == 0 { + return nil, nil + } + bundles := make([]*Hooks, 0, len(refs)) + for _, ref := range refs { + if mw, ok := ref.Config.(Middleware); ok { + h, err := mw.New(ctx) + if err != nil { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai: failed to build middleware %q: %v", ref.Name, err) + } + if h == nil { + return nil, core.NewError(core.INTERNAL, "ai: middleware %q returned nil hooks", ref.Name) + } + bundles = append(bundles, h) + continue + } + d := LookupMiddleware(r, ref.Name) + if d == nil { + return nil, core.NewError(core.NOT_FOUND, "ai: middleware %q not registered (is the providing plugin installed?)", ref.Name) + } + var configJSON []byte + if ref.Config != nil { + b, err := json.Marshal(ref.Config) + if err != nil { + return nil, core.NewError(core.INTERNAL, "ai: failed to marshal config for middleware %q: %v", ref.Name, err) + } + configJSON = b + } + h, err := d.buildFromJSON(ctx, configJSON) + if err != nil { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai: failed to build middleware %q: %v", ref.Name, err) + } + if h == nil { + return nil, core.NewError(core.INTERNAL, "ai: middleware %q factory returned nil", ref.Name) + } + bundles = append(bundles, h) + } + return bundles, nil } diff --git a/go/ai/middleware_test.go b/go/ai/middleware_test.go index f08d8a9e7e..faf0bd814e 100644 --- a/go/ai/middleware_test.go +++ b/go/ai/middleware_test.go @@ -18,138 +18,183 @@ package ai import ( "context" + "errors" + "sync" "sync/atomic" "testing" ) -// testMiddleware is a simple middleware for testing that tracks hook invocations. -type testMiddleware struct { - BaseMiddleware - Label string `json:"label"` - generateCalls int - modelCalls int - toolCalls int32 // atomic since tool hooks run in parallel -} - -func (m *testMiddleware) Name() string { return "test" } +// --- counter: a config whose BuildMiddleware tracks hook invocations --- -func (m *testMiddleware) New() Middleware { - return &testMiddleware{Label: m.Label} -} - -func (m *testMiddleware) WrapGenerate(ctx context.Context, params *GenerateParams, next GenerateNext) (*ModelResponse, error) { - m.generateCalls++ - return next(ctx, params) -} +type counterConfig struct { + Label string `json:"label,omitempty"` -func (m *testMiddleware) WrapModel(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) { - m.modelCalls++ - return next(ctx, params) + // Plugin-level state lives on unexported fields and is preserved by + // the descriptor's value-copy across JSON-dispatch calls. + sharedGenerateCalls *int32 + sharedModelCalls *int32 + sharedToolCalls *int32 } -func (m *testMiddleware) WrapTool(ctx context.Context, params *ToolParams, next ToolNext) (*ToolResponse, error) { - atomic.AddInt32(&m.toolCalls, 1) - return next(ctx, params) +func (counterConfig) Name() string { return "test/counter" } + +func (c counterConfig) New(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapGenerate: func(ctx context.Context, p *GenerateParams, next GenerateNext) (*ModelResponse, error) { + if c.sharedGenerateCalls != nil { + atomic.AddInt32(c.sharedGenerateCalls, 1) + } + return next(ctx, p) + }, + WrapModel: func(ctx context.Context, p *ModelParams, next ModelNext) (*ModelResponse, error) { + if c.sharedModelCalls != nil { + atomic.AddInt32(c.sharedModelCalls, 1) + } + return next(ctx, p) + }, + WrapTool: func(ctx context.Context, p *ToolParams, next ToolNext) (*MultipartToolResponse, error) { + if c.sharedToolCalls != nil { + atomic.AddInt32(c.sharedToolCalls, 1) + } + return next(ctx, p) + }, + }, nil } -func TestNewMiddlewareDesc(t *testing.T) { - proto := &testMiddleware{Label: "original"} - desc := NewMiddleware("test middleware", proto) +// --- core descriptor tests --- - if desc.Name != "test" { - t.Errorf("got name %q, want %q", desc.Name, "test") +func TestNewMiddleware_NameFromPrototype(t *testing.T) { + desc := NewMiddleware("tracks calls", counterConfig{}) + if desc.Name != "test/counter" { + t.Errorf("got name %q, want %q", desc.Name, "test/counter") } - if desc.Description != "test middleware" { - t.Errorf("got description %q, want %q", desc.Description, "test middleware") + if desc.Description != "tracks calls" { + t.Errorf("got description %q, want %q", desc.Description, "tracks calls") } } -func TestDefineAndLookupMiddleware(t *testing.T) { - r := newTestRegistry(t) - proto := &testMiddleware{Label: "original"} - DefineMiddleware(r, "test middleware", proto) - - found := LookupMiddleware(r, "test") - if found == nil { - t.Fatal("expected to find middleware, got nil") +func TestBuildFromJSON(t *testing.T) { + desc := NewMiddleware("desc", counterConfig{}) + mw, err := desc.buildFromJSON(testCtx, []byte(`{"label": "custom"}`)) + if err != nil { + t.Fatalf("buildFromJSON failed: %v", err) } - if found.Name != "test" { - t.Errorf("got name %q, want %q", found.Name, "test") + if mw == nil || mw.WrapModel == nil { + t.Fatal("expected middleware with WrapModel hook") } } -func TestLookupMiddlewareNotFound(t *testing.T) { - r := newTestRegistry(t) - found := LookupMiddleware(r, "nonexistent") - if found != nil { - t.Errorf("expected nil, got %v", found) +func TestBuildFromJSON_InvalidJSON(t *testing.T) { + desc := NewMiddleware("desc", counterConfig{}) + _, err := desc.buildFromJSON(testCtx, []byte(`not-json`)) + if err == nil { + t.Fatal("expected error from invalid JSON") } } -func TestConfigFromJSON(t *testing.T) { - proto := &testMiddleware{Label: "stable"} - desc := NewMiddleware("test middleware", proto) +// --- plugin-level state: prototype unexported fields preserved across calls --- - handler, err := desc.newFromJSON([]byte(`{"label": "custom"}`)) - if err != nil { - t.Fatalf("newFromJSON failed: %v", err) - } +func TestPluginStateCarriedThroughPrototype(t *testing.T) { + // Simulate the plugin's Middlewares() building a prototype that holds + // an "expensive client" (here, a shared counter). JSON dispatch must + // preserve this plugin-level state across invocations. + var shared int32 + desc := NewMiddleware("desc", counterConfig{sharedModelCalls: &shared}) - tm, ok := handler.(*testMiddleware) - if !ok { - t.Fatalf("expected *testMiddleware, got %T", handler) - } - if tm.Label != "custom" { - t.Errorf("got label %q, want %q", tm.Label, "custom") + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + desc.Register(r) + + // Simulate Dev UI JSON dispatch: ref.Config is a map, not a typed config. + refs := []*MiddlewareRef{{Name: "test/counter", Config: map[string]any{"label": "dev-ui"}}} + for i := 0; i < 3; i++ { + _, err := GenerateWithRequest(testCtx, r, &GenerateActionOptions{ + Model: m.Name(), + Messages: []*Message{NewUserTextMessage("go")}, + Use: refs, + }, nil, nil) + assertNoError(t, err) } - // Per-request state should be zeroed by New() - if tm.generateCalls != 0 { - t.Errorf("got generateCalls %d, want 0", tm.generateCalls) + if got := atomic.LoadInt32(&shared); got != 3 { + t.Errorf("shared counter = %d, want 3 (plugin state should persist across JSON dispatches)", got) } } -func TestConfigFromJSONPreservesStableState(t *testing.T) { - // Simulate a plugin middleware with unexported stable state - proto := &stableStateMiddleware{apiKey: "secret123"} - desc := NewMiddleware("middleware with stable state", proto) +// --- call-level state: each Generate gets fresh BuildMiddleware scope --- - handler, err := desc.newFromJSON([]byte(`{"sampleRate": 0.5}`)) - if err != nil { - t.Fatalf("newFromJSON failed: %v", err) - } +type perCallConfig struct { + checker func(n int32) +} - sm, ok := handler.(*stableStateMiddleware) - if !ok { - t.Fatalf("expected *stableStateMiddleware, got %T", handler) - } - if sm.apiKey != "secret123" { - t.Errorf("got apiKey %q, want %q", sm.apiKey, "secret123") +func (perCallConfig) Name() string { return "test/per-call" } + +func (c perCallConfig) New(ctx context.Context) (*Hooks, error) { + var counter int32 + return &Hooks{ + WrapModel: func(ctx context.Context, p *ModelParams, next ModelNext) (*ModelResponse, error) { + n := atomic.AddInt32(&counter, 1) + if c.checker != nil { + c.checker(n) + } + return next(ctx, p) + }, + }, nil +} + +func TestCallLevelStateIsolation(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + cfg := perCallConfig{checker: func(n int32) { + if n != 1 { + t.Errorf("call-level counter leaked: got %d, want 1", n) + } + }} + for i := 0; i < 3; i++ { + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("go"), WithUse(cfg)) + assertNoError(t, err) } - if sm.SampleRate != 0.5 { - t.Errorf("got SampleRate %f, want 0.5", sm.SampleRate) +} + +// --- pure Go usage: no registration required for local calls --- + +func TestWithUseNoRegistrationNeeded(t *testing.T) { + // The whole point: user creates a Genkit with no middleware plugins and + // still calls WithUse(middleware.Retry{...}). The config's BuildMiddleware + // method runs directly; the registry is never consulted. + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + var called int32 + cfg := counterConfig{sharedModelCalls: &called} + // Note: no Register call anywhere. + + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("hi"), WithUse(cfg)) + assertNoError(t, err) + if atomic.LoadInt32(&called) != 1 { + t.Errorf("expected 1 model-hook call, got %d", called) } } +// --- hook invocation: model, tool, generate --- + func TestMiddlewareModelHook(t *testing.T) { r := newTestRegistry(t) m := defineFakeModel(t, r, fakeModelConfig{}) - DefineMiddleware(r, "tracks calls", &testMiddleware{}) - var modelHookCalled bool - mw := &hookTrackingMiddleware{ - onModel: func() { modelHookCalled = true }, - } + var called int32 + tracker := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapModel: func(ctx context.Context, p *ModelParams, next ModelNext) (*ModelResponse, error) { + atomic.AddInt32(&called, 1) + return next(ctx, p) + }, + }, nil + }) - resp, err := Generate(testCtx, r, - WithModel(m), - WithPrompt("hello"), - WithUse(mw), - ) + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("hello"), WithUse(tracker)) assertNoError(t, err) - if resp == nil { - t.Fatal("expected response, got nil") - } - if !modelHookCalled { + if atomic.LoadInt32(&called) == 0 { t.Error("expected model hook to be called") } } @@ -162,49 +207,85 @@ func TestMiddlewareToolHook(t *testing.T) { }) defineFakeTool(t, r, "myTool", "A test tool") - var toolHookCalled int32 - mw := &hookTrackingMiddleware{ - onTool: func() { atomic.AddInt32(&toolHookCalled, 1) }, - } - DefineMiddleware(r, "tracks calls", mw) + var called int32 + tracker := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapTool: func(ctx context.Context, p *ToolParams, next ToolNext) (*MultipartToolResponse, error) { + atomic.AddInt32(&called, 1) + return next(ctx, p) + }, + }, nil + }) _, err := Generate(testCtx, r, WithModelName("test/toolModel"), WithPrompt("use the tool"), WithTools(ToolName("myTool")), - WithUse(mw), + WithUse(tracker), ) assertNoError(t, err) - if atomic.LoadInt32(&toolHookCalled) == 0 { + if atomic.LoadInt32(&called) == 0 { t.Error("expected tool hook to be called at least once") } } +func TestMiddlewareGenerateHook(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + var called int32 + tracker := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapGenerate: func(ctx context.Context, p *GenerateParams, next GenerateNext) (*ModelResponse, error) { + atomic.AddInt32(&called, 1) + return next(ctx, p) + }, + }, nil + }) + + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("hello"), WithUse(tracker)) + assertNoError(t, err) + if atomic.LoadInt32(&called) == 0 { + t.Error("expected generate hook to be called") + } +} + +// --- ordering: first middleware wraps outermost --- + func TestMiddlewareOrdering(t *testing.T) { - // First middleware is outermost + var mu sync.Mutex var order []string + appendOrder := func(s string) { + mu.Lock() + defer mu.Unlock() + order = append(order, s) + } + tracker := func(label string) Middleware { + return MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapModel: func(ctx context.Context, p *ModelParams, next ModelNext) (*ModelResponse, error) { + appendOrder(label + "-before") + resp, err := next(ctx, p) + appendOrder(label + "-after") + return resp, err + }, + }, nil + }) + } + r := newTestRegistry(t) m := defineFakeModel(t, r, fakeModelConfig{}) - mwA := &orderMiddleware{label: "A", order: &order} - mwB := &orderMiddleware{label: "B", order: &order} - DefineMiddleware(r, "middleware A", mwA) - DefineMiddleware(r, "middleware B", mwB) - _, err := Generate(testCtx, r, WithModel(m), WithPrompt("hello"), - WithUse( - &orderMiddleware{label: "A", order: &order}, - &orderMiddleware{label: "B", order: &order}, - ), + WithUse(tracker("A"), tracker("B")), ) assertNoError(t, err) - // Expect: A-before, B-before, B-after, A-after (first is outermost) - want := []string{"A-model-before", "B-model-before", "B-model-after", "A-model-after"} + want := []string{"A-before", "B-before", "B-after", "A-after"} if len(order) != len(want) { - t.Fatalf("got order %v, want %v", order, want) + t.Fatalf("got %v, want %v", order, want) } for i := range want { if order[i] != want[i] { @@ -213,74 +294,283 @@ func TestMiddlewareOrdering(t *testing.T) { } } -// --- helper middleware types for tests --- +// --- MiddlewareFunc adapter basics --- -// hookTrackingMiddleware uses callbacks to verify hooks are actually invoked. -type hookTrackingMiddleware struct { - BaseMiddleware - onGenerate func() - onModel func() - onTool func() -} +func TestMiddlewareFunc(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) -func (m *hookTrackingMiddleware) Name() string { return "hookTracking" } + var called bool + mw := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapGenerate: func(ctx context.Context, p *GenerateParams, next GenerateNext) (*ModelResponse, error) { + called = true + return next(ctx, p) + }, + }, nil + }) -func (m *hookTrackingMiddleware) New() Middleware { - return &hookTrackingMiddleware{onGenerate: m.onGenerate, onModel: m.onModel, onTool: m.onTool} + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("hello"), WithUse(mw)) + assertNoError(t, err) + if !called { + t.Error("inline middleware hook not called") + } } -func (m *hookTrackingMiddleware) WrapGenerate(ctx context.Context, params *GenerateParams, next GenerateNext) (*ModelResponse, error) { - if m.onGenerate != nil { - m.onGenerate() +func TestMiddlewareFuncCoexist(t *testing.T) { + // Two MiddlewareFunc adapter instances should be able to coexist in a + // single WithUse call. + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + var a, b int32 + useA := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{WrapModel: func(ctx context.Context, p *ModelParams, next ModelNext) (*ModelResponse, error) { + atomic.AddInt32(&a, 1) + return next(ctx, p) + }}, nil + }) + useB := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{WrapModel: func(ctx context.Context, p *ModelParams, next ModelNext) (*ModelResponse, error) { + atomic.AddInt32(&b, 1) + return next(ctx, p) + }}, nil + }) + + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("hi"), WithUse(useA, useB)) + assertNoError(t, err) + if atomic.LoadInt32(&a) != 1 || atomic.LoadInt32(&b) != 1 { + t.Errorf("expected both hooks called once, got a=%d b=%d", a, b) } - return next(ctx, params) } -func (m *hookTrackingMiddleware) WrapModel(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) { - if m.onModel != nil { - m.onModel() +// --- optional hooks: nil hook fields must pass through --- + +func TestNilHookFieldsPassThrough(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + passthrough := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{}, nil + }) + + resp, err := Generate(testCtx, r, WithModel(m), WithPrompt("hi"), WithUse(passthrough)) + assertNoError(t, err) + if resp == nil { + t.Fatal("expected response") } - return next(ctx, params) } -func (m *hookTrackingMiddleware) WrapTool(ctx context.Context, params *ToolParams, next ToolNext) (*ToolResponse, error) { - if m.onTool != nil { - m.onTool() +// --- streaming: middleware-emitted chunks accumulate with model chunks --- + +func TestMiddlewareStreamsAccumulateWithModel(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{ + name: "test/streamModel", + handler: streamingModelHandler([]string{"model chunk"}, "done"), + }) + + midChunk := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapGenerate: func(ctx context.Context, p *GenerateParams, next GenerateNext) (*ModelResponse, error) { + if p.Callback != nil { + if err := p.Callback(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart("middleware chunk ")}, + }); err != nil { + return nil, err + } + } + return next(ctx, p) + }, + }, nil + }) + + var chunks []*ModelResponseChunk + _, err := Generate(testCtx, r, + WithModel(m), + WithPrompt("go"), + WithUse(midChunk), + WithStreaming(func(_ context.Context, c *ModelResponseChunk) error { + chunks = append(chunks, c) + return nil + }), + ) + assertNoError(t, err) + + if len(chunks) != 2 { + t.Fatalf("got %d chunks, want 2", len(chunks)) } - return next(ctx, params) } -// stableStateMiddleware has unexported stable state preserved by New(). -type stableStateMiddleware struct { - BaseMiddleware - SampleRate float64 `json:"sampleRate"` - apiKey string +// --- tool contribution: Tools on *Middleware --- + +func TestMiddlewareContributesTool(t *testing.T) { + r := newTestRegistry(t) + defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolModel", + handler: toolCallingModelHandler("mw/tool", map[string]any{"value": "x"}, "done"), + }) + + injectTool := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + Tools: []Tool{NewTool("mw/tool", "injected", + func(tc *ToolContext, in struct { + Value string `json:"value"` + }) (string, error) { + return "ok", nil + })}, + }, nil + }) + + _, err := Generate(testCtx, r, + WithModelName("test/toolModel"), + WithPrompt("use it"), + WithUse(injectTool), + ) + assertNoError(t, err) } -func (m *stableStateMiddleware) Name() string { return "stableState" } +// --- duplicate tool collision: two middleware with same tool name --- + +func TestDuplicateMiddlewareToolRejected(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + makeInjector := func() Middleware { + return MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + Tools: []Tool{NewTool("dup/tool", "d", + func(tc *ToolContext, in struct{}) (string, error) { return "x", nil })}, + }, nil + }) + } -func (m *stableStateMiddleware) New() Middleware { - return &stableStateMiddleware{apiKey: m.apiKey} + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("hi"), + WithUse(makeInjector(), makeInjector())) + if err == nil { + t.Fatal("expected duplicate tool error, got nil") + } } -// orderMiddleware tracks the order of WrapModel hook invocations. -type orderMiddleware struct { - BaseMiddleware - label string - order *[]string +// --- error propagation from BuildMiddleware --- + +func TestBuildMiddlewareErrorPropagates(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + bad := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return nil, errors.New("boom") + }) + + _, err := Generate(testCtx, r, WithModel(m), WithPrompt("hi"), WithUse(bad)) + if err == nil { + t.Fatal("expected BuildMiddleware error, got nil") + } } -func (m *orderMiddleware) Name() string { return "order-" + m.label } +// --- tool interrupt from WrapTool --- + +func TestWrapToolInterrupts(t *testing.T) { + r := newTestRegistry(t) + defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolModel", + handler: toolCallingModelHandler("myTool", map[string]any{"value": "x"}, "done"), + }) + defineFakeTool(t, r, "myTool", "A test tool") -func (m *orderMiddleware) New() Middleware { - return &orderMiddleware{label: m.label, order: m.order} + interrupter := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapTool: func(ctx context.Context, p *ToolParams, next ToolNext) (*MultipartToolResponse, error) { + return nil, NewToolInterruptError(map[string]any{"reason": "blocked"}) + }, + }, nil + }) + + resp, err := Generate(testCtx, r, + WithModelName("test/toolModel"), + WithPrompt("use it"), + WithTools(ToolName("myTool")), + WithUse(interrupter), + ) + assertNoError(t, err) + if resp.FinishReason != "interrupted" { + t.Errorf("expected FinishReason=interrupted, got %q", resp.FinishReason) + } + if len(resp.Interrupts()) == 0 { + t.Error("expected at least one interrupt part in response") + } } -func (m *orderMiddleware) WrapModel(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) { - *m.order = append(*m.order, m.label+"-model-before") - resp, err := next(ctx, params) - *m.order = append(*m.order, m.label+"-model-after") - return resp, err +// --- WrapGenerate fires per tool-loop iteration --- + +func TestGenerateHookFiresEachIteration(t *testing.T) { + r := newTestRegistry(t) + defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolLoop", + handler: toolCallingModelHandler("myTool", map[string]any{"value": "x"}, "done"), + }) + defineFakeTool(t, r, "myTool", "A test tool") + + var iters int32 + tracker := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapGenerate: func(ctx context.Context, p *GenerateParams, next GenerateNext) (*ModelResponse, error) { + atomic.AddInt32(&iters, 1) + return next(ctx, p) + }, + }, nil + }) + + _, err := Generate(testCtx, r, + WithModelName("test/toolLoop"), + WithPrompt("use it"), + WithTools(ToolName("myTool")), + WithUse(tracker), + ) + assertNoError(t, err) + if got := atomic.LoadInt32(&iters); got < 2 { + t.Errorf("expected WrapGenerate to fire >=2 times, got %d", got) + } +} + +// --- WrapTool metadata survives round trip --- + +func TestWrapToolPreservesMetadata(t *testing.T) { + r := newTestRegistry(t) + defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolModel", + handler: toolCallingModelHandler("myTool", map[string]any{"value": "x"}, "done"), + }) + defineFakeTool(t, r, "myTool", "A test tool") + + var sawMetadata map[string]any + reader := MiddlewareFunc(func(ctx context.Context) (*Hooks, error) { + return &Hooks{ + WrapTool: func(ctx context.Context, p *ToolParams, next ToolNext) (*MultipartToolResponse, error) { + resp, err := next(ctx, p) + if err != nil { + return nil, err + } + if resp.Metadata == nil { + resp.Metadata = map[string]any{} + } + resp.Metadata["traced"] = true + sawMetadata = resp.Metadata + return resp, nil + }, + }, nil + }) + + _, err := Generate(testCtx, r, + WithModelName("test/toolModel"), + WithPrompt("use it"), + WithTools(ToolName("myTool")), + WithUse(reader), + ) + assertNoError(t, err) + if sawMetadata == nil || sawMetadata["traced"] != true { + t.Errorf("metadata not threaded, got %v", sawMetadata) + } } var testCtx = context.Background() diff --git a/go/ai/option.go b/go/ai/option.go index 84019b11d7..dc9dbc7816 100644 --- a/go/ai/option.go +++ b/go/ai/option.go @@ -109,8 +109,8 @@ type commonGenOptions struct { ToolChoice ToolChoice // Whether tool calls are required, disabled, or optional. MaxTurns int // Maximum number of tool call iterations. ReturnToolRequests *bool // Whether to return tool requests instead of making the tool calls and continuing the generation. - Middleware []ModelMiddleware // Deprecated: Use WithUse instead. Middleware to apply to the model request and model response. - Use []Middleware // Middleware to apply to generation (Generate, Model, and Tool hooks). + Middleware []ModelMiddleware // Deprecated: Use WithUse instead. Middleware to apply to the model request and model response. + Use []Middleware // Middleware to apply to generation (Generate, Model, and Tool hooks). } type CommonGenOption interface { @@ -249,6 +249,10 @@ func WithMiddleware(middleware ...ModelMiddleware) CommonGenOption { // WithUse sets middleware to apply to generation. Middleware hooks wrap // the generate loop, model calls, and tool executions. +// +// Accepts either a middleware config struct (produced by a plugin) or an +// inline adapter via [MiddlewareFunc]. The chain applies outer-to-inner, so +// WithUse(A, B) expands to A { B { ... } }. func WithUse(middleware ...Middleware) CommonGenOption { return &commonGenOptions{Use: middleware} } diff --git a/go/ai/prompt.go b/go/ai/prompt.go index 1e42be4d2f..de62751ea9 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -249,15 +249,12 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod } } - if len(execOpts.Use) > 0 { - for _, mw := range execOpts.Use { - ref, newR, err := middlewareToRef(r, mw) - if err != nil { - return nil, fmt.Errorf("Prompt.Execute: %w", err) - } - r = newR - actionOpts.Use = append(actionOpts.Use, ref) - } + refs, err := configsToRefs(execOpts.Use) + if err != nil { + return nil, fmt.Errorf("Prompt.Execute: %w", err) + } + if len(refs) > 0 { + actionOpts.Use = refs } return GenerateWithRequest(ctx, r, actionOpts, execOpts.Middleware, execOpts.Stream) diff --git a/go/ai/testutil_test.go b/go/ai/testutil_test.go index 6c606a28a8..165f0cd0a9 100644 --- a/go/ai/testutil_test.go +++ b/go/ai/testutil_test.go @@ -196,6 +196,40 @@ func toolCallingModelHandler(toolName string, toolInput map[string]any, finalRes } } +// parallelToolCallingModelHandler returns count copies of the same tool call +// in a single model response, so the tool loop dispatches them in parallel. +// Used to verify WrapTool concurrency safety. +func parallelToolCallingModelHandler(toolName string, count int) func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + hasToolResponse := false + for _, msg := range req.Messages { + for _, part := range msg.Content { + if part.IsToolResponse() { + hasToolResponse = true + break + } + } + } + if !hasToolResponse && len(req.Tools) > 0 { + parts := make([]*Part, count) + for i := range parts { + parts[i] = NewToolRequestPart(&ToolRequest{ + Name: toolName, + Input: map[string]any{"value": fmt.Sprintf("req-%d", i)}, + }) + } + return &ModelResponse{ + Request: req, + Message: &Message{Role: RoleModel, Content: parts}, + }, nil + } + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("done"), + }, nil + } +} + // cmpPartEqual is a Part comparator for cmp.Diff that compares essential fields. func cmpPartEqual(a, b *Part) bool { if a == nil || b == nil { diff --git a/go/ai/tools.go b/go/ai/tools.go index 453ad779a3..b88712e091 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -106,6 +106,13 @@ func IsToolInterruptError(err error) (bool, map[string]any) { return false, nil } +// NewToolInterruptError creates a tool interrupt error with the given metadata. +// This is intended for use in middleware that needs to interrupt tool execution +// without calling the tool itself. +func NewToolInterruptError(metadata map[string]any) error { + return &toolInterruptError{Metadata: metadata} +} + // InterruptOptions provides configuration for tool interruption. type InterruptOptions struct { Metadata map[string]any diff --git a/go/core/schemas.config b/go/core/schemas.config index 369db20dba..4a27007821 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1148,7 +1148,7 @@ ModelResponseChunk field formatHandler StreamingFormatHandler # Middleware MiddlewareDesc pkg ai MiddlewareDesc.configSchema type map[string]any -MiddlewareDesc field newFromJSON middlewareFactory +MiddlewareDesc field buildFromJSON middlewareFactoryFunc MiddlewareRef pkg ai Score omit diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index d2850f2278..c0055de7a7 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -223,9 +223,9 @@ func Init(ctx context.Context, opts ...GenkitOption) *Genkit { r.RegisterPlugin(plugin.Name(), plugin) if mp, ok := plugin.(ai.MiddlewarePlugin); ok { - descs, err := mp.ListMiddleware(ctx) + descs, err := mp.Middlewares(ctx) if err != nil { - panic(fmt.Errorf("genkit.Init: plugin %q ListMiddleware failed: %w", plugin.Name(), err)) + panic(fmt.Errorf("genkit.Init: plugin %q Middlewares failed: %w", plugin.Name(), err)) } for _, d := range descs { d.Register(r) @@ -690,6 +690,68 @@ func LookupTool(g *Genkit, name string) ai.Tool { return ai.LookupTool(g.reg, name) } +// DefineMiddleware registers a middleware descriptor with the Genkit instance +// and returns the resulting [*ai.MiddlewareDesc]. Registered middleware is +// surfaced to the Dev UI and addressable by name for cross-runtime dispatch. +// +// This is the path for application code that declares its own middleware +// directly. Plugins should instead construct descriptors with [ai.NewMiddleware] +// (no registration) and return them from [ai.MiddlewarePlugin.Middlewares]; +// [Init] registers those descriptors during plugin setup. +// +// The `description` is a human-readable explanation shown in the Dev UI. The +// `prototype` is a value of a type that implements [ai.Middleware]. Its +// [ai.Middleware.Name] method supplies the registered name, and its fields +// (both exported JSON config and unexported plugin-level state) are captured +// by a value-copy inside the descriptor so JSON-dispatched invocations +// preserve prototype state across calls. +// +// For pure Go use, registration is not strictly required: passing a middleware +// config directly to [ai.WithUse] invokes its [ai.Middleware.New] method on +// the local fast path without a registry lookup. Registration is what makes +// the middleware visible to the Dev UI and callable from other runtimes. For +// ad-hoc one-off middleware that doesn't need Dev UI visibility, use +// [ai.MiddlewareFunc] instead of defining a type. +// +// Example: +// +// type Trace struct { +// Label string `json:"label,omitempty"` +// } +// +// func (Trace) Name() string { return "mine/trace" } +// +// func (t Trace) New(ctx context.Context) (*ai.Hooks, error) { +// return &ai.Hooks{ +// WrapModel: func(ctx context.Context, p *ai.ModelParams, next ai.ModelNext) (*ai.ModelResponse, error) { +// start := time.Now() +// resp, err := next(ctx, p) +// log.Printf("[%s] model call took %s", t.Label, time.Since(start)) +// return resp, err +// }, +// }, nil +// } +// +// // Register so it appears in the Dev UI and can be called by name: +// genkit.DefineMiddleware(g, "logs model call latency", Trace{}) +// +// // Use it per-call: +// resp, err := genkit.Generate(ctx, g, +// ai.WithPrompt("hello"), +// ai.WithUse(Trace{Label: "debug"}), +// ) +func DefineMiddleware[M ai.Middleware](g *Genkit, description string, prototype M) *ai.MiddlewareDesc { + return ai.DefineMiddleware(g.reg, description, prototype) +} + +// LookupMiddleware retrieves a registered middleware descriptor by its name. +// It returns the descriptor if found, or `nil` if no middleware with the +// given name is registered (e.g., via [DefineMiddleware] or through a +// plugin's [ai.MiddlewarePlugin.Middlewares] method). +func LookupMiddleware(g *Genkit, name string) *ai.MiddlewareDesc { + return ai.LookupMiddleware(g.reg, name) +} + // DefinePrompt defines a prompt programmatically, registers it as a [core.Action] // of type Prompt, and returns an executable [ai.Prompt]. // From 0bb8499b1c4fdd7a3dc3295569dd6e31e9051458 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 16 Apr 2026 18:08:46 -0700 Subject: [PATCH 20/22] Update option.go --- go/ai/option.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go/ai/option.go b/go/ai/option.go index dc9dbc7816..5b2bcb905a 100644 --- a/go/ai/option.go +++ b/go/ai/option.go @@ -109,8 +109,8 @@ type commonGenOptions struct { ToolChoice ToolChoice // Whether tool calls are required, disabled, or optional. MaxTurns int // Maximum number of tool call iterations. ReturnToolRequests *bool // Whether to return tool requests instead of making the tool calls and continuing the generation. - Middleware []ModelMiddleware // Deprecated: Use WithUse instead. Middleware to apply to the model request and model response. - Use []Middleware // Middleware to apply to generation (Generate, Model, and Tool hooks). + Middleware []ModelMiddleware // Deprecated: Use WithUse instead. Middleware to apply to the model request and model response. + Use []Middleware // Middleware to apply to generation (Generate, Model, and Tool hooks). } type CommonGenOption interface { From e284aacb8f2a0434998b8c5ded32bd89d2d173a3 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 16 Apr 2026 18:26:50 -0700 Subject: [PATCH 21/22] Update middleware.go --- go/ai/middleware.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/go/ai/middleware.go b/go/ai/middleware.go index 6ebeea885d..a46e99216c 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -19,7 +19,6 @@ package ai import ( "context" "encoding/json" - "fmt" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" @@ -147,7 +146,7 @@ func NewMiddleware[M Middleware](description string, prototype M) *MiddlewareDes cfg := prototype // value copy preserves unexported fields, shares pointers if len(configJSON) > 0 { if err := json.Unmarshal(configJSON, &cfg); err != nil { - return nil, fmt.Errorf("middleware %q: %w", name, err) + return nil, core.NewError(core.INVALID_ARGUMENT, "middleware %q: %w", name, err) } } return cfg.New(ctx) From f1fccdfcd44a54a78f527b56c44f0ca25c9a6b27 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 16 Apr 2026 18:47:35 -0700 Subject: [PATCH 22/22] minor fixes --- go/ai/generate.go | 19 +++++++++++++++---- go/ai/middleware.go | 3 --- go/ai/testutil_test.go | 34 ---------------------------------- 3 files changed, 15 insertions(+), 41 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index 7fc74ddd29..307f82f002 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -345,7 +345,6 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi runTool := buildToolRunner(mws) - // Inline recursive helper function that captures variables from parent scope. var generate func(context.Context, *ModelRequest, int, int) (*ModelResponse, error) var runGenerate func(context.Context, *GenerateParams) (*ModelResponse, error) @@ -525,12 +524,24 @@ func buildModelChain(mws []*Hooks, fn ModelFunc) ModelFunc { // buildToolRunner composes the WrapTool hooks from mws (outer-to-inner) into // a single function that executes a tool. The returned function is safe to // invoke from concurrent goroutines; each invocation threads its own params -// through the shared hook chain. +// through the shared hook chain. When no WrapTool hooks are configured, the +// tool is invoked directly without allocating a ToolParams wrapper. func buildToolRunner(mws []*Hooks) func(ctx context.Context, tool Tool, req *ToolRequest) (*MultipartToolResponse, error) { - base := func(ctx context.Context, params *ToolParams) (*MultipartToolResponse, error) { + hasHook := false + for _, mw := range mws { + if mw != nil && mw.WrapTool != nil { + hasHook = true + break + } + } + if !hasHook { + return func(ctx context.Context, tool Tool, req *ToolRequest) (*MultipartToolResponse, error) { + return tool.RunRawMultipart(ctx, req.Input) + } + } + chain := func(ctx context.Context, params *ToolParams) (*MultipartToolResponse, error) { return params.Tool.RunRawMultipart(ctx, params.Request.Input) } - chain := base for i := len(mws) - 1; i >= 0; i-- { mw := mws[i] if mw == nil || mw.WrapTool == nil { diff --git a/go/ai/middleware.go b/go/ai/middleware.go index a46e99216c..1b28012b65 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -118,8 +118,6 @@ type middlewareFactoryFunc = func(ctx context.Context, configJSON []byte) (*Hook // descriptors are stored. The reflection API lists values under this prefix. const middlewareRegistryPrefix = "/middleware/" -// middlewareRegistryKey returns the registry key for a middleware with the -// given name. func middlewareRegistryKey(name string) string { return middlewareRegistryPrefix + name } @@ -177,7 +175,6 @@ type MiddlewareFunc func(ctx context.Context) (*Hooks, error) // in [resolveRefs] and never goes through a name-keyed registry lookup. func (MiddlewareFunc) Name() string { return "inline" } -// New invokes the adapted factory to produce a fresh [Hooks] bundle. func (f MiddlewareFunc) New(ctx context.Context) (*Hooks, error) { return f(ctx) } // LookupMiddleware returns the registered middleware descriptor with the diff --git a/go/ai/testutil_test.go b/go/ai/testutil_test.go index 165f0cd0a9..6c606a28a8 100644 --- a/go/ai/testutil_test.go +++ b/go/ai/testutil_test.go @@ -196,40 +196,6 @@ func toolCallingModelHandler(toolName string, toolInput map[string]any, finalRes } } -// parallelToolCallingModelHandler returns count copies of the same tool call -// in a single model response, so the tool loop dispatches them in parallel. -// Used to verify WrapTool concurrency safety. -func parallelToolCallingModelHandler(toolName string, count int) func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { - return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { - hasToolResponse := false - for _, msg := range req.Messages { - for _, part := range msg.Content { - if part.IsToolResponse() { - hasToolResponse = true - break - } - } - } - if !hasToolResponse && len(req.Tools) > 0 { - parts := make([]*Part, count) - for i := range parts { - parts[i] = NewToolRequestPart(&ToolRequest{ - Name: toolName, - Input: map[string]any{"value": fmt.Sprintf("req-%d", i)}, - }) - } - return &ModelResponse{ - Request: req, - Message: &Message{Role: RoleModel, Content: parts}, - }, nil - } - return &ModelResponse{ - Request: req, - Message: NewModelTextMessage("done"), - }, nil - } -} - // cmpPartEqual is a Part comparator for cmp.Diff that compares essential fields. func cmpPartEqual(a, b *Part) bool { if a == nil || b == nil {