diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index f795ea4e14..b300ac221a 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -419,7 +419,7 @@ function maybeRegisterDynamicTools< hasDynamicTools = true; // Create a temporary registry with dynamic tools for the duration of this // generate request. - registry = Registry.withParent(registry); + registry = registry.child(); } registry.registerAction('tool', t as Action); } diff --git a/js/ai/tests/formats/format_test.ts b/js/ai/tests/formats/format_test.ts index 04ff1b9f1a..6db4cd6302 100644 --- a/js/ai/tests/formats/format_test.ts +++ b/js/ai/tests/formats/format_test.ts @@ -15,6 +15,7 @@ */ import { z } from '@genkit-ai/core'; +import { NodeRegistry } from '@genkit-ai/core/node'; import { Registry } from '@genkit-ai/core/registry'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; @@ -24,7 +25,7 @@ describe('formats', () => { let registry: Registry; beforeEach(() => { - registry = new Registry(); + registry = new NodeRegistry(); configureFormats(registry); }); diff --git a/js/ai/tests/formats/json_test.ts b/js/ai/tests/formats/json_test.ts index 2172588080..ed4adb905d 100644 --- a/js/ai/tests/formats/json_test.ts +++ b/js/ai/tests/formats/json_test.ts @@ -15,6 +15,7 @@ */ import { z } from '@genkit-ai/core'; +import { NodeRegistry } from '@genkit-ai/core/node'; import { Registry } from '@genkit-ai/core/registry'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; @@ -29,12 +30,6 @@ import type { import { defineProgrammableModel, runAsync } from '../helpers.js'; describe('jsonFormat', () => { - let registry: Registry; - - beforeEach(() => { - registry = new Registry(); - }); - const streamingTests = [ { desc: 'parses complete JSON object', @@ -141,7 +136,7 @@ describe('jsonFormat e2e', () => { let registry: Registry; beforeEach(() => { - registry = new Registry(); + registry = new NodeRegistry(); configureFormats(registry); }); diff --git a/js/ai/tests/generate/action_test.ts b/js/ai/tests/generate/action_test.ts index 614e4c939e..db42e3e814 100644 --- a/js/ai/tests/generate/action_test.ts +++ b/js/ai/tests/generate/action_test.ts @@ -15,6 +15,7 @@ */ import { stripUndefinedProps, z } from '@genkit-ai/core'; +import { NodeRegistry } from '@genkit-ai/core/node'; import { Registry } from '@genkit-ai/core/registry'; import * as assert from 'assert'; import { readFileSync } from 'fs'; @@ -58,7 +59,7 @@ describe('spec', () => { let pm: ProgrammableModel; beforeEach(() => { - registry = new Registry(); + registry = new NodeRegistry(); defineGenerateAction(registry); pm = defineProgrammableModel(registry); defineTool( diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index da77f8dfd4..d2033e450f 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -15,6 +15,7 @@ */ import { z, type PluginProvider } from '@genkit-ai/core'; +import { NodeRegistry } from '@genkit-ai/core/node'; import { Registry } from '@genkit-ai/core/registry'; import * as assert from 'assert'; import { beforeEach, describe, it } from 'node:test'; @@ -32,7 +33,7 @@ import { import { defineTool } from '../../src/tool.js'; describe('toGenerateRequest', () => { - const registry = new Registry(); + const registry = new NodeRegistry(); // register tools const tellAFunnyJoke = defineTool( registry, @@ -332,7 +333,7 @@ describe('generate', () => { var echoModel: ModelAction; beforeEach(() => { - registry = new Registry(); + registry = new NodeRegistry(); echoModel = defineModel( registry, { @@ -407,7 +408,7 @@ describe('generate', () => { describe('generate', () => { let registry: Registry; beforeEach(() => { - registry = new Registry(); + registry = new NodeRegistry(); defineModel( registry, @@ -432,7 +433,7 @@ describe('generate', () => { describe('generateStream', () => { it('should stream out chunks', async () => { - const registry = new Registry(); + const registry = new NodeRegistry(); defineModel( registry, diff --git a/js/ai/tests/model/middleware_test.ts b/js/ai/tests/model/middleware_test.ts index edd2d66d88..d680155598 100644 --- a/js/ai/tests/model/middleware_test.ts +++ b/js/ai/tests/model/middleware_test.ts @@ -15,6 +15,7 @@ */ import { z } from '@genkit-ai/core'; +import { NodeRegistry } from '@genkit-ai/core/node'; import { Registry } from '@genkit-ai/core/registry'; import * as assert from 'assert'; import { beforeEach, describe, it } from 'node:test'; @@ -137,7 +138,7 @@ describe('validateSupport', () => { }); }); -const registry = new Registry(); +const registry = new NodeRegistry(); configureFormats(registry); const echoModel = defineModel(registry, { name: 'echo' }, async (req) => { @@ -400,7 +401,7 @@ describe.only('simulateConstrainedGeneration', () => { let registry: Registry; beforeEach(() => { - registry = new Registry(); + registry = new NodeRegistry(); configureFormats(registry); }); diff --git a/js/ai/tests/prompt/prompt_test.ts b/js/ai/tests/prompt/prompt_test.ts index b0e68d3b48..6fa51f5c69 100644 --- a/js/ai/tests/prompt/prompt_test.ts +++ b/js/ai/tests/prompt/prompt_test.ts @@ -15,7 +15,7 @@ */ import { runWithContext, z, type ActionContext } from '@genkit-ai/core'; -import { Registry } from '@genkit-ai/core/registry'; +import { NodeRegistry } from '@genkit-ai/core/node'; import { toJsonSchema } from '@genkit-ai/core/schema'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; @@ -34,7 +34,7 @@ describe('prompt', () => { let registry; beforeEach(() => { - registry = new Registry(); + registry = new NodeRegistry(); defineEchoModel(registry); defineTool( diff --git a/js/ai/tests/reranker/reranker_test.ts b/js/ai/tests/reranker/reranker_test.ts index 1c5020cc0f..569c156890 100644 --- a/js/ai/tests/reranker/reranker_test.ts +++ b/js/ai/tests/reranker/reranker_test.ts @@ -15,6 +15,7 @@ */ import { GenkitError, z } from '@genkit-ai/core'; +import { NodeRegistry } from '@genkit-ai/core/node'; import { Registry } from '@genkit-ai/core/registry'; import * as assert from 'assert'; import { beforeEach, describe, it } from 'node:test'; @@ -25,7 +26,7 @@ describe('reranker', () => { describe('defineReranker()', () => { let registry: Registry; beforeEach(() => { - registry = new Registry(); + registry = new NodeRegistry(); }); it('reranks documents based on custom logic', async () => { const customReranker = defineReranker( diff --git a/js/ai/tests/tool_test.ts b/js/ai/tests/tool_test.ts index 1cdf0f50e5..3ca08acf99 100644 --- a/js/ai/tests/tool_test.ts +++ b/js/ai/tests/tool_test.ts @@ -15,17 +15,17 @@ */ import { z } from '@genkit-ai/core'; -import { Registry } from '@genkit-ai/core/registry'; +import { NodeRegistry } from '@genkit-ai/core/node'; import * as assert from 'assert'; import { afterEach, describe, it } from 'node:test'; import { defineInterrupt, defineTool } from '../src/tool.js'; describe('defineInterrupt', () => { - let registry = new Registry(); + let registry = new NodeRegistry(); registry.apiStability = 'beta'; afterEach(() => { - registry = new Registry(); + registry = new NodeRegistry(); registry.apiStability = 'beta'; }); @@ -107,10 +107,10 @@ describe('defineInterrupt', () => { }); describe('defineTool', () => { - let registry = new Registry(); + let registry = new NodeRegistry(); registry.apiStability = 'beta'; afterEach(() => { - registry = new Registry(); + registry = new NodeRegistry(); registry.apiStability = 'beta'; }); diff --git a/js/core/package.json b/js/core/package.json index ded46b624e..696af3b4eb 100644 --- a/js/core/package.json +++ b/js/core/package.json @@ -92,6 +92,12 @@ "require": "./lib/schema.js", "import": "./lib/schema.mjs", "default": "./lib/schema.js" + }, + "./node": { + "types": "./lib/node.d.ts", + "require": "./lib/node.js", + "import": "./lib/node.mjs", + "default": "./lib/node.js" } }, "typesVersions": { @@ -116,6 +122,9 @@ ], "schema": [ "lib/schema" + ], + "node": [ + "lib/node" ] } } diff --git a/js/core/src/als-async-store.ts b/js/core/src/als-async-store.ts new file mode 100644 index 0000000000..1190d2022c --- /dev/null +++ b/js/core/src/als-async-store.ts @@ -0,0 +1,36 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { AsyncLocalStorage } from 'node:async_hooks'; +import { AsyncStore } from './registry.js'; + +/** + * Node AsyncLocalStorage based AsyncStore impl. + */ +export class AlsAsyncStore implements AsyncStore { + private asls: Record> = {}; + + getStore(key: string): T | undefined { + return this.asls[key]?.getStore(); + } + + run(key: string, store: T, callback: () => R): R { + if (!this.asls[key]) { + this.asls[key] = new AsyncLocalStorage(); + } + return this.asls[key].run(store, callback); + } +} diff --git a/js/core/src/flow.ts b/js/core/src/flow.ts index 128c98265a..1ed59a1553 100644 --- a/js/core/src/flow.ts +++ b/js/core/src/flow.ts @@ -14,13 +14,18 @@ * limitations under the License. */ -import { AsyncLocalStorage } from 'node:async_hooks'; import type { z } from 'zod'; import { defineAction, type Action, type StreamingCallback } from './action.js'; import type { ActionContext } from './context.js'; -import { Registry, type HasRegistry } from './registry.js'; +import { + Registry, + _getAsyncStoreFactory, + type HasRegistry, +} from './registry.js'; import { SPAN_TYPE_ATTR, runInNewSpan } from './tracing.js'; +const legacyRegistryAlsKey = 'legacyRegistryAls'; + /** * Flow is an observable, streamable, (optionally) strongly typed function. */ @@ -132,18 +137,24 @@ function defineFlowAction< metadata: config.metadata, }, async (input, { sendChunk, context, trace }) => { - return await legacyRegistryAls.run(registry, () => { - const ctx = sendChunk; - (ctx as FlowSideChannel>).sendChunk = sendChunk; - (ctx as FlowSideChannel>).context = context; - (ctx as FlowSideChannel>).trace = trace; - return fn(input, ctx as FlowSideChannel>); - }); + return await legacyRegistryAls().run( + legacyRegistryAlsKey, + registry, + () => { + const ctx = sendChunk; + (ctx as FlowSideChannel>).sendChunk = sendChunk; + (ctx as FlowSideChannel>).context = context; + (ctx as FlowSideChannel>).trace = trace; + return fn(input, ctx as FlowSideChannel>); + } + ); } ); } -const legacyRegistryAls = new AsyncLocalStorage(); +function legacyRegistryAls() { + return _getAsyncStoreFactory()(); +} export function run( name: string, @@ -191,7 +202,7 @@ export function run( } if (!registry) { - registry = legacyRegistryAls.getStore(); + registry = legacyRegistryAls().getStore(legacyRegistryAlsKey); } if (!registry) { throw new Error( diff --git a/js/core/src/index.ts b/js/core/src/index.ts index 3933ddb53f..9d2401bd66 100644 --- a/js/core/src/index.ts +++ b/js/core/src/index.ts @@ -56,7 +56,6 @@ export { type FlowSideChannel, } from './flow.js'; export * from './plugin.js'; -export * from './reflection.js'; export { defineJsonSchema, defineSchema, type JSONSchema } from './schema.js'; export * from './telemetryTypes.js'; export * from './utils.js'; diff --git a/js/core/src/node.ts b/js/core/src/node.ts new file mode 100644 index 0000000000..2d84c7f534 --- /dev/null +++ b/js/core/src/node.ts @@ -0,0 +1,37 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { AlsAsyncStore } from './als-async-store.js'; +import { _setAsyncStoreFactory, Registry } from './registry.js'; +export * from './reflection.js'; + +export class NodeRegistry extends Registry { + constructor(parent?: NodeRegistry) { + if (parent) { + super(parent); + } else { + const store = new AlsAsyncStore(); + const asyncStoreFactory = () => store; + + super({ asyncStoreFactory }); + _setAsyncStoreFactory(asyncStoreFactory); + } + } + + child(): Registry { + return new NodeRegistry(this); + } +} diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index 45a5a769fb..bc007058ed 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -15,7 +15,6 @@ */ import { Dotprompt } from 'dotprompt'; -import { AsyncLocalStorage } from 'node:async_hooks'; import type * as z from 'zod'; import { runOutsideActionRuntimeContext, @@ -122,14 +121,15 @@ export class Registry { /** Additional runtime context data for flows and tools. */ context?: ActionContext; - constructor(parent?: Registry) { - if (parent) { + constructor(opts: { asyncStoreFactory: AsyncStoreFactory } | Registry) { + if (opts instanceof Registry) { + const parent = opts as Registry; this.parent = parent; this.apiStability = parent?.apiStability; this.asyncStore = parent.asyncStore; this.dotprompt = parent.dotprompt; } else { - this.asyncStore = new AsyncStore(); + this.asyncStore = opts.asyncStoreFactory(); this.dotprompt = new Dotprompt({ schemaResolver: async (name) => { const resolvedSchema = await this.lookupSchema(name); @@ -145,8 +145,17 @@ export class Registry { } } + /** + * Creates a new child registry overlaid onto this registry instance. + */ + child(): Registry { + throw new Error('Method not implemented.'); + } + /** * Creates a new registry overlaid onto the provided registry. + * + * @deprecated use {@link child}. * @param parent The parent registry. * @returns The new overlaid registry. */ @@ -438,22 +447,30 @@ export class Registry { } } -/** - * Manages AsyncLocalStorage instances in a single place. - */ -export class AsyncStore { - private asls: Record> = {}; - - getStore(key: string): T | undefined { - return this.asls[key]?.getStore(); +/** @deprecated available for legacy backwards compatibility reasons. */ +export function _getAsyncStoreFactory(): AsyncStoreFactory { + const factory = globalThis.__genkit__asyncStoreFactory; + if (!factory) { + throw new GenkitError({ + status: 'FAILED_PRECONDITION', + message: 'Failed to find AsyncStoreFactory, probable misconfiguration.', + }); } - run(key: string, store: T, callback: () => R): R { - if (!this.asls[key]) { - this.asls[key] = new AsyncLocalStorage(); - } - return this.asls[key].run(store, callback); - } + return factory; +} + +/** @deprecated available for legacy backwards compatibility reasons. */ +export function _setAsyncStoreFactory(factory: AsyncStoreFactory) { + globalThis.__genkit__asyncStoreFactory = factory; +} + +export type AsyncStoreFactory = () => AsyncStore; + +export interface AsyncStore { + getStore(key: string): T | undefined; + + run(key: string, store: T, callback: () => R): R; } /** diff --git a/js/core/tests/action_test.ts b/js/core/tests/action_test.ts index 77ed2af526..8d431eec93 100644 --- a/js/core/tests/action_test.ts +++ b/js/core/tests/action_test.ts @@ -18,12 +18,13 @@ import * as assert from 'assert'; import { beforeEach, describe, it } from 'node:test'; import { z } from 'zod'; import { action, defineAction } from '../src/action.js'; +import { NodeRegistry } from '../src/node.js'; import { Registry } from '../src/registry.js'; describe('action', () => { var registry: Registry; beforeEach(() => { - registry = new Registry(); + registry = new NodeRegistry(); }); it('applies middleware', async () => { diff --git a/js/core/tests/flow_test.ts b/js/core/tests/flow_test.ts index 0441751c71..6514439f3e 100644 --- a/js/core/tests/flow_test.ts +++ b/js/core/tests/flow_test.ts @@ -19,6 +19,7 @@ import * as assert from 'assert'; import { beforeEach, describe, it } from 'node:test'; import { defineFlow, run } from '../src/flow.js'; import { defineAction, getContext, z } from '../src/index.js'; +import { NodeRegistry } from '../src/node.js'; import { Registry } from '../src/registry.js'; import { enableTelemetry } from '../src/tracing.js'; import { TestSpanExporter } from './utils.js'; @@ -48,7 +49,7 @@ describe('flow', () => { beforeEach(() => { // Skips starting reflection server. delete process.env.GENKIT_ENV; - registry = new Registry(); + registry = new NodeRegistry(); }); describe('runFlow', () => { @@ -145,7 +146,7 @@ describe('flow', () => { let registry: Registry; beforeEach(() => { - registry = new Registry(); + registry = new NodeRegistry(); }); it('should run the flow', async () => { diff --git a/js/core/tests/registry_test.ts b/js/core/tests/registry_test.ts index 2d51b15c9b..da51815aab 100644 --- a/js/core/tests/registry_test.ts +++ b/js/core/tests/registry_test.ts @@ -21,12 +21,13 @@ import { defineAction, runInActionRuntimeContext, } from '../src/action.js'; +import { NodeRegistry } from '../src/node.js'; import { Registry } from '../src/registry.js'; describe('registry class', () => { var registry: Registry; beforeEach(() => { - registry = new Registry(); + registry = new NodeRegistry(); }); describe('listActions', () => { @@ -134,7 +135,7 @@ describe('registry class', () => { }); it('returns all registered actions, including parent', async () => { - const child = Registry.withParent(registry); + const child = registry.child(); const fooSomethingAction = action( registry, @@ -264,7 +265,7 @@ describe('registry class', () => { }); it('returns all registered actions, including parent', async () => { - const child = Registry.withParent(registry); + const child = registry.child(); const fooSomethingAction = action( registry, @@ -527,7 +528,7 @@ describe('registry class', () => { }); it('should lookup parent registry when child missing action', async () => { - const childRegistry = new Registry(registry); + const childRegistry = registry.child(); const fooAction = action( registry, @@ -544,7 +545,7 @@ describe('registry class', () => { }); it('registration on the child registry should not modify parent', async () => { - const childRegistry = Registry.withParent(registry); + const childRegistry = registry.child(); assert.strictEqual(childRegistry.parent, registry); diff --git a/js/genkit/src/common.ts b/js/genkit/src/common.ts index 5c5f40b6c9..cd2eece8c8 100644 --- a/js/genkit/src/common.ts +++ b/js/genkit/src/common.ts @@ -117,7 +117,6 @@ export { GENKIT_CLIENT_HEADER, GENKIT_VERSION, GenkitError, - ReflectionServer, StatusCodes, StatusSchema, UserFacingError, @@ -138,11 +137,14 @@ export { type JSONSchema, type JSONSchema7, type Middleware, - type ReflectionServerOptions, - type RunActionResponse, type Status, type StatusName, type StreamingCallback, type StreamingResponse, type TelemetryConfig, } from '@genkit-ai/core'; +export { + ReflectionServer, + type ReflectionServerOptions, + type RunActionResponse, +} from '@genkit-ai/core/node'; diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 1e7c687b82..6416448d17 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -98,7 +98,6 @@ import { import { dynamicTool, type ToolFn } from '@genkit-ai/ai/tool'; import { GenkitError, - ReflectionServer, defineFlow, defineJsonSchema, defineSchema, @@ -114,6 +113,7 @@ import { type z, } from '@genkit-ai/core'; import { Channel } from '@genkit-ai/core/async'; +import { NodeRegistry, ReflectionServer } from '@genkit-ai/core/node'; import type { HasRegistry } from '@genkit-ai/core/registry'; import type { BaseEvalDataPointSchema } from './evaluator.js'; import { logger } from './logging.js'; @@ -168,7 +168,7 @@ export class Genkit implements HasRegistry { constructor(options?: GenkitOptions) { this.options = options || {}; - this.registry = new Registry(); + this.registry = new NodeRegistry(); if (this.options.context) { this.registry.context = this.options.context; }