Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions genkit-tools/common/src/types/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -399,5 +399,7 @@ export const GenerateActionOptionsSchema = z.object({
maxTurns: z.number().optional(),
/** Custom step name for this generate call to display in trace views. Defaults to "generate". */
stepName: z.string().optional(),
/** Registered middleware to be used with this model call. */
middleware: z.array(z.string()).optional(),
});
export type GenerateActionOptions = z.infer<typeof GenerateActionOptionsSchema>;
6 changes: 6 additions & 0 deletions genkit-tools/genkit-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,12 @@
},
"stepName": {
"type": "string"
},
"middleware": {
"type": "array",
"items": {
"type": "string"
}
}
},
"required": [
Expand Down
1 change: 1 addition & 0 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ type GenerateActionOptions struct {
Docs []*Document `json:"docs,omitempty"`
MaxTurns int `json:"maxTurns,omitempty"`
Messages []*Message `json:"messages,omitempty"`
Middleware []string `json:"middleware,omitempty"`
Model string `json:"model,omitempty"`
Output *GenerateActionOutputConfig `json:"output,omitempty"`
Resume *GenerateActionResume `json:"resume,omitempty"`
Expand Down
5 changes: 3 additions & 2 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import {
type GenerationCommonConfigSchema,
type MessageData,
type ModelArgument,
type ModelMiddleware,
type ModelMiddlewareArgument,
type Part,
type ToolRequestPart,
type ToolResponsePart,
Expand Down Expand Up @@ -170,7 +170,7 @@ export interface GenerateOptions<
*/
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
/** Middleware to be used with this model call. */
use?: ModelMiddleware[];
use?: ModelMiddlewareArgument[];
/** Additional context (data, like e.g. auth) to be passed down to tools, prompts and other sub actions. */
context?: ActionContext;
/** Abort signal for the generate request. */
Expand Down Expand Up @@ -538,6 +538,7 @@ export async function toGenerateActionOptions<
returnToolRequests: options.returnToolRequests,
maxTurns: options.maxTurns,
stepName: options.stepName,
middleware: options.use?.filter((m): m is string => typeof m === 'string'),
};
// if config is empty and it was not explicitly passed in, we delete it, don't want {}
if (Object.keys(params.config).length === 0 && !options.config) {
Expand Down
102 changes: 80 additions & 22 deletions js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

import {
ActionRunOptions,
GenkitError,
StreamingCallback,
defineAction,
Expand Down Expand Up @@ -42,6 +43,8 @@ import {
GenerateResponseChunkSchema,
GenerateResponseSchema,
MessageData,
ModelMiddlewareArgument,
ModelMiddlewareWithOptions,
resolveModel,
type GenerateActionOptions,
type GenerateActionOutputConfig,
Expand Down Expand Up @@ -85,17 +88,16 @@ export function defineGenerateAction(registry: Registry): GenerateAction {
outputSchema: GenerateResponseSchema,
streamSchema: GenerateResponseChunkSchema,
},
async (request, { streamingRequested, sendChunk }) => {
async (request, { streamingRequested, sendChunk, context }) => {
const generateFn = (
sendChunk?: StreamingCallback<GenerateResponseChunk>
) =>
generate(registry, {
rawRequest: request,
currentTurn: 0,
messageIndex: 0,
// Generate util action does not support middleware. Maybe when we add named/registered middleware....
middleware: [],
streamingCallback: sendChunk,
context,
});
return streamingRequested
? generateFn((c: GenerateResponseChunk) =>
Expand All @@ -113,18 +115,18 @@ export async function generateHelper(
registry: Registry,
options: {
rawRequest: GenerateActionOptions;
middleware?: ModelMiddleware[];
middleware?: ModelMiddlewareArgument[];
currentTurn?: number;
messageIndex?: number;
abortSignal?: AbortSignal;
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
context?: Record<string, any>;
}
): Promise<GenerateResponseData> {
const currentTurn = options.currentTurn ?? 0;
const messageIndex = options.messageIndex ?? 0;
// do tracing
return await runInNewSpan(
registry,
{
metadata: {
name: options.rawRequest.stepName || 'generate',
Expand All @@ -143,6 +145,7 @@ export async function generateHelper(
messageIndex,
abortSignal: options.abortSignal,
streamingCallback: options.streamingCallback,
context: options.context,
});
metadata.output = JSON.stringify(output);
return output;
Expand Down Expand Up @@ -247,13 +250,15 @@ async function generate(
messageIndex,
abortSignal,
streamingCallback,
context,
}: {
rawRequest: GenerateActionOptions;
middleware: ModelMiddleware[] | undefined;
middleware?: ModelMiddlewareArgument[] | undefined;
currentTurn: number;
messageIndex: number;
abortSignal?: AbortSignal;
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
context?: Record<string, any>;
}
): Promise<GenerateResponseData> {
const { model, tools, resources, format } = await resolveParameters(
Expand Down Expand Up @@ -319,30 +324,62 @@ async function generate(
streamingCallback(makeChunk('tool', resumedToolMessage));
}

const rawMiddleware = rawRequest.middleware || [];
const argMiddleware = middleware || [];
const effectiveRawMiddleware = rawMiddleware.filter(
(m) => !argMiddleware.includes(m)
);
const allMiddleware = [...argMiddleware, ...effectiveRawMiddleware];

var response: GenerateResponse;
const sendChunk =
streamingCallback &&
(((chunk: GenerateResponseChunkData) =>
streamingCallback &&
streamingCallback(makeChunk('model', chunk))) as any);
const dispatch = async (
index: number,
req: z.infer<typeof GenerateRequestSchema>
req: z.infer<typeof GenerateRequestSchema>,
actionOpts: ActionRunOptions<any>
) => {
if (!middleware || index === middleware.length) {
if (index === allMiddleware.length) {
// end of the chain, call the original model action
return await model(req, {
abortSignal,
onChunk:
streamingCallback &&
(((chunk: GenerateResponseChunkData) =>
streamingCallback &&
streamingCallback(makeChunk('model', chunk))) as any),
});
return await model(req, actionOpts);
}

const currentMiddleware = middleware[index];
return currentMiddleware(req, async (modifiedReq) =>
dispatch(index + 1, modifiedReq || req)
);
let currentMiddleware = allMiddleware[index];
if (typeof currentMiddleware === 'string') {
const resolvedMiddleware = await registry.lookupValue<
ModelMiddleware | ModelMiddlewareWithOptions
>('modelMiddleware', currentMiddleware);
if (!resolvedMiddleware) {
throw new GenkitError({
status: 'NOT_FOUND',
message: `Middleware '${currentMiddleware}' not found.`,
});
}
currentMiddleware = resolvedMiddleware;
}

if (currentMiddleware.length === 3) {
return (currentMiddleware as ModelMiddlewareWithOptions)(
req,
actionOpts,
async (modifiedReq, opts) =>
dispatch(index + 1, modifiedReq || req, opts || actionOpts)
);
} else {
return (currentMiddleware as ModelMiddleware)(req, async (modifiedReq) =>
dispatch(index + 1, modifiedReq || req, actionOpts)
);
}
};

const modelResponse = await dispatch(0, request);
const modelResponse = await dispatch(0, request, {
abortSignal,
context,
onChunk: sendChunk,
});

if (model.__action.actionType === 'background-model') {
response = new GenerateResponse(
Expand Down Expand Up @@ -416,7 +453,28 @@ async function generate(
// then recursively call for another loop
return await generateHelper(registry, {
rawRequest: nextRequest,
middleware: middleware,
middleware: allMiddleware, // Pass the combined middleware to the next recursion to avoid re-combining logic issues if any (but we re-evaluate rawRequest here)
// Wait, if we pass 'allMiddleware' here, we are passing functions and strings.
// 'generate' function expects that.
// However, we are also passing 'rawRequest' which is 'nextRequest'.
// 'nextRequest' is derived from 'rawRequest'. Does it keep 'middleware' property?
// Yes, spread operator `{...rawRequest, ...}` copies it.
// So 'nextRequest' has 'middleware' strings.
// 'allMiddleware' has functions + unique strings.
// In recursive call, 'generate' will combine them AGAIN.
// 'allMiddleware' (from arg) will be 'argMiddleware' in next call.
// 'rawRequest.middleware' will be 'rawMiddleware' in next call.
// 'effectiveRaw' will filter out strings present in 'allMiddleware'.
// If 'allMiddleware' contains the strings (which it does, from effectiveRaw), then they are filtered out.
// If 'allMiddleware' contains functions (resolved), they are not filtered.
// So we should be fine?
// Actually, 'allMiddleware' passed to 'generateHelper' becomes 'middleware' arg.
// 'middleware' arg will contain everything.
// 'rawRequest.middleware' will contain original strings.
// 'effectiveRaw' = raw.filter(m => !all.includes(m)).
// If 'all' contains the strings, effectiveRaw is empty.
// So we just use 'all'.
// This seems correct recursion-wise.
currentTurn: currentTurn + 1,
messageIndex: messageIndex + 1,
streamingCallback,
Expand Down
2 changes: 2 additions & 0 deletions js/ai/src/model-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -410,5 +410,7 @@ export const GenerateActionOptionsSchema = z.object({
maxTurns: z.number().optional(),
/** Custom step name for this generate call to display in trace views. Defaults to "generate". */
stepName: z.string().optional(),
/** Registered middleware to be used with this model call. */
middleware: z.array(z.string()).optional(),
});
export type GenerateActionOptions = z.infer<typeof GenerateActionOptionsSchema>;
18 changes: 15 additions & 3 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
ActionFnArg,
BackgroundAction,
GenkitError,
MiddlewareWithOptions,
Operation,
OperationSchema,
action,
Expand Down Expand Up @@ -108,6 +109,17 @@ export type ModelMiddleware = SimpleMiddleware<
z.infer<typeof GenerateResponseSchema>
>;

export type ModelMiddlewareWithOptions = MiddlewareWithOptions<
z.infer<typeof GenerateRequestSchema>,
z.infer<typeof GenerateResponseSchema>,
z.infer<typeof GenerateResponseChunkSchema>
>;

export type ModelMiddlewareArgument =
| ModelMiddleware
| ModelMiddlewareWithOptions
| string;

export type DefineModelOptions<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
> = {
Expand All @@ -121,7 +133,7 @@ export type DefineModelOptions<
/** Descriptive name for this model e.g. 'Google AI - Gemini Pro'. */
label?: string;
/** Middleware to be used with this model. */
use?: ModelMiddleware[];
use?: ModelMiddlewareArgument[];
};

export function model<CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny>(
Expand Down Expand Up @@ -324,11 +336,11 @@ export function backgroundModel<
}

function getModelMiddleware(options: {
use?: ModelMiddleware[];
use?: ModelMiddlewareArgument[];
name: string;
supports?: ModelInfo['supports'];
}) {
const middleware: ModelMiddleware[] = options.use || [];
const middleware: ModelMiddlewareArgument[] = options.use || [];
if (!options?.supports?.context) middleware.push(augmentWithContext());
const constratedSimulator = simulateConstrainedGeneration();
middleware.push((req, next) => {
Expand Down
1 change: 1 addition & 0 deletions js/ai/src/model/middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ const DEFAULT_RETRY_STATUSES: StatusName[] = [
];

const DEFAULT_FALLBACK_STATUSES: StatusName[] = [
'UNKNOWN',
'UNAVAILABLE',
'DEADLINE_EXCEEDED',
'RESOURCE_EXHAUSTED',
Expand Down
Loading
Loading