diff --git a/app/components/agent-list.tsx b/app/components/agent-list.tsx
index 0484622f..c82444ec 100644
--- a/app/components/agent-list.tsx
+++ b/app/components/agent-list.tsx
@@ -1,14 +1,15 @@
-import { useAgentStore } from "../store";
+import { useAgentStore, useModelStore } from "../store";
import Image from "next/image";
import { Theme } from "@/app/store";
import { useMemo, useCallback } from "react";
import { useTheme } from "../hooks/use-theme";
import { ProviderIcon } from "./setting/provider-icon";
import { useAppConfig } from "../store";
-import { Agent, AgentSource } from "../typing";
+import { Agent, AgentSource, ModelOption } from "../typing";
import { Switch } from "./shadcn/switch";
import { Button } from "./shadcn/button";
import { useTranslation } from "react-i18next";
+import { getModelDisplayText } from "../utils/model";
interface AgentItemProps {
item: Agent;
@@ -49,6 +50,25 @@ function AgentItem({ item, onEdit }: AgentItemProps) {
return modelList.find((model) => model.model === item.model.name);
}, [modelList, item.model.name]);
+ const formattedModelListMap = useModelStore(
+ (state) => state.formattedModelList,
+ );
+
+ const getDisplayText = useCallback(
+ (modelItem: ModelOption | UserModel): string => {
+ return getModelDisplayText(
+ {
+ model: modelItem.model,
+ provider: modelItem.provider,
+ display: modelItem.display,
+ apiKey: "apiKey" in modelItem ? modelItem.apiKey : undefined,
+ },
+ formattedModelListMap,
+ );
+ },
+ [formattedModelListMap],
+ );
+
const handleSwitch = async (checked: boolean) => {
if (checked !== item.enabled) {
updateAgent({
@@ -120,7 +140,9 @@ function AgentItem({ item, onEdit }: AgentItemProps) {
{renderProviderIcon()}
-
{currentModel?.display}
+
+ {currentModel ? getDisplayText(currentModel) : ""}
+
diff --git a/app/store/agent.ts b/app/store/agent.ts
index b6048690..778a6961 100644
--- a/app/store/agent.ts
+++ b/app/store/agent.ts
@@ -152,6 +152,9 @@ export const useAgentStore = createPersistStore(
const { renderAgents, updateAgent } = get();
const agents: Agent[] = renderAgents;
agents.forEach((agent) => {
+ // return directly if the agent is using a custom model
+ if (agent.model.apiKey) return;
+
if (models.every((m) => m.model !== agent.model.name)) {
console.log("[agent handle default model]", agent.model);
const defaultModel = useAppConfig.getState().defaultModel;
diff --git a/app/store/index.ts b/app/store/index.ts
index 1b3c6844..241c48a5 100644
--- a/app/store/index.ts
+++ b/app/store/index.ts
@@ -6,3 +6,4 @@ export * from "./auth";
export * from "./setting";
export * from "./task";
export * from "./agent";
+export * from "./model";
diff --git a/app/store/model.ts b/app/store/model.ts
new file mode 100644
index 00000000..a1fb80cb
--- /dev/null
+++ b/app/store/model.ts
@@ -0,0 +1,159 @@
+import { createPersistStore } from "../utils/store";
+import { StoreKey } from "../constant";
+import {
+ ProviderOption,
+ CustomModelOption,
+ GeminiModelOptions,
+ OpenAIModelOptions,
+ AnthropicModelOptions,
+} from "../typing";
+import { fetch } from "@tauri-apps/api/http";
+import { toast } from "@/app/utils/toast";
+
+const DEFAULT_MODEL_STATE = {
+ formattedModelList: {} as Record
,
+ isLoading: false,
+};
+
+function filterGeminiModels(models: GeminiModelOptions[]) {
+ return models.filter(
+ (item) =>
+ !item.name.includes("gemma") &&
+ item.supportedGenerationMethods!.includes("generateContent"),
+ );
+}
+
+function formatProviderModels(
+ providerInfo: ProviderOption,
+ modelsData: any[],
+): CustomModelOption[] {
+ const { provider } = providerInfo || {};
+ if (provider === "openai") {
+ return modelsData.map((model: OpenAIModelOptions) => ({
+ value: model.id,
+ label: model.id,
+ }));
+ }
+
+ if (provider === "anthropic") {
+ return modelsData.map((model: AnthropicModelOptions) => ({
+ value: model.id,
+ label: model.display_name,
+ }));
+ }
+
+ if (provider === "gemini") {
+ const filteredModels = filterGeminiModels(modelsData);
+ return filteredModels.map((model: GeminiModelOptions) => ({
+ value: model.name,
+ label: model.displayName,
+ }));
+ }
+
+ return [];
+}
+
+function formatProviderToken(
+ providerInfo: ProviderOption,
+ apiKey: string,
+): string {
+ const { provider } = providerInfo || {};
+ if (provider === "openai") {
+ return `Bearer ${apiKey}`;
+ }
+
+ if (provider === "anthropic") {
+ return apiKey;
+ }
+ return apiKey;
+}
+
+export const useModelStore = createPersistStore(
+ DEFAULT_MODEL_STATE,
+ (set, _get) => {
+ const methods = {
+ getModels: async (
+ provider: string,
+ apiKey: string,
+ providerList: ProviderOption[],
+ ) => {
+ if (!apiKey) {
+ return;
+ }
+
+ const providerInfo = providerList.find((p) => p.provider === provider);
+
+ if (!providerInfo) {
+ return;
+ }
+
+ const { default_endpoint, models_path, headers = {} } = providerInfo;
+ const requestUrl = `${default_endpoint}${models_path}`;
+ const replacedHeaders = Object.fromEntries(
+ Object.entries(headers).map(([key, value]) => [
+ key,
+ // @ts-ignore
+ value.includes("api-key")
+ ? formatProviderToken(providerInfo, apiKey)
+ : value,
+ ]),
+ );
+
+ set({
+ isLoading: true,
+ });
+
+ try {
+ const modelsReturn = await fetch(requestUrl, {
+ method: "GET",
+ headers: replacedHeaders,
+ });
+
+ const { data, status } = modelsReturn;
+ if (status === 200) {
+ const models = formatProviderModels(
+ providerInfo,
+ // @ts-ignore
+ data.data || data.models || [],
+ );
+ const currentState = _get();
+ set({
+ formattedModelList: {
+ ...currentState.formattedModelList,
+ [provider]: models,
+ },
+ isLoading: false,
+ });
+ return models;
+ } else {
+ set({ isLoading: false });
+ toast.error("Failed to get models, status code is " + status);
+ return [];
+ }
+ } catch (error) {
+ const currentState = _get();
+ set({
+ formattedModelList: {
+ ...currentState.formattedModelList,
+ [provider]: [],
+ },
+ isLoading: false,
+ });
+ toast.error(error as string);
+ console.error("Error fetching models:", error);
+ return [];
+ }
+ },
+ clearModelList: () => {
+ set({
+ formattedModelList: {},
+ });
+ },
+ };
+ return methods;
+ },
+ {
+ name: StoreKey.Config + "-model",
+ version: 0.1,
+ },
+);
diff --git a/app/utils/model.ts b/app/utils/model.ts
index 6469d2c9..92ccdde3 100644
--- a/app/utils/model.ts
+++ b/app/utils/model.ts
@@ -190,3 +190,31 @@ export function isGPT4Model(modelName: string): boolean {
!modelName.startsWith("gpt-4o-mini")
);
}
+
+/**
+ * Get model display text based on whether it has apiKey
+ * If model has apiKey, get label from formattedModelListMap, otherwise use default display
+ *
+ * @param modelItem Model item with model, provider, display, and optional apiKey
+ * @param formattedModelListMap Map of provider to formatted model list
+ * @returns Display text for the model
+ */
+export function getModelDisplayText(
+ modelItem: {
+ model: string;
+ provider: string;
+ display: string;
+ apiKey?: string | undefined;
+ },
+ formattedModelListMap: Record,
+): string {
+ if (modelItem.apiKey) {
+ const modelName = modelItem.model.split(":")[1] ?? modelItem.model;
+ return (
+ formattedModelListMap[modelItem.provider]?.find(
+ (model) => model.value === modelName,
+ )?.label ?? modelItem.display
+ );
+ }
+ return modelItem.display;
+}