diff --git a/.gitignore b/.gitignore index 38a67cf65..5c2d3f4f6 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ node_modules *.log hs_err_pid* samples/google-genai/generated_media/ +.astro \ No newline at end of file diff --git a/ai/src/main/java/com/google/genkit/ai/GenerateOptions.java b/ai/src/main/java/com/google/genkit/ai/GenerateOptions.java index 3017243b5..ad3a38571 100644 --- a/ai/src/main/java/com/google/genkit/ai/GenerateOptions.java +++ b/ai/src/main/java/com/google/genkit/ai/GenerateOptions.java @@ -18,6 +18,7 @@ package com.google.genkit.ai; +import com.google.genkit.ai.middleware.GenerationMiddleware; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -43,6 +44,7 @@ public class GenerateOptions { private final Integer maxTurns; private final ResumeOptions resume; private final Class outputClass; + private final List use; /** * Creates new GenerateOptions. @@ -74,7 +76,8 @@ public GenerateOptions( Map context, Integer maxTurns, ResumeOptions resume, - Class outputClass) { + Class outputClass, + List use) { this.model = model; this.prompt = prompt; this.messages = messages; @@ -88,6 +91,7 @@ public GenerateOptions( this.maxTurns = maxTurns; this.resume = resume; this.outputClass = outputClass; + this.use = use; } /** @@ -286,6 +290,15 @@ public Class getOutputClass() { return outputClass; } + /** + * Gets the V2 middleware to apply to this generation. + * + * @return the middleware list, or null if not set + */ + public List getUse() { + return use; + } + /** * Builder for GenerateOptions. * @@ -305,6 +318,7 @@ public static class Builder { private Integer maxTurns; private ResumeOptions resume; private Class outputClass; + private List use; public Builder model(String model) { this.model = model; @@ -407,6 +421,29 @@ public Builder resume(ResumeOptions resume) { return this; } + /** + * Sets V2 middleware to apply to this generation. Middleware hooks wrap the generate loop, + * model calls, and tool executions. + * + * @param use the middleware to apply + * @return this builder + */ + public Builder use(List use) { + this.use = use; + return this; + } + + /** + * Sets V2 middleware to apply to this generation. + * + * @param middleware the middleware to apply + * @return this builder + */ + public Builder use(GenerationMiddleware... middleware) { + this.use = List.of(middleware); + return this; + } + public GenerateOptions build() { return new GenerateOptions<>( model, @@ -421,7 +458,8 @@ public GenerateOptions build() { context, maxTurns, resume, - outputClass); + outputClass, + use); } } } diff --git a/ai/src/main/java/com/google/genkit/ai/middleware/BaseGenerationMiddleware.java b/ai/src/main/java/com/google/genkit/ai/middleware/BaseGenerationMiddleware.java new file mode 100644 index 000000000..0009468c8 --- /dev/null +++ b/ai/src/main/java/com/google/genkit/ai/middleware/BaseGenerationMiddleware.java @@ -0,0 +1,77 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.Part; +import com.google.genkit.ai.Tool; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; +import java.util.Collections; +import java.util.List; + +/** + * BaseGenerationMiddleware provides default pass-through implementations for all three hooks. + * Extend this class and override only the hooks you need. + * + *

Example: + * + *

{@code
+ * public class TimingMiddleware extends BaseGenerationMiddleware {
+ *   @Override
+ *   public String name() { return "timing"; }
+ *
+ *   @Override
+ *   public GenerationMiddleware newInstance() { return new TimingMiddleware(); }
+ *
+ *   @Override
+ *   public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next)
+ *       throws GenkitException {
+ *     long start = System.currentTimeMillis();
+ *     ModelResponse resp = next.apply(ctx, params);
+ *     System.out.println("Model call took " + (System.currentTimeMillis() - start) + "ms");
+ *     return resp;
+ *   }
+ * }
+ * }
+ */ +public abstract class BaseGenerationMiddleware implements GenerationMiddleware { + + @Override + public ModelResponse wrapGenerate(ActionContext ctx, GenerateParams params, GenerateNext next) + throws GenkitException { + return next.apply(ctx, params); + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + return next.apply(ctx, params); + } + + @Override + public Part wrapTool(ActionContext ctx, ToolParams params, ToolNext next) throws GenkitException { + return next.apply(ctx, params); + } + + @Override + public List> tools() { + return Collections.emptyList(); + } +} diff --git a/ai/src/main/java/com/google/genkit/ai/middleware/GenerateNext.java b/ai/src/main/java/com/google/genkit/ai/middleware/GenerateNext.java new file mode 100644 index 000000000..5301e8d50 --- /dev/null +++ b/ai/src/main/java/com/google/genkit/ai/middleware/GenerateNext.java @@ -0,0 +1,38 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +/** Next function in the {@link GenerationMiddleware#wrapGenerate} hook chain. */ +@FunctionalInterface +public interface GenerateNext { + + /** + * Calls the next handler in the generate chain. + * + * @param ctx the action context + * @param params the generate parameters + * @return the model response + * @throws GenkitException if processing fails + */ + ModelResponse apply(ActionContext ctx, GenerateParams params) throws GenkitException; +} diff --git a/ai/src/main/java/com/google/genkit/ai/middleware/GenerateParams.java b/ai/src/main/java/com/google/genkit/ai/middleware/GenerateParams.java new file mode 100644 index 000000000..7508ec467 --- /dev/null +++ b/ai/src/main/java/com/google/genkit/ai/middleware/GenerateParams.java @@ -0,0 +1,54 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import com.google.genkit.ai.ModelRequest; + +/** Holds parameters for the {@link GenerationMiddleware#wrapGenerate} hook. */ +public class GenerateParams { + + private final ModelRequest request; + private final int iteration; + + /** + * Creates GenerateParams. + * + * @param request the current model request for this iteration + * @param iteration the current tool-loop iteration (0-indexed) + */ + public GenerateParams(ModelRequest request, int iteration) { + this.request = request; + this.iteration = iteration; + } + + /** Returns the current model request with accumulated messages. */ + public ModelRequest getRequest() { + return request; + } + + /** Returns the current tool-loop iteration (0-indexed). */ + public int getIteration() { + return iteration; + } + + /** Returns a new GenerateParams with the given request, preserving the iteration. */ + public GenerateParams withRequest(ModelRequest request) { + return new GenerateParams(request, this.iteration); + } +} diff --git a/ai/src/main/java/com/google/genkit/ai/middleware/GenerationMiddleware.java b/ai/src/main/java/com/google/genkit/ai/middleware/GenerationMiddleware.java new file mode 100644 index 000000000..bb751be28 --- /dev/null +++ b/ai/src/main/java/com/google/genkit/ai/middleware/GenerationMiddleware.java @@ -0,0 +1,126 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.Part; +import com.google.genkit.ai.Tool; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; +import java.util.Collections; +import java.util.List; + +/** + * GenerationMiddleware provides hooks for different stages of the generation pipeline. + * + *

This is the V2 middleware interface that replaces the generic {@code Middleware}. It + * provides three distinct hooks: + * + *

    + *
  • {@link #wrapGenerate} - wraps each iteration of the tool loop + *
  • {@link #wrapModel} - wraps each model API call + *
  • {@link #wrapTool} - wraps each tool execution + *
+ * + *

Each {@code generate()} call creates a fresh instance via {@link #newInstance()}, enabling + * per-invocation state (e.g., counters, timers) without shared mutable state across requests. + * + *

Example: + * + *

{@code
+ * public class LoggingMiddleware extends BaseGenerationMiddleware {
+ *   private int modelCalls = 0;
+ *
+ *   @Override
+ *   public String name() { return "logging"; }
+ *
+ *   @Override
+ *   public GenerationMiddleware newInstance() { return new LoggingMiddleware(); }
+ *
+ *   @Override
+ *   public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next)
+ *       throws GenkitException {
+ *     modelCalls++;
+ *     System.out.println("Model call #" + modelCalls);
+ *     ModelResponse resp = next.apply(ctx, params);
+ *     System.out.println("Model responded with " + resp.getText());
+ *     return resp;
+ *   }
+ * }
+ * }
+ */ +public interface GenerationMiddleware { + + /** Returns the middleware's unique identifier. */ + String name(); + + /** + * Returns a fresh instance for each {@code generate()} call, enabling per-invocation state. + * + *

Stable state (e.g., API keys, configuration) should be preserved. Per-request state (e.g., + * counters) should be reset. + */ + GenerationMiddleware newInstance(); + + /** + * Wraps each iteration of the generate tool loop. + * + * @param ctx the action context + * @param params the generate parameters including the current request and iteration + * @param next the next function in the chain + * @return the model response + * @throws GenkitException if processing fails + */ + ModelResponse wrapGenerate(ActionContext ctx, GenerateParams params, GenerateNext next) + throws GenkitException; + + /** + * Wraps each model API call. + * + * @param ctx the action context + * @param params the model parameters including the request + * @param next the next function in the chain + * @return the model response + * @throws GenkitException if processing fails + */ + ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException; + + /** + * Wraps each tool execution. May be called concurrently when multiple tools execute in parallel. + * Implementations must be safe for concurrent use. + * + * @param ctx the action context + * @param params the tool parameters including the request part and resolved tool + * @param next the next function in the chain + * @return the tool response part (includes part-level metadata) + * @throws GenkitException if processing fails + */ + Part wrapTool(ActionContext ctx, ToolParams params, ToolNext next) throws GenkitException; + + /** + * Returns additional tools to make available during generation. These tools are dynamically added + * when the middleware is used. + * + * @return the list of additional tools, or empty list if none + */ + default List> tools() { + return Collections.emptyList(); + } +} diff --git a/ai/src/main/java/com/google/genkit/ai/middleware/ModelNext.java b/ai/src/main/java/com/google/genkit/ai/middleware/ModelNext.java new file mode 100644 index 000000000..227b87f56 --- /dev/null +++ b/ai/src/main/java/com/google/genkit/ai/middleware/ModelNext.java @@ -0,0 +1,38 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +/** Next function in the {@link GenerationMiddleware#wrapModel} hook chain. */ +@FunctionalInterface +public interface ModelNext { + + /** + * Calls the next handler in the model chain. + * + * @param ctx the action context + * @param params the model parameters + * @return the model response + * @throws GenkitException if processing fails + */ + ModelResponse apply(ActionContext ctx, ModelParams params) throws GenkitException; +} diff --git a/ai/src/main/java/com/google/genkit/ai/middleware/ModelParams.java b/ai/src/main/java/com/google/genkit/ai/middleware/ModelParams.java new file mode 100644 index 000000000..0e2aa1633 --- /dev/null +++ b/ai/src/main/java/com/google/genkit/ai/middleware/ModelParams.java @@ -0,0 +1,56 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import com.google.genkit.ai.ModelRequest; +import com.google.genkit.ai.ModelResponseChunk; +import java.util.function.Consumer; + +/** Holds parameters for the {@link GenerationMiddleware#wrapModel} hook. */ +public class ModelParams { + + private final ModelRequest request; + private final Consumer streamCallback; + + /** + * Creates ModelParams. + * + * @param request the model request about to be sent + * @param streamCallback the streaming callback, or null if not streaming + */ + public ModelParams(ModelRequest request, Consumer streamCallback) { + this.request = request; + this.streamCallback = streamCallback; + } + + /** Returns the model request about to be sent. */ + public ModelRequest getRequest() { + return request; + } + + /** Returns the streaming callback, or null if not streaming. */ + public Consumer getStreamCallback() { + return streamCallback; + } + + /** Returns a new ModelParams with the given request, preserving the stream callback. */ + public ModelParams withRequest(ModelRequest request) { + return new ModelParams(request, this.streamCallback); + } +} diff --git a/ai/src/main/java/com/google/genkit/ai/middleware/ToolNext.java b/ai/src/main/java/com/google/genkit/ai/middleware/ToolNext.java new file mode 100644 index 000000000..fddbe6b71 --- /dev/null +++ b/ai/src/main/java/com/google/genkit/ai/middleware/ToolNext.java @@ -0,0 +1,38 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import com.google.genkit.ai.Part; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +/** Next function in the {@link GenerationMiddleware#wrapTool} hook chain. */ +@FunctionalInterface +public interface ToolNext { + + /** + * Calls the next handler in the tool chain. + * + * @param ctx the action context + * @param params the tool parameters + * @return the tool response part (includes part-level metadata) + * @throws GenkitException if processing fails + */ + Part apply(ActionContext ctx, ToolParams params) throws GenkitException; +} diff --git a/ai/src/main/java/com/google/genkit/ai/middleware/ToolParams.java b/ai/src/main/java/com/google/genkit/ai/middleware/ToolParams.java new file mode 100644 index 000000000..5b2e4b92d --- /dev/null +++ b/ai/src/main/java/com/google/genkit/ai/middleware/ToolParams.java @@ -0,0 +1,56 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import com.google.genkit.ai.Part; +import com.google.genkit.ai.Tool; +import com.google.genkit.ai.ToolRequest; + +/** Holds parameters for the {@link GenerationMiddleware#wrapTool} hook. */ +public class ToolParams { + + private final Part requestPart; + private final Tool tool; + + /** + * Creates ToolParams. + * + * @param requestPart the tool request part (includes metadata) about to be executed + * @param tool the resolved tool being called + */ + public ToolParams(Part requestPart, Tool tool) { + this.requestPart = requestPart; + this.tool = tool; + } + + /** Returns the full tool request part, including part-level metadata. */ + public Part getRequestPart() { + return requestPart; + } + + /** Convenience method: returns the tool request from the request part. */ + public ToolRequest getRequest() { + return requestPart.getToolRequest(); + } + + /** Returns the resolved tool being called. */ + public Tool getTool() { + return tool; + } +} diff --git a/ai/src/test/java/com/google/genkit/ai/middleware/GenerationMiddlewareTest.java b/ai/src/test/java/com/google/genkit/ai/middleware/GenerationMiddlewareTest.java new file mode 100644 index 000000000..fd4184fe7 --- /dev/null +++ b/ai/src/test/java/com/google/genkit/ai/middleware/GenerationMiddlewareTest.java @@ -0,0 +1,667 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import static org.junit.jupiter.api.Assertions.*; + +import com.google.genkit.ai.Candidate; +import com.google.genkit.ai.Message; +import com.google.genkit.ai.ModelRequest; +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.Part; +import com.google.genkit.ai.Tool; +import com.google.genkit.ai.ToolRequest; +import com.google.genkit.ai.ToolResponse; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.DefaultRegistry; +import com.google.genkit.core.GenkitException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** Tests for V2 GenerationMiddleware hooks: GenerateNext, ModelNext, ToolNext. */ +class GenerationMiddlewareTest { + + private ActionContext ctx; + + @BeforeEach + void setUp() { + ctx = new ActionContext(new DefaultRegistry()); + } + + // ========================================================================= + // Helper: build a simple ModelResponse with text + // ========================================================================= + + private static ModelResponse responseWithText(String text) { + Message msg = Message.model(text); + Candidate candidate = new Candidate(msg); + return ModelResponse.builder().addCandidate(candidate).build(); + } + + // ========================================================================= + // GenerateNext tests + // ========================================================================= + + @Test + void testGenerateNext_passThrough() { + ModelRequest request = ModelRequest.builder().addUserMessage("hello").build(); + GenerateParams params = new GenerateParams(request, 0); + ModelResponse expected = responseWithText("world"); + + GenerateNext next = (c, p) -> expected; + + ModelResponse result = next.apply(ctx, params); + assertSame(expected, result); + } + + @Test + void testGenerateNext_chainOrder() { + List order = new ArrayList<>(); + + // Core function + GenerateNext core = + (c, p) -> { + order.add("core"); + return responseWithText("response"); + }; + + // Outer middleware wrapping core + GenerateNext outer = + (c, p) -> { + order.add("outer-before"); + ModelResponse resp = core.apply(c, p); + order.add("outer-after"); + return resp; + }; + + ModelRequest request = ModelRequest.builder().addUserMessage("test").build(); + outer.apply(ctx, new GenerateParams(request, 0)); + + assertEquals(List.of("outer-before", "core", "outer-after"), order); + } + + @Test + void testGenerateNext_canModifyParams() { + ModelRequest original = ModelRequest.builder().addUserMessage("original").build(); + ModelRequest modified = ModelRequest.builder().addUserMessage("modified").build(); + + AtomicInteger iterationSeen = new AtomicInteger(-1); + GenerateNext core = + (c, p) -> { + iterationSeen.set(p.getIteration()); + assertEquals(modified, p.getRequest()); + return responseWithText("ok"); + }; + + // Middleware that replaces the request + GenerateNext wrapper = + (c, p) -> { + GenerateParams newParams = p.withRequest(modified); + return core.apply(c, newParams); + }; + + wrapper.apply(ctx, new GenerateParams(original, 5)); + assertEquals(5, iterationSeen.get()); // iteration preserved by withRequest + } + + @Test + void testGenerateNext_exceptionPropagates() { + GenerateNext failing = + (c, p) -> { + throw new GenkitException("boom"); + }; + + ModelRequest request = ModelRequest.builder().build(); + assertThrows(GenkitException.class, () -> failing.apply(ctx, new GenerateParams(request, 0))); + } + + // ========================================================================= + // ModelNext tests + // ========================================================================= + + @Test + void testModelNext_passThrough() { + ModelRequest request = ModelRequest.builder().addUserMessage("hello").build(); + ModelParams params = new ModelParams(request, null); + ModelResponse expected = responseWithText("model output"); + + ModelNext next = (c, p) -> expected; + + ModelResponse result = next.apply(ctx, params); + assertSame(expected, result); + } + + @Test + void testModelNext_chainOrder() { + List order = new ArrayList<>(); + + ModelNext core = + (c, p) -> { + order.add("model"); + return responseWithText("result"); + }; + + ModelNext wrapper = + (c, p) -> { + order.add("before-model"); + ModelResponse resp = core.apply(c, p); + order.add("after-model"); + return resp; + }; + + ModelRequest request = ModelRequest.builder().build(); + wrapper.apply(ctx, new ModelParams(request, null)); + + assertEquals(List.of("before-model", "model", "after-model"), order); + } + + @Test + void testModelNext_canModifyRequest() { + ModelRequest original = ModelRequest.builder().addUserMessage("original").build(); + ModelRequest modified = ModelRequest.builder().addUserMessage("injected").build(); + + ModelNext core = + (c, p) -> { + assertEquals(modified, p.getRequest()); + return responseWithText("ok"); + }; + + ModelNext wrapper = + (c, p) -> { + ModelParams newParams = p.withRequest(modified); + return core.apply(c, newParams); + }; + + wrapper.apply(ctx, new ModelParams(original, null)); + } + + @Test + void testModelNext_preservesStreamCallback() { + List streamed = new ArrayList<>(); + ModelParams params = + new ModelParams(ModelRequest.builder().build(), chunk -> streamed.add("chunk")); + + ModelNext next = + (c, p) -> { + assertNotNull(p.getStreamCallback()); + return responseWithText("ok"); + }; + + next.apply(ctx, params); + assertNotNull(params.getStreamCallback()); + } + + @Test + void testModelNext_exceptionPropagates() { + ModelNext failing = + (c, p) -> { + throw new GenkitException("model failed"); + }; + + assertThrows( + GenkitException.class, + () -> failing.apply(ctx, new ModelParams(ModelRequest.builder().build(), null))); + } + + // ========================================================================= + // ToolNext tests + // ========================================================================= + + @Test + void testToolNext_passThrough() { + ToolRequest toolReq = new ToolRequest("myTool", Map.of("key", "value")); + Tool tool = createTestTool("myTool"); + ToolParams params = new ToolParams(Part.toolRequest(toolReq), tool); + Part expected = Part.toolResponse(new ToolResponse("myTool", "tool output")); + + ToolNext next = (c, p) -> expected; + + Part result = next.apply(ctx, params); + assertSame(expected, result); + } + + @Test + void testToolNext_chainOrder() { + List order = new ArrayList<>(); + + ToolNext core = + (c, p) -> { + order.add("tool-exec"); + return Part.toolResponse(new ToolResponse(p.getRequest().getName(), "result")); + }; + + ToolNext wrapper = + (c, p) -> { + order.add("before-tool"); + Part resp = core.apply(c, p); + order.add("after-tool"); + return resp; + }; + + ToolRequest toolReq = new ToolRequest("test", Map.of()); + wrapper.apply(ctx, new ToolParams(Part.toolRequest(toolReq), createTestTool("test"))); + + assertEquals(List.of("before-tool", "tool-exec", "after-tool"), order); + } + + @Test + void testToolNext_accessesToolInfo() { + Tool tool = createTestTool("weatherTool"); + ToolRequest toolReq = new ToolRequest("weatherTool", Map.of("city", "Paris")); + + ToolNext next = + (c, p) -> { + assertEquals("weatherTool", p.getRequest().getName()); + assertEquals("weatherTool", p.getTool().getName()); + return Part.toolResponse(new ToolResponse("weatherTool", "sunny")); + }; + + Part resp = next.apply(ctx, new ToolParams(Part.toolRequest(toolReq), tool)); + assertEquals("weatherTool", resp.getToolResponse().getName()); + } + + @Test + void testToolNext_exceptionPropagates() { + ToolNext failing = + (c, p) -> { + throw new GenkitException("tool failed"); + }; + + ToolRequest toolReq = new ToolRequest("t", Map.of()); + assertThrows( + GenkitException.class, + () -> failing.apply(ctx, new ToolParams(Part.toolRequest(toolReq), createTestTool("t")))); + } + + // ========================================================================= + // BaseGenerationMiddleware tests + // ========================================================================= + + @Test + void testBaseMiddleware_defaultsPassThrough() { + BaseGenerationMiddleware base = + new BaseGenerationMiddleware() { + @Override + public String name() { + return "noop"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + }; + + // wrapGenerate passes through + ModelRequest req = ModelRequest.builder().addUserMessage("test").build(); + ModelResponse expected = responseWithText("pass"); + GenerateNext gNext = (c, p) -> expected; + ModelResponse gResult = base.wrapGenerate(ctx, new GenerateParams(req, 0), gNext); + assertSame(expected, gResult); + + // wrapModel passes through + ModelNext mNext = (c, p) -> expected; + ModelResponse mResult = base.wrapModel(ctx, new ModelParams(req, null), mNext); + assertSame(expected, mResult); + + // wrapTool passes through + Part toolExpected = Part.toolResponse(new ToolResponse("t", "data")); + ToolNext tNext = (c, p) -> toolExpected; + Part tResult = + base.wrapTool( + ctx, + new ToolParams(Part.toolRequest(new ToolRequest("t", Map.of())), createTestTool("t")), + tNext); + assertSame(toolExpected, tResult); + + // tools returns empty + assertTrue(base.tools().isEmpty()); + } + + @Test + void testCustomMiddleware_overridesSelectedHooks() { + AtomicInteger modelCallCount = new AtomicInteger(0); + + BaseGenerationMiddleware middleware = + new BaseGenerationMiddleware() { + @Override + public String name() { + return "model-counter"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + modelCallCount.incrementAndGet(); + return next.apply(ctx, params); + } + }; + + ModelRequest req = ModelRequest.builder().build(); + ModelResponse resp = responseWithText("ok"); + + // wrapModel is overridden + middleware.wrapModel(ctx, new ModelParams(req, null), (c, p) -> resp); + assertEquals(1, modelCallCount.get()); + + // wrapGenerate still passes through (default) + ModelResponse gResp = middleware.wrapGenerate(ctx, new GenerateParams(req, 0), (c, p) -> resp); + assertSame(resp, gResp); + assertEquals(1, modelCallCount.get()); // not incremented + } + + // ========================================================================= + // Chaining multiple middleware + // ========================================================================= + + @Test + void testChainGenerateHooks_nestedOrder() { + List order = new ArrayList<>(); + + GenerationMiddleware outer = + new BaseGenerationMiddleware() { + @Override + public String name() { + return "outer"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + + @Override + public ModelResponse wrapGenerate( + ActionContext ctx, GenerateParams params, GenerateNext next) throws GenkitException { + order.add("outer-before"); + ModelResponse resp = next.apply(ctx, params); + order.add("outer-after"); + return resp; + } + }; + + GenerationMiddleware inner = + new BaseGenerationMiddleware() { + @Override + public String name() { + return "inner"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + + @Override + public ModelResponse wrapGenerate( + ActionContext ctx, GenerateParams params, GenerateNext next) throws GenkitException { + order.add("inner-before"); + ModelResponse resp = next.apply(ctx, params); + order.add("inner-after"); + return resp; + } + }; + + // Chain: outer wraps inner wraps core + // This mirrors the chaining in Genkit.chainGenerateHooks() + List middlewares = List.of(outer, inner); + GenerateNext core = + (c, p) -> { + order.add("core"); + return responseWithText("done"); + }; + + // Build chain by reverse iteration (first middleware = outermost) + GenerateNext chain = core; + for (int i = middlewares.size() - 1; i >= 0; i--) { + GenerationMiddleware mw = middlewares.get(i); + GenerateNext wrapped = chain; + chain = (c, p) -> mw.wrapGenerate(c, p, wrapped); + } + + ModelRequest req = ModelRequest.builder().build(); + chain.apply(ctx, new GenerateParams(req, 0)); + + assertEquals( + List.of("outer-before", "inner-before", "core", "inner-after", "outer-after"), order); + } + + @Test + void testChainModelHooks_nestedOrder() { + List order = new ArrayList<>(); + + GenerationMiddleware first = + new BaseGenerationMiddleware() { + @Override + public String name() { + return "first"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + order.add("first-before"); + ModelResponse resp = next.apply(ctx, params); + order.add("first-after"); + return resp; + } + }; + + GenerationMiddleware second = + new BaseGenerationMiddleware() { + @Override + public String name() { + return "second"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + order.add("second-before"); + ModelResponse resp = next.apply(ctx, params); + order.add("second-after"); + return resp; + } + }; + + List middlewares = List.of(first, second); + ModelNext core = + (c, p) -> { + order.add("model"); + return responseWithText("result"); + }; + + ModelNext chain = core; + for (int i = middlewares.size() - 1; i >= 0; i--) { + GenerationMiddleware mw = middlewares.get(i); + ModelNext wrapped = chain; + chain = (c, p) -> mw.wrapModel(c, p, wrapped); + } + + chain.apply(ctx, new ModelParams(ModelRequest.builder().build(), null)); + + assertEquals( + List.of("first-before", "second-before", "model", "second-after", "first-after"), order); + } + + @Test + void testChainToolHooks_nestedOrder() { + List order = new ArrayList<>(); + + GenerationMiddleware first = + new BaseGenerationMiddleware() { + @Override + public String name() { + return "first"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + + @Override + public Part wrapTool(ActionContext ctx, ToolParams params, ToolNext next) + throws GenkitException { + order.add("first-before"); + Part resp = next.apply(ctx, params); + order.add("first-after"); + return resp; + } + }; + + GenerationMiddleware second = + new BaseGenerationMiddleware() { + @Override + public String name() { + return "second"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + + @Override + public Part wrapTool(ActionContext ctx, ToolParams params, ToolNext next) + throws GenkitException { + order.add("second-before"); + Part resp = next.apply(ctx, params); + order.add("second-after"); + return resp; + } + }; + + List middlewares = List.of(first, second); + ToolNext core = + (c, p) -> { + order.add("tool"); + return Part.toolResponse(new ToolResponse(p.getRequest().getName(), "output")); + }; + + ToolNext chain = core; + for (int i = middlewares.size() - 1; i >= 0; i--) { + GenerationMiddleware mw = middlewares.get(i); + ToolNext wrapped = chain; + chain = (c, p) -> mw.wrapTool(c, p, wrapped); + } + + ToolRequest toolReq = new ToolRequest("myTool", Map.of()); + chain.apply(ctx, new ToolParams(Part.toolRequest(toolReq), createTestTool("myTool"))); + + assertEquals( + List.of("first-before", "second-before", "tool", "second-after", "first-after"), order); + } + + // ========================================================================= + // newInstance() isolation + // ========================================================================= + + @Test + void testNewInstance_isolatesState() { + AtomicInteger sharedCounter = new AtomicInteger(0); + + GenerationMiddleware template = + new BaseGenerationMiddleware() { + private final AtomicInteger calls = new AtomicInteger(0); + + @Override + public String name() { + return "stateful"; + } + + @Override + public GenerationMiddleware newInstance() { + // Each instance gets its own counter + return new BaseGenerationMiddleware() { + private final AtomicInteger instanceCalls = new AtomicInteger(0); + + @Override + public String name() { + return "stateful"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + instanceCalls.incrementAndGet(); + sharedCounter.incrementAndGet(); + return next.apply(ctx, params); + } + }; + } + }; + + // Simulate two generate() calls creating separate instances + GenerationMiddleware instance1 = template.newInstance(); + GenerationMiddleware instance2 = template.newInstance(); + + ModelResponse resp = responseWithText("ok"); + ModelNext passThrough = (c, p) -> resp; + ModelParams params = new ModelParams(ModelRequest.builder().build(), null); + + // Call instance1 three times + instance1.wrapModel(ctx, params, passThrough); + instance1.wrapModel(ctx, params, passThrough); + instance1.wrapModel(ctx, params, passThrough); + + // Call instance2 once + instance2.wrapModel(ctx, params, passThrough); + + // Shared counter sees all 4 calls + assertEquals(4, sharedCounter.get()); + + // But instances are independent (verified by the fact that both ran without error) + } + + // ========================================================================= + // Helper + // ========================================================================= + + private static Tool createTestTool(String name) { + Map schema = new HashMap<>(); + schema.put("type", "string"); + return new Tool<>(name, "Test tool", schema, schema, String.class, (ctx, input) -> "result"); + } +} diff --git a/core/src/main/java/com/google/genkit/core/middleware/Middleware.java b/core/src/main/java/com/google/genkit/core/middleware/Middleware.java index db508e7d8..4479eda02 100644 --- a/core/src/main/java/com/google/genkit/core/middleware/Middleware.java +++ b/core/src/main/java/com/google/genkit/core/middleware/Middleware.java @@ -41,7 +41,10 @@ * * @param The input type * @param The output type + * @deprecated Use {@code com.google.genkit.ai.middleware.GenerationMiddleware} instead, which + * supports distinct Generate, Model, and Tool hooks. */ +@Deprecated @FunctionalInterface public interface Middleware { diff --git a/docs/src/content.config.ts b/docs/src/content.config.ts index 7fbcf2c33..3cd166154 100644 --- a/docs/src/content.config.ts +++ b/docs/src/content.config.ts @@ -1,3 +1,19 @@ +/** + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + import { defineCollection } from "astro:content"; import { docsLoader } from "@astrojs/starlight/loaders"; import { docsSchema } from "@astrojs/starlight/schema"; diff --git a/docs/src/content/docs/middleware.md b/docs/src/content/docs/middleware.md index 216848067..89c945724 100644 --- a/docs/src/content/docs/middleware.md +++ b/docs/src/content/docs/middleware.md @@ -1,11 +1,18 @@ --- title: Middleware -description: Add cross-cutting concerns to your AI workflows with middleware. +description: Add cross-cutting concerns to your AI workflows with flow middleware and generation middleware. --- -Middleware lets you intercept and modify the behavior of flow executions. Use middleware for logging, caching, rate limiting, retries, input validation, and more. Middleware follows the chain-of-responsibility pattern — each middleware can modify the request, call the next handler, and modify the response. +Middleware lets you intercept and modify the behavior of flow executions and AI generation. Genkit provides two middleware systems: -## Defining middleware +- **Flow Middleware** — wraps the entire flow function. Use for logging, caching, rate limiting, retries, and input validation. +- **Generation Middleware (V2)** — hooks into the `generate()` pipeline at three levels: model calls, tool executions, and loop iterations. Use for metering, observability, and tool interception. + +## Flow Middleware + +Flow middleware follows the chain-of-responsibility pattern — each middleware can modify the request, call the next handler, and modify the response. + +### Defining middleware A middleware is a function that receives the request, an `ActionContext`, and a `next` function to call the next handler in the chain: @@ -20,7 +27,7 @@ Middleware loggingMiddleware = (request, context, next) -> { }; ``` -## Attaching middleware to flows +### Attaching middleware to flows Pass middleware as a list when defining a flow: @@ -46,11 +53,11 @@ Flow chatFlow = genkit.defineFlow( Middleware executes in order — the first middleware in the list runs first (outermost), wrapping all subsequent middleware and the flow handler. -## Built-in middleware +### Built-in middleware The `CommonMiddleware` class provides factory methods for common patterns: -### Logging +#### Logging ```java import com.google.genkit.core.middleware.CommonMiddleware; @@ -62,7 +69,7 @@ Middleware logging = CommonMiddleware.logging("chat"); Middleware logging = CommonMiddleware.logging("chat", myLogger); ``` -### Retry with exponential backoff +#### Retry with exponential backoff ```java // Retry up to 3 times with 100ms initial delay @@ -73,7 +80,7 @@ Middleware retry = CommonMiddleware.retry(3, 100, error -> error.getMessage().contains("rate limit")); ``` -### Input validation +#### Input validation ```java Middleware validate = CommonMiddleware.validate(input -> { @@ -86,7 +93,7 @@ Middleware validate = CommonMiddleware.validate(input -> { }); ``` -### Request and response transformation +#### Request and response transformation ```java // Sanitize input @@ -98,7 +105,7 @@ Middleware format = CommonMiddleware.transformResponse( output -> "[" + Instant.now() + "] " + output); ``` -### Caching +#### Caching ```java import com.google.genkit.core.middleware.MiddlewareCache; @@ -111,28 +118,28 @@ Middleware cache = CommonMiddleware.cache( The `MiddlewareCache` interface requires `get(String key)` and `put(String key, O value)` methods. -### Rate limiting +#### Rate limiting ```java // Max 10 requests per 60 seconds Middleware rateLimit = CommonMiddleware.rateLimit(10, 60_000); ``` -### Timeout +#### Timeout ```java // 30 second timeout Middleware timeout = CommonMiddleware.timeout(30_000); ``` -### Error handling +#### Error handling ```java Middleware errorHandler = CommonMiddleware.errorHandler( error -> "Sorry, something went wrong: " + error.getMessage()); ``` -### Conditional middleware +#### Conditional middleware Apply middleware only when a condition is met: @@ -143,7 +150,7 @@ Middleware conditional = CommonMiddleware.conditional( ); ``` -### Before/after hooks +#### Before/after hooks ```java Middleware hooks = CommonMiddleware.beforeAfter( @@ -152,14 +159,14 @@ Middleware hooks = CommonMiddleware.beforeAfter( ); ``` -### Timing +#### Timing ```java Middleware timing = CommonMiddleware.timing( duration -> System.out.println("Took " + duration + "ms")); ``` -## Building a middleware chain +### Building a middleware chain Use `MiddlewareChain` for more control over middleware ordering: @@ -182,7 +189,7 @@ String result = chain.execute(input, context, (ctx, req) -> { }); ``` -## Custom middleware example +### Custom middleware example A metrics-collecting middleware: @@ -207,7 +214,7 @@ Middleware metricsMiddleware = (request, context, next) -> { }; ``` -## Built-in middleware reference +### Built-in middleware reference | Factory Method | Description | |---------------|-------------| @@ -224,7 +231,208 @@ Middleware metricsMiddleware = (request, context, next) -> { | `beforeAfter(before, after)` | Run hooks before and after | | `timing(callback)` | Measure execution duration | +## Generation Middleware (V2) + +Generation Middleware provides fine-grained hooks into the generation pipeline, letting you intercept model calls, tool executions, and generate loop iterations independently. Unlike flow-level middleware (which wraps the entire flow function), Generation Middleware operates inside `generate()` and is attached per call. + +### Three hooks + +| Hook | Wraps | Receives | Use cases | +|------|-------|----------|-----------| +| `wrapGenerate` | Each iteration of the tool loop | `GenerateParams` (request + iteration number) | Timing, logging per turn, retry logic | +| `wrapModel` | Each model API call | `ModelParams` (request + stream callback) | Token metering, request/response rewriting, caching | +| `wrapTool` | Each tool execution | `ToolParams` (request part + resolved tool) | Tool authorization, audit logging, error handling | + +Hooks nest naturally: `wrapGenerate` is the outermost layer, `wrapModel` runs inside it, and `wrapTool` runs for each tool the model requests. + +``` +wrapGenerate (iteration 0) +├── wrapModel → model API call +├── wrapTool → tool1 +├── wrapTool → tool2 +└── wrapGenerate (iteration 1) ← recursive via tool loop + ├── wrapModel → model API call + └── (no more tool calls → return) +``` + +### Defining Generation Middleware + +Implement the `GenerationMiddleware` interface or extend `BaseGenerationMiddleware` (which passes through by default). Override only the hooks you need: + +```java +import com.google.genkit.ai.middleware.BaseGenerationMiddleware; +import com.google.genkit.ai.middleware.GenerationMiddleware; +import com.google.genkit.ai.middleware.ModelNext; +import com.google.genkit.ai.middleware.ModelParams; + +class TokenMeteringMiddleware extends BaseGenerationMiddleware { + + private final AtomicInteger totalTokens = new AtomicInteger(0); + + @Override + public String name() { + return "token-metering"; + } + + @Override + public GenerationMiddleware newInstance() { + return new TokenMeteringMiddleware(); // fresh counters per generate() + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + ModelResponse response = next.apply(ctx, params); + // Inspect response for token usage + logger.info("Tokens used: {}", response.getUsage()); + return response; + } +} +``` + +Key points: + +- **`name()`** — unique identifier for the middleware. +- **`newInstance()`** — called once per `generate()` invocation. Return a fresh object so per-request state (counters, timers) is isolated. Stateless middleware can return `this`. +- **`next.apply(ctx, params)`** — calls the next middleware in the chain (or the core handler). You must call it to continue the pipeline. Skip it to short-circuit (e.g., return a cached response). + +### Attaching middleware to generate() + +Use `GenerateOptions.builder().use()`: + +```java +GenerationMiddleware metering = new TokenMeteringMiddleware(); +GenerationMiddleware logging = new ModelLoggingMiddleware(); + +ModelResponse response = genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .prompt("Explain middleware") + .use(metering, logging) + .build()); +``` + +Middleware order matters — the **first** middleware listed is **outermost** (runs first on the way in, last on the way out). + +### Multi-hook middleware + +A single middleware can implement all three hooks to observe every stage: + +```java +class FullObservabilityMiddleware extends BaseGenerationMiddleware { + + private final AtomicInteger iterations = new AtomicInteger(0); + private final AtomicInteger modelCalls = new AtomicInteger(0); + private final AtomicInteger toolCalls = new AtomicInteger(0); + + @Override + public String name() { return "full-observability"; } + + @Override + public GenerationMiddleware newInstance() { + return new FullObservabilityMiddleware(); + } + + @Override + public ModelResponse wrapGenerate(ActionContext ctx, GenerateParams params, + GenerateNext next) throws GenkitException { + int iter = iterations.incrementAndGet(); + logger.info("=== Generate iteration {} ===", iter); + ModelResponse resp = next.apply(ctx, params); + logger.info("=== Iteration {} done (model: {}, tools: {}) ===", + iter, modelCalls.get(), toolCalls.get()); + return resp; + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, + ModelNext next) throws GenkitException { + modelCalls.incrementAndGet(); + return next.apply(ctx, params); + } + + @Override + public Part wrapTool(ActionContext ctx, ToolParams params, + ToolNext next) throws GenkitException { + toolCalls.incrementAndGet(); + logger.info("Tool: {}", params.getRequest().getName()); + return next.apply(ctx, params); + } +} +``` + +### Middleware-provided tools + +Middleware can inject additional tools into the generation by overriding `tools()`: + +```java +@Override +public List> tools() { + return List.of(myCustomTool); +} +``` + +These tools are merged with the tools from `GenerateOptions.tools()` and are available for the model to call. + +### Middleware with interrupts and restarts + +Generation Middleware integrates with the [interrupt system](/docs/interrupts). When a tool throws `ToolInterruptException`, the `wrapTool` hook still fires — the exception propagates through the middleware chain, so you can observe or handle it. + +When resuming with `ResumeOptions.builder().restart(toolRequest)`, the restarted tool executes through the full `wrapTool` chain, and the subsequent model call goes through a new `wrapGenerate` iteration. This ensures middleware sees every operation regardless of whether it was an initial call or a restart. + +``` +Initial generate: + wrapGenerate(0) + ├── wrapModel → model requests tool4 + ├── wrapTool → tool1 (completes) + ├── wrapTool → tool2 (completes) + └── wrapTool → tool4 (interrupts!) → return interrupted response + +Restart generate: + wrapTool → tool4 (restart, completes) + wrapGenerate(1) + ├── wrapModel → model returns final answer + └── return response +``` + +### BaseGenerationMiddleware + +`BaseGenerationMiddleware` provides pass-through defaults for all hooks. Extend it to override only what you need: + +```java +class TimingMiddleware extends BaseGenerationMiddleware { + + @Override + public String name() { return "timing"; } + + @Override + public GenerationMiddleware newInstance() { return new TimingMiddleware(); } + + @Override + public ModelResponse wrapGenerate(ActionContext ctx, GenerateParams params, + GenerateNext next) throws GenkitException { + long start = System.currentTimeMillis(); + ModelResponse resp = next.apply(ctx, params); + logger.info("Iteration {} took {}ms", + params.getIteration(), System.currentTimeMillis() - start); + return resp; + } +} +``` + +### Generation Middleware vs Flow Middleware + +| | Flow Middleware | Generation Middleware (V2) | +|---|---|---| +| **Scope** | The entire flow function | Inside `generate()` — model, tools, iterations | +| **Attached to** | `defineFlow(..., middleware)` | `GenerateOptions.builder().use()` | +| **Typed to** | Flow input/output types | `ModelRequest` / `ModelResponse` / `Part` | +| **State** | Shared across requests | Fresh per `generate()` via `newInstance()` | +| **Best for** | Auth, rate limiting, validation | Observability, metering, tool interception | + +You can use both together — flow middleware wraps the outer flow, and generation middleware wraps the inner AI pipeline. + ## Samples -- [middleware sample](https://github.com/genkit-ai/genkit-java/tree/main/samples/middleware) — Custom and built-in middleware patterns -- [middleware-v2 sample](https://github.com/genkit-ai/genkit-java/tree/main/samples/middleware-v2) — Updated middleware approaches +- [middleware sample](https://github.com/genkit-ai/genkit-java/tree/main/samples/middleware) — Flow-level middleware patterns (logging, retry, caching, validation) +- [middleware-v2 sample](https://github.com/genkit-ai/genkit-java/tree/main/samples/middleware-v2) — Generation Middleware with all three hooks and interrupt/restart lifecycle diff --git a/genkit/src/main/java/com/google/genkit/Genkit.java b/genkit/src/main/java/com/google/genkit/Genkit.java index ef6c45f2e..84a90ecaa 100644 --- a/genkit/src/main/java/com/google/genkit/Genkit.java +++ b/genkit/src/main/java/com/google/genkit/Genkit.java @@ -20,6 +20,7 @@ import com.google.genkit.ai.*; import com.google.genkit.ai.evaluation.*; +import com.google.genkit.ai.middleware.*; import com.google.genkit.ai.session.*; import com.google.genkit.ai.telemetry.ModelTelemetryHelper; import com.google.genkit.core.*; @@ -658,85 +659,380 @@ private ModelResponse generateInternal(GenerateOptions options) throws Genkit ActionContext ctx = new ActionContext(registry); int maxTurns = options.getMaxTurns() != null ? options.getMaxTurns() : 5; - int turn = 0; + + // Create fresh middleware instances for this invocation + List middlewares = createMiddlewareInstances(options.getUse()); + + // Collect tools from middleware instances and merge with options tools + List> allTools = new ArrayList<>(); + if (options.getTools() != null) { + allTools.addAll(options.getTools()); + } + for (GenerationMiddleware mw : middlewares) { + List> mwTools = mw.tools(); + if (mwTools != null && !mwTools.isEmpty()) { + allTools.addAll(mwTools); + } + } + + // Add middleware tool definitions to the model request + if (allTools.size() > (options.getTools() != null ? options.getTools().size() : 0)) { + List allToolDefs = new ArrayList<>(); + if (request.getTools() != null) { + allToolDefs.addAll(request.getTools()); + } + for (GenerationMiddleware mw : middlewares) { + List> mwTools = mw.tools(); + if (mwTools != null) { + for (Tool t : mwTools) { + allToolDefs.add(t.getDefinition()); + } + } + } + request.setTools(allToolDefs); + } // Handle resume option if provided if (options.getResume() != null) { request = handleResumeOption(request, options); } - while (turn < maxTurns) { - // Create span metadata for the model call - SpanMetadata modelSpanMetadata = - SpanMetadata.builder() - .name(options.getModel()) - .type(ActionType.MODEL.getValue()) - .subtype("model") - .build(); + // Extract pending restart requests (handled inside the generate loop for proper middleware + // lifecycle: wrapTool hooks fire for restarted tools, then wrapGenerate fires for next turn) + final List pendingRestarts = new java.util.ArrayList<>(); + if (options.getResume() != null && options.getResume().getRestart() != null) { + pendingRestarts.addAll(options.getResume().getRestart()); + } - String flowName = ctx.getFlowName(); - if (flowName != null) { - modelSpanMetadata.getAttributes().put("genkit:metadata:flow:name", flowName); - } + // Build model call wrapped with WrapModel hooks + ModelNext wrappedModelCall = buildWrappedModelCall(model, options, ctx, middlewares); - final ModelRequest currentRequest = request; - final String flowNameForTelemetry = flowName; - final String spanPath = "/generate/" + options.getModel(); - ModelResponse response = - Tracer.runInNewSpan( - ctx, + // Use an array to hold the reference for recursive WrapGenerate wrapping + final GenerateNext[] generateRef = new GenerateNext[1]; + + // Core generate iteration: model call → tool handling → recurse + GenerateNext rawGenerate = + (actx, params) -> { + ModelRequest req = params.getRequest(); + int turn = params.getIteration(); + + if (turn >= maxTurns) { + throw new GenkitException("Max tool execution turns (" + maxTurns + ") exceeded"); + } + + // Handle pending restart tools through middleware before calling model. + // This ensures wrapTool hooks fire for restarted tools, and subsequent + // recursion through generateRef fires wrapGenerate for the next turn. + if (!pendingRestarts.isEmpty()) { + List restarts = new java.util.ArrayList<>(pendingRestarts); + pendingRestarts.clear(); + + // Convert restart requests to tool request parts for middleware execution + List restartParts = + restarts.stream() + .map(Part::toolRequest) + .collect(java.util.stream.Collectors.toList()); + + // Execute through WrapTool chain (fires wrapTool hooks) + ToolExecutionResult toolResult = + executeToolsWithMiddleware(actx, restartParts, allTools, middlewares); + + // If a restart tool interrupts again, fail + if (!toolResult.getInterrupts().isEmpty()) { + throw new GenkitException( + "Tool triggered an interrupt during restart. " + + "Re-interrupting during restart is not supported."); + } + + // Add restart tool responses to messages + List updatedMessages = new java.util.ArrayList<>(req.getMessages()); + + // If last message is a TOOL message (from respond directives), merge restart responses + if (!updatedMessages.isEmpty() + && updatedMessages.get(updatedMessages.size() - 1).getRole() == Role.TOOL) { + Message existingToolMsg = updatedMessages.get(updatedMessages.size() - 1); + List mergedContent = new java.util.ArrayList<>(existingToolMsg.getContent()); + for (Part restartResp : toolResult.getResponses()) { + Map metadata = + restartResp.getMetadata() != null + ? new java.util.HashMap<>(restartResp.getMetadata()) + : new java.util.HashMap<>(); + metadata.put("source", "restart"); + restartResp.setMetadata(metadata); + mergedContent.add(restartResp); + } + existingToolMsg.setContent(mergedContent); + } else { + // Create new TOOL message with restart responses + Message toolResponseMessage = new Message(); + toolResponseMessage.setRole(Role.TOOL); + List restartResponses = new java.util.ArrayList<>(); + for (Part restartResp : toolResult.getResponses()) { + Map metadata = + restartResp.getMetadata() != null + ? new java.util.HashMap<>(restartResp.getMetadata()) + : new java.util.HashMap<>(); + metadata.put("source", "restart"); + restartResp.setMetadata(metadata); + restartResponses.add(restartResp); + } + toolResponseMessage.setContent(restartResponses); + Map toolMsgMetadata = new java.util.HashMap<>(); + toolMsgMetadata.put("resumed", true); + toolResponseMessage.setMetadata(toolMsgMetadata); + updatedMessages.add(toolResponseMessage); + } + + // Recurse through WrapGenerate hooks for the next turn + ModelRequest nextRequest = + ModelRequest.builder() + .messages(updatedMessages) + .config(req.getConfig()) + .tools(req.getTools()) + .output(req.getOutput()) + .build(); + + return generateRef[0].apply(actx, new GenerateParams(nextRequest, turn + 1)); + } + + // Call model through WrapModel chain + ModelParams mparams = new ModelParams(req, null); + ModelResponse response = wrappedModelCall.apply(actx, mparams); + + // Check if the model requested tool calls + List toolRequestParts = extractToolRequestParts(response); + if (toolRequestParts.isEmpty()) { + return response; + } + + // Execute tools through WrapTool chain (includes middleware-provided tools) + ToolExecutionResult toolResult = + executeToolsWithMiddleware(actx, toolRequestParts, allTools, middlewares); + + // If there are interrupts, return immediately + if (!toolResult.getInterrupts().isEmpty()) { + ModelResponse interruptedResponse = buildInterruptedResponse(response, toolResult); + // Set original request so getMessages() includes conversation history + interruptedResponse.setRequest(req); + return interruptedResponse; + } + + // Build next request with updated messages + Message assistantMessage = response.getMessage(); + List updatedMessages = new java.util.ArrayList<>(req.getMessages()); + updatedMessages.add(assistantMessage); + + Message toolResponseMessage = new Message(); + toolResponseMessage.setRole(Role.TOOL); + toolResponseMessage.setContent(toolResult.getResponses()); + updatedMessages.add(toolResponseMessage); + + ModelRequest nextRequest = + ModelRequest.builder() + .messages(updatedMessages) + .config(req.getConfig()) + .tools(req.getTools()) + .output(req.getOutput()) + .build(); + + // Recurse through the wrapped generate function (goes through WrapGenerate hooks) + return generateRef[0].apply(actx, new GenerateParams(nextRequest, turn + 1)); + }; + + // Chain WrapGenerate hooks around the core iteration + generateRef[0] = chainGenerateHooks(middlewares, rawGenerate); + + // Start generation + return generateRef[0].apply(ctx, new GenerateParams(request, 0)); + } + + /** Creates fresh middleware instances for a single generate invocation. */ + private List createMiddlewareInstances(List use) { + if (use == null || use.isEmpty()) { + return List.of(); + } + return use.stream().map(GenerationMiddleware::newInstance).toList(); + } + + /** Builds the model call function wrapped with WrapModel hooks from middleware. */ + private ModelNext buildWrappedModelCall( + Model model, + GenerateOptions options, + ActionContext ctx, + List middlewares) { + + // Core model call with telemetry + ModelNext core = + (actx, mparams) -> { + ModelRequest req = mparams.getRequest(); + + SpanMetadata modelSpanMetadata = + SpanMetadata.builder() + .name(options.getModel()) + .type(ActionType.MODEL.getValue()) + .subtype("model") + .build(); + + String flowName = actx.getFlowName(); + if (flowName != null) { + modelSpanMetadata.getAttributes().put("genkit:metadata:flow:name", flowName); + } + + final String spanPath = "/generate/" + options.getModel(); + return Tracer.runInNewSpan( + actx, modelSpanMetadata, - request, - (spanCtx, req) -> { - // Wrap model execution with telemetry to record generate metrics + req, + (spanCtx, r) -> { return ModelTelemetryHelper.runWithTelemetry( options.getModel(), - flowNameForTelemetry, + flowName, spanPath, - currentRequest, - r -> model.run(ctx.withSpanContext(spanCtx), r)); + req, + mr -> model.run(actx.withSpanContext(spanCtx), mr)); }); + }; - // Check if the model requested tool calls - List toolRequestParts = extractToolRequestParts(response); - if (toolRequestParts.isEmpty()) { - // No tool calls, return the response - return response; - } + return chainModelHooks(middlewares, core); + } + + /** Chains WrapGenerate hooks. First middleware is outermost. */ + private GenerateNext chainGenerateHooks( + List middlewares, GenerateNext core) { + if (middlewares.isEmpty()) { + return core; + } + GenerateNext current = core; + for (int i = middlewares.size() - 1; i >= 0; i--) { + final GenerationMiddleware mw = middlewares.get(i); + final GenerateNext next = current; + current = (ctx, params) -> mw.wrapGenerate(ctx, params, next); + } + return current; + } + + /** Chains WrapModel hooks. First middleware is outermost. */ + private ModelNext chainModelHooks(List middlewares, ModelNext core) { + if (middlewares.isEmpty()) { + return core; + } + ModelNext current = core; + for (int i = middlewares.size() - 1; i >= 0; i--) { + final GenerationMiddleware mw = middlewares.get(i); + final ModelNext next = current; + current = (ctx, params) -> mw.wrapModel(ctx, params, next); + } + return current; + } + + /** Chains WrapTool hooks. First middleware is outermost. */ + private ToolNext chainToolHooks(List middlewares, ToolNext core) { + if (middlewares.isEmpty()) { + return core; + } + ToolNext current = core; + for (int i = middlewares.size() - 1; i >= 0; i--) { + final GenerationMiddleware mw = middlewares.get(i); + final ToolNext next = current; + current = (ctx, params) -> mw.wrapTool(ctx, params, next); + } + return current; + } + + /** Executes tools with WrapTool middleware hooks applied. */ + private ToolExecutionResult executeToolsWithMiddleware( + ActionContext ctx, + List toolRequestParts, + List> tools, + List middlewares) { + + // Build WrapTool chain + ToolNext wrappedToolCall = + chainToolHooks( + middlewares, + (actx, tparams) -> { + Tool tool = tparams.getTool(); + ToolRequest toolReq = tparams.getRequest(); + + Object toolInput = toolReq.getInput(); + Class inputClass = tool.getInputClass(); + if (inputClass != null && toolInput != null && !inputClass.isInstance(toolInput)) { + toolInput = JsonUtils.convert(toolInput, inputClass); + } + + @SuppressWarnings("unchecked") + Tool typedTool = (Tool) tool; + Object result = typedTool.run(actx, toolInput); + + return Part.toolResponse( + new ToolResponse(toolReq.getRef(), toolReq.getName(), result)); + }); + + List responseParts = new java.util.ArrayList<>(); + List interrupts = new java.util.ArrayList<>(); + Map interruptMap = new java.util.HashMap<>(); + Map pendingOutputMap = new java.util.HashMap<>(); - // Execute tools and handle interrupts - ToolExecutionResult toolResult = - executeToolsWithInterruptHandling(ctx, toolRequestParts, options.getTools()); + for (Part toolRequestPart : toolRequestParts) { + ToolRequest toolRequest = toolRequestPart.getToolRequest(); + String toolName = toolRequest.getName(); + String key = toolName + "#" + (toolRequest.getRef() != null ? toolRequest.getRef() : ""); - // If there are interrupts, return immediately with interrupted response - if (!toolResult.getInterrupts().isEmpty()) { - return buildInterruptedResponse(response, toolResult); + Tool tool = findTool(toolName, tools); + if (tool == null) { + Part errorPart = new Part(); + ToolResponse errorResponse = + new ToolResponse( + toolRequest.getRef(), toolName, Map.of("error", "Tool not found: " + toolName)); + errorPart.setToolResponse(errorResponse); + responseParts.add(errorPart); + continue; } - // Add the assistant message with tool requests - Message assistantMessage = response.getMessage(); - List updatedMessages = new java.util.ArrayList<>(request.getMessages()); - updatedMessages.add(assistantMessage); + try { + // Execute through WrapTool chain + ToolParams tparams = new ToolParams(toolRequestPart, tool); + Part responsePart = wrappedToolCall.apply(ctx, tparams); - // Add tool response message - Message toolResponseMessage = new Message(); - toolResponseMessage.setRole(Role.TOOL); - toolResponseMessage.setContent(toolResult.getResponses()); - updatedMessages.add(toolResponseMessage); + responseParts.add(responsePart); - // Update request with new messages for next turn - request = - ModelRequest.builder() - .messages(updatedMessages) - .config(request.getConfig()) - .tools(request.getTools()) - .output(request.getOutput()) - .build(); + pendingOutputMap.put(key, responsePart.getToolResponse().getOutput()); - turn++; + logger.debug("Executed tool '{}' successfully", toolName); + + } catch (ToolInterruptException e) { + Map interruptMetadata = e.getMetadata(); + + Part interruptPart = new Part(); + interruptPart.setToolRequest(toolRequest); + Map metadata = + toolRequestPart.getMetadata() != null + ? new java.util.HashMap<>(toolRequestPart.getMetadata()) + : new java.util.HashMap<>(); + metadata.put( + "interrupt", + interruptMetadata != null && !interruptMetadata.isEmpty() ? interruptMetadata : true); + interruptPart.setMetadata(metadata); + + interrupts.add(interruptPart); + interruptMap.put(key, interruptPart); + + logger.debug("Tool '{}' triggered interrupt", toolName); + + } catch (Exception e) { + logger.error("Tool execution failed for '{}': {}", toolName, e.getMessage()); + Part errorPart = new Part(); + ToolResponse errorResponse = + new ToolResponse( + toolRequest.getRef(), + toolName, + Map.of("error", "Tool execution failed: " + e.getMessage())); + errorPart.setToolResponse(errorResponse); + responseParts.add(errorPart); + } } - throw new GenkitException("Max tool execution turns (" + maxTurns + ") exceeded"); + return new ToolExecutionResult(responseParts, interrupts, interruptMap, pendingOutputMap); } /** Handles resume options by processing respond and restart directives. */ @@ -756,9 +1052,16 @@ private ModelRequest handleResumeOption(ModelRequest request, GenerateOptions // Build tool response parts from resume options List toolResponseParts = new java.util.ArrayList<>(); + // Collect tool names/refs from respond directives + java.util.Set respondedTools = new java.util.HashSet<>(); + // Handle respond directives if (resume.getRespond() != null) { for (ToolResponse toolResponse : resume.getRespond()) { + respondedTools.add( + toolResponse.getName() + + "#" + + (toolResponse.getRef() != null ? toolResponse.getRef() : "")); Part responsePart = new Part(); responsePart.setToolResponse(toolResponse); Map metadata = new java.util.HashMap<>(); @@ -768,52 +1071,63 @@ private ModelRequest handleResumeOption(ModelRequest request, GenerateOptions } } - // Handle restart directives - execute the tools + // Note: restart directives are handled inside the generate loop + // for proper middleware lifecycle (wrapTool and wrapGenerate hooks fire correctly) + boolean hasRespond = resume.getRespond() != null && !resume.getRespond().isEmpty(); + boolean hasRestart = resume.getRestart() != null && !resume.getRestart().isEmpty(); + + if (!hasRespond && !hasRestart) { + throw new GenkitException("Resume options must contain either respond or restart directives"); + } + + // Collect tool names/refs from restart directives to avoid duplicating their responses + java.util.Set restartedTools = new java.util.HashSet<>(); if (resume.getRestart() != null) { - ActionContext ctx = new ActionContext(registry); - for (ToolRequest restartRequest : resume.getRestart()) { - Tool tool = findTool(restartRequest.getName(), options.getTools()); - if (tool == null) { - throw new GenkitException("Tool not found for restart: " + restartRequest.getName()); - } + for (ToolRequest toolRequest : resume.getRestart()) { + restartedTools.add( + toolRequest.getName() + + "#" + + (toolRequest.getRef() != null ? toolRequest.getRef() : "")); + } + } - try { - @SuppressWarnings("unchecked") - Tool typedTool = (Tool) tool; - Object result = typedTool.run(ctx, restartRequest.getInput()); - - Part responsePart = new Part(); - ToolResponse toolResponse = - new ToolResponse(restartRequest.getRef(), restartRequest.getName(), result); - responsePart.setToolResponse(toolResponse); - Map metadata = new java.util.HashMap<>(); - metadata.put("source", "restart"); - responsePart.setMetadata(metadata); - toolResponseParts.add(responsePart); - } catch (ToolInterruptException e) { - // Tool interrupted again during restart - throw new GenkitException( - "Tool '" - + restartRequest.getName() - + "' triggered an interrupt during restart. " - + "Re-interrupting during restart is not supported."); + // Add tool responses for completed tools (pendingOutput metadata) that aren't + // being explicitly responded to or restarted. This ensures all tool_calls in the + // model message have matching tool responses (required by providers like OpenAI). + for (Part part : lastMessage.getContent()) { + if (part.getToolRequest() != null && part.getMetadata() != null) { + Object pendingOutput = part.getMetadata().get("pendingOutput"); + if (pendingOutput != null) { + String key = + part.getToolRequest().getName() + + "#" + + (part.getToolRequest().getRef() != null ? part.getToolRequest().getRef() : ""); + if (!respondedTools.contains(key) && !restartedTools.contains(key)) { + Part responsePart = new Part(); + ToolResponse toolResponse = + new ToolResponse( + part.getToolRequest().getRef(), part.getToolRequest().getName(), pendingOutput); + responsePart.setToolResponse(toolResponse); + Map metadata = new java.util.HashMap<>(); + metadata.put("pendingOutput", true); + responsePart.setMetadata(metadata); + toolResponseParts.add(responsePart); + } } } } - if (toolResponseParts.isEmpty()) { - throw new GenkitException("Resume options must contain either respond or restart directives"); + if (!toolResponseParts.isEmpty()) { + // Add tool response message for completed and responded tools + Message toolResponseMessage = new Message(); + toolResponseMessage.setRole(Role.TOOL); + toolResponseMessage.setContent(toolResponseParts); + Map toolMsgMetadata = new java.util.HashMap<>(); + toolMsgMetadata.put("resumed", true); + toolResponseMessage.setMetadata(toolMsgMetadata); + messages.add(toolResponseMessage); } - // Add tool response message - Message toolResponseMessage = new Message(); - toolResponseMessage.setRole(Role.TOOL); - toolResponseMessage.setContent(toolResponseParts); - Map toolMsgMetadata = new java.util.HashMap<>(); - toolMsgMetadata.put("resumed", true); - toolResponseMessage.setMetadata(toolMsgMetadata); - messages.add(toolResponseMessage); - return ModelRequest.builder() .messages(messages) .config(request.getConfig()) diff --git a/pom.xml b/pom.xml index 4ccf05dde..de25ae4f3 100644 --- a/pom.xml +++ b/pom.xml @@ -104,6 +104,7 @@ samples/evaluators-plugin samples/complex-io samples/middleware + samples/middleware-v2 samples/mcp samples/chat-session samples/multi-agent diff --git a/samples/firebase/dependency-reduced-pom.xml b/samples/firebase/dependency-reduced-pom.xml index a98fb527c..56d745010 100644 --- a/samples/firebase/dependency-reduced-pom.xml +++ b/samples/firebase/dependency-reduced-pom.xml @@ -27,7 +27,7 @@ com.google.cloud.functions function-maven-plugin - 1.0.0 + 1.0.1 com.google.genkit.samples.firebase.functions.GeneratePoemFunction @@ -68,12 +68,12 @@ com.google.cloud.functions functions-framework-api - 2.0.0 + 2.0.1 provided - 2.0.0 + 2.0.1 21 21 1.0.0-SNAPSHOT diff --git a/samples/middleware-v2/README.md b/samples/middleware-v2/README.md new file mode 100644 index 000000000..5726e101c --- /dev/null +++ b/samples/middleware-v2/README.md @@ -0,0 +1,161 @@ +# Genkit Java Middleware V2 Sample + +This sample demonstrates the **V2 GenerationMiddleware** system, which provides three distinct hooks into the generation pipeline: + +- **WrapGenerate** — wraps each iteration of the tool loop +- **WrapModel** — wraps each model API call +- **WrapTool** — wraps each tool execution + +Unlike V1 middleware (which wraps flows), V2 middleware is attached per `generate()` call via `GenerateOptions.builder().use()` and hooks directly into the AI generation pipeline. + +## Prerequisites + +- Java 21+ +- Maven 3.6+ +- OpenAI API key + +## Running the Sample + +### Option 1: Direct Run + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/middleware-v2 + +# Run the sample +./run.sh +# Or: mvn compile exec:java +``` + +### Option 2: With Genkit Dev UI (Recommended) + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/middleware-v2 + +# Run with Genkit CLI +genkit start -- ./run.sh +``` + +The Dev UI will be available at http://localhost:4000 + +## Middleware Examples + +### 1. ModelLoggingMiddleware (WrapModel) +Logs every model API call with a per-invocation counter. Demonstrates `newInstance()` for fresh state per `generate()` call. + +### 2. GenerateTimingMiddleware (WrapGenerate) +Measures wall-clock time for each generate loop iteration (model call + tool execution). + +### 3. ToolMonitorMiddleware (WrapTool) +Logs tool execution name and duration. Stateless — `newInstance()` returns `this`. + +### 4. FullObservabilityMiddleware (All 3 hooks) +A single middleware that implements all three hooks, showing how one middleware can observe the entire pipeline with per-invocation counters. + +## Available Endpoints + +| Endpoint | Description | Middleware | +|----------|-------------|------------| +| `/v2-chat` | AI chat | Model logging + generate timing | +| `/v2-observable` | AI chat | Full observability (all 3 hooks) | +| `/v2-stacked` | AI chat | Three separate middleware stacked | +| `/v2-baseline` | AI chat | No middleware (baseline) | + +## Example Requests + +```bash +# Chat with model logging + timing +curl -X POST http://localhost:8080/v2-chat \ + -H 'Content-Type: application/json' \ + -d '"What is middleware?"' + +# Chat with full observability +curl -X POST http://localhost:8080/v2-observable \ + -H 'Content-Type: application/json' \ + -d '"Explain Java records"' + +# Chat with stacked middleware +curl -X POST http://localhost:8080/v2-stacked \ + -H 'Content-Type: application/json' \ + -d '"Hello world"' + +# Baseline (no middleware) +curl -X POST http://localhost:8080/v2-baseline \ + -H 'Content-Type: application/json' \ + -d '"Hello world"' +``` + +## Creating Custom V2 Middleware + +Extend `BaseGenerationMiddleware` and override only the hooks you need: + +```java +import com.google.genkit.ai.middleware.*; +import com.google.genkit.core.ActionContext; + +public class MyMiddleware extends BaseGenerationMiddleware { + + @Override + public String name() { return "my-middleware"; } + + @Override + public GenerationMiddleware newInstance() { return new MyMiddleware(); } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + System.out.println("Before model call"); + ModelResponse resp = next.apply(ctx, params); + System.out.println("After model call: " + resp.getText().length() + " chars"); + return resp; + } +} +``` + +Then attach it to a `generate()` call: + +```java +ModelResponse response = genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .prompt("Hello") + .use(new MyMiddleware()) + .build()); +``` + +## Architecture + +V2 middleware wraps the generation pipeline at three levels: + +``` +generate() call + └─ WrapGenerate (per tool-loop iteration) + └─ WrapModel (per model API call) + └─ WrapTool (per tool execution) + └─ recurse → next WrapGenerate iteration +``` + +Each `generate()` call creates fresh middleware instances via `newInstance()`, enabling per-invocation state (counters, timers) without shared mutable state across requests. + +Middleware are chained in order — the first middleware in the `use()` list is the outermost wrapper. + +## V1 vs V2 Middleware + +| | V1 (`Middleware`) | V2 (`GenerationMiddleware`) | +|---|---|---| +| **Scope** | Wraps flows | Wraps generation pipeline | +| **Hooks** | Single `apply()` | 3 hooks: Generate, Model, Tool | +| **Attachment** | `defineFlow(..., middleware)` | `GenerateOptions.builder().use(...)` | +| **State** | Shared across calls | Fresh per `generate()` via `newInstance()` | + +## See Also + +- [V1 Middleware Sample](../middleware/) — flow-level middleware +- [Genkit Documentation](https://github.com/genkit-ai/genkit-java) diff --git a/samples/middleware-v2/pom.xml b/samples/middleware-v2/pom.xml new file mode 100644 index 000000000..a2a1443bb --- /dev/null +++ b/samples/middleware-v2/pom.xml @@ -0,0 +1,92 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + com.google.genkit.samples + genkit-sample-middleware-v2 + jar + Genkit Middleware V2 Sample + Sample application demonstrating V2 GenerationMiddleware with Generate, Model, and Tool hooks + + + UTF-8 + 21 + 21 + 1.0.0-SNAPSHOT + true + true + com.google.genkit.samples.MiddlewareV2Sample + + + + + com.google.genkit + genkit + ${genkit.version} + + + com.google.genkit + genkit-plugin-openai + ${genkit.version} + + + com.google.genkit + genkit-plugin-jetty + ${genkit.version} + + + ch.qos.logback + logback-classic + 1.5.32 + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.15.0 + + 21 + 21 + + + + org.codehaus.mojo + exec-maven-plugin + 3.6.3 + + ${exec.mainClass} + + + + + diff --git a/samples/middleware-v2/run.sh b/samples/middleware-v2/run.sh new file mode 100755 index 000000000..54805f669 --- /dev/null +++ b/samples/middleware-v2/run.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +# Run the Genkit Middleware V2 Sample + +set -e + +# Navigate to the sample directory +cd "$(dirname "$0")" + +# Check for OPENAI_API_KEY +if [ -z "$OPENAI_API_KEY" ]; then + echo "Warning: OPENAI_API_KEY is not set. The sample may not work correctly." + echo "Set it with: export OPENAI_API_KEY=your-api-key" +fi + +# Build and run +echo "Building and running Genkit Middleware V2 Sample..." +mvn compile exec:java -q diff --git a/samples/middleware-v2/src/main/java/com/google/genkit/samples/MiddlewareInterruptRestartSample.java b/samples/middleware-v2/src/main/java/com/google/genkit/samples/MiddlewareInterruptRestartSample.java new file mode 100644 index 000000000..8400c1034 --- /dev/null +++ b/samples/middleware-v2/src/main/java/com/google/genkit/samples/MiddlewareInterruptRestartSample.java @@ -0,0 +1,385 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.ai.GenerateOptions; +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.Part; +import com.google.genkit.ai.ResumeOptions; +import com.google.genkit.ai.Tool; +import com.google.genkit.ai.ToolInterruptException; +import com.google.genkit.ai.ToolRequest; +import com.google.genkit.ai.middleware.BaseGenerationMiddleware; +import com.google.genkit.ai.middleware.GenerateNext; +import com.google.genkit.ai.middleware.GenerateParams; +import com.google.genkit.ai.middleware.GenerationMiddleware; +import com.google.genkit.ai.middleware.ModelNext; +import com.google.genkit.ai.middleware.ModelParams; +import com.google.genkit.ai.middleware.ToolNext; +import com.google.genkit.ai.middleware.ToolParams; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; +import com.google.genkit.plugins.openai.OpenAIPlugin; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Sample that tests the middleware lifecycle during interrupt restart. + * + *

This validates Pavel's feedback on PR #125: when restarting an interrupted tool, the + * middleware must follow the correct nested lifecycle. Specifically: + * + *

+ * === Initial generate call (triggers tool4 interrupt mid-flow) ===
+ *
+ * generate - 1
+ *     model
+ *     tool1
+ *     tool2
+ *     tool3
+ *     generate - 2
+ *         model
+ *         tool4  // <--- INTERRUPT
+ *
+ * === Restart of tool4 (correct lifecycle) ===
+ *
+ * generate - 1          // restart generate call
+ *     tool4             // RESTART (through wrapTool middleware)
+ *     generate - 2      // nested - NOT flat!
+ *         model
+ *         // done
+ * 
+ * + *

The WRONG (naive) implementation would flatten this to: + * + *

+ * generate - 1
+ *     tool4   // RESTART
+ *     model   // flat - no nested generate
+ *     // done
+ * 
+ * + *

This sample uses a lifecycle-tracking middleware that records every hook invocation, then + * verifies the correct nesting after restart. + * + *

To run: + * + *

    + *
  1. Set the OPENAI_API_KEY environment variable + *
  2. Run: mvn exec:java + * -Dexec.mainClass="com.google.genkit.samples.MiddlewareInterruptRestartSample" -pl + * samples/middleware-v2 + *
+ */ +public class MiddlewareInterruptRestartSample { + + // ========================================================================= + // Lifecycle-tracking middleware + // ========================================================================= + + /** + * Middleware that records every hook invocation as a structured log entry. Used to verify the + * correct nesting of generate/model/tool calls. + */ + static class LifecycleTracker extends BaseGenerationMiddleware { + + private final List log; + private final AtomicInteger depth = new AtomicInteger(0); + + LifecycleTracker(List log) { + this.log = log; + } + + @Override + public String name() { + return "lifecycle-tracker"; + } + + @Override + public GenerationMiddleware newInstance() { + return new LifecycleTracker(log); + } + + private String indent() { + return " ".repeat(depth.get()); + } + + @Override + public ModelResponse wrapGenerate(ActionContext ctx, GenerateParams params, GenerateNext next) + throws GenkitException { + String entry = indent() + "generate - " + (params.getIteration() + 1); + log.add(entry); + System.out.println(entry); + depth.incrementAndGet(); + try { + return next.apply(ctx, params); + } finally { + depth.decrementAndGet(); + } + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + String entry = indent() + "model"; + log.add(entry); + System.out.println(entry); + return next.apply(ctx, params); + } + + @Override + public Part wrapTool(ActionContext ctx, ToolParams params, ToolNext next) + throws GenkitException { + String toolName = params.getRequest().getName(); + String entry = indent() + toolName; + log.add(entry); + System.out.println(entry); + return next.apply(ctx, params); + } + } + + // ========================================================================= + // Data classes + // ========================================================================= + + public static class ActionInput { + private String action; + + public ActionInput() {} + + public String getAction() { + return action; + } + + public void setAction(String action) { + this.action = action; + } + } + + // ========================================================================= + // Main + // ========================================================================= + + public static void main(String[] args) throws Exception { + Genkit genkit = + Genkit.builder() + .options(GenkitOptions.builder().devMode(false).build()) + .plugin(OpenAIPlugin.create()) + .build(); + + // Define regular tools (tool1, tool2, tool3) with object input schemas (required by OpenAI) + @SuppressWarnings("unchecked") + Tool, String> tool1 = + genkit.defineTool( + "tool1", + "First tool - runs task 1", + Map.of("type", "object", "properties", Map.of(), "additionalProperties", false), + (Class>) (Class) Map.class, + (ctx, input) -> "tool1-result"); + + @SuppressWarnings("unchecked") + Tool, String> tool2 = + genkit.defineTool( + "tool2", + "Second tool - runs task 2", + Map.of("type", "object", "properties", Map.of(), "additionalProperties", false), + (Class>) (Class) Map.class, + (ctx, input) -> "tool2-result"); + + @SuppressWarnings("unchecked") + Tool, String> tool3 = + genkit.defineTool( + "tool3", + "Third tool - runs task 3", + Map.of("type", "object", "properties", Map.of(), "additionalProperties", false), + (Class>) (Class) Map.class, + (ctx, input) -> "tool3-result"); + + // Define tool4 as a restartable tool that interrupts on first call. + // On restart (second call), it succeeds. This simulates a tool that needs + // human confirmation before proceeding. + final java.util.concurrent.atomic.AtomicBoolean tool4HasInterrupted = + new java.util.concurrent.atomic.AtomicBoolean(false); + + @SuppressWarnings("unchecked") + Tool, String> tool4 = + genkit.defineTool( + "tool4", + "Fourth tool - requires confirmation, interrupts on first call", + Map.of( + "type", + "object", + "properties", + Map.of("action", Map.of("type", "string")), + "additionalProperties", + false), + (Class>) (Class) Map.class, + (ctx, input) -> { + if (!tool4HasInterrupted.getAndSet(true)) { + // First call: interrupt to request confirmation + throw new ToolInterruptException( + Map.of("reason", "needs confirmation", "action", String.valueOf(input))); + } + // Restart call: proceed normally + return "tool4-confirmed-result"; + }); + + // Shared log across initial + restart calls + List lifecycleLog = Collections.synchronizedList(new ArrayList<>()); + + GenerationMiddleware tracker = new LifecycleTracker(lifecycleLog); + + System.out.println("==========================================================="); + System.out.println(" Middleware Interrupt Restart Lifecycle Test"); + System.out.println("==========================================================="); + System.out.println(); + + // --------------------------------------------------------------- + // Step 1: Initial generate - model should call tools, tool4 interrupts + // --------------------------------------------------------------- + // Note: The actual model call pattern depends on the LLM response. + // We ask it to call all 4 tools. The first 3 succeed, tool4 interrupts. + System.out.println(">>> Step 1: Initial generate (expecting tool4 to interrupt)"); + System.out.println("-----------------------------------------------------------"); + + ModelResponse response = + genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .system( + "You are a task executor. When asked to run all tasks, you MUST call all 4 " + + "tools in order: tool1, tool2, tool3, tool4. Call them all at once.") + .prompt("Run all tasks now.") + .tools(List.of(tool1, tool2, tool3, tool4)) + .use(tracker) + .maxTurns(5) + .build()); + + System.out.println(); + if (response.isInterrupted()) { + System.out.println(">>> tool4 interrupted as expected!"); + System.out.println(); + + // --------------------------------------------------------------- + // Step 2: Restart tool4 with middleware - should fire wrapTool + nested wrapGenerate + // --------------------------------------------------------------- + System.out.println(">>> Step 2: Restart tool4 (expecting nested lifecycle)"); + System.out.println("-----------------------------------------------------------"); + + // Find the interrupted tool request + Part interruptPart = response.getInterrupts().get(0); + ToolRequest interruptedRequest = interruptPart.getToolRequest(); + + // Create restart request (re-execute tool4 with same input) + ToolRequest restartRequest = new ToolRequest(); + restartRequest.setName(interruptedRequest.getName()); + restartRequest.setRef(interruptedRequest.getRef()); + restartRequest.setInput(interruptedRequest.getInput()); + + ModelResponse resumedResponse = + genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .messages(response.getMessages()) + .tools(List.of(tool1, tool2, tool3, tool4)) + .use(tracker) + .resume(ResumeOptions.builder().restart(restartRequest).build()) + .maxTurns(5) + .build()); + + System.out.println(); + System.out.println(">>> Restart completed. Final response:"); + System.out.println(resumedResponse.getText()); + } else { + System.out.println(">>> No interrupt occurred (model didn't call tool4)."); + System.out.println(">>> Response: " + response.getText()); + } + + // --------------------------------------------------------------- + // Print full lifecycle log + // --------------------------------------------------------------- + System.out.println(); + System.out.println("==========================================================="); + System.out.println(" Full Lifecycle Log"); + System.out.println("==========================================================="); + for (String entry : lifecycleLog) { + System.out.println(entry); + } + + System.out.println(); + System.out.println("==========================================================="); + System.out.println(" Verification"); + System.out.println("==========================================================="); + + // Verify that the restart lifecycle shows nested generate calls + // After restart, we expect to see at minimum: + // generate - 1 (restart iteration) + // tool4 (through wrapTool) + // generate - 2 (nested, NOT flat) + // model (model call after tool4 completes) + boolean foundRestartGenerate = false; + boolean foundRestartTool = false; + boolean foundNestedGenerate = false; + boolean foundNestedModel = false; + + // Look at the restart portion of the log (after the initial call) + boolean inRestartPhase = false; + for (String entry : lifecycleLog) { + // The restart phase starts with the second "generate - 1" + if (!inRestartPhase && entry.trim().equals("generate - 1")) { + if (foundRestartGenerate) { + // This is the second "generate - 1", so we're in restart phase + inRestartPhase = true; + foundRestartGenerate = true; + continue; + } + foundRestartGenerate = true; + } + if (inRestartPhase) { + if (entry.trim().equals("tool4")) { + foundRestartTool = true; + } + if (entry.trim().equals("generate - 2")) { + foundNestedGenerate = true; + } + if (foundNestedGenerate && entry.trim().equals("model")) { + foundNestedModel = true; + } + } + } + + System.out.println("Restart fires wrapTool for tool4: " + foundRestartTool); + System.out.println("Restart has nested generate: " + foundNestedGenerate); + System.out.println("Nested generate calls model: " + foundNestedModel); + + if (foundRestartTool && foundNestedGenerate && foundNestedModel) { + System.out.println(); + System.out.println("PASS: Restart follows correct nested lifecycle!"); + } else { + System.out.println(); + System.out.println("FAIL: Restart lifecycle is flat (naive implementation)."); + System.out.println(" Expected: generate -> tool4 -> generate -> model"); + } + } +} diff --git a/samples/middleware-v2/src/main/java/com/google/genkit/samples/MiddlewareV2Sample.java b/samples/middleware-v2/src/main/java/com/google/genkit/samples/MiddlewareV2Sample.java new file mode 100644 index 000000000..476eb2e04 --- /dev/null +++ b/samples/middleware-v2/src/main/java/com/google/genkit/samples/MiddlewareV2Sample.java @@ -0,0 +1,402 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.ai.GenerateOptions; +import com.google.genkit.ai.GenerationConfig; +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.Part; +import com.google.genkit.ai.Tool; +import com.google.genkit.ai.middleware.BaseGenerationMiddleware; +import com.google.genkit.ai.middleware.GenerateNext; +import com.google.genkit.ai.middleware.GenerateParams; +import com.google.genkit.ai.middleware.GenerationMiddleware; +import com.google.genkit.ai.middleware.ModelNext; +import com.google.genkit.ai.middleware.ModelParams; +import com.google.genkit.ai.middleware.ToolNext; +import com.google.genkit.ai.middleware.ToolParams; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.Flow; +import com.google.genkit.core.GenkitException; +import com.google.genkit.plugins.jetty.JettyPlugin; +import com.google.genkit.plugins.jetty.JettyPluginOptions; +import com.google.genkit.plugins.openai.OpenAIPlugin; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Sample application demonstrating the V2 GenerationMiddleware system. + * + *

V2 middleware hooks into three distinct stages of the generation pipeline: + * + *

    + *
  • WrapGenerate — wraps each iteration of the tool loop + *
  • WrapModel — wraps each model API call + *
  • WrapTool — wraps each tool execution + *
+ * + *

Middleware is attached per {@code generate()} call via {@code GenerateOptions.builder().use()} + * rather than per flow. + * + *

Each {@code generate()} call creates a fresh middleware instance via {@code newInstance()}, + * enabling per-invocation state (counters, timers) without shared mutable state across requests. + * + *

To run: + * + *

    + *
  1. Set the OPENAI_API_KEY environment variable + *
  2. Run: mvn exec:java + *
+ */ +public class MiddlewareV2Sample { + + private static final Logger logger = LoggerFactory.getLogger(MiddlewareV2Sample.class); + + // ========================================================================= + // Example 1: Model logging middleware (WrapModel hook) + // ========================================================================= + + /** + * Logs every model API call with a call counter. The counter resets per generate() invocation + * because {@code newInstance()} returns a fresh object. + */ + static class ModelLoggingMiddleware extends BaseGenerationMiddleware { + + private final AtomicInteger modelCalls = new AtomicInteger(0); + + @Override + public String name() { + return "model-logging"; + } + + @Override + public GenerationMiddleware newInstance() { + return new ModelLoggingMiddleware(); + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + int callNum = modelCalls.incrementAndGet(); + logger.info("[model-logging] Model call #{}", callNum); + ModelResponse resp = next.apply(ctx, params); + logger.info( + "[model-logging] Model call #{} returned ({} chars)", + callNum, + resp.getText() != null ? resp.getText().length() : 0); + return resp; + } + } + + // ========================================================================= + // Example 2: Generate timing middleware (WrapGenerate hook) + // ========================================================================= + + /** + * Measures the wall-clock time of each generate loop iteration including model call + tool + * execution within that iteration. + */ + static class GenerateTimingMiddleware extends BaseGenerationMiddleware { + + @Override + public String name() { + return "generate-timing"; + } + + @Override + public GenerationMiddleware newInstance() { + return new GenerateTimingMiddleware(); + } + + @Override + public ModelResponse wrapGenerate(ActionContext ctx, GenerateParams params, GenerateNext next) + throws GenkitException { + long start = System.currentTimeMillis(); + logger.info("[generate-timing] Starting iteration {}", params.getIteration()); + ModelResponse resp = next.apply(ctx, params); + logger.info( + "[generate-timing] Iteration {} completed in {}ms", + params.getIteration(), + System.currentTimeMillis() - start); + return resp; + } + } + + // ========================================================================= + // Example 3: Tool monitor middleware (WrapTool hook) + // ========================================================================= + + /** Logs tool execution name and duration. Stateless, so newInstance() returns {@code this}. */ + static class ToolMonitorMiddleware extends BaseGenerationMiddleware { + + @Override + public String name() { + return "tool-monitor"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; // stateless — safe to reuse + } + + @Override + public Part wrapTool(ActionContext ctx, ToolParams params, ToolNext next) + throws GenkitException { + String toolName = params.getRequest().getName(); + logger.info("[tool-monitor] Executing tool: {}", toolName); + long start = System.currentTimeMillis(); + Part resp = next.apply(ctx, params); + logger.info( + "[tool-monitor] Tool {} completed in {}ms", toolName, System.currentTimeMillis() - start); + return resp; + } + } + + // ========================================================================= + // Example 4: Combined multi-hook middleware + // ========================================================================= + + /** + * A single middleware that implements all three hooks. Demonstrates that one middleware can + * observe every stage of the pipeline. + */ + static class FullObservabilityMiddleware extends BaseGenerationMiddleware { + + private final AtomicInteger iterations = new AtomicInteger(0); + private final AtomicInteger modelCalls = new AtomicInteger(0); + private final AtomicInteger toolCalls = new AtomicInteger(0); + + @Override + public String name() { + return "full-observability"; + } + + @Override + public GenerationMiddleware newInstance() { + return new FullObservabilityMiddleware(); + } + + @Override + public ModelResponse wrapGenerate(ActionContext ctx, GenerateParams params, GenerateNext next) + throws GenkitException { + int iter = iterations.incrementAndGet(); + logger.info("[observability] === Generate iteration {} ===", iter); + ModelResponse resp = next.apply(ctx, params); + logger.info( + "[observability] === Iteration {} done (model calls: {}, tool calls: {}) ===", + iter, + modelCalls.get(), + toolCalls.get()); + return resp; + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + int call = modelCalls.incrementAndGet(); + logger.info("[observability] Model call #{}", call); + return next.apply(ctx, params); + } + + @Override + public Part wrapTool(ActionContext ctx, ToolParams params, ToolNext next) + throws GenkitException { + int call = toolCalls.incrementAndGet(); + logger.info("[observability] Tool call #{}: {}", call, params.getRequest().getName()); + return next.apply(ctx, params); + } + } + + // ========================================================================= + // Main + // ========================================================================= + + public static void main(String[] args) throws Exception { + JettyPlugin jetty = new JettyPlugin(JettyPluginOptions.builder().port(8080).build()); + + Genkit genkit = + Genkit.builder() + .options(GenkitOptions.builder().devMode(true).reflectionPort(3100).build()) + .plugin(OpenAIPlugin.create()) + .plugin(jetty) + .build(); + + // Instantiate middleware (templates — newInstance() is called per generate()) + GenerationMiddleware modelLogging = new ModelLoggingMiddleware(); + GenerationMiddleware generateTiming = new GenerateTimingMiddleware(); + GenerationMiddleware toolMonitor = new ToolMonitorMiddleware(); + GenerationMiddleware fullObservability = new FullObservabilityMiddleware(); + + // Define a simple tool so the WrapTool hook gets exercised + @SuppressWarnings("unchecked") + Tool, Map> weatherTool = + genkit.defineTool( + "getWeather", + "Gets the current weather for a given city", + Map.of( + "type", + "object", + "properties", + Map.of("city", Map.of("type", "string", "description", "The city name")), + "required", + new String[] {"city"}), + (Class>) (Class) Map.class, + (ctx, input) -> { + String city = (String) input.get("city"); + Map weather = new HashMap<>(); + weather.put("city", city); + weather.put("temperature", "22°C"); + weather.put("conditions", "Sunny"); + return weather; + }); + + // ======================================================= + // Flow 1: Simple chat with model logging + generate timing + // ======================================================= + + Flow chatFlow = + genkit.defineFlow( + "v2-chat", + String.class, + String.class, + (ctx, userMessage) -> { + ModelResponse response = + genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .system("You are a helpful assistant. Be concise.") + .prompt(userMessage) + .use(modelLogging, generateTiming) + .config( + GenerationConfig.builder() + .temperature(0.7) + .maxOutputTokens(200) + .build()) + .build()); + return response.getText(); + }); + + // ======================================================= + // Flow 2: Chat with all three hooks via full observability + // ======================================================= + + Flow observableFlow = + genkit.defineFlow( + "v2-observable", + String.class, + String.class, + (ctx, userMessage) -> { + ModelResponse response = + genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .system( + "You are a helpful assistant. Use the getWeather tool when asked about weather.") + .prompt(userMessage) + .tools(List.of(weatherTool)) + .use(fullObservability) + .config( + GenerationConfig.builder() + .temperature(0.7) + .maxOutputTokens(300) + .build()) + .build()); + return response.getText(); + }); + + // ======================================================= + // Flow 3: Stacking multiple middleware together + // ======================================================= + + Flow stackedFlow = + genkit.defineFlow( + "v2-stacked", + String.class, + String.class, + (ctx, userMessage) -> { + ModelResponse response = + genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .prompt(userMessage) + .use(modelLogging, generateTiming, toolMonitor) + .config( + GenerationConfig.builder() + .temperature(0.7) + .maxOutputTokens(200) + .build()) + .build()); + return response.getText(); + }); + + // ======================================================= + // Flow 4: No middleware (baseline for comparison) + // ======================================================= + + Flow baselineFlow = + genkit.defineFlow( + "v2-baseline", + String.class, + String.class, + (ctx, userMessage) -> { + ModelResponse response = + genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .prompt(userMessage) + .config( + GenerationConfig.builder() + .temperature(0.7) + .maxOutputTokens(200) + .build()) + .build()); + return response.getText(); + }); + + logger.info("\n========================================"); + logger.info("Genkit Middleware V2 Sample Started!"); + logger.info("========================================\n"); + + logger.info("Available flows:"); + logger.info(" - v2-chat: Model logging + generate timing middleware"); + logger.info(" - v2-observable: Full observability (all 3 hooks in one middleware)"); + logger.info(" - v2-stacked: Three separate middleware stacked together"); + logger.info(" - v2-baseline: No middleware (baseline comparison)\n"); + + logger.info("Server running on http://localhost:8080"); + logger.info("Reflection server running on http://localhost:3100"); + logger.info("\nExample requests:"); + logger.info( + " curl -X POST http://localhost:8080/v2-chat -H 'Content-Type: application/json' -d '\"What is middleware?\"'"); + logger.info( + " curl -X POST http://localhost:8080/v2-observable -H 'Content-Type: application/json' -d '\"Explain Java records\"'"); + logger.info( + " curl -X POST http://localhost:8080/v2-stacked -H 'Content-Type: application/json' -d '\"Hello world\"'"); + logger.info( + " curl -X POST http://localhost:8080/v2-baseline -H 'Content-Type: application/json' -d '\"Hello world\"'"); + + jetty.start(); + } +} diff --git a/samples/middleware-v2/src/main/resources/logback.xml b/samples/middleware-v2/src/main/resources/logback.xml new file mode 100644 index 000000000..fe98c37a8 --- /dev/null +++ b/samples/middleware-v2/src/main/resources/logback.xml @@ -0,0 +1,26 @@ + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + + + + + + + + + + +