diff --git a/apps/server/src/doctor.ts b/apps/server/src/doctor.ts index 3326b506..2330e300 100644 --- a/apps/server/src/doctor.ts +++ b/apps/server/src/doctor.ts @@ -34,12 +34,15 @@ const PROVIDER_LABELS: Record = { codex: "Codex (OpenAI)", claudeAgent: "Claude Code", copilot: "GitHub Copilot", + gemini: "Gemini CLI", + openclaw: "OpenClaw", }; function printStatus(status: ServerProviderStatus): void { const icon = STATUS_ICONS[status.status] ?? "?"; const label = PROVIDER_LABELS[status.provider] ?? status.provider; - const auth = AUTH_LABELS[status.authStatus] ?? status.authStatus; + const authStatus = status.authStatus ?? status.auth?.status ?? "unknown"; + const auth = AUTH_LABELS[authStatus] ?? authStatus; console.log(""); console.log(` ${icon} ${label}`); diff --git a/apps/server/src/orchestration/Layers/ProjectionOverviewQuery.ts b/apps/server/src/orchestration/Layers/ProjectionOverviewQuery.ts index 0b6de5bb..3c25a50a 100644 --- a/apps/server/src/orchestration/Layers/ProjectionOverviewQuery.ts +++ b/apps/server/src/orchestration/Layers/ProjectionOverviewQuery.ts @@ -184,6 +184,7 @@ const makeProjectionOverviewQuery = Effect.gen(function* () { p.title, p.workspace_root AS "workspaceRoot", p.default_model AS "defaultModel", + p.default_model_selection AS "defaultModelSelection", p.scripts_json AS "scripts", p.created_at AS "createdAt", p.updated_at AS "updatedAt", @@ -209,6 +210,7 @@ const makeProjectionOverviewQuery = Effect.gen(function* () { t.project_id AS "projectId", t.title, t.model, + t.model_selection AS "modelSelection", t.runtime_mode AS "runtimeMode", t.interaction_mode AS "interactionMode", t.branch, @@ -361,6 +363,7 @@ const makeProjectionOverviewQuery = Effect.gen(function* () { title: row.title, workspaceRoot: row.workspaceRoot, defaultModel: row.defaultModel, + defaultModelSelection: row.defaultModelSelection, scripts: row.scripts, activeThreadCount: row.activeThreadCount, createdAt: row.createdAt, @@ -374,6 +377,7 @@ const makeProjectionOverviewQuery = Effect.gen(function* () { projectId: row.projectId, title: row.title, model: row.model, + modelSelection: row.modelSelection, runtimeMode: row.runtimeMode, interactionMode: row.interactionMode, branch: row.branch, diff --git a/apps/server/src/orchestration/Layers/ProjectionPipeline.ts b/apps/server/src/orchestration/Layers/ProjectionPipeline.ts index 3ce3b5b1..fa152f0d 100644 --- a/apps/server/src/orchestration/Layers/ProjectionPipeline.ts +++ b/apps/server/src/orchestration/Layers/ProjectionPipeline.ts @@ -398,6 +398,7 @@ const makeOrchestrationProjectionPipeline = Effect.gen(function* () { title: event.payload.title, workspaceRoot: event.payload.workspaceRoot, defaultModel: event.payload.defaultModel, + defaultModelSelection: event.payload.defaultModelSelection, scripts: event.payload.scripts, createdAt: event.payload.createdAt, updatedAt: event.payload.updatedAt, @@ -421,6 +422,9 @@ const makeOrchestrationProjectionPipeline = Effect.gen(function* () { ...(event.payload.defaultModel !== undefined ? { defaultModel: event.payload.defaultModel } : {}), + ...(event.payload.defaultModelSelection !== undefined + ? { defaultModelSelection: event.payload.defaultModelSelection } + : {}), ...(event.payload.scripts !== undefined ? { scripts: event.payload.scripts } : {}), updatedAt: event.payload.updatedAt, }); @@ -456,6 +460,7 @@ const makeOrchestrationProjectionPipeline = Effect.gen(function* () { projectId: event.payload.projectId, title: event.payload.title, model: event.payload.model, + modelSelection: event.payload.modelSelection, runtimeMode: event.payload.runtimeMode, interactionMode: event.payload.interactionMode, branch: event.payload.branch, @@ -479,6 +484,9 @@ const makeOrchestrationProjectionPipeline = Effect.gen(function* () { ...existingRow.value, ...(event.payload.title !== undefined ? { title: event.payload.title } : {}), ...(event.payload.model !== undefined ? { model: event.payload.model } : {}), + ...(event.payload.modelSelection !== undefined + ? { modelSelection: event.payload.modelSelection } + : {}), ...(event.payload.branch !== undefined ? { branch: event.payload.branch } : {}), ...(event.payload.worktreePath !== undefined ? { worktreePath: event.payload.worktreePath } diff --git a/apps/server/src/orchestration/Layers/ProjectionSnapshotQuery.ts b/apps/server/src/orchestration/Layers/ProjectionSnapshotQuery.ts index b4ca0588..78b0c644 100644 --- a/apps/server/src/orchestration/Layers/ProjectionSnapshotQuery.ts +++ b/apps/server/src/orchestration/Layers/ProjectionSnapshotQuery.ts @@ -153,6 +153,7 @@ const makeProjectionSnapshotQuery = Effect.gen(function* () { title, workspace_root AS "workspaceRoot", default_model AS "defaultModel", + default_model_selection AS "defaultModelSelection", scripts_json AS "scripts", created_at AS "createdAt", updated_at AS "updatedAt", @@ -172,6 +173,7 @@ const makeProjectionSnapshotQuery = Effect.gen(function* () { project_id AS "projectId", title, model, + model_selection AS "modelSelection", runtime_mode AS "runtimeMode", interaction_mode AS "interactionMode", branch, @@ -549,6 +551,7 @@ const makeProjectionSnapshotQuery = Effect.gen(function* () { title: row.title, workspaceRoot: row.workspaceRoot, defaultModel: row.defaultModel, + defaultModelSelection: row.defaultModelSelection, scripts: row.scripts, createdAt: row.createdAt, updatedAt: row.updatedAt, @@ -563,6 +566,7 @@ const makeProjectionSnapshotQuery = Effect.gen(function* () { projectId: row.projectId, title: row.title, model: row.model, + modelSelection: row.modelSelection, runtimeMode: row.runtimeMode, interactionMode: row.interactionMode, branch: row.branch, diff --git a/apps/server/src/orchestration/Layers/ProjectionThreadDetailQuery.ts b/apps/server/src/orchestration/Layers/ProjectionThreadDetailQuery.ts index 813744e9..a7f7a7d1 100644 --- a/apps/server/src/orchestration/Layers/ProjectionThreadDetailQuery.ts +++ b/apps/server/src/orchestration/Layers/ProjectionThreadDetailQuery.ts @@ -371,6 +371,7 @@ const makeProjectionThreadDetailQuery = Effect.gen(function* () { projectId: threadRow.value.projectId, title: threadRow.value.title, model: threadRow.value.model, + modelSelection: threadRow.value.modelSelection, runtimeMode: threadRow.value.runtimeMode, interactionMode: threadRow.value.interactionMode, branch: threadRow.value.branch, diff --git a/apps/server/src/orchestration/Layers/ProviderCommandReactor.ts b/apps/server/src/orchestration/Layers/ProviderCommandReactor.ts index 6276523f..496c7821 100644 --- a/apps/server/src/orchestration/Layers/ProviderCommandReactor.ts +++ b/apps/server/src/orchestration/Layers/ProviderCommandReactor.ts @@ -5,6 +5,7 @@ import { CommandId, DEFAULT_GIT_TEXT_GENERATION_MODEL, EventId, + type ModelSelection, type OrchestrationEvent, type ProjectId, type ProviderModelOptions, @@ -30,6 +31,12 @@ import { type ProviderCommandReactorShape, } from "../Services/ProviderCommandReactor.ts"; import { inferProviderForModel } from "@okcode/shared/model"; +import { + getModelSelectionModel, + getModelSelectionOptions, + getModelSelectionProvider, + toCanonicalModelSelection, +} from "@okcode/shared/modelSelection"; import { resolveRuntimeEnvironment } from "../../runtimeEnvironment.ts"; type ProviderIntentEvent = Extract< @@ -68,6 +75,16 @@ function mapProviderSessionStatusToOrchestrationStatus( } } +function resolveThreadModelSelection(thread: { + readonly model: string; + readonly modelSelection?: ModelSelection | null | undefined; +}): ModelSelection { + return ( + thread.modelSelection ?? + toCanonicalModelSelection(inferProviderForModel(thread.model), thread.model, undefined) + ); +} + const turnStartKeyForEvent = (event: ProviderIntentEvent): string => event.commandId !== null ? `command:${event.commandId}` : `event:${event.eventId}`; @@ -257,6 +274,7 @@ const make = Effect.gen(function* () { threadId: ThreadId, createdAt: string, options?: { + readonly modelSelection?: ModelSelection; readonly provider?: ProviderKind; readonly model?: string; readonly modelOptions?: ProviderModelOptions; @@ -275,7 +293,18 @@ const make = Effect.gen(function* () { ) ? thread.session.providerName : undefined; - const threadProvider: ProviderKind = currentProvider ?? inferProviderForModel(thread.model); + const threadSelection = resolveThreadModelSelection(thread); + const requestedSelection = + options?.modelSelection ?? + (options?.provider || options?.model || options?.modelOptions + ? toCanonicalModelSelection( + options?.provider ?? threadSelection.provider, + options?.model ?? threadSelection.model, + options?.modelOptions ?? getModelSelectionOptions(threadSelection), + ) + : threadSelection); + const threadProvider: ProviderKind = + currentProvider ?? getModelSelectionProvider(threadSelection); if (options?.provider !== undefined && options.provider !== threadProvider) { return yield* new ProviderAdapterRequestError({ provider: threadProvider, @@ -283,9 +312,19 @@ const make = Effect.gen(function* () { detail: `Thread '${threadId}' is bound to provider '${threadProvider}' and cannot switch to '${options.provider}'.`, }); } + if ( + options?.modelSelection !== undefined && + getModelSelectionProvider(options.modelSelection) !== threadProvider + ) { + return yield* new ProviderAdapterRequestError({ + provider: threadProvider, + method: "thread.turn.start", + detail: `Thread '${threadId}' is bound to provider '${threadProvider}' and cannot use model selection for '${getModelSelectionProvider(options.modelSelection)}'.`, + }); + } if ( options?.model !== undefined && - inferProviderForModel(options.model, threadProvider) !== threadProvider + getModelSelectionProvider(requestedSelection) !== threadProvider ) { return yield* new ProviderAdapterRequestError({ provider: threadProvider, @@ -294,7 +333,9 @@ const make = Effect.gen(function* () { }); } const preferredProvider: ProviderKind = currentProvider ?? threadProvider; - const desiredModel = options?.model ?? thread.model; + const desiredModel = getModelSelectionModel(requestedSelection); + const desiredModelOptions = + options?.modelOptions ?? getModelSelectionOptions(requestedSelection); const { cwd: effectiveCwd, staleWorktreePath } = resolveSessionCwd({ thread, projects: readModel.projects, @@ -330,7 +371,7 @@ const make = Effect.gen(function* () { : {}), ...(effectiveCwd ? { cwd: effectiveCwd } : {}), ...(desiredModel ? { model: desiredModel } : {}), - ...(options?.modelOptions !== undefined ? { modelOptions: options.modelOptions } : {}), + ...(desiredModelOptions !== undefined ? { modelOptions: desiredModelOptions } : {}), ...(options?.providerOptions !== undefined ? { providerOptions: options.providerOptions } : {}), @@ -376,8 +417,8 @@ const make = Effect.gen(function* () { const previousModelOptions = threadModelOptions.get(threadId); const shouldRestartForModelOptionsChange = currentProvider === "claudeAgent" && - options?.modelOptions !== undefined && - !sameModelOptions(previousModelOptions, options.modelOptions); + desiredModelOptions !== undefined && + !sameModelOptions(previousModelOptions, desiredModelOptions); const activeSessionCwd = activeSession?.cwd; const shouldRestartForCwdChange = staleWorktreePath !== null || activeSessionCwd !== effectiveCwd; @@ -441,6 +482,7 @@ const make = Effect.gen(function* () { readonly messageText: string; readonly providerInput?: string; readonly attachments?: ReadonlyArray; + readonly modelSelection?: ModelSelection; readonly provider?: ProviderKind; readonly model?: string; readonly modelOptions?: ProviderModelOptions; @@ -453,6 +495,7 @@ const make = Effect.gen(function* () { return; } yield* ensureSessionForThread(input.threadId, input.createdAt, { + ...(input.modelSelection !== undefined ? { modelSelection: input.modelSelection } : {}), ...(input.provider !== undefined ? { provider: input.provider } : {}), ...(input.model !== undefined ? { model: input.model } : {}), ...(input.modelOptions !== undefined ? { modelOptions: input.modelOptions } : {}), @@ -463,6 +506,11 @@ const make = Effect.gen(function* () { } if (input.modelOptions !== undefined) { threadModelOptions.set(input.threadId, input.modelOptions); + } else if (input.modelSelection?.options) { + const selectionOptions = getModelSelectionOptions(input.modelSelection); + if (selectionOptions !== undefined) { + threadModelOptions.set(input.threadId, selectionOptions); + } } const normalizedInput = toNonEmptyProviderInput(input.providerInput ?? input.messageText); const normalizedAttachments = input.attachments ?? []; @@ -476,13 +524,20 @@ const make = Effect.gen(function* () { ? "in-session" : (yield* providerService.getCapabilities(activeSession.provider)).sessionModelSwitch; const modelForTurn = sessionModelSwitch === "unsupported" ? activeSession?.model : input.model; + const requestedModelForTurn = input.modelSelection + ? getModelSelectionModel(input.modelSelection) + : modelForTurn; + const requestedModelOptionsForTurn = + input.modelOptions ?? getModelSelectionOptions(input.modelSelection); yield* providerService.sendTurn({ threadId: input.threadId, ...(normalizedInput ? { input: normalizedInput } : {}), ...(normalizedAttachments.length > 0 ? { attachments: normalizedAttachments } : {}), - ...(modelForTurn !== undefined ? { model: modelForTurn } : {}), - ...(input.modelOptions !== undefined ? { modelOptions: input.modelOptions } : {}), + ...(requestedModelForTurn !== undefined ? { model: requestedModelForTurn } : {}), + ...(requestedModelOptionsForTurn !== undefined + ? { modelOptions: requestedModelOptionsForTurn } + : {}), ...(input.interactionMode !== undefined ? { interactionMode: input.interactionMode } : {}), }); }); @@ -598,6 +653,9 @@ const make = Effect.gen(function* () { ? { providerInput: event.payload.providerInput } : {}), ...(message.attachments !== undefined ? { attachments: message.attachments } : {}), + ...(event.payload.modelSelection != null + ? { modelSelection: event.payload.modelSelection } + : {}), ...(event.payload.provider !== undefined ? { provider: event.payload.provider } : {}), ...(event.payload.model !== undefined ? { model: event.payload.model } : {}), ...(event.payload.modelOptions !== undefined diff --git a/apps/server/src/orchestration/decider.ts b/apps/server/src/orchestration/decider.ts index df54fb91..0694ba61 100644 --- a/apps/server/src/orchestration/decider.ts +++ b/apps/server/src/orchestration/decider.ts @@ -6,6 +6,8 @@ import type { ThreadId, } from "@okcode/contracts"; import { Effect } from "effect"; +import { inferProviderForModel } from "@okcode/shared/model"; +import { toCanonicalModelSelection } from "@okcode/shared/modelSelection"; import { OrchestrationCommandInvariantError } from "./Errors.ts"; import { @@ -127,6 +129,15 @@ export const decideOrchestrationCommand = Effect.fn("decideOrchestrationCommand" title: command.title, workspaceRoot: command.workspaceRoot, defaultModel: command.defaultModel ?? null, + defaultModelSelection: + command.defaultModelSelection ?? + (command.defaultModel + ? toCanonicalModelSelection( + inferProviderForModel(command.defaultModel), + command.defaultModel, + undefined, + ) + : null), scripts: command.scripts ?? [], createdAt: command.createdAt, updatedAt: command.createdAt, @@ -184,6 +195,17 @@ export const decideOrchestrationCommand = Effect.fn("decideOrchestrationCommand" ...(command.title !== undefined ? { title: command.title } : {}), ...(command.workspaceRoot !== undefined ? { workspaceRoot: command.workspaceRoot } : {}), ...(command.defaultModel !== undefined ? { defaultModel: command.defaultModel } : {}), + ...(command.defaultModelSelection !== undefined + ? { defaultModelSelection: command.defaultModelSelection } + : command.defaultModel !== undefined + ? { + defaultModelSelection: toCanonicalModelSelection( + inferProviderForModel(command.defaultModel), + command.defaultModel, + undefined, + ), + } + : {}), ...(command.scripts !== undefined ? { scripts: command.scripts } : {}), updatedAt: occurredAt, }, @@ -245,6 +267,13 @@ export const decideOrchestrationCommand = Effect.fn("decideOrchestrationCommand" projectId: command.projectId, title: command.title, model: command.model, + modelSelection: + command.modelSelection ?? + toCanonicalModelSelection( + inferProviderForModel(command.model), + command.model, + undefined, + ), runtimeMode: command.runtimeMode, interactionMode: command.interactionMode, branch: command.branch, @@ -306,6 +335,17 @@ export const decideOrchestrationCommand = Effect.fn("decideOrchestrationCommand" threadId: command.threadId, ...(command.title !== undefined ? { title: command.title } : {}), ...(command.model !== undefined ? { model: command.model } : {}), + ...(command.modelSelection !== undefined + ? { modelSelection: command.modelSelection } + : command.model !== undefined + ? { + modelSelection: toCanonicalModelSelection( + inferProviderForModel(command.model), + command.model, + undefined, + ), + } + : {}), ...(command.branch !== undefined ? { branch: command.branch } : {}), ...(command.worktreePath !== undefined ? { worktreePath: command.worktreePath } : {}), ...(command.githubRef !== undefined ? { githubRef: command.githubRef } : {}), @@ -423,6 +463,9 @@ export const decideOrchestrationCommand = Effect.fn("decideOrchestrationCommand" threadId: command.threadId, messageId: command.message.messageId, ...(command.providerInput !== undefined ? { providerInput: command.providerInput } : {}), + ...(command.modelSelection !== undefined + ? { modelSelection: command.modelSelection } + : {}), ...(command.provider !== undefined ? { provider: command.provider } : {}), ...(command.model !== undefined ? { model: command.model } : {}), ...(command.modelOptions !== undefined ? { modelOptions: command.modelOptions } : {}), diff --git a/apps/server/src/orchestration/projector.ts b/apps/server/src/orchestration/projector.ts index 70d37746..0241f8eb 100644 --- a/apps/server/src/orchestration/projector.ts +++ b/apps/server/src/orchestration/projector.ts @@ -6,6 +6,8 @@ import { OrchestrationThread, } from "@okcode/contracts"; import { Effect, Schema } from "effect"; +import { inferProviderForModel } from "@okcode/shared/model"; +import { toCanonicalModelSelection } from "@okcode/shared/modelSelection"; import { toProjectorDecodeError, type OrchestrationProjectorDecodeError } from "./Errors.ts"; import { @@ -182,6 +184,7 @@ export function projectEvent( title: payload.title, workspaceRoot: payload.workspaceRoot, defaultModel: payload.defaultModel, + defaultModelSelection: payload.defaultModelSelection, scripts: payload.scripts, createdAt: payload.createdAt, updatedAt: payload.updatedAt, @@ -214,6 +217,17 @@ export function projectEvent( ...(payload.defaultModel !== undefined ? { defaultModel: payload.defaultModel } : {}), + ...(payload.defaultModelSelection !== undefined + ? { defaultModelSelection: payload.defaultModelSelection } + : payload.defaultModel !== undefined + ? { + defaultModelSelection: toCanonicalModelSelection( + inferProviderForModel(payload.defaultModel), + payload.defaultModel, + undefined, + ), + } + : {}), ...(payload.scripts !== undefined ? { scripts: payload.scripts } : {}), updatedAt: payload.updatedAt, } @@ -253,6 +267,13 @@ export function projectEvent( projectId: payload.projectId, title: payload.title, model: payload.model, + modelSelection: + payload.modelSelection ?? + toCanonicalModelSelection( + inferProviderForModel(payload.model), + payload.model, + undefined, + ), runtimeMode: payload.runtimeMode, interactionMode: payload.interactionMode, branch: payload.branch, @@ -303,6 +324,17 @@ export function projectEvent( threads: updateThread(nextBase.threads, payload.threadId, { ...(payload.title !== undefined ? { title: payload.title } : {}), ...(payload.model !== undefined ? { model: payload.model } : {}), + ...(payload.modelSelection !== undefined + ? { modelSelection: payload.modelSelection } + : payload.model !== undefined + ? { + modelSelection: toCanonicalModelSelection( + inferProviderForModel(payload.model), + payload.model, + undefined, + ), + } + : {}), ...(payload.branch !== undefined ? { branch: payload.branch } : {}), ...(payload.worktreePath !== undefined ? { worktreePath: payload.worktreePath } : {}), ...(payload.githubRef !== undefined ? { githubRef: payload.githubRef } : {}), diff --git a/apps/server/src/persistence/Layers/ProjectionProjects.ts b/apps/server/src/persistence/Layers/ProjectionProjects.ts index f5d5c9ee..39deb63f 100644 --- a/apps/server/src/persistence/Layers/ProjectionProjects.ts +++ b/apps/server/src/persistence/Layers/ProjectionProjects.ts @@ -15,7 +15,12 @@ import { ProjectScript } from "@okcode/contracts"; // Makes sure that the scripts are parsed from the JSON string the DB returns const ProjectionProjectDbRowSchema = ProjectionProject.mapFields( - Struct.assign({ scripts: Schema.fromJsonString(Schema.Array(ProjectScript)) }), + Struct.assign({ + defaultModelSelection: Schema.NullOr( + Schema.fromJsonString(ProjectionProject.fields.defaultModelSelection), + ), + scripts: Schema.fromJsonString(Schema.Array(ProjectScript)), + }), ); function toPersistenceSqlOrDecodeError(sqlOperation: string, decodeOperation: string) { @@ -37,6 +42,7 @@ const makeProjectionProjectRepository = Effect.gen(function* () { title, workspace_root, default_model, + default_model_selection, scripts_json, created_at, updated_at, @@ -47,6 +53,7 @@ const makeProjectionProjectRepository = Effect.gen(function* () { ${row.title}, ${row.workspaceRoot}, ${row.defaultModel}, + ${row.defaultModelSelection}, ${row.scripts}, ${row.createdAt}, ${row.updatedAt}, @@ -57,6 +64,7 @@ const makeProjectionProjectRepository = Effect.gen(function* () { title = excluded.title, workspace_root = excluded.workspace_root, default_model = excluded.default_model, + default_model_selection = excluded.default_model_selection, scripts_json = excluded.scripts_json, created_at = excluded.created_at, updated_at = excluded.updated_at, @@ -74,6 +82,7 @@ const makeProjectionProjectRepository = Effect.gen(function* () { title, workspace_root AS "workspaceRoot", default_model AS "defaultModel", + default_model_selection AS "defaultModelSelection", scripts_json AS "scripts", created_at AS "createdAt", updated_at AS "updatedAt", @@ -93,6 +102,7 @@ const makeProjectionProjectRepository = Effect.gen(function* () { title, workspace_root AS "workspaceRoot", default_model AS "defaultModel", + default_model_selection AS "defaultModelSelection", scripts_json AS "scripts", created_at AS "createdAt", updated_at AS "updatedAt", @@ -112,7 +122,10 @@ const makeProjectionProjectRepository = Effect.gen(function* () { }); const upsert: ProjectionProjectRepositoryShape["upsert"] = (row) => - upsertProjectionProjectRow(row).pipe( + upsertProjectionProjectRow({ + ...row, + defaultModelSelection: row.defaultModelSelection ?? null, + }).pipe( Effect.mapError( toPersistenceSqlOrDecodeError( "ProjectionProjectRepository.upsert:query", diff --git a/apps/server/src/persistence/Layers/ProjectionThreads.ts b/apps/server/src/persistence/Layers/ProjectionThreads.ts index bae3adba..0115818e 100644 --- a/apps/server/src/persistence/Layers/ProjectionThreads.ts +++ b/apps/server/src/persistence/Layers/ProjectionThreads.ts @@ -1,6 +1,6 @@ import * as SqlClient from "effect/unstable/sql/SqlClient"; import * as SqlSchema from "effect/unstable/sql/SqlSchema"; -import { Effect, Layer } from "effect"; +import { Effect, Layer, Schema, Struct } from "effect"; import { toPersistenceSqlError } from "../Errors.ts"; import { @@ -12,11 +12,22 @@ import { type ProjectionThreadRepositoryShape, } from "../Services/ProjectionThreads.ts"; +// Schema.NullOr wraps fromJsonString so that NULL database rows decode to null +// rather than failing to parse. Pre-migration rows and threads without a model +// selection will have a NULL model_selection column. +const ProjectionThreadDbRowSchema = ProjectionThread.mapFields( + Struct.assign({ + modelSelection: Schema.NullOr( + Schema.fromJsonString(ProjectionThread.fields.modelSelection), + ), + }), +); + const makeProjectionThreadRepository = Effect.gen(function* () { const sql = yield* SqlClient.SqlClient; const upsertProjectionThreadRow = SqlSchema.void({ - Request: ProjectionThread, + Request: ProjectionThreadDbRowSchema, execute: (row) => sql` INSERT INTO projection_threads ( @@ -24,6 +35,7 @@ const makeProjectionThreadRepository = Effect.gen(function* () { project_id, title, model, + model_selection, runtime_mode, interaction_mode, branch, @@ -39,6 +51,7 @@ const makeProjectionThreadRepository = Effect.gen(function* () { ${row.projectId}, ${row.title}, ${row.model}, + ${row.modelSelection}, ${row.runtimeMode}, ${row.interactionMode}, ${row.branch}, @@ -54,6 +67,7 @@ const makeProjectionThreadRepository = Effect.gen(function* () { project_id = excluded.project_id, title = excluded.title, model = excluded.model, + model_selection = excluded.model_selection, runtime_mode = excluded.runtime_mode, interaction_mode = excluded.interaction_mode, branch = excluded.branch, @@ -68,7 +82,7 @@ const makeProjectionThreadRepository = Effect.gen(function* () { const getProjectionThreadRow = SqlSchema.findOneOption({ Request: GetProjectionThreadInput, - Result: ProjectionThread, + Result: ProjectionThreadDbRowSchema, execute: ({ threadId }) => sql` SELECT @@ -76,6 +90,7 @@ const makeProjectionThreadRepository = Effect.gen(function* () { project_id AS "projectId", title, model, + model_selection AS "modelSelection", runtime_mode AS "runtimeMode", interaction_mode AS "interactionMode", branch, @@ -92,7 +107,7 @@ const makeProjectionThreadRepository = Effect.gen(function* () { const listProjectionThreadRows = SqlSchema.findAll({ Request: ListProjectionThreadsByProjectInput, - Result: ProjectionThread, + Result: ProjectionThreadDbRowSchema, execute: ({ projectId }) => sql` SELECT @@ -100,6 +115,7 @@ const makeProjectionThreadRepository = Effect.gen(function* () { project_id AS "projectId", title, model, + model_selection AS "modelSelection", runtime_mode AS "runtimeMode", interaction_mode AS "interactionMode", branch, @@ -125,7 +141,10 @@ const makeProjectionThreadRepository = Effect.gen(function* () { }); const upsert: ProjectionThreadRepositoryShape["upsert"] = (row) => - upsertProjectionThreadRow(row).pipe( + upsertProjectionThreadRow({ + ...row, + modelSelection: row.modelSelection ?? null, + }).pipe( Effect.mapError(toPersistenceSqlError("ProjectionThreadRepository.upsert:query")), ); diff --git a/apps/server/src/persistence/Migrations/025_CanonicalizeModelSelections.ts b/apps/server/src/persistence/Migrations/025_CanonicalizeModelSelections.ts new file mode 100644 index 00000000..57fdd880 --- /dev/null +++ b/apps/server/src/persistence/Migrations/025_CanonicalizeModelSelections.ts @@ -0,0 +1,42 @@ +import { Effect } from "effect"; +import * as SqlClient from "effect/unstable/sql/SqlClient"; + +export default Effect.gen(function* () { + const sql = yield* SqlClient.SqlClient; + + yield* sql` + ALTER TABLE projection_projects + ADD COLUMN default_model_selection TEXT + `.pipe(Effect.catch(() => Effect.void)); + + yield* sql` + ALTER TABLE projection_threads + ADD COLUMN model_selection TEXT + `.pipe(Effect.catch(() => Effect.void)); + + yield* sql` + UPDATE projection_projects + SET default_model_selection = CASE + WHEN default_model IS NULL OR TRIM(default_model) = '' THEN NULL + WHEN default_model LIKE 'claude-%' THEN json_object('provider', 'claudeAgent', 'model', default_model) + WHEN default_model LIKE 'openclaw/%' THEN json_object('provider', 'openclaw', 'model', default_model) + WHEN default_model LIKE 'copilot/%' THEN json_object('provider', 'copilot', 'model', default_model) + WHEN default_model LIKE 'gemini-%' OR default_model LIKE 'auto-gemini-%' THEN json_object('provider', 'gemini', 'model', default_model) + ELSE json_object('provider', 'codex', 'model', default_model) + END + WHERE default_model_selection IS NULL + `; + + yield* sql` + UPDATE projection_threads + SET model_selection = CASE + WHEN model IS NULL OR TRIM(model) = '' THEN NULL + WHEN model LIKE 'claude-%' THEN json_object('provider', 'claudeAgent', 'model', model) + WHEN model LIKE 'openclaw/%' THEN json_object('provider', 'openclaw', 'model', model) + WHEN model LIKE 'copilot/%' THEN json_object('provider', 'copilot', 'model', model) + WHEN model LIKE 'gemini-%' OR model LIKE 'auto-gemini-%' THEN json_object('provider', 'gemini', 'model', model) + ELSE json_object('provider', 'codex', 'model', model) + END + WHERE model_selection IS NULL + `; +}); diff --git a/apps/server/src/persistence/Services/ProjectionProjects.ts b/apps/server/src/persistence/Services/ProjectionProjects.ts index 3ef1c7eb..af692cae 100644 --- a/apps/server/src/persistence/Services/ProjectionProjects.ts +++ b/apps/server/src/persistence/Services/ProjectionProjects.ts @@ -6,7 +6,7 @@ * * @module ProjectionProjectRepository */ -import { IsoDateTime, ProjectId, ProjectScript } from "@okcode/contracts"; +import { IsoDateTime, ModelSelection, ProjectId, ProjectScript } from "@okcode/contracts"; import { Option, Schema, ServiceMap } from "effect"; import type { Effect } from "effect"; @@ -17,6 +17,7 @@ export const ProjectionProject = Schema.Struct({ title: Schema.String, workspaceRoot: Schema.String, defaultModel: Schema.NullOr(Schema.String), + defaultModelSelection: Schema.optional(Schema.NullOr(ModelSelection)), scripts: Schema.Array(ProjectScript), createdAt: IsoDateTime, updatedAt: IsoDateTime, diff --git a/apps/server/src/persistence/Services/ProjectionThreads.ts b/apps/server/src/persistence/Services/ProjectionThreads.ts index 1b57abcf..4d1bef76 100644 --- a/apps/server/src/persistence/Services/ProjectionThreads.ts +++ b/apps/server/src/persistence/Services/ProjectionThreads.ts @@ -8,6 +8,7 @@ */ import { IsoDateTime, + ModelSelection, ProjectId, ProviderInteractionMode, RuntimeMode, @@ -24,6 +25,7 @@ export const ProjectionThread = Schema.Struct({ projectId: ProjectId, title: Schema.String, model: Schema.String, + modelSelection: Schema.optional(Schema.NullOr(ModelSelection)), runtimeMode: RuntimeMode, interactionMode: ProviderInteractionMode, branch: Schema.NullOr(Schema.String), diff --git a/apps/server/src/provider/Layers/GeminiAdapter.ts b/apps/server/src/provider/Layers/GeminiAdapter.ts new file mode 100644 index 00000000..65d0fb80 --- /dev/null +++ b/apps/server/src/provider/Layers/GeminiAdapter.ts @@ -0,0 +1,440 @@ +import crypto from "node:crypto"; +import { + EventId, + ProviderRuntimeEvent, + ProviderSession, + ProviderTurnStartResult, + RuntimeItemId, + TurnId, +} from "@okcode/contracts"; +import type { ProviderSendTurnInput, ThreadId } from "@okcode/contracts"; +import { Effect, Layer, Queue, Ref, Stream } from "effect"; +import { ChildProcess, ChildProcessSpawner } from "effect/unstable/process"; + +import { + ProviderAdapterProcessError, + ProviderAdapterRequestError, + ProviderAdapterSessionNotFoundError, +} from "../Errors.ts"; +import { GeminiAdapter, type GeminiAdapterShape } from "../Services/GeminiAdapter.ts"; + +type GeminiSessionContext = { + readonly session: ProviderSession; + readonly resumeId?: string | undefined; + readonly turns: ReadonlyArray<{ readonly id: TurnId; readonly items: ReadonlyArray }>; +}; + +type GeminiStreamEvent = + | { type: "init"; session_id?: string; model?: string } + | { type: "message"; role?: string; content?: string; delta?: boolean } + | { type: "tool_use"; tool_name?: string; tool_id?: string; parameters?: Record } + | { + type: "tool_result"; + tool_id?: string; + status?: "success" | "error"; + output?: string; + error?: { type?: string; message?: string }; + } + | { + type: "error"; + severity?: "warning" | "error"; + message?: string; + } + | { + type: "result"; + status?: "success" | "error"; + error?: { type?: string; message?: string }; + stats?: Record; + }; + +function nowIso(): string { + return new Date().toISOString(); +} + +function eventId(prefix: string): EventId { + return EventId.makeUnsafe(`${prefix}_${crypto.randomUUID()}`); +} + +function turnId(): TurnId { + return TurnId.makeUnsafe(`turn_${crypto.randomUUID()}`); +} + +function runtimeItemId(value: string): RuntimeItemId { + return RuntimeItemId.makeUnsafe(value); +} + +function decodeNdjson(stdout: string): ReadonlyArray { + return stdout + .split("\n") + .map((line) => line.trim()) + .filter(Boolean) + .flatMap((line) => { + try { + return [JSON.parse(line) as GeminiStreamEvent]; + } catch { + return []; + } + }); +} + +const makeGeminiAdapter = Effect.gen(function* () { + const runtimeEventQueue = yield* Queue.unbounded(); + const sessionsRef = yield* Ref.make(new Map()); + const spawner = yield* ChildProcessSpawner.ChildProcessSpawner; + + const emit = (event: ProviderRuntimeEvent) => + Queue.offer(runtimeEventQueue, event).pipe(Effect.asVoid); + + const getContext = (threadId: ThreadId) => + Ref.get(sessionsRef).pipe( + Effect.flatMap((sessions) => { + const context = sessions.get(threadId); + return context + ? Effect.succeed(context) + : Effect.fail( + new ProviderAdapterSessionNotFoundError({ + provider: "gemini", + threadId, + }), + ); + }), + ); + + const setContext = (threadId: ThreadId, context: GeminiSessionContext) => + Ref.update(sessionsRef, (sessions) => { + const next = new Map(sessions); + next.set(threadId, context); + return next; + }); + + const startSession: GeminiAdapterShape["startSession"] = (input) => + Effect.gen(function* () { + const session: ProviderSession = { + provider: "gemini", + status: "ready", + runtimeMode: input.runtimeMode, + ...(input.cwd ? { cwd: input.cwd } : {}), + ...(input.model ? { model: input.model } : {}), + ...(input.resumeCursor ? { resumeCursor: input.resumeCursor } : {}), + threadId: input.threadId, + createdAt: nowIso(), + updatedAt: nowIso(), + }; + yield* setContext(input.threadId, { + session, + resumeId: typeof input.resumeCursor === "string" ? input.resumeCursor : undefined, + turns: [], + }); + yield* emit({ + eventId: eventId("gemini_session_started"), + provider: "gemini", + type: "session.started", + threadId: input.threadId, + createdAt: nowIso(), + payload: typeof input.resumeCursor === "string" ? { resume: input.resumeCursor } : {}, + }); + yield* emit({ + eventId: eventId("gemini_session_state"), + provider: "gemini", + type: "session.state.changed", + threadId: input.threadId, + createdAt: nowIso(), + payload: { state: "ready" }, + }); + return session; + }); + + const sendTurn: GeminiAdapterShape["sendTurn"] = (input: ProviderSendTurnInput) => + Effect.gen(function* () { + const existing = yield* getContext(input.threadId); + const nextModel = input.model ?? existing.session.model ?? "auto-gemini-3"; + const currentTurnId = turnId(); + const prompt = input.input ?? ""; + const args = [ + "-p", + prompt, + "--output-format", + "stream-json", + "--model", + nextModel, + "--sandbox", + "--approval-mode", + existing.session.runtimeMode === "full-access" ? "yolo" : "suggest", + ]; + if (existing.resumeId) { + args.push("--resume", existing.resumeId); + } + + const nextSession: ProviderSession = { + ...existing.session, + status: "running", + model: nextModel, + activeTurnId: currentTurnId, + updatedAt: nowIso(), + }; + yield* setContext(input.threadId, { ...existing, session: nextSession }); + yield* emit({ + eventId: eventId("gemini_turn_started"), + provider: "gemini", + type: "turn.started", + threadId: input.threadId, + turnId: currentTurnId, + createdAt: nowIso(), + payload: { model: nextModel }, + }); + + const { stdout, stderr, exitCode } = yield* Effect.scoped( + Effect.gen(function* () { + const command = ChildProcess.make("gemini", args, { + shell: process.platform === "win32", + env: process.env, + }); + const child = yield* spawner.spawn(command); + const stdout = yield* Stream.runFold( + child.stdout, + () => "", + (acc, chunk) => acc + new TextDecoder().decode(chunk), + ); + const stderr = yield* Stream.runFold( + child.stderr, + () => "", + (acc, chunk) => acc + new TextDecoder().decode(chunk), + ); + const exitCode = Number(yield* child.exitCode); + return { stdout, stderr, exitCode }; + }), + ).pipe( + Effect.mapError( + (cause) => + new ProviderAdapterProcessError({ + provider: "gemini", + threadId: input.threadId, + detail: cause instanceof Error ? cause.message : String(cause), + }), + ), + ); + + const streamEvents = decodeNdjson(stdout); + let resumeId = existing.resumeId; + for (const streamEvent of streamEvents) { + if (streamEvent.type === "init" && typeof streamEvent.session_id === "string") { + resumeId = streamEvent.session_id; + continue; + } + if ( + streamEvent.type === "message" && + streamEvent.role === "assistant" && + typeof streamEvent.content === "string" + ) { + yield* emit({ + eventId: eventId("gemini_content_delta"), + provider: "gemini", + type: "content.delta", + threadId: input.threadId, + turnId: currentTurnId, + createdAt: nowIso(), + payload: { + streamKind: "assistant_text", + delta: streamEvent.content, + }, + }); + continue; + } + if (streamEvent.type === "tool_use" && streamEvent.tool_id) { + yield* emit({ + eventId: eventId("gemini_tool_started"), + provider: "gemini", + type: "item.started", + threadId: input.threadId, + turnId: currentTurnId, + itemId: runtimeItemId(streamEvent.tool_id), + createdAt: nowIso(), + payload: { + itemType: "dynamic_tool_call", + title: streamEvent.tool_name, + data: streamEvent.parameters, + }, + }); + continue; + } + if (streamEvent.type === "tool_result" && streamEvent.tool_id) { + yield* emit({ + eventId: eventId("gemini_tool_completed"), + provider: "gemini", + type: "item.completed", + threadId: input.threadId, + turnId: currentTurnId, + itemId: runtimeItemId(streamEvent.tool_id), + createdAt: nowIso(), + payload: { + itemType: "dynamic_tool_call", + status: streamEvent.status === "error" ? "failed" : "completed", + detail: streamEvent.output ?? streamEvent.error?.message, + }, + }); + continue; + } + if (streamEvent.type === "error" && streamEvent.message) { + yield* emit({ + eventId: eventId("gemini_runtime_warning"), + provider: "gemini", + type: streamEvent.severity === "error" ? "runtime.error" : "runtime.warning", + threadId: input.threadId, + turnId: currentTurnId, + createdAt: nowIso(), + payload: + streamEvent.severity === "error" + ? { message: streamEvent.message, class: "provider_error" } + : { message: streamEvent.message }, + } as ProviderRuntimeEvent); + } + } + + const resultEvent = streamEvents.toReversed().find((entry) => entry.type === "result"); + const failed = + exitCode !== 0 || + resultEvent?.status === "error" || + (!resultEvent && streamEvents.length === 0 && stderr.trim().length > 0); + + const completedSession: ProviderSession = { + ...nextSession, + status: failed ? "error" : "ready", + activeTurnId: undefined, + updatedAt: nowIso(), + ...(failed + ? { + lastError: + resultEvent?.error?.message ?? (stderr.trim() || "Gemini CLI turn failed."), + } + : {}), + ...(resumeId ? { resumeCursor: resumeId } : {}), + }; + yield* setContext(input.threadId, { + session: completedSession, + resumeId, + turns: [ + ...existing.turns, + { + id: currentTurnId, + items: [], + }, + ], + }); + + if (failed) { + const message = resultEvent?.error?.message ?? (stderr.trim() || "Gemini CLI turn failed."); + yield* emit({ + eventId: eventId("gemini_runtime_error"), + provider: "gemini", + type: "runtime.error", + threadId: input.threadId, + turnId: currentTurnId, + createdAt: nowIso(), + payload: { message, class: "provider_error" }, + }); + } + + yield* emit({ + eventId: eventId("gemini_turn_completed"), + provider: "gemini", + type: "turn.completed", + threadId: input.threadId, + turnId: currentTurnId, + createdAt: nowIso(), + payload: { + state: failed ? "failed" : "completed", + ...(failed ? { errorMessage: resultEvent?.error?.message ?? stderr.trim() } : {}), + ...(resultEvent?.stats ? { usage: resultEvent.stats } : {}), + }, + }); + + if (failed) { + return yield* new ProviderAdapterRequestError({ + provider: "gemini", + method: "gemini turn", + detail: + resultEvent?.error?.message ?? (stderr.trim() || "Gemini CLI exited with an error."), + }); + } + + return { + threadId: input.threadId, + turnId: currentTurnId, + ...(resumeId ? { resumeCursor: resumeId } : {}), + } satisfies ProviderTurnStartResult; + }); + + const interruptTurn: GeminiAdapterShape["interruptTurn"] = () => Effect.void; + const respondToRequest: GeminiAdapterShape["respondToRequest"] = ( + _threadId, + _requestId, + _decision, + ) => Effect.void; + const respondToUserInput: GeminiAdapterShape["respondToUserInput"] = ( + _threadId, + _requestId, + _answers, + ) => Effect.void; + + const stopSession: GeminiAdapterShape["stopSession"] = (threadId) => + Ref.update(sessionsRef, (sessions) => { + const next = new Map(sessions); + next.delete(threadId); + return next; + }); + + const listSessions: GeminiAdapterShape["listSessions"] = () => + Ref.get(sessionsRef).pipe( + Effect.map((sessions) => Array.from(sessions.values(), (entry) => entry.session)), + ); + + const hasSession: GeminiAdapterShape["hasSession"] = (threadId) => + Ref.get(sessionsRef).pipe(Effect.map((sessions) => sessions.has(threadId))); + + const readThread: GeminiAdapterShape["readThread"] = (threadId) => + getContext(threadId).pipe( + Effect.map((context) => ({ + threadId, + turns: context.turns, + })), + ); + + const rollbackThread: GeminiAdapterShape["rollbackThread"] = (threadId, numTurns) => + getContext(threadId).pipe( + Effect.flatMap((context) => { + const turns = + numTurns <= 0 + ? context.turns + : context.turns.slice(0, Math.max(0, context.turns.length - numTurns)); + return setContext(threadId, { ...context, turns }).pipe( + Effect.as({ + threadId, + turns, + }), + ); + }), + ); + + const stopAll: GeminiAdapterShape["stopAll"] = () => Ref.set(sessionsRef, new Map()); + + return { + provider: "gemini", + capabilities: { + sessionModelSwitch: "restart-session", + }, + startSession, + sendTurn, + interruptTurn, + respondToRequest, + respondToUserInput, + stopSession, + listSessions, + hasSession, + readThread, + rollbackThread, + stopAll, + streamEvents: Stream.fromQueue(runtimeEventQueue), + } satisfies GeminiAdapterShape; +}); + +export const GeminiAdapterLive = Layer.effect(GeminiAdapter, makeGeminiAdapter); diff --git a/apps/server/src/provider/Layers/ProviderAdapterRegistry.ts b/apps/server/src/provider/Layers/ProviderAdapterRegistry.ts index 257219a3..e15f88b6 100644 --- a/apps/server/src/provider/Layers/ProviderAdapterRegistry.ts +++ b/apps/server/src/provider/Layers/ProviderAdapterRegistry.ts @@ -8,6 +8,7 @@ * @module ProviderAdapterRegistryLive */ import { Effect, Layer } from "effect"; +import { Option } from "effect"; import { ProviderUnsupportedError, type ProviderAdapterError } from "../Errors.ts"; import type { ProviderAdapterShape } from "../Services/ProviderAdapter.ts"; @@ -18,6 +19,7 @@ import { import { ClaudeAdapter } from "../Services/ClaudeAdapter.ts"; import { CopilotAdapter } from "../Services/CopilotAdapter.ts"; import { CodexAdapter } from "../Services/CodexAdapter.ts"; +import { GeminiAdapter } from "../Services/GeminiAdapter.ts"; import { OpenClawAdapter } from "../Services/OpenClawAdapter.ts"; export interface ProviderAdapterRegistryLiveOptions { @@ -26,6 +28,7 @@ export interface ProviderAdapterRegistryLiveOptions { const makeProviderAdapterRegistry = (options?: ProviderAdapterRegistryLiveOptions) => Effect.gen(function* () { + const maybeGeminiAdapter = yield* Effect.serviceOption(GeminiAdapter); const adapters = options?.adapters !== undefined ? options.adapters @@ -34,6 +37,7 @@ const makeProviderAdapterRegistry = (options?: ProviderAdapterRegistryLiveOption yield* ClaudeAdapter, yield* OpenClawAdapter, yield* CopilotAdapter, + ...(Option.isSome(maybeGeminiAdapter) ? [maybeGeminiAdapter.value] : []), ]; const byProvider = new Map(adapters.map((adapter) => [adapter.provider, adapter])); diff --git a/apps/server/src/provider/Layers/ProviderHealth.ts b/apps/server/src/provider/Layers/ProviderHealth.ts index 59b06166..f5f7c399 100644 --- a/apps/server/src/provider/Layers/ProviderHealth.ts +++ b/apps/server/src/provider/Layers/ProviderHealth.ts @@ -10,6 +10,7 @@ import * as OS from "node:os"; import { CopilotClient } from "@github/copilot-sdk"; import type { + ServerProvider, ServerProviderAuthStatus, ServerProviderStatus, ServerProviderStatusState, @@ -25,12 +26,14 @@ import { isCodexCliVersionSupported, parseCodexCliVersion, } from "../codexCliVersion"; +import { withServerProviderModels } from "../providerCatalog.ts"; import { ProviderHealth, type ProviderHealthShape } from "../Services/ProviderHealth"; const DEFAULT_TIMEOUT_MS = 4_000; const CODEX_PROVIDER = "codex" as const; const CLAUDE_AGENT_PROVIDER = "claudeAgent" as const; const COPILOT_PROVIDER = "copilot" as const; +const GEMINI_PROVIDER = "gemini" as const; class OpenClawHealthProbeError extends Data.TaggedError("OpenClawHealthProbeError")<{ cause: unknown; @@ -44,6 +47,21 @@ function formatHealthProbeCause(cause: unknown): string { return cause instanceof Error ? cause.message : String(cause); } +function createServerProviderStatus( + input: Omit, +): ServerProviderStatus { + return withServerProviderModels({ + ...input, + available: (input.installed ?? false) && (input.enabled ?? true), + authStatus: input.auth?.status ?? "unknown", + }); +} + +function nonEmptyVersion(stdout: string, stderr: string): string | null { + const version = nonEmptyTrimmed(stdout) ?? nonEmptyTrimmed(stderr); + return version ?? null; +} + const OPENCLAW_HEALTH_REQUIRED_METHODS = [ "sessions.create", "sessions.get", @@ -130,6 +148,17 @@ function extractAuthString(value: unknown): string | undefined { return undefined; } +function hasGeminiHeadlessAuthEnv(): boolean { + if (nonEmptyTrimmed(process.env.GEMINI_API_KEY) || nonEmptyTrimmed(process.env.GOOGLE_API_KEY)) { + return true; + } + return Boolean( + nonEmptyTrimmed(process.env.GOOGLE_APPLICATION_CREDENTIALS) && + nonEmptyTrimmed(process.env.GOOGLE_CLOUD_PROJECT) && + nonEmptyTrimmed(process.env.GOOGLE_CLOUD_LOCATION), + ); +} + const CLAUDE_CLI_AUTH_METHODS = new Set(["claude.ai", "oauth"]); const CLAUDE_SUPPORTED_AUTH_METHODS = new Set(["apiKey", "authToken", "claude.ai", "oauth"]); const CLAUDE_AUTH_GUIDANCE = @@ -334,6 +363,28 @@ const runClaudeCommand = (args: ReadonlyArray) => return { stdout, stderr, code: exitCode } satisfies CommandResult; }).pipe(Effect.scoped); +const runGeminiCommand = (args: ReadonlyArray) => + Effect.gen(function* () { + const spawner = yield* ChildProcessSpawner.ChildProcessSpawner; + const command = ChildProcess.make("gemini", [...args], { + shell: process.platform === "win32", + env: process.env, + }); + + const child = yield* spawner.spawn(command); + + const [stdout, stderr, exitCode] = yield* Effect.all( + [ + collectStreamAsString(child.stdout), + collectStreamAsString(child.stderr), + child.exitCode.pipe(Effect.map(Number)), + ], + { concurrency: "unbounded" }, + ); + + return { stdout, stderr, code: exitCode } satisfies CommandResult; + }).pipe(Effect.scoped); + export const checkCopilotProviderStatus: Effect.Effect = Effect.gen(function* () { const checkedAt = new Date().toISOString(); @@ -346,17 +397,19 @@ export const checkCopilotProviderStatus: Effect.Effect undefined) .catch(() => undefined), ); - return { + return createServerProviderStatus({ provider: COPILOT_PROVIDER, + enabled: true, + installed: false, + version: null, status: "error" as const, - available: false, - authStatus: "unknown" as const, + auth: { status: "unknown" as const }, checkedAt, message: "GitHub Copilot CLI timed out while starting.", - } satisfies ServerProviderStatus; + }); } const authResult = yield* Effect.tryPromise({ @@ -389,44 +444,52 @@ export const checkCopilotProviderStatus: Effect.Effect = Effect.gen(function* () { + const checkedAt = new Date().toISOString(); + const versionProbe = yield* runGeminiCommand(["--version"]).pipe( + Effect.timeoutOption(DEFAULT_TIMEOUT_MS), + Effect.result, + ); + + if (Result.isFailure(versionProbe)) { + const error = versionProbe.failure; + return createServerProviderStatus({ + provider: GEMINI_PROVIDER, + enabled: true, + installed: false, + version: null, + status: "error", + auth: { status: "unknown" }, + checkedAt, + message: isCommandMissingCause(error) + ? "Gemini CLI (`gemini`) is not installed or not on PATH." + : `Failed to execute Gemini CLI health check: ${error instanceof Error ? error.message : String(error)}.`, + }); + } + + if (Option.isNone(versionProbe.success)) { + return createServerProviderStatus({ + provider: GEMINI_PROVIDER, + enabled: true, + installed: false, + version: null, + status: "error", + auth: { status: "unknown" }, + checkedAt, + message: "Gemini CLI is installed but failed to run. Timed out while running command.", + }); + } + + const version = versionProbe.success.value; + if (version.code !== 0) { + return createServerProviderStatus({ + provider: GEMINI_PROVIDER, + enabled: true, + installed: false, + version: nonEmptyVersion(version.stdout, version.stderr), + status: "error", + auth: { status: "unknown" }, + checkedAt, + message: detailFromResult(version) ?? "Gemini CLI is installed but failed to run.", + }); + } + + if (hasGeminiHeadlessAuthEnv()) { + return createServerProviderStatus({ + provider: GEMINI_PROVIDER, + enabled: true, + installed: true, + version: nonEmptyVersion(version.stdout, version.stderr), + status: "ready", + auth: { status: "authenticated", type: "headless", label: "Environment credentials" }, + checkedAt, + }); + } + + return createServerProviderStatus({ + provider: GEMINI_PROVIDER, + enabled: true, + installed: true, + version: nonEmptyVersion(version.stdout, version.stderr), + status: "warning", + auth: { status: "unknown" }, + checkedAt, + message: + "Gemini CLI is installed. Headless auth was not prevalidated; cached OAuth may still work locally, or configure GEMINI_API_KEY / Vertex credentials for non-interactive use.", + }); }); // ── Layer ─────────────────────────────────────────────────────────── @@ -932,6 +1112,7 @@ export const ProviderHealthLive = Layer.effect( checkClaudeProviderStatus, checkCopilotProviderStatus, checkOpenClawProviderStatus, + checkGeminiProviderStatus, ], { concurrency: "unbounded", diff --git a/apps/server/src/provider/Services/GeminiAdapter.ts b/apps/server/src/provider/Services/GeminiAdapter.ts new file mode 100644 index 00000000..b15ea18a --- /dev/null +++ b/apps/server/src/provider/Services/GeminiAdapter.ts @@ -0,0 +1,12 @@ +import { ServiceMap } from "effect"; + +import type { ProviderAdapterError } from "../Errors.ts"; +import type { ProviderAdapterShape } from "./ProviderAdapter.ts"; + +export interface GeminiAdapterShape extends ProviderAdapterShape { + readonly provider: "gemini"; +} + +export class GeminiAdapter extends ServiceMap.Service()( + "okcode/provider/Services/GeminiAdapter", +) {} diff --git a/apps/server/src/provider/providerCatalog.test.ts b/apps/server/src/provider/providerCatalog.test.ts new file mode 100644 index 00000000..7eac423e --- /dev/null +++ b/apps/server/src/provider/providerCatalog.test.ts @@ -0,0 +1,34 @@ +import { describe, expect, it } from "vitest"; + +import { createServerProviderModels } from "./providerCatalog"; + +describe("createServerProviderModels", () => { + it("includes Gemini built-ins in the server snapshot inventory", () => { + expect(createServerProviderModels("gemini").map((model) => model.slug)).toEqual([ + "auto-gemini-3", + "auto-gemini-2.5", + "gemini-2.5-pro", + "gemini-2.5-flash", + "gemini-3-pro-preview", + "gemini-3-flash-preview", + ]); + }); + + it("merges custom models without dropping built-in capabilities", () => { + const models = createServerProviderModels("codex", [ + { slug: "gpt-5.4", name: "Duplicate built-in" }, + { slug: "custom-codex-preview", name: "Custom Codex Preview" }, + ]); + + expect(models.find((model) => model.slug === "gpt-5.4")).toMatchObject({ + isCustom: false, + capabilities: expect.any(Object), + }); + expect(models.find((model) => model.slug === "custom-codex-preview")).toEqual({ + slug: "custom-codex-preview", + name: "Custom Codex Preview", + isCustom: true, + capabilities: null, + }); + }); +}); diff --git a/apps/server/src/provider/providerCatalog.ts b/apps/server/src/provider/providerCatalog.ts new file mode 100644 index 00000000..3b76bc5c --- /dev/null +++ b/apps/server/src/provider/providerCatalog.ts @@ -0,0 +1,166 @@ +import type { + ModelCapabilities, + ProviderKind, + ServerProvider, + ServerProviderModel, +} from "@okcode/contracts"; + +type ProviderCatalogEntry = { + readonly slug: string; + readonly name: string; + readonly capabilities?: ModelCapabilities | null | undefined; +}; + +const noCapabilities = null; + +export const BUILT_IN_PROVIDER_MODELS: Record> = { + codex: [ + { + slug: "gpt-5.4", + name: "GPT-5.4", + capabilities: { + reasoningEffortLevels: [ + { value: "low", label: "Low" }, + { value: "medium", label: "Medium" }, + { value: "high", label: "High", isDefault: true }, + { value: "xhigh", label: "Extra High" }, + ], + supportsFastMode: true, + supportsThinkingToggle: false, + contextWindowOptions: [], + promptInjectedEffortLevels: [], + }, + }, + { + slug: "gpt-5.4-mini", + name: "GPT-5.4 Mini", + capabilities: { + reasoningEffortLevels: [ + { value: "low", label: "Low" }, + { value: "medium", label: "Medium" }, + { value: "high", label: "High", isDefault: true }, + { value: "xhigh", label: "Extra High" }, + ], + supportsFastMode: true, + supportsThinkingToggle: false, + contextWindowOptions: [], + promptInjectedEffortLevels: [], + }, + }, + { slug: "gpt-5.3-codex", name: "GPT-5.3 Codex", capabilities: noCapabilities }, + { slug: "gpt-5.3-codex-spark", name: "GPT-5.3 Codex Spark", capabilities: noCapabilities }, + { slug: "gpt-5.2-codex", name: "GPT-5.2 Codex", capabilities: noCapabilities }, + { slug: "gpt-5.2", name: "GPT-5.2", capabilities: noCapabilities }, + ], + claudeAgent: [ + { + slug: "claude-opus-4-6", + name: "Claude Opus 4.6", + capabilities: { + reasoningEffortLevels: [ + { value: "low", label: "Low" }, + { value: "medium", label: "Medium" }, + { value: "high", label: "High", isDefault: true }, + { value: "max", label: "Max" }, + { value: "ultrathink", label: "Ultrathink" }, + ], + supportsFastMode: true, + supportsThinkingToggle: false, + contextWindowOptions: [], + promptInjectedEffortLevels: ["ultrathink"], + }, + }, + { + slug: "claude-sonnet-4-6", + name: "Claude Sonnet 4.6", + capabilities: { + reasoningEffortLevels: [ + { value: "low", label: "Low" }, + { value: "medium", label: "Medium" }, + { value: "high", label: "High", isDefault: true }, + { value: "ultrathink", label: "Ultrathink" }, + ], + supportsFastMode: false, + supportsThinkingToggle: false, + contextWindowOptions: [], + promptInjectedEffortLevels: ["ultrathink"], + }, + }, + { + slug: "claude-haiku-4-5", + name: "Claude Haiku 4.5", + capabilities: { + reasoningEffortLevels: [], + supportsFastMode: false, + supportsThinkingToggle: true, + contextWindowOptions: [], + promptInjectedEffortLevels: [], + }, + }, + ], + openclaw: [], + copilot: [ + { slug: "gpt-5.4", name: "GPT-5.4", capabilities: noCapabilities }, + { slug: "gpt-5.4-mini", name: "GPT-5.4 Mini", capabilities: noCapabilities }, + { slug: "gpt-5.3-codex", name: "GPT-5.3 Codex", capabilities: noCapabilities }, + { slug: "gpt-5.2-codex", name: "GPT-5.2 Codex", capabilities: noCapabilities }, + { slug: "gpt-5.2", name: "GPT-5.2", capabilities: noCapabilities }, + { slug: "gpt-5-mini", name: "GPT-5 Mini", capabilities: noCapabilities }, + { slug: "gpt-4.1", name: "GPT-4.1", capabilities: noCapabilities }, + { slug: "claude-sonnet-4-6", name: "Claude Sonnet 4.6", capabilities: noCapabilities }, + { slug: "claude-sonnet-4-5", name: "Claude Sonnet 4.5", capabilities: noCapabilities }, + { slug: "claude-haiku-4-5", name: "Claude Haiku 4.5", capabilities: noCapabilities }, + { slug: "claude-opus-4-6", name: "Claude Opus 4.6", capabilities: noCapabilities }, + { slug: "claude-opus-4-5", name: "Claude Opus 4.5", capabilities: noCapabilities }, + { slug: "gemini-3.1-pro", name: "Gemini 3.1 Pro", capabilities: noCapabilities }, + { slug: "gemini-2.5-pro", name: "Gemini 2.5 Pro", capabilities: noCapabilities }, + { slug: "grok-code-fast-1", name: "Grok Code Fast 1", capabilities: noCapabilities }, + ], + gemini: [ + { slug: "auto-gemini-3", name: "Auto (Gemini 3)", capabilities: noCapabilities }, + { slug: "auto-gemini-2.5", name: "Auto (Gemini 2.5)", capabilities: noCapabilities }, + { slug: "gemini-2.5-pro", name: "Gemini 2.5 Pro", capabilities: noCapabilities }, + { slug: "gemini-2.5-flash", name: "Gemini 2.5 Flash", capabilities: noCapabilities }, + { slug: "gemini-3-pro-preview", name: "Gemini 3 Pro Preview", capabilities: noCapabilities }, + { + slug: "gemini-3-flash-preview", + name: "Gemini 3 Flash Preview", + capabilities: noCapabilities, + }, + ], +}; + +export function createServerProviderModels( + provider: ProviderKind, + customModels: ReadonlyArray<{ slug: string; name?: string }> = [], +): ReadonlyArray { + const builtIns = BUILT_IN_PROVIDER_MODELS[provider].map((model) => ({ + slug: model.slug, + name: model.name, + isCustom: false, + capabilities: model.capabilities ?? null, + })) satisfies ReadonlyArray; + const seen = new Set(builtIns.map((model) => model.slug)); + const custom = customModels.flatMap((model) => { + if (!model.slug || seen.has(model.slug)) return []; + return [ + { + slug: model.slug, + name: model.name?.trim() || model.slug, + isCustom: true, + capabilities: null, + } satisfies ServerProviderModel, + ]; + }); + return [...builtIns, ...custom]; +} + +export function withServerProviderModels( + provider: Omit, + customModels?: ReadonlyArray<{ slug: string; name?: string }>, +): ServerProvider { + return { + ...provider, + models: createServerProviderModels(provider.provider, customModels), + }; +} diff --git a/apps/server/src/serverLayers.ts b/apps/server/src/serverLayers.ts index f4e5a8c5..43f4197f 100644 --- a/apps/server/src/serverLayers.ts +++ b/apps/server/src/serverLayers.ts @@ -1,5 +1,6 @@ import * as NodeServices from "@effect/platform-node/NodeServices"; import { Effect, FileSystem, Layer, Path } from "effect"; +import { ChildProcessSpawner } from "effect/unstable/process"; import * as SqlClient from "effect/unstable/sql/SqlClient"; import { CheckpointDiffQueryLive } from "./checkpointing/Layers/CheckpointDiffQuery"; @@ -22,6 +23,7 @@ import { ProviderUnsupportedError } from "./provider/Errors"; import { makeClaudeAdapterLive } from "./provider/Layers/ClaudeAdapter"; import { makeCopilotAdapterLive } from "./provider/Layers/CopilotAdapter"; import { makeCodexAdapterLive } from "./provider/Layers/CodexAdapter"; +import { GeminiAdapterLive } from "./provider/Layers/GeminiAdapter"; import { makeOpenClawAdapterLive } from "./provider/Layers/OpenClawAdapter"; import { ProviderHealthLive } from "./provider/Layers/ProviderHealth"; import { ProviderAdapterRegistryLive } from "./provider/Layers/ProviderAdapterRegistry"; @@ -73,7 +75,11 @@ const makeRuntimePtyAdapterLayer = () => export function makeServerProviderLayer(): Layer.Layer< ProviderService, ProviderUnsupportedError, - SqlClient.SqlClient | ServerConfig | FileSystem.FileSystem | ProviderRuntimeEventFeed + | SqlClient.SqlClient + | ServerConfig + | FileSystem.FileSystem + | ProviderRuntimeEventFeed + | ChildProcessSpawner.ChildProcessSpawner > { return Effect.gen(function* () { const { providerEventLogPath } = yield* ServerConfig; @@ -108,11 +114,13 @@ export function makeServerProviderLayer(): Layer.Layer< const copilotAdapterLayer = makeCopilotAdapterLive( nativeEventLogger ? { nativeEventLogger } : undefined, ); + const geminiAdapterLayer = GeminiAdapterLive; const adapterRegistryLayer = ProviderAdapterRegistryLive.pipe( Layer.provide(codexAdapterLayer), Layer.provide(claudeAdapterLayer), Layer.provide(openclawAdapterLayer), Layer.provide(copilotAdapterLayer), + Layer.provide(geminiAdapterLayer), Layer.provideMerge(providerSessionDirectoryLayer), ); return makeProviderServiceLive( diff --git a/apps/server/src/sme/authValidation.ts b/apps/server/src/sme/authValidation.ts index e476bd0c..9d5e5bcc 100644 --- a/apps/server/src/sme/authValidation.ts +++ b/apps/server/src/sme/authValidation.ts @@ -31,6 +31,8 @@ export function getAllowedSmeAuthMethods(provider: ProviderKind): readonly SmeAu return ["auto"]; case "codex": return ["auto", "chatgpt", "apiKey", "customProvider"]; + case "gemini": + return ["auto", "apiKey"]; case "openclaw": return ["auto", "password", "none"]; } @@ -44,6 +46,8 @@ export function getDefaultSmeAuthMethod(provider: ProviderKind): SmeAuthMethod { return "auto"; case "codex": return "chatgpt"; + case "gemini": + return "apiKey"; case "openclaw": return "password"; } diff --git a/apps/web/src/appSettings.ts b/apps/web/src/appSettings.ts index e52cf25c..9580e6a7 100644 --- a/apps/web/src/appSettings.ts +++ b/apps/web/src/appSettings.ts @@ -54,7 +54,8 @@ type CustomModelSettingsKey = | "customCodexModels" | "customClaudeModels" | "customOpenClawModels" - | "customCopilotModels"; + | "customCopilotModels" + | "customGeminiModels"; export type ProviderCustomModelConfig = { provider: ProviderKind; settingsKey: CustomModelSettingsKey; @@ -70,6 +71,7 @@ const BUILT_IN_MODEL_SLUGS_BY_PROVIDER: Record claudeAgent: new Set(getModelOptions("claudeAgent").map((option) => option.slug)), openclaw: new Set(getModelOptions("openclaw").map((option) => option.slug)), copilot: new Set(getModelOptions("copilot").map((option) => option.slug)), + gemini: new Set(getModelOptions("gemini").map((option) => option.slug)), }; const withDefaults = @@ -139,6 +141,7 @@ export const AppSettingsSchema = Schema.Struct({ customClaudeModels: Schema.Array(Schema.String).pipe(withDefaults(() => [])), customCopilotModels: Schema.Array(Schema.String).pipe(withDefaults(() => [])), customOpenClawModels: Schema.Array(Schema.String).pipe(withDefaults(() => [])), + customGeminiModels: Schema.Array(Schema.String).pipe(withDefaults(() => [])), openclawGatewayUrl: Schema.String.check(Schema.isMaxLength(4096)).pipe(withDefaults(() => "")), openclawPassword: Schema.String.check(Schema.isMaxLength(4096)).pipe(withDefaults(() => "")), textGenerationModel: Schema.optional(TrimmedNonEmptyString), @@ -188,6 +191,15 @@ const PROVIDER_CUSTOM_MODEL_CONFIG: Record getCustomModelsByProvider(settings), [settings]); - const selectedModel = useMemo(() => { - const draftModel = composerDraft.model; - if (!draftModel) { - return baseThreadModel; - } - return resolveAppModelSelection(selectedProvider, customModelsByProvider, draftModel); - }, [baseThreadModel, composerDraft.model, customModelsByProvider, selectedProvider]); - const draftModelOptions = composerDraft.modelOptions; + const providerModelsByProvider = useMemo( + () => + getProviderModelOptionsByProvider({ + providers: providerStatuses, + customModelsByProvider, + }), + [customModelsByProvider, providerStatuses], + ); + const baseModelSelection = + activeThread?.modelSelection ?? activeProject?.defaultModelSelection ?? null; + const baseThreadModel = + (baseModelSelection ? getModelSelectionModel(baseModelSelection) : null) ?? + activeThread?.model ?? + activeProject?.model ?? + getDefaultModel(selectedProvider); + const draftModelOptions = + composerDraft.modelOptions ?? + (baseModelSelection ? getModelSelectionOptions(baseModelSelection) : null); + const selectedModelSelection = useMemo( + () => + resolveLiveModelSelection({ + providerModelsByProvider, + fallbackProvider: selectedProvider, + preferredModelSelection: + composerDraft.provider || composerDraft.model || composerDraft.modelOptions + ? null + : baseModelSelection, + provider: selectedProvider, + model: composerDraft.model ?? baseThreadModel, + modelOptions: draftModelOptions, + }), + [ + baseModelSelection, + baseThreadModel, + composerDraft.model, + composerDraft.modelOptions, + composerDraft.provider, + draftModelOptions, + providerModelsByProvider, + selectedProvider, + ], + ); + const selectedModel = getModelSelectionModel(selectedModelSelection); const composerProviderState = useMemo( () => getComposerProviderState({ @@ -914,22 +938,26 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { ); const selectedPromptEffort = composerProviderState.promptEffort; const selectedModelOptionsForDispatch = composerProviderState.modelOptionsForDispatch; - const providerOptionsForDispatch = useMemo(() => getProviderStartOptions(settings), [settings]); - const selectedModelForPicker = selectedModel; - const modelOptionsByProvider = useMemo( - () => getCustomModelOptionsByProvider(settings), - [settings], + const selectedModelSelectionForDispatch = useMemo( + () => + resolveLiveModelSelection({ + providerModelsByProvider, + fallbackProvider: selectedProvider, + preferredModelSelection: selectedModelSelection, + modelOptions: selectedModelOptionsForDispatch, + }), + [ + providerModelsByProvider, + selectedModelOptionsForDispatch, + selectedModelSelection, + selectedProvider, + ], ); - const selectedModelForPickerWithCustomFallback = useMemo(() => { - const currentOptions = modelOptionsByProvider[selectedProvider]; - return currentOptions.some((option) => option.slug === selectedModelForPicker) - ? selectedModelForPicker - : (normalizeModelSlug(selectedModelForPicker, selectedProvider) ?? selectedModelForPicker); - }, [modelOptionsByProvider, selectedModelForPicker, selectedProvider]); + const providerOptionsForDispatch = useMemo(() => getProviderStartOptions(settings), [settings]); const searchableModelOptions = useMemo( () => (lockedProvider !== null ? [lockedProvider] : selectableProviders).flatMap((provider) => - modelOptionsByProvider[provider].map(({ slug, name }) => ({ + providerModelsByProvider[provider].map(({ slug, name }) => ({ provider, providerLabel: getThreadProviderLabel(provider), slug, @@ -939,7 +967,7 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { searchProvider: getThreadProviderLabel(provider).toLowerCase(), })), ), - [lockedProvider, modelOptionsByProvider, selectableProviders], + [lockedProvider, providerModelsByProvider, selectableProviders], ); const phase = derivePhase(activeThread?.session ?? null); const isSendBusy = sendPhase !== "idle"; @@ -2224,7 +2252,7 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { async (input: { threadId: ThreadId; createdAt: string; - model?: string; + modelSelection?: ModelSelection; runtimeMode: RuntimeMode; interactionMode: ProviderInteractionMode; }) => { @@ -2236,12 +2264,15 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { return; } - if (input.model !== undefined && input.model !== serverThread.model) { + if ( + input.modelSelection !== undefined && + !modelSelectionsAreEqual(input.modelSelection, serverThread.modelSelection ?? null) + ) { await api.orchestration.dispatchCommand({ type: "thread.meta.update", commandId: newCommandId(), threadId: input.threadId, - model: input.model, + modelSelection: input.modelSelection, }); } @@ -2486,7 +2517,7 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { queuedMessages.length, runtimeMode, scheduleStickToBottom, - selectedModelForPickerWithCustomFallback, + selectedModel, selectedProvider, sidebarProposedPlan, ]); @@ -2829,7 +2860,7 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { await persistThreadSettingsForNextTurn({ threadId: threadIdForSend, createdAt: messageCreatedAt, - ...(selectedModel ? { model: selectedModel } : {}), + modelSelection: selectedModelSelectionForDispatch, runtimeMode, interactionMode, }); @@ -2853,12 +2884,8 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { attachments: turnAttachments, }, ...(nextQueued.providerInput ? { providerInput: nextQueued.providerInput } : {}), - model: selectedModel || undefined, - ...(selectedModelOptionsForDispatch - ? { modelOptions: selectedModelOptionsForDispatch } - : {}), + modelSelection: selectedModelSelectionForDispatch, ...(providerOptionsForDispatch ? { providerOptions: providerOptionsForDispatch } : {}), - provider: selectedProvider, assistantDeliveryMode: settings.enableAssistantStreaming ? "streaming" : "buffered", runtimeMode, interactionMode, @@ -3669,8 +3696,7 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { } title = truncateTitle(titleSeed); } - let threadCreateModel: ModelSlug = - selectedModel || (activeProject.model as ModelSlug) || DEFAULT_MODEL_BY_PROVIDER.codex; + const threadCreateModelSelection = selectedModelSelectionForDispatch; if (isLocalDraftThread) { await api.orchestration.dispatchCommand({ @@ -3679,7 +3705,8 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { threadId: threadIdForSend, projectId: activeProject.id, title, - model: threadCreateModel, + model: selectedModel, + modelSelection: threadCreateModelSelection, runtimeMode, interactionMode, branch: nextThreadBranch, @@ -3728,7 +3755,7 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { await persistThreadSettingsForNextTurn({ threadId: threadIdForSend, createdAt: messageCreatedAt, - ...(selectedModel ? { model: selectedModel } : {}), + modelSelection: selectedModelSelectionForDispatch, runtimeMode, interactionMode, }); @@ -3747,12 +3774,8 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { attachments: turnAttachments, }, ...(hiddenProviderInput ? { providerInput: hiddenProviderInput } : {}), - model: selectedModel || undefined, - ...(selectedModelOptionsForDispatch - ? { modelOptions: selectedModelOptionsForDispatch } - : {}), + modelSelection: selectedModelSelectionForDispatch, ...(providerOptionsForDispatch ? { providerOptions: providerOptionsForDispatch } : {}), - provider: selectedProvider, assistantDeliveryMode: settings.enableAssistantStreaming ? "streaming" : "buffered", runtimeMode, interactionMode, @@ -4054,7 +4077,7 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { await persistThreadSettingsForNextTurn({ threadId: threadIdForSend, createdAt: messageCreatedAt, - ...(selectedModel ? { model: selectedModel } : {}), + modelSelection: selectedModelSelectionForDispatch, runtimeMode, interactionMode: nextInteractionMode, }); @@ -4073,11 +4096,7 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { text: outgoingMessageText, attachments: [], }, - provider: selectedProvider, - model: selectedModel || undefined, - ...(selectedModelOptionsForDispatch - ? { modelOptions: selectedModelOptionsForDispatch } - : {}), + modelSelection: selectedModelSelectionForDispatch, ...(providerOptionsForDispatch ? { providerOptions: providerOptionsForDispatch } : {}), assistantDeliveryMode: settings.enableAssistantStreaming ? "streaming" : "buffered", runtimeMode, @@ -4125,10 +4144,9 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { resetSendPhase, runtimeMode, selectedPromptEffort, - selectedModel, - selectedModelOptionsForDispatch, - providerOptionsForDispatch, selectedProvider, + selectedModelSelectionForDispatch, + providerOptionsForDispatch, setComposerDraftInteractionMode, setThreadError, settings.enableAssistantStreaming, @@ -4161,11 +4179,7 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { text: implementationPrompt, }); const nextThreadTitle = truncateTitle(buildPlanImplementationThreadTitle(planMarkdown)); - const nextThreadModel: ModelSlug = - selectedModel || - (activeThread.model as ModelSlug) || - (activeProject.model as ModelSlug) || - DEFAULT_MODEL_BY_PROVIDER.codex; + const nextThreadModelSelection = selectedModelSelectionForDispatch; sendInFlightRef.current = true; beginSendPhase("sending-turn"); @@ -4181,7 +4195,8 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { threadId: nextThreadId, projectId: activeProject.id, title: nextThreadTitle, - model: nextThreadModel, + model: selectedModel, + modelSelection: nextThreadModelSelection, runtimeMode, interactionMode: "code", branch: activeThread.branch, @@ -4199,11 +4214,7 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { text: outgoingImplementationPrompt, attachments: [], }, - provider: selectedProvider, - model: selectedModel || undefined, - ...(selectedModelOptionsForDispatch - ? { modelOptions: selectedModelOptionsForDispatch } - : {}), + modelSelection: selectedModelSelectionForDispatch, ...(providerOptionsForDispatch ? { providerOptions: providerOptionsForDispatch } : {}), assistantDeliveryMode: settings.enableAssistantStreaming ? "streaming" : "buffered", runtimeMode, @@ -4256,10 +4267,10 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { resetSendPhase, runtimeMode, selectedPromptEffort, + selectedProvider, selectedModel, - selectedModelOptionsForDispatch, + selectedModelSelectionForDispatch, providerOptionsForDispatch, - selectedProvider, settings.enableAssistantStreaming, syncServerReadModel, ]); @@ -4271,10 +4282,9 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { scheduleComposerFocus(); return; } - const resolvedModel = resolveAppModelSelection(provider, customModelsByProvider, model); setComposerDraftProvider(activeThread.id, provider); - setComposerDraftModel(activeThread.id, resolvedModel); - setStickyComposerModel(resolvedModel); + setComposerDraftModel(activeThread.id, model); + setStickyComposerModel(model); scheduleComposerFocus(); }, [ @@ -4284,7 +4294,6 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { setComposerDraftModel, setComposerDraftProvider, setStickyComposerModel, - customModelsByProvider, ], ); const setPromptFromTraits = useCallback( @@ -5334,10 +5343,14 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { + (lockedProvider !== null + ? [lockedProvider] + : selectableProviders + ).includes(provider.provider), + )} {...(composerProviderState.modelPickerIconClassName ? { activeProviderIconClassName: diff --git a/apps/web/src/components/chat/ProviderModelPicker.browser.tsx b/apps/web/src/components/chat/ProviderModelPicker.browser.tsx index 7b1dfcd5..f93c7c6e 100644 --- a/apps/web/src/components/chat/ProviderModelPicker.browser.tsx +++ b/apps/web/src/components/chat/ProviderModelPicker.browser.tsx @@ -1,29 +1,73 @@ -import { type ModelSlug, type ProviderKind } from "@okcode/contracts"; +import { type ModelSlug, type ProviderKind, type ServerProviderStatus } from "@okcode/contracts"; import { page } from "vitest/browser"; import { afterEach, describe, expect, it, vi } from "vitest"; import { render } from "vitest-browser-react"; import { ProviderModelPicker } from "./ProviderModelPicker"; -const MODEL_OPTIONS_BY_PROVIDER = { - claudeAgent: [ - { slug: "claude-opus-4-6", name: "Claude Opus 4.6" }, - { slug: "claude-sonnet-4-6", name: "Claude Sonnet 4.6" }, - { slug: "claude-haiku-4-5", name: "Claude Haiku 4.5" }, - ], - codex: [ - { slug: "gpt-5-codex", name: "GPT-5 Codex" }, - { slug: "gpt-5.3-codex", name: "GPT-5.3 Codex" }, - ], - copilot: [{ slug: "claude-sonnet-4-5", name: "Claude Sonnet 4.5" }], - openclaw: [], -} as const satisfies Record>; +const PROVIDERS = [ + { + provider: "codex", + status: "ready", + available: true, + enabled: true, + installed: true, + version: "1.0.0", + authStatus: "authenticated", + auth: { status: "authenticated" }, + checkedAt: "2026-04-13T00:00:00.000Z", + models: [ + { slug: "gpt-5-codex", name: "GPT-5 Codex", isCustom: false, capabilities: null }, + { slug: "gpt-5.3-codex", name: "GPT-5.3 Codex", isCustom: false, capabilities: null }, + ], + }, + { + provider: "claudeAgent", + status: "ready", + available: true, + enabled: true, + installed: true, + version: "1.0.0", + authStatus: "authenticated", + auth: { status: "authenticated" }, + checkedAt: "2026-04-13T00:00:00.000Z", + models: [ + { slug: "claude-opus-4-6", name: "Claude Opus 4.6", isCustom: false, capabilities: null }, + { + slug: "claude-sonnet-4-6", + name: "Claude Sonnet 4.6", + isCustom: false, + capabilities: null, + }, + ], + }, + { + provider: "gemini", + status: "ready", + available: true, + enabled: true, + installed: true, + version: "1.0.0", + authStatus: "authenticated", + auth: { status: "authenticated" }, + checkedAt: "2026-04-13T00:00:00.000Z", + models: [ + { slug: "auto-gemini-3", name: "Auto Gemini 3", isCustom: false, capabilities: null }, + { + slug: "gemini-2.5-pro", + name: "Gemini 2.5 Pro", + isCustom: false, + capabilities: null, + }, + ], + }, +] as const satisfies ReadonlyArray; async function mountPicker(props: { provider: ProviderKind; model: ModelSlug; lockedProvider: ProviderKind | null; - availableProviders?: ReadonlyArray; + providers?: ReadonlyArray; }) { const host = document.createElement("div"); document.body.append(host); @@ -33,8 +77,7 @@ async function mountPicker(props: { provider={props.provider} model={props.model} lockedProvider={props.lockedProvider} - availableProviders={props.availableProviders ?? ["codex", "claudeAgent", "openclaw"]} - modelOptionsByProvider={MODEL_OPTIONS_BY_PROVIDER} + providers={props.providers ?? PROVIDERS} onProviderModelChange={onProviderModelChange} />, { container: host }, @@ -54,12 +97,11 @@ describe("ProviderModelPicker", () => { document.body.innerHTML = ""; }); - it("shows both Codex and Claude Code model groups when provider switching is allowed", async () => { + it("renders live provider snapshots including Gemini", async () => { const mounted = await mountPicker({ provider: "claudeAgent", model: "claude-opus-4-6", lockedProvider: null, - availableProviders: ["codex", "claudeAgent"], }); try { @@ -69,20 +111,19 @@ describe("ProviderModelPicker", () => { const text = document.body.textContent ?? ""; expect(text).toContain("Codex"); expect(text).toContain("Claude Code"); - expect(text).toContain("GPT-5 Codex"); - expect(text).toContain("Claude Sonnet 4.6"); - expect(text).toContain("Claude Haiku 4.5"); + expect(text).toContain("Gemini"); + expect(text).toContain("Auto Gemini 3"); }); } finally { await mounted.cleanup(); } }); - it("shows models directly when the provider is locked mid-thread", async () => { + it("shows only the locked provider group mid-thread", async () => { const mounted = await mountPicker({ - provider: "claudeAgent", - model: "claude-opus-4-6", - lockedProvider: "claudeAgent", + provider: "gemini", + model: "auto-gemini-3", + lockedProvider: "gemini", }); try { @@ -90,52 +131,26 @@ describe("ProviderModelPicker", () => { await vi.waitFor(() => { const text = document.body.textContent ?? ""; - expect(text).toContain("Claude Sonnet 4.6"); - expect(text).toContain("Claude Haiku 4.5"); - expect(text).not.toContain("Codex"); + expect(text).toContain("Gemini 2.5 Pro"); + expect(text).not.toContain("GPT-5 Codex"); }); } finally { await mounted.cleanup(); } }); - it("dispatches the canonical slug when a model is selected", async () => { - const mounted = await mountPicker({ - provider: "claudeAgent", - model: "claude-opus-4-6", - lockedProvider: "claudeAgent", - }); - - try { - await page.getByRole("button").click(); - await page.getByRole("menuitemradio", { name: "Claude Sonnet 4.6" }).click(); - - expect(mounted.onProviderModelChange).toHaveBeenCalledWith( - "claudeAgent", - "claude-sonnet-4-6", - ); - } finally { - await mounted.cleanup(); - } - }); - - it("only shows authenticated providers when switching is allowed", async () => { + it("dispatches the selected provider/model pair", async () => { const mounted = await mountPicker({ - provider: "codex", - model: "gpt-5-codex", - lockedProvider: null, - availableProviders: ["codex"], + provider: "gemini", + model: "auto-gemini-3", + lockedProvider: "gemini", }); try { await page.getByRole("button").click(); + await page.getByRole("menuitemradio", { name: "Gemini 2.5 Pro" }).click(); - await vi.waitFor(() => { - const text = document.body.textContent ?? ""; - expect(text).toContain("Codex"); - expect(text).toContain("GPT-5.3 Codex"); - expect(text).not.toContain("Claude Code"); - }); + expect(mounted.onProviderModelChange).toHaveBeenCalledWith("gemini", "gemini-2.5-pro"); } finally { await mounted.cleanup(); } diff --git a/apps/web/src/components/chat/ProviderModelPicker.tsx b/apps/web/src/components/chat/ProviderModelPicker.tsx index 0f030a8f..d627b1bc 100644 --- a/apps/web/src/components/chat/ProviderModelPicker.tsx +++ b/apps/web/src/components/chat/ProviderModelPicker.tsx @@ -1,52 +1,33 @@ -import { type ModelSlug, type ProviderKind } from "@okcode/contracts"; -import { resolveSelectableModel } from "@okcode/shared/model"; +import { type ModelSlug, type ProviderKind, type ServerProviderStatus } from "@okcode/contracts"; import { memo, useState } from "react"; -import { type ProviderPickerKind, PROVIDER_OPTIONS } from "../../session-logic"; import { ChevronDownIcon } from "lucide-react"; + +import { cn } from "~/lib/utils"; +import { getThreadProviderLabel } from "~/lib/providerAvailability"; +import { ClaudeAI, Gemini, GitHubIcon, type Icon, OpenAI, OpenClawIcon } from "../Icons"; import { Button } from "../ui/button"; import { Menu, MenuGroup, MenuGroupLabel, - MenuItem, MenuPopup, MenuRadioGroup, MenuRadioItem, MenuSeparator as MenuDivider, MenuTrigger, } from "../ui/menu"; -import { - ClaudeAI, - CursorIcon, - Gemini, - GitHubIcon, - Icon, - OpenAI, - OpenClawIcon, - OpenCodeIcon, -} from "../Icons"; -import { cn } from "~/lib/utils"; -import { getThreadProviderLabel } from "~/lib/providerAvailability"; -const PROVIDER_ICON_BY_PROVIDER: Record = { +const PROVIDER_ICON_BY_PROVIDER: Record = { codex: OpenAI, claudeAgent: ClaudeAI, - openclaw: OpenClawIcon, + gemini: Gemini, copilot: GitHubIcon, - cursor: CursorIcon, + openclaw: OpenClawIcon, }; -const UNAVAILABLE_PROVIDER_OPTIONS = PROVIDER_OPTIONS.filter((option) => !option.available); -const COMING_SOON_PROVIDER_OPTIONS = [ - { id: "opencode", label: "OpenCode", icon: OpenCodeIcon }, - { id: "gemini", label: "Gemini", icon: Gemini }, -] as const; - -function providerIconClassName( - provider: ProviderKind | ProviderPickerKind, - fallbackClassName: string, -): string { +function providerIconClassName(provider: ProviderKind, fallbackClassName: string): string { if (provider === "claudeAgent") return "text-[#d97757]"; + if (provider === "gemini") return "text-[#78c2ff]"; if (provider === "openclaw") return "text-[#6cb4ee]"; if (provider === "copilot") return "text-white/85"; return fallbackClassName; @@ -56,12 +37,18 @@ function getProviderLabel(provider: ProviderKind): string { return getThreadProviderLabel(provider); } +function getProviderSnapshot( + providers: ReadonlyArray, + provider: ProviderKind, +): ServerProviderStatus | null { + return providers.find((entry) => entry.provider === provider) ?? null; +} + export const ProviderModelPicker = memo(function ProviderModelPicker(props: { provider: ProviderKind; model: ModelSlug; lockedProvider: ProviderKind | null; - availableProviders: ReadonlyArray; - modelOptionsByProvider: Record>; + providers: ReadonlyArray; activeProviderIconClassName?: string; compact?: boolean; disabled?: boolean; @@ -70,21 +57,26 @@ export const ProviderModelPicker = memo(function ProviderModelPicker(props: { const [isMenuOpen, setIsMenuOpen] = useState(false); const activeProvider = props.lockedProvider ?? props.provider; const visibleProviders = - props.lockedProvider !== null ? [props.lockedProvider] : props.availableProviders; - const selectedProviderOptions = props.modelOptionsByProvider[activeProvider]; + props.lockedProvider !== null + ? [props.lockedProvider] + : props.providers.map((provider) => provider.provider); + const activeProviderSnapshot = getProviderSnapshot(props.providers, activeProvider); const selectedModelLabel = - selectedProviderOptions.find((option) => option.slug === props.model)?.name ?? props.model; + activeProviderSnapshot?.models?.find((option) => option.slug === props.model)?.name ?? + props.model; const ProviderIcon = PROVIDER_ICON_BY_PROVIDER[activeProvider]; + const handleModelChange = (provider: ProviderKind, value: string) => { - if (props.disabled) return; - if (!value) return; - const resolvedModel = resolveSelectableModel( - provider, - value, - props.modelOptionsByProvider[provider], + if (props.disabled || !value) { + return; + } + const option = getProviderSnapshot(props.providers, provider)?.models?.find( + (model) => model.slug === value, ); - if (!resolvedModel) return; - props.onProviderModelChange(provider, resolvedModel); + if (!option) { + return; + } + props.onProviderModelChange(provider, option.slug); setIsMenuOpen(false); }; @@ -131,95 +123,46 @@ export const ProviderModelPicker = memo(function ProviderModelPicker(props: { - {props.lockedProvider !== null ? ( - - - {getProviderLabel(props.lockedProvider)} · locked for this thread - - handleModelChange(props.lockedProvider!, value)} - > - {props.modelOptionsByProvider[props.lockedProvider].map((modelOption) => ( - setIsMenuOpen(false)} - > - {modelOption.name} - - ))} - - - ) : ( - <> - {visibleProviders.map((provider, index) => { - const option = { - value: provider, - label: getThreadProviderLabel(provider), - }; - const OptionIcon = PROVIDER_ICON_BY_PROVIDER[option.value]; - return ( - - {index > 0 ? : null} - - - handleModelChange(option.value, value)} + {visibleProviders.map((provider, index) => { + const providerSnapshot = getProviderSnapshot(props.providers, provider); + const ProviderOptionIcon = PROVIDER_ICON_BY_PROVIDER[provider]; + if (!providerSnapshot?.models || providerSnapshot.models.length === 0) { + return null; + } + + return ( + + {index > 0 ? : null} + + + handleModelChange(provider, value)} + > + {providerSnapshot.models.map((modelOption) => ( + setIsMenuOpen(false)} > - {props.modelOptionsByProvider[option.value].map((modelOption) => ( - setIsMenuOpen(false)} - > - {modelOption.name} - - ))} - - - ); - })} - {UNAVAILABLE_PROVIDER_OPTIONS.length > 0 && } - {UNAVAILABLE_PROVIDER_OPTIONS.map((option) => { - const OptionIcon = PROVIDER_ICON_BY_PROVIDER[option.value]; - return ( - - - ); - })} - {UNAVAILABLE_PROVIDER_OPTIONS.length === 0 && } - {COMING_SOON_PROVIDER_OPTIONS.map((option) => { - const OptionIcon = option.icon; - return ( - - - ); - })} - - )} + {modelOption.name} + + ))} + + + ); + })} ); diff --git a/apps/web/src/components/chat/composerProviderRegistry.tsx b/apps/web/src/components/chat/composerProviderRegistry.tsx index 4c5472b5..c4c3e263 100644 --- a/apps/web/src/components/chat/composerProviderRegistry.tsx +++ b/apps/web/src/components/chat/composerProviderRegistry.tsx @@ -139,6 +139,15 @@ const composerProviderRegistry: Record = { renderTraitsMenuContent: () => null, renderTraitsPicker: () => null, }, + gemini: { + getState: () => ({ + provider: "gemini", + promptEffort: null, + modelOptionsForDispatch: undefined, + }), + renderTraitsMenuContent: () => null, + renderTraitsPicker: () => null, + }, }; export function getComposerProviderState(input: ComposerProviderStateInput): ComposerProviderState { diff --git a/apps/web/src/components/chat/providerStatusPresentation.ts b/apps/web/src/components/chat/providerStatusPresentation.ts index 0b03ea75..624bc9cb 100644 --- a/apps/web/src/components/chat/providerStatusPresentation.ts +++ b/apps/web/src/components/chat/providerStatusPresentation.ts @@ -8,6 +8,7 @@ const PROVIDER_LABELS = { claudeAgent: "Claude Code", openclaw: "OpenClaw", copilot: "GitHub Copilot", + gemini: "Gemini CLI", } as const; export function getProviderLabel(provider: ServerProviderStatus["provider"]): string { @@ -15,10 +16,11 @@ export function getProviderLabel(provider: ServerProviderStatus["provider"]): st } export function getProviderSetupPhase(status: ServerProviderStatus): ProviderSetupPhase { + const authStatus = status.authStatus ?? status.auth?.status; if (!status.available) { return "install"; } - if (status.authStatus === "unauthenticated") { + if (authStatus === "unauthenticated") { return "authenticate"; } if (status.status === "ready") { diff --git a/apps/web/src/components/sme/SmeConversationDialog.tsx b/apps/web/src/components/sme/SmeConversationDialog.tsx index 22f724fa..17cf0a98 100644 --- a/apps/web/src/components/sme/SmeConversationDialog.tsx +++ b/apps/web/src/components/sme/SmeConversationDialog.tsx @@ -98,6 +98,7 @@ export function SmeConversationDialog({ claudeAgent: settings.customClaudeModels, copilot: settings.customCopilotModels, openclaw: settings.customOpenClawModels, + gemini: settings.customGeminiModels, }, null, ); @@ -114,6 +115,7 @@ export function SmeConversationDialog({ settings.customClaudeModels, settings.customCopilotModels, settings.customCodexModels, + settings.customGeminiModels, settings.customOpenClawModels, ]); @@ -153,6 +155,7 @@ export function SmeConversationDialog({ claudeAgent: settings.customClaudeModels, copilot: settings.customCopilotModels, openclaw: settings.customOpenClawModels, + gemini: settings.customGeminiModels, }, nextProvider === "openclaw" ? "default" : null, ), diff --git a/apps/web/src/components/sme/smeConversationConfig.ts b/apps/web/src/components/sme/smeConversationConfig.ts index 62af7df1..acb49fb2 100644 --- a/apps/web/src/components/sme/smeConversationConfig.ts +++ b/apps/web/src/components/sme/smeConversationConfig.ts @@ -5,6 +5,7 @@ export const SME_PROVIDER_LABELS: Record = { claudeAgent: "Claude Code", copilot: "GitHub Copilot", openclaw: "OpenClaw", + gemini: "Gemini CLI", }; export function getDefaultSmeAuthMethod(provider: ProviderKind): SmeAuthMethod { @@ -17,6 +18,8 @@ export function getDefaultSmeAuthMethod(provider: ProviderKind): SmeAuthMethod { return "chatgpt"; case "openclaw": return "password"; + case "gemini": + return "apiKey"; } } @@ -45,6 +48,11 @@ export function getSmeAuthMethodOptions( { value: "none", label: "Device Token Only" }, { value: "auto", label: "Auto (prefer shared secret)" }, ]; + case "gemini": + return [ + { value: "apiKey", label: "API Key" }, + { value: "auto", label: "Auto" }, + ]; } } diff --git a/apps/web/src/lib/providerAvailability.ts b/apps/web/src/lib/providerAvailability.ts index 8630277c..c0d86da8 100644 --- a/apps/web/src/lib/providerAvailability.ts +++ b/apps/web/src/lib/providerAvailability.ts @@ -3,6 +3,7 @@ import type { ProviderKind, ServerProviderStatus } from "@okcode/contracts"; const THREAD_PROVIDER_ORDER: readonly ProviderKind[] = [ "codex", "claudeAgent", + "gemini", "copilot", "openclaw", ]; @@ -10,6 +11,7 @@ const THREAD_PROVIDER_ORDER: readonly ProviderKind[] = [ const THREAD_PROVIDER_LABELS: Record = { codex: "Codex", claudeAgent: "Claude Code", + gemini: "Gemini", copilot: "GitHub Copilot", openclaw: "OpenClaw", }; @@ -48,12 +50,13 @@ export function isProviderReadyForThreadSelection(input: { input.provider === "claudeAgent" && (input.claudeAuthTokenHelperCommand ?? "").trim().length > 0 && status.available && - status.authStatus === "unauthenticated" + (status.authStatus ?? status.auth?.status) === "unauthenticated" ) { return true; } - return status.available && status.status === "ready" && status.authStatus !== "unauthenticated"; + const authStatus = status.authStatus ?? status.auth?.status; + return Boolean(status.available && status.status === "ready" && authStatus !== "unauthenticated"); } export function getSelectableThreadProviders(input: { diff --git a/apps/web/src/modelSelection.ts b/apps/web/src/modelSelection.ts new file mode 100644 index 00000000..21539c9f --- /dev/null +++ b/apps/web/src/modelSelection.ts @@ -0,0 +1,42 @@ +import type { ModelSelection, ProviderKind, ProviderModelOptions } from "@okcode/contracts"; +import { + getModelSelectionModel, + getModelSelectionOptions, + getModelSelectionProvider, + normalizeModelSelectionWithCapabilities, + toCanonicalModelSelection, +} from "@okcode/shared/modelSelection"; + +import { getProviderDefaultModel, type ProviderModelOption } from "./providerModels"; + +export function resolveLiveModelSelection(input: { + providerModelsByProvider: Record>; + fallbackProvider: ProviderKind; + preferredModelSelection?: ModelSelection | null | undefined; + provider?: ProviderKind | null | undefined; + model?: string | null | undefined; + modelOptions?: ProviderModelOptions | null | undefined; +}): ModelSelection { + const draftProvider = input.provider ?? input.fallbackProvider; + const baseSelection = + input.preferredModelSelection ?? + toCanonicalModelSelection( + draftProvider, + input.model ?? getProviderDefaultModel(draftProvider, input.providerModelsByProvider), + input.modelOptions ?? undefined, + ); + const provider = getModelSelectionProvider(baseSelection); + const providerModels = input.providerModelsByProvider[provider]; + const resolvedModel = providerModels.some( + (entry) => entry.slug === getModelSelectionModel(baseSelection), + ) + ? getModelSelectionModel(baseSelection) + : getProviderDefaultModel(provider, input.providerModelsByProvider); + const capabilities = + providerModels.find((entry) => entry.slug === resolvedModel)?.capabilities ?? null; + + return normalizeModelSelectionWithCapabilities( + toCanonicalModelSelection(provider, resolvedModel, getModelSelectionOptions(baseSelection)), + capabilities ? [{ slug: resolvedModel, capabilities }] : [], + ); +} diff --git a/apps/web/src/providerModels.ts b/apps/web/src/providerModels.ts new file mode 100644 index 00000000..48e07bcd --- /dev/null +++ b/apps/web/src/providerModels.ts @@ -0,0 +1,68 @@ +import type { ModelCapabilities, ProviderKind, ServerProviderStatus } from "@okcode/contracts"; +import { getDefaultModel } from "@okcode/shared/model"; + +import { normalizeCustomModelSlugs, type AppModelOption } from "./appSettings"; + +export type ProviderModelOption = AppModelOption & { + capabilities?: ModelCapabilities | null | undefined; +}; + +const PROVIDER_KINDS: readonly ProviderKind[] = [ + "codex", + "claudeAgent", + "copilot", + "openclaw", + "gemini", +]; + +export function getProviderSnapshot( + providers: ReadonlyArray, + provider: ProviderKind, +): ServerProviderStatus | null { + return providers.find((entry) => entry.provider === provider) ?? null; +} + +export function getProviderModelOptionsByProvider(input: { + providers: ReadonlyArray; + customModelsByProvider: Record; +}): Record> { + return PROVIDER_KINDS.reduce( + (acc, provider) => { + const snapshotModels = getProviderSnapshot(input.providers, provider)?.models ?? []; + const options: ProviderModelOption[] = snapshotModels.map((model) => ({ + slug: model.slug, + name: model.name, + isCustom: model.isCustom, + capabilities: model.capabilities, + })); + const seen = new Set(options.map((model) => model.slug)); + + for (const slug of normalizeCustomModelSlugs( + input.customModelsByProvider[provider], + provider, + )) { + if (seen.has(slug)) { + continue; + } + seen.add(slug); + options.push({ + slug, + name: slug, + isCustom: true, + capabilities: null, + }); + } + + acc[provider] = options; + return acc; + }, + {} as Record>, + ); +} + +export function getProviderDefaultModel( + provider: ProviderKind, + providerModelsByProvider: Record>, +): string { + return providerModelsByProvider[provider][0]?.slug ?? getDefaultModel(provider); +} diff --git a/apps/web/src/routes/_chat.settings.index.tsx b/apps/web/src/routes/_chat.settings.index.tsx index c2ceb4b1..79f3008e 100644 --- a/apps/web/src/routes/_chat.settings.index.tsx +++ b/apps/web/src/routes/_chat.settings.index.tsx @@ -306,6 +306,12 @@ const PROVIDER_AUTH_GUIDES: Record< verifyCmd: "copilot auth status", note: "GitHub Copilot must be installed and signed in before it appears in the thread picker.", }, + gemini: { + installCmd: "npm install -g @google/gemini-cli", + authCmd: "set GEMINI_API_KEY or GOOGLE_API_KEY", + verifyCmd: "gemini --version", + note: "Gemini CLI appears in the thread picker when the binary is installed and headless auth is available or locally cached.", + }, openclaw: { verifyCmd: "Test Connection", note: "OpenClaw uses the gateway URL and shared secret below rather than a local CLI login. Shared-secret auth usually works without device pairing and is the recommended default for Tailscale and remote gateways.", @@ -491,6 +497,7 @@ function SettingsRouteView() { const [openInstallProviders, setOpenInstallProviders] = useState>({ codex: Boolean(settings.codexBinaryPath || settings.codexHomePath), claudeAgent: Boolean(settings.claudeBinaryPath), + gemini: false, copilot: Boolean(settings.copilotBinaryPath || settings.copilotConfigDir), openclaw: Boolean(settings.openclawGatewayUrl || settings.openclawPassword), }); @@ -501,6 +508,7 @@ function SettingsRouteView() { >({ codex: "", claudeAgent: "", + gemini: "", copilot: "", openclaw: "", }); @@ -1427,6 +1435,7 @@ function SettingsRouteView() { setOpenInstallProviders({ codex: false, claudeAgent: false, + gemini: false, copilot: false, openclaw: false, }); diff --git a/apps/web/src/routes/_chat.settings.tsx b/apps/web/src/routes/_chat.settings.tsx index 1bc71eca..5bec7874 100644 --- a/apps/web/src/routes/_chat.settings.tsx +++ b/apps/web/src/routes/_chat.settings.tsx @@ -384,6 +384,12 @@ const PROVIDER_AUTH_GUIDES: Record< verifyCmd: "claude auth status", note: "Claude Code must be installed and configured with an Anthropic API key or auth token before it appears in the thread picker. Use Environment to add a Claude auth token in one click, or configure a helper command in the Claude install panel.", }, + gemini: { + installCmd: "npm install -g @google/gemini-cli", + authCmd: "set GEMINI_API_KEY or GOOGLE_API_KEY", + verifyCmd: "gemini --version", + note: "Gemini CLI appears in the thread picker when the binary is installed and headless auth is available or locally cached.", + }, openclaw: { verifyCmd: "Test Connection", note: "OpenClaw uses the gateway URL and shared secret below rather than a local CLI login. Shared-secret auth usually works without device pairing and is the recommended default for Tailscale and remote gateways.", @@ -821,6 +827,7 @@ function SettingsRouteView() { const [openInstallProviders, setOpenInstallProviders] = useState>({ codex: Boolean(settings.codexBinaryPath || settings.codexHomePath), claudeAgent: Boolean(settings.claudeBinaryPath || settings.claudeAuthTokenHelperCommand), + gemini: false, openclaw: Boolean(settings.openclawGatewayUrl || settings.openclawPassword), copilot: Boolean(settings.copilotBinaryPath || settings.copilotConfigDir), }); @@ -831,6 +838,7 @@ function SettingsRouteView() { >({ codex: "", claudeAgent: "", + gemini: "", openclaw: "", copilot: "", }); @@ -1311,6 +1319,7 @@ function SettingsRouteView() { setOpenInstallProviders({ codex: false, claudeAgent: false, + gemini: false, openclaw: false, copilot: false, }); @@ -1318,6 +1327,7 @@ function SettingsRouteView() { setCustomModelInputByProvider({ codex: "", claudeAgent: "", + gemini: "", openclaw: "", copilot: "", }); @@ -2641,6 +2651,7 @@ function SettingsRouteView() { setOpenInstallProviders({ codex: false, claudeAgent: false, + gemini: false, openclaw: false, copilot: false, }); diff --git a/apps/web/src/store.ts b/apps/web/src/store.ts index 56973782..38539eec 100644 --- a/apps/web/src/store.ts +++ b/apps/web/src/store.ts @@ -11,6 +11,7 @@ import { resolveModelSlug, resolveModelSlugForProvider, } from "@okcode/shared/model"; +import { getModelSelectionModel, getModelSelectionProvider } from "@okcode/shared/modelSelection"; import { create } from "zustand"; import { type ChatMessage, type Project, type Thread } from "./types"; import { Debouncer } from "@tanstack/react-pacer"; @@ -187,7 +188,12 @@ function mapProjectsFromReadModel( cwd: project.workspaceRoot, model: existing?.model ?? - resolveModelSlug(project.defaultModel ?? DEFAULT_MODEL_BY_PROVIDER.codex), + resolveModelSlug( + project.defaultModelSelection + ? getModelSelectionModel(project.defaultModelSelection) + : (project.defaultModel ?? DEFAULT_MODEL_BY_PROVIDER.codex), + ), + defaultModelSelection: project.defaultModelSelection ?? null, expanded: resolveProjectExpandedState({ existingExpanded: existing?.expanded, persistedExpanded: persistedProjectExpansionByCwd.get(project.workspaceRoot), @@ -296,12 +302,15 @@ export function syncServerReadModel(state: AppState, readModel: OrchestrationRea projectId: thread.projectId, title: thread.title, model: resolveModelSlugForProvider( - inferProviderForThreadModel({ - model: thread.model, - sessionProviderName: thread.session?.providerName ?? null, - }), - thread.model, + thread.modelSelection + ? getModelSelectionProvider(thread.modelSelection) + : inferProviderForThreadModel({ + model: thread.model, + sessionProviderName: thread.session?.providerName ?? null, + }), + thread.modelSelection ? getModelSelectionModel(thread.modelSelection) : thread.model, ), + modelSelection: thread.modelSelection ?? null, runtimeMode: thread.runtimeMode, interactionMode: thread.interactionMode, session: thread.session diff --git a/apps/web/src/types.ts b/apps/web/src/types.ts index b29b28f7..7cb08280 100644 --- a/apps/web/src/types.ts +++ b/apps/web/src/types.ts @@ -5,6 +5,7 @@ import type { OrchestrationThreadActivity, ProjectScript as ContractProjectScript, GitHubRef, + ModelSelection, ThreadId, ProjectId, TurnId, @@ -92,6 +93,7 @@ export interface Project { name: string; cwd: string; model: string; + defaultModelSelection?: ModelSelection | null; expanded: boolean; createdAt?: string | undefined; updatedAt?: string | undefined; @@ -104,6 +106,7 @@ export interface Thread { projectId: ProjectId; title: string; model: string; + modelSelection?: ModelSelection | null; runtimeMode: RuntimeMode; interactionMode: ProviderInteractionMode; session: ThreadSession | null; diff --git a/packages/contracts/src/model.ts b/packages/contracts/src/model.ts index 1c9f9e38..454ce37d 100644 --- a/packages/contracts/src/model.ts +++ b/packages/contracts/src/model.ts @@ -1,4 +1,5 @@ import { Schema } from "effect"; +import { TrimmedNonEmptyString } from "./baseSchemas"; import type { ProviderKind } from "./orchestration"; export const CODEX_REASONING_EFFORT_OPTIONS = ["xhigh", "high", "medium", "low"] as const; @@ -9,6 +10,7 @@ export const OPENCLAW_REASONING_EFFORT_OPTIONS = ["low", "medium", "high"] as co export type OpenClawReasoningEffort = (typeof OPENCLAW_REASONING_EFFORT_OPTIONS)[number]; export const COPILOT_REASONING_EFFORT_OPTIONS = ["low", "medium", "high", "xhigh"] as const; export type CopilotReasoningEffort = (typeof COPILOT_REASONING_EFFORT_OPTIONS)[number]; +export type GeminiReasoningEffort = never; export type ProviderReasoningEffort = | CodexReasoningEffort | ClaudeCodeEffort @@ -25,6 +27,7 @@ export const ClaudeModelOptions = Schema.Struct({ thinking: Schema.optional(Schema.Boolean), effort: Schema.optional(Schema.Literals(CLAUDE_CODE_EFFORT_OPTIONS)), fastMode: Schema.optional(Schema.Boolean), + contextWindow: Schema.optional(Schema.String), }); export type ClaudeModelOptions = typeof ClaudeModelOptions.Type; @@ -38,11 +41,15 @@ export const CopilotModelOptions = Schema.Struct({ }); export type CopilotModelOptions = typeof CopilotModelOptions.Type; +export const GeminiModelOptions = Schema.Struct({}); +export type GeminiModelOptions = typeof GeminiModelOptions.Type; + export const ProviderModelOptions = Schema.Struct({ codex: Schema.optional(CodexModelOptions), claudeAgent: Schema.optional(ClaudeModelOptions), openclaw: Schema.optional(OpenClawModelOptions), copilot: Schema.optional(CopilotModelOptions), + gemini: Schema.optional(GeminiModelOptions), }); export type ProviderModelOptions = typeof ProviderModelOptions.Type; @@ -83,6 +90,14 @@ export const MODEL_OPTIONS_BY_PROVIDER = { { slug: "gemini-2.5-pro", name: "Gemini 2.5 Pro" }, { slug: "grok-code-fast-1", name: "Grok Code Fast 1" }, ], + gemini: [ + { slug: "auto-gemini-3", name: "Auto (Gemini 3)" }, + { slug: "auto-gemini-2.5", name: "Auto (Gemini 2.5)" }, + { slug: "gemini-2.5-pro", name: "Gemini 2.5 Pro" }, + { slug: "gemini-2.5-flash", name: "Gemini 2.5 Flash" }, + { slug: "gemini-3-pro-preview", name: "Gemini 3 Pro Preview" }, + { slug: "gemini-3-flash-preview", name: "Gemini 3 Flash Preview" }, + ], } as const satisfies Record; export type ModelOptionsByProvider = typeof MODEL_OPTIONS_BY_PROVIDER; @@ -94,6 +109,7 @@ export const DEFAULT_MODEL_BY_PROVIDER: Record = { claudeAgent: "claude-sonnet-4-6", openclaw: "default", copilot: "gpt-5.3-codex", + gemini: "auto-gemini-3", }; // Backward compatibility for existing Codex-only call sites. @@ -161,6 +177,19 @@ export const MODEL_SLUG_ALIASES_BY_PROVIDER: Record; export const DEFAULT_REASONING_EFFORT_BY_PROVIDER = { @@ -175,4 +205,28 @@ export const DEFAULT_REASONING_EFFORT_BY_PROVIDER = { claudeAgent: "high", openclaw: "high", copilot: "high", + gemini: "high", } as const satisfies Record; + +export const EffortOption = Schema.Struct({ + value: TrimmedNonEmptyString, + label: TrimmedNonEmptyString, + isDefault: Schema.optional(Schema.Boolean), +}); +export type EffortOption = typeof EffortOption.Type; + +export const ContextWindowOption = Schema.Struct({ + value: TrimmedNonEmptyString, + label: TrimmedNonEmptyString, + isDefault: Schema.optional(Schema.Boolean), +}); +export type ContextWindowOption = typeof ContextWindowOption.Type; + +export const ModelCapabilities = Schema.Struct({ + reasoningEffortLevels: Schema.Array(EffortOption), + supportsFastMode: Schema.Boolean, + supportsThinkingToggle: Schema.Boolean, + contextWindowOptions: Schema.Array(ContextWindowOption), + promptInjectedEffortLevels: Schema.Array(TrimmedNonEmptyString), +}); +export type ModelCapabilities = typeof ModelCapabilities.Type; diff --git a/packages/contracts/src/orchestration.ts b/packages/contracts/src/orchestration.ts index ee935aed..4a510649 100644 --- a/packages/contracts/src/orchestration.ts +++ b/packages/contracts/src/orchestration.ts @@ -1,5 +1,12 @@ import { Option, Schema, SchemaIssue, SchemaTransformation, Struct } from "effect"; -import { ProviderModelOptions } from "./model"; +import { + ClaudeModelOptions, + CodexModelOptions, + CopilotModelOptions, + GeminiModelOptions, + OpenClawModelOptions, + ProviderModelOptions, +} from "./model"; import { ApprovalRequestId, CheckpointRef, @@ -29,7 +36,13 @@ export const ORCHESTRATION_WS_CHANNELS = { domainEvent: "orchestration.domainEvent", } as const; -export const ProviderKind = Schema.Literals(["codex", "claudeAgent", "openclaw", "copilot"]); +export const ProviderKind = Schema.Literals([ + "codex", + "claudeAgent", + "openclaw", + "copilot", + "gemini", +]); export type ProviderKind = typeof ProviderKind.Type; export const ProviderApprovalPolicy = Schema.Literals([ "untrusted", @@ -68,14 +81,63 @@ export const CopilotProviderStartOptions = Schema.Struct({ configDir: Schema.optional(TrimmedNonEmptyString), }); +export const GeminiProviderStartOptions = Schema.Struct({ + binaryPath: Schema.optional(TrimmedNonEmptyString), +}); + export const ProviderStartOptions = Schema.Struct({ codex: Schema.optional(CodexProviderStartOptions), claudeAgent: Schema.optional(ClaudeProviderStartOptions), openclaw: Schema.optional(OpenClawProviderStartOptions), copilot: Schema.optional(CopilotProviderStartOptions), + gemini: Schema.optional(GeminiProviderStartOptions), }); export type ProviderStartOptions = typeof ProviderStartOptions.Type; +export const CodexModelSelection = Schema.Struct({ + provider: Schema.Literal("codex"), + model: TrimmedNonEmptyString, + options: Schema.optional(CodexModelOptions), +}); +export type CodexModelSelection = typeof CodexModelSelection.Type; + +export const ClaudeModelSelection = Schema.Struct({ + provider: Schema.Literal("claudeAgent"), + model: TrimmedNonEmptyString, + options: Schema.optional(ClaudeModelOptions), +}); +export type ClaudeModelSelection = typeof ClaudeModelSelection.Type; + +export const OpenClawModelSelection = Schema.Struct({ + provider: Schema.Literal("openclaw"), + model: TrimmedNonEmptyString, + options: Schema.optional(OpenClawModelOptions), +}); +export type OpenClawModelSelection = typeof OpenClawModelSelection.Type; + +export const CopilotModelSelection = Schema.Struct({ + provider: Schema.Literal("copilot"), + model: TrimmedNonEmptyString, + options: Schema.optional(CopilotModelOptions), +}); +export type CopilotModelSelection = typeof CopilotModelSelection.Type; + +export const GeminiModelSelection = Schema.Struct({ + provider: Schema.Literal("gemini"), + model: TrimmedNonEmptyString, + options: Schema.optional(GeminiModelOptions), +}); +export type GeminiModelSelection = typeof GeminiModelSelection.Type; + +export const ModelSelection = Schema.Union([ + CodexModelSelection, + ClaudeModelSelection, + OpenClawModelSelection, + CopilotModelSelection, + GeminiModelSelection, +]); +export type ModelSelection = typeof ModelSelection.Type; + export const RuntimeMode = Schema.Literals(["approval-required", "full-access"]); export type RuntimeMode = typeof RuntimeMode.Type; export const DEFAULT_RUNTIME_MODE: RuntimeMode = "full-access"; @@ -207,6 +269,7 @@ export const OrchestrationProject = Schema.Struct({ title: TrimmedNonEmptyString, workspaceRoot: TrimmedNonEmptyString, defaultModel: Schema.NullOr(TrimmedNonEmptyString), + defaultModelSelection: Schema.optional(Schema.NullOr(ModelSelection)), scripts: Schema.Array(ProjectScript), createdAt: IsoDateTime, updatedAt: IsoDateTime, @@ -336,6 +399,7 @@ export const OrchestrationThread = Schema.Struct({ projectId: ProjectId, title: TrimmedNonEmptyString, model: TrimmedNonEmptyString, + modelSelection: Schema.optional(Schema.NullOr(ModelSelection)), runtimeMode: RuntimeMode, interactionMode: ProviderInteractionMode.pipe( Schema.withDecodingDefault(() => DEFAULT_PROVIDER_INTERACTION_MODE), @@ -366,6 +430,7 @@ export const OrchestrationOverviewProject = Schema.Struct({ title: TrimmedNonEmptyString, workspaceRoot: TrimmedNonEmptyString, defaultModel: Schema.NullOr(TrimmedNonEmptyString), + defaultModelSelection: Schema.optional(Schema.NullOr(ModelSelection)), scripts: Schema.Array(ProjectScript), activeThreadCount: NonNegativeInt, createdAt: IsoDateTime, @@ -378,6 +443,7 @@ export const OrchestrationOverviewThread = Schema.Struct({ projectId: ProjectId, title: TrimmedNonEmptyString, model: TrimmedNonEmptyString, + modelSelection: Schema.optional(Schema.NullOr(ModelSelection)), runtimeMode: RuntimeMode, interactionMode: ProviderInteractionMode.pipe( Schema.withDecodingDefault(() => DEFAULT_PROVIDER_INTERACTION_MODE), @@ -420,6 +486,7 @@ export const ProjectCreateCommand = Schema.Struct({ title: TrimmedNonEmptyString, workspaceRoot: TrimmedNonEmptyString, defaultModel: Schema.optional(TrimmedNonEmptyString), + defaultModelSelection: Schema.optional(ModelSelection), scripts: Schema.optional(Schema.Array(ProjectScript)), createdAt: IsoDateTime, }); @@ -431,6 +498,7 @@ const ProjectMetaUpdateCommand = Schema.Struct({ title: Schema.optional(TrimmedNonEmptyString), workspaceRoot: Schema.optional(TrimmedNonEmptyString), defaultModel: Schema.optional(TrimmedNonEmptyString), + defaultModelSelection: Schema.optional(Schema.NullOr(ModelSelection)), scripts: Schema.optional(Schema.Array(ProjectScript)), }); @@ -447,6 +515,7 @@ const ThreadCreateCommand = Schema.Struct({ projectId: ProjectId, title: TrimmedNonEmptyString, model: TrimmedNonEmptyString, + modelSelection: Schema.optional(Schema.NullOr(ModelSelection)), runtimeMode: RuntimeMode, interactionMode: ProviderInteractionMode.pipe( Schema.withDecodingDefault(() => DEFAULT_PROVIDER_INTERACTION_MODE), @@ -469,6 +538,7 @@ const ThreadMetaUpdateCommand = Schema.Struct({ threadId: ThreadId, title: Schema.optional(TrimmedNonEmptyString), model: Schema.optional(TrimmedNonEmptyString), + modelSelection: Schema.optional(ModelSelection), branch: Schema.optional(Schema.NullOr(TrimmedNonEmptyString)), worktreePath: Schema.optional(Schema.NullOr(TrimmedNonEmptyString)), githubRef: Schema.optional(GitHubRef), @@ -503,6 +573,7 @@ export const ThreadTurnStartCommand = Schema.Struct({ providerInput: Schema.optional( TrimmedNonEmptyString.check(Schema.isMaxLength(PROVIDER_SEND_TURN_MAX_INPUT_CHARS)), ), + modelSelection: Schema.optional(ModelSelection), provider: Schema.optional(ProviderKind), model: Schema.optional(TrimmedNonEmptyString), modelOptions: Schema.optional(ProviderModelOptions), @@ -529,6 +600,7 @@ const ClientThreadTurnStartCommand = Schema.Struct({ providerInput: Schema.optional( TrimmedNonEmptyString.check(Schema.isMaxLength(PROVIDER_SEND_TURN_MAX_INPUT_CHARS)), ), + modelSelection: Schema.optional(ModelSelection), provider: Schema.optional(ProviderKind), model: Schema.optional(TrimmedNonEmptyString), modelOptions: Schema.optional(ProviderModelOptions), @@ -733,6 +805,7 @@ export const ProjectCreatedPayload = Schema.Struct({ title: TrimmedNonEmptyString, workspaceRoot: TrimmedNonEmptyString, defaultModel: Schema.NullOr(TrimmedNonEmptyString), + defaultModelSelection: Schema.optional(Schema.NullOr(ModelSelection)), scripts: Schema.Array(ProjectScript), createdAt: IsoDateTime, updatedAt: IsoDateTime, @@ -743,6 +816,7 @@ export const ProjectMetaUpdatedPayload = Schema.Struct({ title: Schema.optional(TrimmedNonEmptyString), workspaceRoot: Schema.optional(TrimmedNonEmptyString), defaultModel: Schema.optional(Schema.NullOr(TrimmedNonEmptyString)), + defaultModelSelection: Schema.optional(Schema.NullOr(ModelSelection)), scripts: Schema.optional(Schema.Array(ProjectScript)), updatedAt: IsoDateTime, }); @@ -761,6 +835,7 @@ export const ThreadCreatedPayload = Schema.Struct({ projectId: ProjectId, title: TrimmedNonEmptyString, model: TrimmedNonEmptyString, + modelSelection: Schema.optional(Schema.NullOr(ModelSelection)), runtimeMode: RuntimeMode.pipe(Schema.withDecodingDefault(() => DEFAULT_RUNTIME_MODE)), interactionMode: ProviderInteractionMode.pipe( Schema.withDecodingDefault(() => DEFAULT_PROVIDER_INTERACTION_MODE), @@ -785,6 +860,7 @@ export const ThreadMetaUpdatedPayload = Schema.Struct({ threadId: ThreadId, title: Schema.optional(TrimmedNonEmptyString), model: Schema.optional(TrimmedNonEmptyString), + modelSelection: Schema.optional(Schema.NullOr(ModelSelection)), branch: Schema.optional(Schema.NullOr(TrimmedNonEmptyString)), worktreePath: Schema.optional(Schema.NullOr(TrimmedNonEmptyString)), githubRef: Schema.optional(GitHubRef), @@ -823,6 +899,7 @@ export const ThreadTurnStartRequestedPayload = Schema.Struct({ providerInput: Schema.optional( TrimmedNonEmptyString.check(Schema.isMaxLength(PROVIDER_SEND_TURN_MAX_INPUT_CHARS)), ), + modelSelection: Schema.optional(Schema.NullOr(ModelSelection)), provider: Schema.optional(ProviderKind), model: Schema.optional(TrimmedNonEmptyString), modelOptions: Schema.optional(ProviderModelOptions), diff --git a/packages/contracts/src/server.ts b/packages/contracts/src/server.ts index 1ebaa476..4193bf00 100644 --- a/packages/contracts/src/server.ts +++ b/packages/contracts/src/server.ts @@ -9,6 +9,7 @@ import { } from "./keybindings"; import { EditorId } from "./editor"; import { ProviderKind } from "./orchestration"; +import { ModelCapabilities } from "./model"; const KeybindingsMalformedConfigIssue = Schema.Struct({ kind: Schema.Literal("keybindings.malformed-config"), @@ -39,17 +40,40 @@ export const ServerProviderAuthStatus = Schema.Literals([ ]); export type ServerProviderAuthStatus = typeof ServerProviderAuthStatus.Type; -export const ServerProviderStatus = Schema.Struct({ +export const ServerProviderAuth = Schema.Struct({ + status: ServerProviderAuthStatus, + type: Schema.optional(TrimmedNonEmptyString), + label: Schema.optional(TrimmedNonEmptyString), +}); +export type ServerProviderAuth = typeof ServerProviderAuth.Type; + +export const ServerProviderModel = Schema.Struct({ + slug: TrimmedNonEmptyString, + name: TrimmedNonEmptyString, + isCustom: Schema.Boolean, + capabilities: Schema.NullOr(ModelCapabilities), +}); +export type ServerProviderModel = typeof ServerProviderModel.Type; + +export const ServerProvider = Schema.Struct({ provider: ProviderKind, + enabled: Schema.optional(Schema.Boolean), + installed: Schema.optional(Schema.Boolean), + version: Schema.optional(Schema.NullOr(TrimmedNonEmptyString)), status: ServerProviderStatusState, - available: Schema.Boolean, - authStatus: ServerProviderAuthStatus, + auth: Schema.optional(ServerProviderAuth), checkedAt: IsoDateTime, message: Schema.optional(TrimmedNonEmptyString), + models: Schema.optional(Schema.Array(ServerProviderModel)), + // Compatibility aliases for older web/server code paths during migration. + available: Schema.optional(Schema.Boolean), + authStatus: Schema.optional(ServerProviderAuthStatus), }); -export type ServerProviderStatus = typeof ServerProviderStatus.Type; +export type ServerProvider = typeof ServerProvider.Type; +export const ServerProviderStatus = ServerProvider; +export type ServerProviderStatus = ServerProvider; -const ServerProviderStatuses = Schema.Array(ServerProviderStatus); +const ServerProviderStatuses = Schema.Array(ServerProvider); export const ServerConfig = Schema.Struct({ cwd: TrimmedNonEmptyString, diff --git a/packages/shared/package.json b/packages/shared/package.json index 21aa7fe1..7c2288fd 100644 --- a/packages/shared/package.json +++ b/packages/shared/package.json @@ -8,6 +8,10 @@ "types": "./src/model.ts", "import": "./src/model.ts" }, + "./modelSelection": { + "types": "./src/modelSelection.ts", + "import": "./src/modelSelection.ts" + }, "./git": { "types": "./src/git.ts", "import": "./src/git.ts" diff --git a/packages/shared/src/model.ts b/packages/shared/src/model.ts index 36bed69d..727e29cb 100644 --- a/packages/shared/src/model.ts +++ b/packages/shared/src/model.ts @@ -25,6 +25,7 @@ const MODEL_SLUG_SET_BY_PROVIDER: Record> = codex: new Set(MODEL_OPTIONS_BY_PROVIDER.codex.map((option) => option.slug)), openclaw: new Set(), copilot: new Set(MODEL_OPTIONS_BY_PROVIDER.copilot.map((option) => option.slug)), + gemini: new Set(MODEL_OPTIONS_BY_PROVIDER.gemini.map((option) => option.slug)), }; const CLAUDE_OPUS_4_6_MODEL = "claude-opus-4-6"; @@ -171,11 +172,17 @@ export function inferProviderForModel( return "copilot"; } + const normalizedGemini = normalizeModelSlug(model, "gemini"); + if (normalizedGemini && MODEL_SLUG_SET_BY_PROVIDER.gemini.has(normalizedGemini)) { + return "gemini"; + } + if (typeof model === "string") { const trimmed = model.trim(); if (trimmed.startsWith("claude-")) return "claudeAgent"; if (trimmed.startsWith("openclaw/")) return "openclaw"; if (trimmed.startsWith("copilot/")) return "copilot"; + if (trimmed.startsWith("gemini-") || trimmed.startsWith("auto-gemini-")) return "gemini"; } return fallback; } @@ -191,6 +198,7 @@ export function getReasoningEffortOptions( export function getReasoningEffortOptions( provider: "copilot", ): ReadonlyArray; +export function getReasoningEffortOptions(provider: "gemini"): ReadonlyArray; export function getReasoningEffortOptions( provider?: ProviderKind, model?: string | null | undefined, @@ -215,6 +223,7 @@ export function getDefaultReasoningEffort(provider: "codex"): CodexReasoningEffo export function getDefaultReasoningEffort(provider: "claudeAgent"): ClaudeCodeEffort; export function getDefaultReasoningEffort(provider: "openclaw"): OpenClawReasoningEffort; export function getDefaultReasoningEffort(provider: "copilot"): CopilotReasoningEffort; +export function getDefaultReasoningEffort(provider: "gemini"): ProviderReasoningEffort; export function getDefaultReasoningEffort(provider?: ProviderKind): ProviderReasoningEffort; export function getDefaultReasoningEffort( provider: ProviderKind = "codex", @@ -238,6 +247,10 @@ export function resolveReasoningEffortForProvider( provider: "copilot", effort: string | null | undefined, ): CopilotReasoningEffort | null; +export function resolveReasoningEffortForProvider( + provider: "gemini", + effort: string | null | undefined, +): null; export function resolveReasoningEffortForProvider( provider: ProviderKind, effort: string | null | undefined, diff --git a/packages/shared/src/modelSelection.test.ts b/packages/shared/src/modelSelection.test.ts new file mode 100644 index 00000000..ad25fa39 --- /dev/null +++ b/packages/shared/src/modelSelection.test.ts @@ -0,0 +1,55 @@ +import { describe, expect, it } from "vitest"; + +import { + toCanonicalModelSelection, + normalizeModelSelectionWithCapabilities, +} from "./modelSelection"; + +describe("toCanonicalModelSelection", () => { + it("normalizes provider aliases into a canonical selection", () => { + expect(toCanonicalModelSelection("gemini", "Gemini 2.5 Pro", undefined)).toEqual({ + provider: "gemini", + model: "gemini-2.5-pro", + }); + }); + + it("falls back to the provider default when the model is missing", () => { + expect(toCanonicalModelSelection("gemini", null, undefined)).toEqual({ + provider: "gemini", + model: "auto-gemini-3", + }); + }); +}); + +describe("normalizeModelSelectionWithCapabilities", () => { + it("prunes unsupported codex options from the canonical selection", () => { + expect( + normalizeModelSelectionWithCapabilities( + { + provider: "codex", + model: "gpt-5.4", + options: { reasoningEffort: "xhigh", fastMode: true }, + }, + [ + { + slug: "gpt-5.4", + capabilities: { + reasoningEffortLevels: [ + { value: "medium", label: "Medium" }, + { value: "high", label: "High", isDefault: true }, + ], + supportsFastMode: false, + supportsThinkingToggle: false, + contextWindowOptions: [], + promptInjectedEffortLevels: [], + }, + }, + ], + ), + ).toEqual({ + provider: "codex", + model: "gpt-5.4", + options: {}, + }); + }); +}); diff --git a/packages/shared/src/modelSelection.ts b/packages/shared/src/modelSelection.ts new file mode 100644 index 00000000..c57001e0 --- /dev/null +++ b/packages/shared/src/modelSelection.ts @@ -0,0 +1,147 @@ +import { + DEFAULT_MODEL_BY_PROVIDER, + MODEL_OPTIONS_BY_PROVIDER, + MODEL_SLUG_ALIASES_BY_PROVIDER, + type ModelCapabilities, + type ModelSelection, + type ProviderKind, + type ProviderModelOptions, +} from "@okcode/contracts"; + +type SelectableModel = { + readonly slug: string; + readonly capabilities?: ModelCapabilities | null | undefined; +}; + +const PROVIDER_MODEL_SET = { + codex: new Set(MODEL_OPTIONS_BY_PROVIDER.codex.map((option) => option.slug)), + claudeAgent: new Set(MODEL_OPTIONS_BY_PROVIDER.claudeAgent.map((option) => option.slug)), + openclaw: new Set(), + copilot: new Set(MODEL_OPTIONS_BY_PROVIDER.copilot.map((option) => option.slug)), + gemini: new Set(MODEL_OPTIONS_BY_PROVIDER.gemini.map((option) => option.slug)), +} as const satisfies Record>; + +export function normalizeModelSelectionModel( + provider: ProviderKind, + model: string | null | undefined, +): string { + const trimmed = typeof model === "string" ? model.trim() : ""; + const aliasMap = MODEL_SLUG_ALIASES_BY_PROVIDER[provider] as Record; + const aliased = trimmed ? (aliasMap[trimmed] ?? aliasMap[trimmed.toLowerCase()] ?? trimmed) : ""; + if (aliased && (PROVIDER_MODEL_SET[provider] as ReadonlySet).has(aliased)) { + return aliased; + } + return trimmed || DEFAULT_MODEL_BY_PROVIDER[provider]; +} + +export function toCanonicalModelSelection( + provider: ProviderKind, + model: string | null | undefined, + modelOptions: ProviderModelOptions | null | undefined, +): ModelSelection { + const normalizedModel = normalizeModelSelectionModel(provider, model); + const providerOptions = modelOptions?.[provider]; + return providerOptions + ? ({ provider, model: normalizedModel, options: providerOptions } as ModelSelection) + : ({ provider, model: normalizedModel } as ModelSelection); +} + +export function getModelSelectionProvider( + selection: ModelSelection | null | undefined, +): ProviderKind { + return selection?.provider ?? "codex"; +} + +export function getModelSelectionModel(selection: ModelSelection | null | undefined): string { + const provider = selection?.provider ?? "codex"; + return selection?.model ?? DEFAULT_MODEL_BY_PROVIDER[provider]; +} + +export function getModelSelectionOptions( + selection: ModelSelection | null | undefined, +): ProviderModelOptions | undefined { + if (!selection?.options) return undefined; + return { [selection.provider]: selection.options } as ProviderModelOptions; +} + +export function modelSelectionsAreEqual( + a: ModelSelection | null | undefined, + b: ModelSelection | null | undefined, +): boolean { + if (a == null && b == null) return true; + if (a == null || b == null) return false; + if (a.provider !== b.provider || a.model !== b.model) return false; + const aOpts = a.options ?? null; + const bOpts = b.options ?? null; + if (aOpts === null && bOpts === null) return true; + if (aOpts === null || bOpts === null) return false; + const aKeys = Object.keys(aOpts).sort(); + const bKeys = Object.keys(bOpts).sort(); + if (aKeys.join(",") !== bKeys.join(",")) return false; + return aKeys.every( + (k) => + (aOpts as Record)[k] === (bOpts as Record)[k], + ); +} + +export function normalizeModelSelectionWithCapabilities( + selection: ModelSelection, + models: ReadonlyArray, +): ModelSelection { + const matchedModel = models.find((candidate) => candidate.slug === selection.model); + if (!matchedModel?.capabilities) { + return selection; + } + const capabilities = matchedModel.capabilities; + const supportsFastMode = capabilities.supportsFastMode; + const supportsThinkingToggle = capabilities.supportsThinkingToggle; + const reasoningEffortValues = new Set( + capabilities.reasoningEffortLevels.map((option) => option.value), + ); + const contextWindowValues = new Set( + capabilities.contextWindowOptions.map((option) => option.value), + ); + + if (!selection.options) { + return selection; + } + + switch (selection.provider) { + case "codex": + return { + ...selection, + options: { + ...(selection.options.reasoningEffort && + reasoningEffortValues.has(selection.options.reasoningEffort) + ? { reasoningEffort: selection.options.reasoningEffort } + : {}), + ...(supportsFastMode && selection.options.fastMode !== undefined + ? { fastMode: selection.options.fastMode } + : {}), + }, + }; + case "claudeAgent": + return { + ...selection, + options: { + ...(supportsThinkingToggle && selection.options.thinking !== undefined + ? { thinking: selection.options.thinking } + : {}), + ...(selection.options.effort && reasoningEffortValues.has(selection.options.effort) + ? { effort: selection.options.effort } + : {}), + ...(supportsFastMode && selection.options.fastMode !== undefined + ? { fastMode: selection.options.fastMode } + : {}), + ...(selection.options.contextWindow && + contextWindowValues.has(selection.options.contextWindow) + ? { contextWindow: selection.options.contextWindow } + : {}), + }, + }; + case "openclaw": + case "copilot": + case "gemini": + return selection; + } +}