diff --git a/app/components/agent-list.tsx b/app/components/agent-list.tsx index 0484622f..c82444ec 100644 --- a/app/components/agent-list.tsx +++ b/app/components/agent-list.tsx @@ -1,14 +1,15 @@ -import { useAgentStore } from "../store"; +import { useAgentStore, useModelStore } from "../store"; import Image from "next/image"; import { Theme } from "@/app/store"; import { useMemo, useCallback } from "react"; import { useTheme } from "../hooks/use-theme"; import { ProviderIcon } from "./setting/provider-icon"; import { useAppConfig } from "../store"; -import { Agent, AgentSource } from "../typing"; +import { Agent, AgentSource, ModelOption } from "../typing"; import { Switch } from "./shadcn/switch"; import { Button } from "./shadcn/button"; import { useTranslation } from "react-i18next"; +import { getModelDisplayText } from "../utils/model"; interface AgentItemProps { item: Agent; @@ -49,6 +50,25 @@ function AgentItem({ item, onEdit }: AgentItemProps) { return modelList.find((model) => model.model === item.model.name); }, [modelList, item.model.name]); + const formattedModelListMap = useModelStore( + (state) => state.formattedModelList, + ); + + const getDisplayText = useCallback( + (modelItem: ModelOption | UserModel): string => { + return getModelDisplayText( + { + model: modelItem.model, + provider: modelItem.provider, + display: modelItem.display, + apiKey: "apiKey" in modelItem ? modelItem.apiKey : undefined, + }, + formattedModelListMap, + ); + }, + [formattedModelListMap], + ); + const handleSwitch = async (checked: boolean) => { if (checked !== item.enabled) { updateAgent({ @@ -120,7 +140,9 @@ function AgentItem({ item, onEdit }: AgentItemProps) {
{renderProviderIcon()}
-

{currentModel?.display}

+

+ {currentModel ? getDisplayText(currentModel) : ""} +

diff --git a/app/store/agent.ts b/app/store/agent.ts index b6048690..778a6961 100644 --- a/app/store/agent.ts +++ b/app/store/agent.ts @@ -152,6 +152,9 @@ export const useAgentStore = createPersistStore( const { renderAgents, updateAgent } = get(); const agents: Agent[] = renderAgents; agents.forEach((agent) => { + // return directly if the agent is using a custom model + if (agent.model.apiKey) return; + if (models.every((m) => m.model !== agent.model.name)) { console.log("[agent handle default model]", agent.model); const defaultModel = useAppConfig.getState().defaultModel; diff --git a/app/store/index.ts b/app/store/index.ts index 1b3c6844..241c48a5 100644 --- a/app/store/index.ts +++ b/app/store/index.ts @@ -6,3 +6,4 @@ export * from "./auth"; export * from "./setting"; export * from "./task"; export * from "./agent"; +export * from "./model"; diff --git a/app/store/model.ts b/app/store/model.ts new file mode 100644 index 00000000..a1fb80cb --- /dev/null +++ b/app/store/model.ts @@ -0,0 +1,159 @@ +import { createPersistStore } from "../utils/store"; +import { StoreKey } from "../constant"; +import { + ProviderOption, + CustomModelOption, + GeminiModelOptions, + OpenAIModelOptions, + AnthropicModelOptions, +} from "../typing"; +import { fetch } from "@tauri-apps/api/http"; +import { toast } from "@/app/utils/toast"; + +const DEFAULT_MODEL_STATE = { + formattedModelList: {} as Record, + isLoading: false, +}; + +function filterGeminiModels(models: GeminiModelOptions[]) { + return models.filter( + (item) => + !item.name.includes("gemma") && + item.supportedGenerationMethods!.includes("generateContent"), + ); +} + +function formatProviderModels( + providerInfo: ProviderOption, + modelsData: any[], +): CustomModelOption[] { + const { provider } = providerInfo || {}; + if (provider === "openai") { + return modelsData.map((model: OpenAIModelOptions) => ({ + value: model.id, + label: model.id, + })); + } + + if (provider === "anthropic") { + return modelsData.map((model: AnthropicModelOptions) => ({ + value: model.id, + label: model.display_name, + })); + } + + if (provider === "gemini") { + const filteredModels = filterGeminiModels(modelsData); + return filteredModels.map((model: GeminiModelOptions) => ({ + value: model.name, + label: model.displayName, + })); + } + + return []; +} + +function formatProviderToken( + providerInfo: ProviderOption, + apiKey: string, +): string { + const { provider } = providerInfo || {}; + if (provider === "openai") { + return `Bearer ${apiKey}`; + } + + if (provider === "anthropic") { + return apiKey; + } + return apiKey; +} + +export const useModelStore = createPersistStore( + DEFAULT_MODEL_STATE, + (set, _get) => { + const methods = { + getModels: async ( + provider: string, + apiKey: string, + providerList: ProviderOption[], + ) => { + if (!apiKey) { + return; + } + + const providerInfo = providerList.find((p) => p.provider === provider); + + if (!providerInfo) { + return; + } + + const { default_endpoint, models_path, headers = {} } = providerInfo; + const requestUrl = `${default_endpoint}${models_path}`; + const replacedHeaders = Object.fromEntries( + Object.entries(headers).map(([key, value]) => [ + key, + // @ts-ignore + value.includes("api-key") + ? formatProviderToken(providerInfo, apiKey) + : value, + ]), + ); + + set({ + isLoading: true, + }); + + try { + const modelsReturn = await fetch(requestUrl, { + method: "GET", + headers: replacedHeaders, + }); + + const { data, status } = modelsReturn; + if (status === 200) { + const models = formatProviderModels( + providerInfo, + // @ts-ignore + data.data || data.models || [], + ); + const currentState = _get(); + set({ + formattedModelList: { + ...currentState.formattedModelList, + [provider]: models, + }, + isLoading: false, + }); + return models; + } else { + set({ isLoading: false }); + toast.error("Failed to get models, status code is " + status); + return []; + } + } catch (error) { + const currentState = _get(); + set({ + formattedModelList: { + ...currentState.formattedModelList, + [provider]: [], + }, + isLoading: false, + }); + toast.error(error as string); + console.error("Error fetching models:", error); + return []; + } + }, + clearModelList: () => { + set({ + formattedModelList: {}, + }); + }, + }; + return methods; + }, + { + name: StoreKey.Config + "-model", + version: 0.1, + }, +); diff --git a/app/utils/model.ts b/app/utils/model.ts index 6469d2c9..92ccdde3 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -190,3 +190,31 @@ export function isGPT4Model(modelName: string): boolean { !modelName.startsWith("gpt-4o-mini") ); } + +/** + * Get model display text based on whether it has apiKey + * If model has apiKey, get label from formattedModelListMap, otherwise use default display + * + * @param modelItem Model item with model, provider, display, and optional apiKey + * @param formattedModelListMap Map of provider to formatted model list + * @returns Display text for the model + */ +export function getModelDisplayText( + modelItem: { + model: string; + provider: string; + display: string; + apiKey?: string | undefined; + }, + formattedModelListMap: Record, +): string { + if (modelItem.apiKey) { + const modelName = modelItem.model.split(":")[1] ?? modelItem.model; + return ( + formattedModelListMap[modelItem.provider]?.find( + (model) => model.value === modelName, + )?.label ?? modelItem.display + ); + } + return modelItem.display; +}