diff --git a/packages/opencode/src/auth/index.ts b/packages/opencode/src/auth/index.ts index ef9846a375e..98e151f2c5c 100644 --- a/packages/opencode/src/auth/index.ts +++ b/packages/opencode/src/auth/index.ts @@ -17,6 +17,7 @@ export namespace Auth { .object({ type: z.literal("api"), key: z.string(), + customBaseUrl: z.string().optional(), }) .meta({ ref: "ApiAuth" }) diff --git a/packages/opencode/src/cli/cmd/auth.ts b/packages/opencode/src/cli/cmd/auth.ts index 382232f5ace..09ec964c3bf 100644 --- a/packages/opencode/src/cli/cmd/auth.ts +++ b/packages/opencode/src/cli/cmd/auth.ts @@ -253,6 +253,22 @@ export const AuthLoginCommand = cmd({ prompts.log.info("You can create an api key at https://vercel.link/ai-gateway-token") } + let customBaseUrl: string | undefined + if (provider === "cloudflare-workers-ai") { + if (!providers[provider].api) { + UI.error("Cloudflare API URL not found") + return + } + + const accountIdInput = await prompts.text({ + message: "Enter your Cloudflare Account ID", + validate: (x) => (x && x.length > 0 ? undefined : "Required"), + }) + if (prompts.isCancel(accountIdInput)) throw new UI.CancelledError() + + customBaseUrl = providers[provider].api?.replace("{CLOUDFLARE_ACCOUNT_ID}", accountIdInput) + } + const key = await prompts.password({ message: "Enter your API key", validate: (x) => (x && x.length > 0 ? undefined : "Required"), @@ -261,6 +277,7 @@ export const AuthLoginCommand = cmd({ await Auth.set(provider, { type: "api", key, + ...(customBaseUrl && { customBaseUrl }), }) prompts.outro("Done") diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index 2d30a738ae2..5ca80736e4d 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -18,7 +18,7 @@ export namespace Provider { type CustomLoader = ( provider: ModelsDev.Provider, - api?: string, + auth?: Auth.Info, ) => Promise<{ autoload: boolean getModel?: (sdk: any, modelID: string) => Promise @@ -143,6 +143,18 @@ export namespace Provider { }, } }, + "cloudflare-workers-ai": async (_provider, auth) => { + if (!auth || auth.type !== "api" || !auth.customBaseUrl) { + return { autoload: false, options: {} } + } + + return { + autoload: true, + options: { + baseURL: auth.customBaseUrl, + }, + } + }, } const state = Instance.state(async () => { @@ -256,18 +268,23 @@ export namespace Provider { ) } - // load apikeys + // load apikeys and customBaseUrl for (const [providerID, provider] of Object.entries(await Auth.all())) { if (disabled.has(providerID)) continue if (provider.type === "api") { - mergeProvider(providerID, { apiKey: provider.key }, "api") + const options: Record = { apiKey: provider.key } + if (provider.customBaseUrl) { + options["baseURL"] = provider.customBaseUrl + } + mergeProvider(providerID, options, "api") } } // load custom for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) { if (disabled.has(providerID)) continue - const result = await fn(database[providerID]) + const auth = await Auth.get(providerID) + const result = await fn(database[providerID], auth) if (result && (result.autoload || providers[providerID])) { mergeProvider(providerID, result.options ?? {}, "custom", result.getModel) }