diff --git a/src/azure.ts b/src/azure.ts index 490b82b9f..5719fc40c 100644 --- a/src/azure.ts +++ b/src/azure.ts @@ -3,7 +3,6 @@ import * as Errors from './error'; import { FinalRequestOptions } from './internal/request-options'; import { isObj, readEnv } from './internal/utils'; import { ClientOptions, OpenAI } from './client'; -import { buildHeaders, NullableHeaders } from './internal/headers'; /** API Client for interfacing with the Azure OpenAI API. */ export interface AzureClientOptions extends ClientOptions { @@ -37,7 +36,6 @@ export interface AzureClientOptions extends ClientOptions { /** API Client for interfacing with the Azure OpenAI API. */ export class AzureOpenAI extends OpenAI { - private _azureADTokenProvider: (() => Promise) | undefined; deploymentName: string | undefined; apiVersion: string = ''; @@ -90,9 +88,6 @@ export class AzureOpenAI extends OpenAI { ); } - // define a sentinel value to avoid any typing issues - apiKey ??= API_KEY_SENTINEL; - opts.defaultQuery = { ...opts.defaultQuery, 'api-version': apiVersion }; if (!baseURL) { @@ -114,13 +109,12 @@ export class AzureOpenAI extends OpenAI { } super({ - apiKey, + apiKey: azureADTokenProvider ?? apiKey, baseURL, ...opts, ...(dangerouslyAllowBrowser !== undefined ? { dangerouslyAllowBrowser } : {}), }); - this._azureADTokenProvider = azureADTokenProvider; this.apiVersion = apiVersion; this.deploymentName = deployment; } @@ -140,47 +134,6 @@ export class AzureOpenAI extends OpenAI { } return super.buildRequest(options, props); } - - async _getAzureADToken(): Promise { - if (typeof this._azureADTokenProvider === 'function') { - const token = await this._azureADTokenProvider(); - if (!token || typeof token !== 'string') { - throw new Errors.OpenAIError( - `Expected 'azureADTokenProvider' argument to return a string but it returned ${token}`, - ); - } - return token; - } - return undefined; - } - - protected override async authHeaders(opts: FinalRequestOptions): Promise { - return; - } - - protected override async prepareOptions(opts: FinalRequestOptions): Promise { - opts.headers = buildHeaders([opts.headers]); - - /** - * The user should provide a bearer token provider if they want - * to use Azure AD authentication. The user shouldn't set the - * Authorization header manually because the header is overwritten - * with the Azure AD token if a bearer token provider is provided. - */ - if (opts.headers.values.get('Authorization') || opts.headers.values.get('api-key')) { - return super.prepareOptions(opts); - } - - const token = await this._getAzureADToken(); - if (token) { - opts.headers.values.set('Authorization', `Bearer ${token}`); - } else if (this.apiKey !== API_KEY_SENTINEL) { - opts.headers.values.set('api-key', this.apiKey); - } else { - throw new Errors.OpenAIError('Unable to handle auth'); - } - return super.prepareOptions(opts); - } } const _deployments_endpoints = new Set([ @@ -194,5 +147,3 @@ const _deployments_endpoints = new Set([ '/batches', '/images/edits', ]); - -const API_KEY_SENTINEL = ''; diff --git a/src/beta/realtime/websocket.ts b/src/beta/realtime/websocket.ts index 2bf0b75d5..31284f12b 100644 --- a/src/beta/realtime/websocket.ts +++ b/src/beta/realtime/websocket.ts @@ -31,16 +31,17 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter { * @internal */ onURL?: (url: URL) => void; + /** Indicates the token was resolved by the factory just before connecting. @internal */ + __resolvedApiKey?: boolean; }, client?: Pick, ) { super(); - + const hasProvider = typeof (client as any)?._options?.apiKey === 'function'; const dangerouslyAllowBrowser = props.dangerouslyAllowBrowser ?? (client as any)?._options?.dangerouslyAllowBrowser ?? - (client?.apiKey.startsWith('ek_') ? true : null); - + (client?.apiKey?.startsWith('ek_') ? true : null); if (!dangerouslyAllowBrowser && isRunningInBrowser()) { throw new OpenAIError( "It looks like you're running in a browser-like environment.\n\nThis is disabled by default, as it risks exposing your secret API credentials to attackers.\n\nYou can avoid this error by creating an ephemeral session token:\nhttps://platform.openai.com/docs/api-reference/realtime-sessions\n", @@ -49,6 +50,16 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter { client ??= new OpenAI({ dangerouslyAllowBrowser }); + if (hasProvider && !props?.__resolvedApiKey) { + throw new Error( + [ + 'Cannot open Realtime WebSocket with a function-based apiKey.', + 'Use the .create() method so that the key is resolved before connecting:', + 'await OpenAIRealtimeWebSocket.create(client, { model })', + ].join('\n'), + ); + } + this.url = buildRealtimeURL(client, props.model); props.onURL?.(this.url); @@ -94,20 +105,23 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter { } } + static async create( + client: Pick, + props: { model: string; dangerouslyAllowBrowser?: boolean }, + ): Promise { + return new OpenAIRealtimeWebSocket({ ...props, __resolvedApiKey: await client._callApiKey() }, client); + } + static async azure( - client: Pick, + client: Pick, options: { deploymentName?: string; dangerouslyAllowBrowser?: boolean } = {}, ): Promise { - const token = await client._getAzureADToken(); + const isApiKeyProvider = await client._callApiKey(); function onURL(url: URL) { - if (client.apiKey !== '') { - url.searchParams.set('api-key', client.apiKey); + if (isApiKeyProvider) { + url.searchParams.set('Authorization', `Bearer ${client.apiKey}`); } else { - if (token) { - url.searchParams.set('Authorization', `Bearer ${token}`); - } else { - throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.'); - } + url.searchParams.set('api-key', client.apiKey); } } const deploymentName = options.deploymentName ?? client.deploymentName; @@ -120,6 +134,7 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter { model: deploymentName, onURL, ...(dangerouslyAllowBrowser ? { dangerouslyAllowBrowser } : {}), + __resolvedApiKey: isApiKeyProvider, }, client, ); diff --git a/src/beta/realtime/ws.ts b/src/beta/realtime/ws.ts index 3f51dfc4b..d1b834331 100644 --- a/src/beta/realtime/ws.ts +++ b/src/beta/realtime/ws.ts @@ -8,18 +8,31 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter { socket: WS.WebSocket; constructor( - props: { model: string; options?: WS.ClientOptions | undefined }, + props: { + model: string; + options?: WS.ClientOptions | undefined; + /** @internal */ __resolvedApiKey?: boolean; + }, client?: Pick, ) { super(); client ??= new OpenAI(); - + const hasProvider = typeof (client as any)?._options?.apiKey === 'function'; + if (hasProvider && !props.__resolvedApiKey) { + throw new Error( + [ + 'Cannot open Realtime WebSocket with a function-based apiKey.', + 'Use the .create() method so that the key is resolved before connecting:', + 'await OpenAIRealtimeWS.create(client, { model })', + ].join('\n'), + ); + } this.url = buildRealtimeURL(client, props.model); this.socket = new WS.WebSocket(this.url, { ...props.options, headers: { ...props.options?.headers, - ...(isAzure(client) ? {} : { Authorization: `Bearer ${client.apiKey}` }), + ...(isAzure(client) && !props.__resolvedApiKey ? {} : { Authorization: `Bearer ${client.apiKey}` }), 'OpenAI-Beta': 'realtime=v1', }, }); @@ -51,16 +64,34 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter { }); } + static async create( + client: Pick, + props: { model: string; options?: WS.ClientOptions | undefined }, + ): Promise { + return new OpenAIRealtimeWS({ ...props, __resolvedApiKey: await client._callApiKey() }, client); + } + static async azure( - client: Pick, - options: { deploymentName?: string; options?: WS.ClientOptions | undefined } = {}, + client: Pick, + props: { deploymentName?: string; options?: WS.ClientOptions | undefined } = {}, ): Promise { - const deploymentName = options.deploymentName ?? client.deploymentName; + const isApiKeyProvider = await client._callApiKey(); + const deploymentName = props.deploymentName ?? client.deploymentName; if (!deploymentName) { throw new Error('No deployment name provided'); } return new OpenAIRealtimeWS( - { model: deploymentName, options: { headers: await getAzureHeaders(client) } }, + { + model: deploymentName, + options: { + ...props.options, + headers: { + ...props.options?.headers, + ...(isApiKeyProvider ? {} : { 'api-key': client.apiKey }), + }, + }, + __resolvedApiKey: isApiKeyProvider, + }, client, ); } @@ -81,16 +112,3 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter { } } } - -async function getAzureHeaders(client: Pick) { - if (client.apiKey !== '') { - return { 'api-key': client.apiKey }; - } else { - const token = await client._getAzureADToken(); - if (token) { - return { Authorization: `Bearer ${token}` }; - } else { - throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.'); - } - } -} diff --git a/src/client.ts b/src/client.ts index a81b0f87f..306de4b7e 100644 --- a/src/client.ts +++ b/src/client.ts @@ -206,12 +206,21 @@ import { } from './internal/utils/log'; import { isEmptyObj } from './internal/utils/values'; +export type ApiKeySetter = () => Promise; + export interface ClientOptions { /** - * Defaults to process.env['OPENAI_API_KEY']. + * API key used for authentication. + * + * - Accepts either a static string or an async function that resolves to a string. + * - Defaults to process.env['OPENAI_API_KEY']. + * - When a function is provided, it is invoked before each request so you can rotate + * or refresh credentials at runtime. + * - The function must return a non-empty string; otherwise an OpenAIError is thrown. + * - If the function throws, the error is wrapped in an OpenAIError with the original + * error available as `cause`. */ - apiKey?: string | undefined; - + apiKey?: string | ApiKeySetter | undefined; /** * Defaults to process.env['OPENAI_ORG_ID']. */ @@ -349,7 +358,7 @@ export class OpenAI { }: ClientOptions = {}) { if (apiKey === undefined) { throw new Errors.OpenAIError( - "The OPENAI_API_KEY environment variable is missing or empty; either provide it, or instantiate the OpenAI client with an apiKey option, like new OpenAI({ apiKey: 'My API Key' }).", + 'Missing credentials. Please pass an `apiKey`, or set the `OPENAI_API_KEY` environment variable.', ); } @@ -385,7 +394,7 @@ export class OpenAI { this._options = options; - this.apiKey = apiKey; + this.apiKey = typeof apiKey === 'string' ? apiKey : 'Missing Key'; this.organization = organization; this.project = project; this.webhookSecret = webhookSecret; @@ -453,6 +462,31 @@ export class OpenAI { return Errors.APIError.generate(status, error, message, headers); } + async _callApiKey(): Promise { + const apiKey = this._options.apiKey; + if (typeof apiKey !== 'function') return false; + + let token: unknown; + try { + token = await apiKey(); + } catch (err: any) { + if (err instanceof Errors.OpenAIError) throw err; + throw new Errors.OpenAIError( + `Failed to get token from 'apiKey' function: ${err.message}`, + // @ts-ignore + { cause: err }, + ); + } + + if (typeof token !== 'string' || !token) { + throw new Errors.OpenAIError( + `Expected 'apiKey' function argument to return a string but it returned ${token}`, + ); + } + this.apiKey = token; + return true; + } + buildURL( path: string, query: Record | null | undefined, @@ -479,7 +513,9 @@ export class OpenAI { /** * Used as a callback for mutating the given `FinalRequestOptions` object. */ - protected async prepareOptions(options: FinalRequestOptions): Promise {} + protected async prepareOptions(options: FinalRequestOptions): Promise { + await this._callApiKey(); + } /** * Used as a callback for mutating the given `RequestInit` object. diff --git a/tests/index.test.ts b/tests/index.test.ts index c8b4b819c..661b28516 100644 --- a/tests/index.test.ts +++ b/tests/index.test.ts @@ -719,4 +719,81 @@ describe('retries', () => { ).toEqual(JSON.stringify({ a: 1 })); expect(count).toEqual(3); }); + + describe('auth', () => { + test('apiKey', async () => { + const client = new OpenAI({ + baseURL: 'http://localhost:5000/', + apiKey: 'My API Key', + }); + const { req } = await client.buildRequest({ path: '/foo', method: 'get' }); + expect(req.headers.get('authorization')).toEqual('Bearer My API Key'); + }); + + test('token', async () => { + const testFetch = async (url: any, { headers }: RequestInit = {}): Promise => { + return new Response(JSON.stringify({}), { headers: headers ?? [] }); + }; + const client = new OpenAI({ + baseURL: 'http://localhost:5000/', + apiKey: async () => 'my token', + fetch: testFetch, + }); + expect( + (await client.request({ method: 'post', path: 'https://example.com' }).asResponse()).headers.get( + 'authorization', + ), + ).toEqual('Bearer my token'); + }); + + test('token is refreshed', async () => { + let fail = true; + const testFetch = async (url: any, { headers }: RequestInit = {}): Promise => { + if (fail) { + fail = false; + return new Response(undefined, { + status: 429, + headers: { + 'Retry-After': '0.1', + }, + }); + } + return new Response(JSON.stringify({}), { + headers: headers ?? [], + }); + }; + let counter = 0; + async function apiKey() { + return `token-${counter++}`; + } + const client = new OpenAI({ + baseURL: 'http://localhost:5000/', + apiKey, + fetch: testFetch, + }); + expect( + ( + await client.chat.completions + .create({ + model: '', + messages: [{ role: 'system', content: 'Hello' }], + }) + .asResponse() + ).headers.get('authorization'), + ).toEqual('Bearer token-1'); + }); + + test('at least one', () => { + try { + new OpenAI({ + baseURL: 'http://localhost:5000/', + }); + } catch (error: any) { + expect(error).toBeInstanceOf(Error); + expect(error.message).toEqual( + 'Missing credentials. Please pass an `apiKey`, or set the `OPENAI_API_KEY` environment variable.', + ); + } + }); + }); }); diff --git a/tests/lib/azure.test.ts b/tests/lib/azure.test.ts index 49e3df1c0..b93defab0 100644 --- a/tests/lib/azure.test.ts +++ b/tests/lib/azure.test.ts @@ -268,9 +268,9 @@ describe('instantiate azure client', () => { ); }); - test.skip('AAD token is refreshed', async () => { + test('AAD token is refreshed', async () => { let fail = true; - const testFetch = async (url: RequestInfo, req: RequestInit | undefined): Promise => { + const testFetch = async (url: RequestInfo, { headers }: RequestInit = {}): Promise => { if (fail) { fail = false; return new Response(undefined, { @@ -280,8 +280,8 @@ describe('instantiate azure client', () => { }, }); } - return new Response(JSON.stringify({ auth: (req?.headers as Headers).get('authorization') }), { - headers: { 'content-type': 'application/json' }, + return new Response(JSON.stringify({}), { + headers: headers ?? [], }); }; let counter = 0; @@ -295,13 +295,15 @@ describe('instantiate azure client', () => { fetch: testFetch, }); expect( - await client.chat.completions.create({ - model, - messages: [{ role: 'system', content: 'Hello' }], - }), - ).toStrictEqual({ - auth: 'Bearer token-1', - }); + ( + await client.chat.completions + .create({ + model, + messages: [{ role: 'system', content: 'Hello' }], + }) + .asResponse() + ).headers.get('authorization'), + ).toEqual('Bearer token-1'); }); });