diff --git a/js/core/src/action.ts b/js/core/src/action.ts index a9621c543a..89ee1f8c09 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -17,7 +17,7 @@ import type { JSONSchema7 } from 'json-schema'; import * as z from 'zod'; import { getAsyncContext } from './async-context.js'; -import { lazy } from './async.js'; +import { Channel, lazy } from './async.js'; import { getContext, runWithContext, type ActionContext } from './context.js'; import type { ActionType, Registry } from './registry.js'; import { parseSchema } from './schema.js'; @@ -51,15 +51,25 @@ export interface ActionMetadata< O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, > { + /** The type of action (e.g. 'prompt', 'flow'). */ actionType?: ActionType; + /** The key of the action. */ key?: string; + /** The name of the action. */ name: string; + /** Description of the action. */ description?: string; + /** Input Zod schema. */ inputSchema?: I; + /** Input JSON schema. */ inputJsonSchema?: JSONSchema7; + /** Output Zod schema. */ outputSchema?: O; + /** Output JSON schema. */ outputJsonSchema?: JSONSchema7; + /** Stream Zod schema. */ streamSchema?: S; + /** Metadata for the action. */ metadata?: Record; } @@ -90,7 +100,7 @@ export interface ActionResult { /** * Options (side channel) data to pass to the model. */ -export interface ActionRunOptions { +export interface ActionRunOptions { /** * Streaming callback (optional). */ @@ -118,12 +128,22 @@ export interface ActionRunOptions { * Note: This only fires once for the root action span, not for nested spans. */ onTraceStart?: (traceInfo: { traceId: string; spanId: string }) => void; + + /** + * Streaming input (optional). + */ + inputStream?: AsyncIterable; + + /** + * Initialization data provided to the action. + */ + init?: Init; } /** * Options (side channel) data to pass to the model. */ -export interface ActionFnArg { +export interface ActionFnArg { /** * Whether the caller of the action requested streaming. */ @@ -153,6 +173,16 @@ export interface ActionFnArg { abortSignal: AbortSignal; registry?: Registry; + + /** + * Streaming input. + */ + inputStream: AsyncIterable; + + /** + * Initialization data provided to the action. + */ + init?: Init; } /** @@ -168,6 +198,24 @@ export interface StreamingResponse< output: Promise>; } +/** + * Streaming response from a bi-directional action. + */ +export interface BidiStreamingResponse< + O extends z.ZodTypeAny = z.ZodTypeAny, + S extends z.ZodTypeAny = z.ZodTypeAny, + I extends z.ZodTypeAny = z.ZodTypeAny, +> extends StreamingResponse { + /** + * Sends a chunk of data to the action (for bi-directional streaming). + */ + send(chunk: z.infer): void; + /** + * Closes the input stream to the action. + */ + close(): void; +} + /** * Self-describing, validating, observable, locally and remotely callable function. */ @@ -175,21 +223,44 @@ export type Action< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, - RunOptions extends ActionRunOptions = ActionRunOptions, + RunOptions extends ActionRunOptions< + z.infer, + z.infer + > = ActionRunOptions, z.infer>, + Init extends z.ZodTypeAny = z.ZodTypeAny, > = ((input?: z.infer, options?: RunOptions) => Promise>) & { + /** @hidden */ __action: ActionMetadata; + /** @hidden */ __registry?: Registry; run( input?: z.infer, - options?: ActionRunOptions> + options?: ActionRunOptions, z.infer, z.infer> ): Promise>>; stream( input?: z.infer, - opts?: ActionRunOptions> + opts?: ActionRunOptions, z.infer, z.infer> ): StreamingResponse; }; +export interface BidiAction< + IS extends z.ZodTypeAny = z.ZodTypeAny, + O extends z.ZodTypeAny = z.ZodTypeAny, + OS extends z.ZodTypeAny = z.ZodTypeAny, + Init extends z.ZodTypeAny = z.ZodTypeAny, + RunOptions extends ActionRunOptions< + z.infer, + z.infer, + z.infer + > = ActionRunOptions, z.infer, z.infer>, +> extends Action { + streamBidi( + init?: z.infer, + opts?: RunOptions + ): BidiStreamingResponse; +} + /** * Action factory params. */ @@ -197,24 +268,66 @@ export type ActionParams< I extends z.ZodTypeAny, O extends z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, + Init extends z.ZodTypeAny = z.ZodTypeAny, > = { + /** + * Name of the action, or an object with pluginId and actionId. + */ name: | string | { pluginId: string; actionId: string; }; + /** + * Description of the action. + */ description?: string; + /** + * Input Zod schema. + */ inputSchema?: I; + /** + * Input JSON schema. + */ inputJsonSchema?: JSONSchema7; + /** + * Output Zod schema. + */ outputSchema?: O; + /** + * Output JSON schema. + */ outputJsonSchema?: JSONSchema7; + /** + * Metadata for the action. + */ metadata?: Record; + /** + * Middleware to apply to the action. + */ use?: Middleware, z.infer, z.infer>[]; + /** + * Stream Zod schema. + */ streamSchema?: S; + /** + * The type of action. + */ actionType: ActionType; + /** + * Zod schema for the initialization data. + */ + initSchema?: Init; + /** + * JSON schema for the initialization data. + */ + initJsonSchema?: JSONSchema7; }; +/** + * Configuration for an async action (lazy loaded). + */ export type ActionAsyncParams< I extends z.ZodTypeAny, O extends z.ZodTypeAny, @@ -222,19 +335,25 @@ export type ActionAsyncParams< > = ActionParams & { fn: ( input: z.infer, - options: ActionFnArg> + options: ActionFnArg, z.infer> ) => Promise>; }; +/** + * Simple middleware that only modifies request/response. + */ export type SimpleMiddleware = ( req: I, next: (req?: I) => Promise ) => Promise; +/** + * Middleware that has access to options (including streaming callback). + */ export type MiddlewareWithOptions = ( req: I, - options: ActionRunOptions | undefined, - next: (req?: I, options?: ActionRunOptions) => Promise + options: ActionRunOptions | undefined, + next: (req?: I, options?: ActionRunOptions) => Promise ) => Promise; /** @@ -251,26 +370,27 @@ export function actionWithMiddleware< I extends z.ZodTypeAny, O extends z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, + Init extends z.ZodTypeAny = z.ZodTypeAny, >( - action: Action, + action: Action, middleware: Middleware, z.infer, z.infer>[] -): Action { +): Action { const wrapped = (async ( req: z.infer, - options?: ActionRunOptions> + options?: ActionRunOptions, z.infer, z.infer> ) => { return (await wrapped.run(req, options)).result; - }) as Action; + }) as Action; wrapped.__action = action.__action; wrapped.run = async ( req: z.infer, - options?: ActionRunOptions> + options?: ActionRunOptions, z.infer, z.infer> ): Promise>> => { let telemetry; const dispatch = async ( index: number, req: z.infer, - opts?: ActionRunOptions> + opts?: ActionRunOptions, z.infer, z.infer> ) => { if (index === middleware.length) { // end of the chain, call the original model action @@ -297,6 +417,11 @@ export function actionWithMiddleware< } }; wrapped.stream = action.stream; + if ((action as any as BidiAction).streamBidi) { + (wrapped as BidiAction).streamBidi = ( + action as BidiAction + ).streamBidi; + } return { result: await dispatch(0, req, options), telemetry }; }; @@ -310,13 +435,14 @@ export function action< I extends z.ZodTypeAny, O extends z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, + Init extends z.ZodTypeAny = z.ZodTypeAny, >( - config: ActionParams, + config: ActionParams, fn: ( input: z.infer, - options: ActionFnArg> + options: ActionFnArg, z.infer> ) => Promise> -): Action> { +): Action, any, Init> { const actionName = typeof config.name === 'string' ? config.name @@ -336,20 +462,46 @@ export function action< const actionFn = (async ( input?: I, - options?: ActionRunOptions> + options?: ActionRunOptions, z.infer, z.infer> ) => { return (await actionFn.run(input, options)).result; - }) as Action>; + }) as Action, any, Init>; actionFn.__action = { ...actionMetadata }; actionFn.run = async ( input: z.infer, - options?: ActionRunOptions> + options?: ActionRunOptions, z.infer, z.infer> ): Promise>> => { - input = parseSchema(input, { - schema: config.inputSchema, - jsonSchema: config.inputJsonSchema, - }); + if (config.inputSchema || config.inputJsonSchema) { + if (!options?.inputStream) { + input = parseSchema(input, { + schema: config.inputSchema, + jsonSchema: config.inputJsonSchema, + }); + } else { + const inputStream = options.inputStream; + options = { + ...options, + inputStream: (async function* () { + for await (const item of inputStream) { + yield parseSchema(item, { + schema: config.inputSchema, + jsonSchema: config.inputJsonSchema, + }); + } + })(), + }; + } + } + + if (config.initSchema || config.initJsonSchema) { + const validatedInit = parseSchema(options?.init, { + schema: config.initSchema, + jsonSchema: config.initJsonSchema, + }); + options = { ...options, init: validatedInit }; + } + let traceId; let spanId; const genkitKey = actionFn.__action.key; @@ -396,13 +548,15 @@ export function action< !!options?.onChunk && options.onChunk !== sentinelNoopStreamingCallback, sendChunk: options?.onChunk ?? sentinelNoopStreamingCallback, + inputStream: + options?.inputStream ?? asyncIterableFromArray([input]), trace: { traceId, spanId, }, registry: actionFn.__registry, abortSignal: options?.abortSignal ?? makeNoopAbortSignal(), - }); + } as ActionFnArg, z.infer>); // if context is explicitly passed in, we run action with the provided context, // otherwise we let upstream context carry through. const output = await runWithContext(options?.context, actFn); @@ -432,7 +586,7 @@ export function action< actionFn.stream = ( input?: z.infer, - opts?: ActionRunOptions> + opts?: ActionRunOptions, z.infer, z.infer> ): StreamingResponse => { let chunkStreamController: ReadableStreamController>; const chunkStream = new ReadableStream>({ @@ -444,17 +598,24 @@ export function action< }); const invocationPromise = actionFn - .run(config.inputSchema ? config.inputSchema.parse(input) : input, { - onChunk: ((chunk: z.infer) => { - chunkStreamController.enqueue(chunk); - }) as S extends z.ZodVoid ? undefined : StreamingCallback>, - context: { - ...actionFn.__registry?.context, - ...(opts?.context ?? getContext()), - }, - abortSignal: opts?.abortSignal, - telemetryLabels: opts?.telemetryLabels, - }) + .run( + !opts?.inputStream && config.inputSchema + ? config.inputSchema.parse(input) + : input, + { + onChunk: ((chunk: z.infer) => { + chunkStreamController.enqueue(chunk); + }) as S extends z.ZodVoid ? undefined : StreamingCallback>, + context: { + ...actionFn.__registry?.context, + ...(opts?.context ?? getContext()), + }, + inputStream: opts?.inputStream, + abortSignal: opts?.abortSignal, + telemetryLabels: opts?.telemetryLabels, + init: opts?.init, + } + ) .then((s) => s.result) .finally(() => { chunkStreamController.close(); @@ -495,14 +656,15 @@ export function defineAction< I extends z.ZodTypeAny, O extends z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, + Init extends z.ZodTypeAny = z.ZodTypeAny, >( registry: Registry, - config: ActionParams, + config: ActionParams, fn: ( input: z.infer, - options: ActionFnArg> + options: ActionFnArg, z.infer, z.infer> ) => Promise> -): Action { +): Action, z.infer>, Init> { if (isInRuntimeContext()) { throw new Error( 'Cannot define new actions at runtime.\n' + @@ -518,6 +680,97 @@ export function defineAction< return act; } +/** + * Defines a bi-directional action with the given config and registers it in the registry. + */ +export function defineBidiAction< + IS extends z.ZodTypeAny, + O extends z.ZodTypeAny, + OS extends z.ZodTypeAny = z.ZodTypeAny, + Init extends z.ZodTypeAny = z.ZodTypeAny, +>( + registry: Registry, + config: ActionParams, + fn: ( + input: ActionFnArg, z.infer, z.infer> + ) => AsyncGenerator, z.infer, void> +): BidiAction { + const act = bidiAction(config, fn); + registry.registerAction(config.actionType, act); + return act; +} + +/** + * Creates a bi-directional action with the given config. + */ +export function bidiAction< + IS extends z.ZodTypeAny, + O extends z.ZodTypeAny, + OS extends z.ZodTypeAny = z.ZodTypeAny, + Init extends z.ZodTypeAny = z.ZodTypeAny, +>( + config: ActionParams, + fn: ( + input: ActionFnArg, z.infer, z.infer> + ) => AsyncGenerator, z.infer, void> +): BidiAction { + const meta = { ...config.metadata, bidi: true }; + const act = action({ ...config, metadata: meta }, async (input, options) => { + const stream = options.inputStream; + + const outputGen = fn({ + ...options, + init: options.init, + inputStream: stream, + } as ActionFnArg, z.infer, z.infer>); + + const iter = outputGen[Symbol.asyncIterator](); + let result: z.infer; + while (true) { + const { value, done } = await iter.next(); + if (done) { + result = value; + break; + } + options.sendChunk(value); + } + return result; + }) as unknown as BidiAction; + + act.streamBidi = (init, opts) => { + let channel: Channel> | undefined; + let stream = opts?.inputStream; + if (!stream) { + channel = new Channel>(); + stream = channel; + } + + const result = act.stream(undefined, { + ...opts, + init: init, + inputStream: stream, + }); + + return { + ...result, + send: (chunk) => { + if (!channel) { + throw new Error('Cannot send to a provided stream.'); + } + channel.send(chunk); + }, + close: () => { + if (!channel) { + throw new Error('Cannot close a provided stream.'); + } + channel.close(); + }, + }; + }; + + return act; +} + /** * Defines an action with the given config promise and registers it in the registry. */ @@ -616,3 +869,9 @@ export function runInActionRuntimeContext(fn: () => R) { export function runOutsideActionRuntimeContext(fn: () => R) { return getAsyncContext().run(runtimeContextAslKey, 'outside', fn); } + +async function* asyncIterableFromArray(array: T[]): AsyncIterable { + for (const item of array) { + yield item; + } +} diff --git a/js/core/src/flow.ts b/js/core/src/flow.ts index 56c3a3d6b9..cdb480efb5 100644 --- a/js/core/src/flow.ts +++ b/js/core/src/flow.ts @@ -15,7 +15,13 @@ */ import type { z } from 'zod'; -import { ActionFnArg, action, type Action } from './action.js'; +import { + ActionFnArg, + ActionRunOptions, + JSONSchema7, + action, + type Action, +} from './action.js'; import { Registry, type HasRegistry } from './registry.js'; import { SPAN_TYPE_ATTR, runInNewSpan } from './tracing.js'; @@ -26,7 +32,8 @@ export interface Flow< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, -> extends Action {} + Init extends z.ZodTypeAny = z.ZodTypeAny, +> extends Action, z.infer>, Init> {} /** * Configuration for a streaming flow. @@ -35,6 +42,7 @@ export interface FlowConfig< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, + Init extends z.ZodTypeAny = z.ZodTypeAny, > { /** Name of the flow. */ name: string; @@ -46,6 +54,10 @@ export interface FlowConfig< streamSchema?: S; /** Metadata of the flow used by tooling. */ metadata?: Record; + /** Schema of the initialization data. */ + initSchema?: Init; + /** JSON schema of the initialization data. */ + initJsonSchema?: JSONSchema7; } /** @@ -137,12 +149,18 @@ function flowAction< ); } +/** + * A flow step that executes the provided function. + */ export function run( name: string, func: () => Promise, _?: Registry ): Promise; +/** + * A flow step that executes the provided function with input. + */ export function run( name: string, input: any, diff --git a/js/core/tests/bidi-action_test.ts b/js/core/tests/bidi-action_test.ts new file mode 100644 index 0000000000..3d7ddc0a77 --- /dev/null +++ b/js/core/tests/bidi-action_test.ts @@ -0,0 +1,229 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import * as assert from 'assert'; +import { beforeEach, describe, it } from 'node:test'; +import { z } from 'zod'; +import { defineBidiAction } from '../src/action.js'; +import { initNodeFeatures } from '../src/node.js'; +import { Registry } from '../src/registry.js'; + +initNodeFeatures(); + +describe('bidi action', () => { + var registry: Registry; + beforeEach(() => { + registry = new Registry(); + }); + + it('streamBidi ergonomic (push)', async () => { + const act = defineBidiAction( + registry, + { + name: 'chat', + actionType: 'custom', + outputSchema: z.string(), + inputSchema: z.string(), // Option 1: messages are strings + }, + async function* ({ inputStream }) { + for await (const chunk of inputStream) { + yield `echo ${chunk}`; + } + return 'done'; + } + ); + + const session = act.streamBidi(); + session.send('1'); + session.send('2'); + session.close(); + + const chunks: string[] = []; + for await (const chunk of session.stream) { + chunks.push(chunk); + } + + assert.deepStrictEqual(chunks, ['echo 1', 'echo 2']); + assert.strictEqual(await session.output, 'done'); + }); + + it('streamBidi pull (generator)', async () => { + const act = defineBidiAction( + registry, + { + name: 'chat', + actionType: 'custom', + outputSchema: z.string(), + inputSchema: z.string(), + }, + async function* ({ inputStream }) { + for await (const chunk of inputStream) { + yield `echo ${chunk}`; + } + return 'done'; + } + ); + + async function* inputGen() { + yield '1'; + yield '2'; + } + + const session = act.streamBidi(undefined, { inputStream: inputGen() }); // Pass stream in options! + + const chunks: string[] = []; + for await (const chunk of session.stream) { + chunks.push(chunk); + } + + assert.deepStrictEqual(chunks, ['echo 1', 'echo 2']); + assert.strictEqual(await session.output, 'done'); + }); + + it('classic run works on bidi action', async () => { + const act = defineBidiAction( + registry, + { + name: 'chat', + actionType: 'custom', + outputSchema: z.string(), + inputSchema: z.string(), + }, + async function* ({ inputStream }) { + const inputs: string[] = []; + for await (const chunk of inputStream) { + inputs.push(chunk); + } + return `done: ${inputs.join(', ')}`; + } + ); + + const result = await act.run('1'); // Pass single message as input! + assert.strictEqual(result.result, 'done: 1'); + }); + + it('classic run works on bidi action with streaming', async () => { + const act = defineBidiAction( + registry, + { + name: 'chat', + actionType: 'custom', + outputSchema: z.string(), + inputSchema: z.string(), + }, + async function* ({ inputStream }) { + for await (const chunk of inputStream) { + yield `echo ${chunk}`; + } + return 'done'; + } + ); + + const chunks: string[] = []; + const result = await act.run('1', { + onChunk: (c) => chunks.push(c), + }); + + assert.deepStrictEqual(chunks, ['echo 1']); + assert.strictEqual(result.result, 'done'); + }); + + it('bidi action receives init data', async () => { + const act = defineBidiAction( + registry, + { + name: 'chatWithInit', + actionType: 'custom', + outputSchema: z.string(), + inputSchema: z.string(), + initSchema: z.object({ prefix: z.string() }), + }, + async function* ({ inputStream, init }) { + const prefix = init?.prefix || ''; + for await (const chunk of inputStream) { + yield `${prefix}${chunk}`; + } + return 'done'; + } + ); + + const session = act.streamBidi({ prefix: '>> ' }); // Pass init as first argument! + session.send('1'); + session.send('2'); + session.close(); + + const chunks: string[] = []; + for await (const chunk of session.stream) { + chunks.push(chunk); + } + + assert.deepStrictEqual(chunks, ['>> 1', '>> 2']); + assert.strictEqual(await session.output, 'done'); + }); + it('classic run works on bidi action with init', async () => { + const act = defineBidiAction( + registry, + { + name: 'chatWithInitRun', + actionType: 'custom', + outputSchema: z.string(), + inputSchema: z.string(), + initSchema: z.object({ prefix: z.string() }), + }, + async function* ({ inputStream, init }) { + const prefix = init?.prefix || ''; + const inputs: string[] = []; + for await (const chunk of inputStream) { + inputs.push(chunk); + } + return `${prefix}${inputs.join(', ')}`; + } + ); + + const result = await act.run('1', { init: { prefix: '>> ' } }); + assert.strictEqual(result.result, '>> 1'); + }); + + it('classic stream works on bidi action with init', async () => { + const act = defineBidiAction( + registry, + { + name: 'chatWithInitStream', + actionType: 'custom', + outputSchema: z.string(), + inputSchema: z.string(), + initSchema: z.object({ prefix: z.string() }), + }, + async function* ({ inputStream, init }) { + const prefix = init?.prefix || ''; + for await (const chunk of inputStream) { + yield `${prefix}${chunk}`; + } + return 'done'; + } + ); + + const session = act.stream('1', { init: { prefix: '>> ' } }); + + const chunks: string[] = []; + for await (const chunk of session.stream) { + chunks.push(chunk); + } + + assert.deepStrictEqual(chunks, ['>> 1']); + assert.strictEqual(await session.output, 'done'); + }); +}); diff --git a/js/genkit/src/genkit-beta.ts b/js/genkit/src/genkit-beta.ts index a784954f53..50f2d38967 100644 --- a/js/genkit/src/genkit-beta.ts +++ b/js/genkit/src/genkit-beta.ts @@ -38,7 +38,15 @@ import { type SessionData, type SessionOptions, } from '@genkit-ai/ai/session'; -import type { Operation, z } from '@genkit-ai/core'; +import { + defineBidiFlow, + type Action, + type ActionFnArg, + type ActionRunOptions, + type FlowConfig, + type Operation, + type z, +} from '@genkit-ai/core'; import { v4 as uuidv4 } from 'uuid'; import type { Formatter } from './formats'; import { Genkit, type GenkitOptions } from './genkit'; @@ -289,4 +297,23 @@ export class GenkitBeta extends Genkit { defineResource(opts: ResourceOptions, fn: ResourceFn): ResourceAction { return defineResource(this.registry, opts, fn); } + + /** + * Defines and registers a bi-directional flow. + */ + defineBidiFlow< + I extends z.ZodTypeAny = z.ZodTypeAny, + O extends z.ZodTypeAny = z.ZodTypeAny, + S extends z.ZodTypeAny = z.ZodTypeAny, + Init extends z.ZodTypeAny = z.ZodTypeAny, + >( + config: FlowConfig, + fn: ( + input: ActionFnArg, z.infer, z.infer> + ) => AsyncGenerator, z.infer, void> + ): Action, z.infer>, Init> { + const flow = defineBidiFlow(this.registry, config, fn); + this.flows.push(flow); + return flow; + } } diff --git a/js/testapps/flow-sample1/src/index.ts b/js/testapps/flow-sample1/src/index.ts index 8c3f55abf1..e088233659 100644 --- a/js/testapps/flow-sample1/src/index.ts +++ b/js/testapps/flow-sample1/src/index.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { genkit, z } from 'genkit'; +import { genkit, z } from 'genkit/beta'; const ai = genkit({}); @@ -312,3 +312,33 @@ function generateString(length: number) { } return str.substring(0, length); } + +export const chatFlow = ai.defineBidiFlow( + { + name: 'chatFlow', + inputSchema: z.string(), + outputSchema: z.string(), + }, + async function* ({ inputStream }) { + for await (const chunk of inputStream) { + yield `echo ${chunk}`; + } + return 'done'; + } +); + +export const chatFlowWithInit = ai.defineBidiFlow( + { + name: 'chatFlowWithInit', + inputSchema: z.string(), + outputSchema: z.string(), + initSchema: z.object({ prefix: z.string() }).optional(), + }, + async function* ({ inputStream, init }) { + const prefix = init?.prefix || ''; + for await (const chunk of inputStream) { + yield `${prefix}${chunk}`; + } + return 'done'; + } +);