Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 1 addition & 50 deletions src/azure.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import * as Errors from './error';
import { FinalRequestOptions } from './internal/request-options';
import { isObj, readEnv } from './internal/utils';
import { ClientOptions, OpenAI } from './client';
import { buildHeaders, NullableHeaders } from './internal/headers';

/** API Client for interfacing with the Azure OpenAI API. */
export interface AzureClientOptions extends ClientOptions {
Expand Down Expand Up @@ -37,7 +36,6 @@ export interface AzureClientOptions extends ClientOptions {

/** API Client for interfacing with the Azure OpenAI API. */
export class AzureOpenAI extends OpenAI {
private _azureADTokenProvider: (() => Promise<string>) | undefined;
deploymentName: string | undefined;
apiVersion: string = '';

Expand Down Expand Up @@ -90,9 +88,6 @@ export class AzureOpenAI extends OpenAI {
);
}

// define a sentinel value to avoid any typing issues
apiKey ??= API_KEY_SENTINEL;

opts.defaultQuery = { ...opts.defaultQuery, 'api-version': apiVersion };

if (!baseURL) {
Expand All @@ -114,13 +109,12 @@ export class AzureOpenAI extends OpenAI {
}

super({
apiKey,
apiKey: azureADTokenProvider ?? apiKey,
baseURL,
...opts,
...(dangerouslyAllowBrowser !== undefined ? { dangerouslyAllowBrowser } : {}),
});

this._azureADTokenProvider = azureADTokenProvider;
this.apiVersion = apiVersion;
this.deploymentName = deployment;
}
Expand All @@ -140,47 +134,6 @@ export class AzureOpenAI extends OpenAI {
}
return super.buildRequest(options, props);
}

async _getAzureADToken(): Promise<string | undefined> {
if (typeof this._azureADTokenProvider === 'function') {
const token = await this._azureADTokenProvider();
if (!token || typeof token !== 'string') {
throw new Errors.OpenAIError(
`Expected 'azureADTokenProvider' argument to return a string but it returned ${token}`,
);
}
return token;
}
return undefined;
}

protected override async authHeaders(opts: FinalRequestOptions): Promise<NullableHeaders | undefined> {
return;
}

protected override async prepareOptions(opts: FinalRequestOptions): Promise<void> {
opts.headers = buildHeaders([opts.headers]);

/**
* The user should provide a bearer token provider if they want
* to use Azure AD authentication. The user shouldn't set the
* Authorization header manually because the header is overwritten
* with the Azure AD token if a bearer token provider is provided.
*/
if (opts.headers.values.get('Authorization') || opts.headers.values.get('api-key')) {
return super.prepareOptions(opts);
}

const token = await this._getAzureADToken();
if (token) {
opts.headers.values.set('Authorization', `Bearer ${token}`);
} else if (this.apiKey !== API_KEY_SENTINEL) {
opts.headers.values.set('api-key', this.apiKey);
} else {
throw new Errors.OpenAIError('Unable to handle auth');
}
return super.prepareOptions(opts);
}
}

const _deployments_endpoints = new Set([
Expand All @@ -194,5 +147,3 @@ const _deployments_endpoints = new Set([
'/batches',
'/images/edits',
]);

const API_KEY_SENTINEL = '<Missing Key>';
39 changes: 27 additions & 12 deletions src/beta/realtime/websocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,17 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter {
* @internal
*/
onURL?: (url: URL) => void;
/** Indicates the token was resolved by the factory just before connecting. @internal */
__resolvedApiKey?: boolean;
},
client?: Pick<OpenAI, 'apiKey' | 'baseURL'>,
) {
super();

const hasProvider = typeof (client as any)?._options?.apiKey === 'function';
const dangerouslyAllowBrowser =
props.dangerouslyAllowBrowser ??
(client as any)?._options?.dangerouslyAllowBrowser ??
(client?.apiKey.startsWith('ek_') ? true : null);

(client?.apiKey?.startsWith('ek_') ? true : null);
if (!dangerouslyAllowBrowser && isRunningInBrowser()) {
throw new OpenAIError(
"It looks like you're running in a browser-like environment.\n\nThis is disabled by default, as it risks exposing your secret API credentials to attackers.\n\nYou can avoid this error by creating an ephemeral session token:\nhttps://platform.openai.com/docs/api-reference/realtime-sessions\n",
Expand All @@ -49,6 +50,16 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter {

client ??= new OpenAI({ dangerouslyAllowBrowser });

if (hasProvider && !props?.__resolvedApiKey) {
throw new Error(
[
'Cannot open Realtime WebSocket with a function-based apiKey.',
'Use the .create() method so that the key is resolved before connecting:',
'await OpenAIRealtimeWebSocket.create(client, { model })',
].join('\n'),
);
}

this.url = buildRealtimeURL(client, props.model);
props.onURL?.(this.url);

Expand Down Expand Up @@ -94,20 +105,23 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter {
}
}

static async create(
client: Pick<OpenAI, 'apiKey' | 'baseURL' | '_callApiKey'>,
props: { model: string; dangerouslyAllowBrowser?: boolean },
): Promise<OpenAIRealtimeWebSocket> {
return new OpenAIRealtimeWebSocket({ ...props, __resolvedApiKey: await client._callApiKey() }, client);
}

static async azure(
client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
client: Pick<AzureOpenAI, '_callApiKey' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
options: { deploymentName?: string; dangerouslyAllowBrowser?: boolean } = {},
): Promise<OpenAIRealtimeWebSocket> {
const token = await client._getAzureADToken();
const isApiKeyProvider = await client._callApiKey();
function onURL(url: URL) {
if (client.apiKey !== '<Missing Key>') {
url.searchParams.set('api-key', client.apiKey);
if (isApiKeyProvider) {
url.searchParams.set('Authorization', `Bearer ${client.apiKey}`);
} else {
if (token) {
url.searchParams.set('Authorization', `Bearer ${token}`);
} else {
throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.');
}
url.searchParams.set('api-key', client.apiKey);
}
}
const deploymentName = options.deploymentName ?? client.deploymentName;
Expand All @@ -120,6 +134,7 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter {
model: deploymentName,
onURL,
...(dangerouslyAllowBrowser ? { dangerouslyAllowBrowser } : {}),
__resolvedApiKey: isApiKeyProvider,
},
client,
);
Expand Down
58 changes: 38 additions & 20 deletions src/beta/realtime/ws.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,31 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter {
socket: WS.WebSocket;

constructor(
props: { model: string; options?: WS.ClientOptions | undefined },
props: {
model: string;
options?: WS.ClientOptions | undefined;
/** @internal */ __resolvedApiKey?: boolean;
},
client?: Pick<OpenAI, 'apiKey' | 'baseURL'>,
) {
super();
client ??= new OpenAI();

const hasProvider = typeof (client as any)?._options?.apiKey === 'function';
if (hasProvider && !props.__resolvedApiKey) {
throw new Error(
[
'Cannot open Realtime WebSocket with a function-based apiKey.',
'Use the .create() method so that the key is resolved before connecting:',
'await OpenAIRealtimeWS.create(client, { model })',
].join('\n'),
);
}
this.url = buildRealtimeURL(client, props.model);
this.socket = new WS.WebSocket(this.url, {
...props.options,
headers: {
...props.options?.headers,
...(isAzure(client) ? {} : { Authorization: `Bearer ${client.apiKey}` }),
...(isAzure(client) && !props.__resolvedApiKey ? {} : { Authorization: `Bearer ${client.apiKey}` }),
'OpenAI-Beta': 'realtime=v1',
},
});
Expand Down Expand Up @@ -51,16 +64,34 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter {
});
}

static async create(
client: Pick<OpenAI, 'apiKey' | 'baseURL' | '_callApiKey'>,
props: { model: string; options?: WS.ClientOptions | undefined },
): Promise<OpenAIRealtimeWS> {
return new OpenAIRealtimeWS({ ...props, __resolvedApiKey: await client._callApiKey() }, client);
}

static async azure(
client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
options: { deploymentName?: string; options?: WS.ClientOptions | undefined } = {},
client: Pick<AzureOpenAI, '_callApiKey' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
props: { deploymentName?: string; options?: WS.ClientOptions | undefined } = {},
): Promise<OpenAIRealtimeWS> {
const deploymentName = options.deploymentName ?? client.deploymentName;
const isApiKeyProvider = await client._callApiKey();
const deploymentName = props.deploymentName ?? client.deploymentName;
if (!deploymentName) {
throw new Error('No deployment name provided');
}
return new OpenAIRealtimeWS(
{ model: deploymentName, options: { headers: await getAzureHeaders(client) } },
{
model: deploymentName,
options: {
...props.options,
headers: {
...props.options?.headers,
...(isApiKeyProvider ? {} : { 'api-key': client.apiKey }),
},
},
__resolvedApiKey: isApiKeyProvider,
},
client,
);
}
Expand All @@ -81,16 +112,3 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter {
}
}
}

async function getAzureHeaders(client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiKey'>) {
if (client.apiKey !== '<Missing Key>') {
return { 'api-key': client.apiKey };
} else {
const token = await client._getAzureADToken();
if (token) {
return { Authorization: `Bearer ${token}` };
} else {
throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.');
}
}
}
48 changes: 42 additions & 6 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,21 @@ import {
} from './internal/utils/log';
import { isEmptyObj } from './internal/utils/values';

export type ApiKeySetter = () => Promise<string>;

export interface ClientOptions {
/**
* Defaults to process.env['OPENAI_API_KEY'].
* API key used for authentication.
*
* - Accepts either a static string or an async function that resolves to a string.
* - Defaults to process.env['OPENAI_API_KEY'].
* - When a function is provided, it is invoked before each request so you can rotate
* or refresh credentials at runtime.
* - The function must return a non-empty string; otherwise an OpenAIError is thrown.
* - If the function throws, the error is wrapped in an OpenAIError with the original
* error available as `cause`.
*/
apiKey?: string | undefined;

apiKey?: string | ApiKeySetter | undefined;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring should mention the new function behaviour

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, addressed in 577ebc6.

/**
* Defaults to process.env['OPENAI_ORG_ID'].
*/
Expand Down Expand Up @@ -349,7 +358,7 @@ export class OpenAI {
}: ClientOptions = {}) {
if (apiKey === undefined) {
throw new Errors.OpenAIError(
"The OPENAI_API_KEY environment variable is missing or empty; either provide it, or instantiate the OpenAI client with an apiKey option, like new OpenAI({ apiKey: 'My API Key' }).",
'Missing credentials. Please pass an `apiKey`, or set the `OPENAI_API_KEY` environment variable.',
);
}

Expand Down Expand Up @@ -385,7 +394,7 @@ export class OpenAI {

this._options = options;

this.apiKey = apiKey;
this.apiKey = typeof apiKey === 'string' ? apiKey : 'Missing Key';
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: should probably be something like <not set yet>?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, we should move API_KEY_SENTINEL our of the azure specific file and use that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is only used once here so I am not sure if there is value to have it in a constant var.

this.organization = organization;
this.project = project;
this.webhookSecret = webhookSecret;
Expand Down Expand Up @@ -453,6 +462,31 @@ export class OpenAI {
return Errors.APIError.generate(status, error, message, headers);
}

async _callApiKey(): Promise<boolean> {
const apiKey = this._options.apiKey;
if (typeof apiKey !== 'function') return false;

let token: unknown;
try {
token = await apiKey();
} catch (err: any) {
if (err instanceof Errors.OpenAIError) throw err;
throw new Errors.OpenAIError(
`Failed to get token from 'apiKey' function: ${err.message}`,
// @ts-ignore
{ cause: err },
);
}

if (typeof token !== 'string' || !token) {
throw new Errors.OpenAIError(
`Expected 'apiKey' function argument to return a string but it returned ${token}`,
);
}
this.apiKey = token;
return true;
}

buildURL(
path: string,
query: Record<string, unknown> | null | undefined,
Expand All @@ -479,7 +513,9 @@ export class OpenAI {
/**
* Used as a callback for mutating the given `FinalRequestOptions` object.
*/
protected async prepareOptions(options: FinalRequestOptions): Promise<void> {}
protected async prepareOptions(options: FinalRequestOptions): Promise<void> {
await this._callApiKey();
}

/**
* Used as a callback for mutating the given `RequestInit` object.
Expand Down
Loading
Loading