diff --git a/signatures/cla.json b/signatures/cla.json index 800b3e0090..b23724845c 100644 --- a/signatures/cla.json +++ b/signatures/cla.json @@ -1687,6 +1687,14 @@ "created_at": "2026-02-22T10:57:33Z", "repoId": 1108837393, "pullRequestNo": 2045 + }, + { + "name": "DMax1314", + "id": 54206290, + "comment_id": 3943046087, + "created_at": "2026-02-23T07:06:14Z", + "repoId": 1108837393, + "pullRequestNo": 2068 } ] } \ No newline at end of file diff --git a/src/features/background-agent/manager.test.ts b/src/features/background-agent/manager.test.ts index 7bd7709f13..2f62f64352 100644 --- a/src/features/background-agent/manager.test.ts +++ b/src/features/background-agent/manager.test.ts @@ -2554,8 +2554,8 @@ describe("BackgroundManager.checkAndInterruptStaleTasks", () => { expect(task.status).toBe("running") }) - test("should NOT interrupt running session with no progress (undefined lastUpdate)", async () => { - //#given — no progress at all, but session is running + test("should interrupt running session with no progress after messageStalenessTimeout (API hang detection)", async () => { + //#given — no progress at all, session is running but exceeded stale timeout const client = { session: { prompt: async () => ({}), @@ -2580,11 +2580,11 @@ describe("BackgroundManager.checkAndInterruptStaleTasks", () => { getTaskMap(manager).set(task.id, task) - //#when — session is running despite no progress + //#when — session is running despite no progress for 15min (exceeds 10min timeout) await manager["checkAndInterruptStaleTasks"]({ "session-rnp": { type: "running" } }) - - //#then — running sessions are NEVER killed - expect(task.status).toBe("running") + //#then — running sessions with no progress ARE interrupted after stale timeout + expect(task.status).toBe("cancelled") + expect(task.error).toContain("possible API hang") }) test("should interrupt task with no lastUpdate after messageStalenessTimeout", async () => { diff --git a/src/features/background-agent/stale-fallback-handler.test.ts b/src/features/background-agent/stale-fallback-handler.test.ts new file mode 100644 index 0000000000..ba19b99cf3 --- /dev/null +++ b/src/features/background-agent/stale-fallback-handler.test.ts @@ -0,0 +1,238 @@ +import { describe, it, expect, mock } from "bun:test" + +import { resolveNextFallbackModel, buildFallbackLaunchInput, createStaleFallbackHandler } from "./stale-fallback-handler" +import type { BackgroundTask, LaunchInput } from "./types" +import type { OhMyOpenCodeConfig } from "../../config" + +function createTask(overrides: Partial = {}): BackgroundTask { + return { + id: "task-1", + sessionID: "ses-1", + parentSessionID: "parent-ses-1", + parentMessageID: "msg-1", + description: "test task", + prompt: "test prompt", + agent: "explore", + status: "cancelled", + startedAt: new Date(), + model: { providerID: "kimi", modelID: "kimi-k2.5-free" }, + ...overrides, + } +} + +function createConfig(overrides: Partial = {}): OhMyOpenCodeConfig { + return { + agents: { + explore: { + fallback_models: ["kimi/kimi-k2.5-free", "glm/glm-4-flash-250414"], + }, + }, + ...overrides, + } as unknown as OhMyOpenCodeConfig +} + +describe("resolveNextFallbackModel", () => { + it("should use task.fallbackModels when present", () => { + //#given + const task = createTask({ + fallbackModels: ["openai/gpt-4o", "anthropic/claude-sonnet-4-20250514"], + }) + const config = createConfig() + + //#when + const result = resolveNextFallbackModel(task, config) + + //#then + expect(result).toEqual({ + nextModel: "openai/gpt-4o", + remainingModels: ["anthropic/claude-sonnet-4-20250514"], + }) + }) + + it("should resolve from agent config when task.fallbackModels is empty", () => { + //#given + const task = createTask({ + agent: "explore", + model: { providerID: "kimi", modelID: "kimi-k2.5-free" }, + fallbackModels: undefined, + }) + const config = createConfig() + + //#when + const result = resolveNextFallbackModel(task, config) + + //#then + expect(result).toEqual({ + nextModel: "glm/glm-4-flash-250414", + remainingModels: [], + }) + }) + + it("should resolve from category config when agent has no fallback_models", () => { + //#given + const task = createTask({ + agent: "unknown-agent", + category: "quick", + model: { providerID: "kimi", modelID: "kimi-k2.5-free" }, + fallbackModels: undefined, + }) + const config = { + agents: {}, + categories: { + quick: { + fallback_models: ["kimi/kimi-k2.5-free", "openai/gpt-4o-mini"], + }, + }, + } as unknown as OhMyOpenCodeConfig + + //#when + const result = resolveNextFallbackModel(task, config) + + //#then + expect(result).toEqual({ + nextModel: "openai/gpt-4o-mini", + remainingModels: [], + }) + }) + + it("should return undefined when no fallback models available", () => { + //#given + const task = createTask({ + agent: "unknown-agent", + fallbackModels: undefined, + }) + const config = { agents: {} } as unknown as OhMyOpenCodeConfig + + //#when + const result = resolveNextFallbackModel(task, config) + + //#then + expect(result).toBeUndefined() + }) + + it("should return undefined when current model is the last in the chain", () => { + //#given + const task = createTask({ + agent: "explore", + model: { providerID: "glm", modelID: "glm-4-flash-250414" }, + fallbackModels: undefined, + }) + const config = createConfig() + + //#when + const result = resolveNextFallbackModel(task, config) + + //#then + expect(result).toBeUndefined() + }) +}) + +describe("buildFallbackLaunchInput", () => { + it("should build a valid LaunchInput with the new model", () => { + //#given + const task = createTask({ + parentModel: { providerID: "anthropic", modelID: "claude-opus-4-6" }, + parentAgent: "sisyphus", + isUnstableAgent: true, + category: "quick", + }) + + //#when + const result = buildFallbackLaunchInput(task, "openai/gpt-4o", ["anthropic/claude-sonnet-4-20250514"]) + + //#then + expect(result).toEqual({ + description: "test task", + prompt: "test prompt", + agent: "explore", + parentSessionID: "parent-ses-1", + parentMessageID: "msg-1", + parentModel: { providerID: "anthropic", modelID: "claude-opus-4-6" }, + parentAgent: "sisyphus", + parentTools: undefined, + model: { providerID: "openai", modelID: "gpt-4o" }, + isUnstableAgent: true, + category: "quick", + fallbackModels: ["anthropic/claude-sonnet-4-20250514"], + }) + }) + + it("should return undefined for invalid model format", () => { + //#given + const task = createTask() + + //#when + const result = buildFallbackLaunchInput(task, "no-slash-model", []) + + //#then + expect(result).toBeUndefined() + }) + + it("should handle model IDs with multiple slashes", () => { + //#given + const task = createTask() + + //#when + const result = buildFallbackLaunchInput(task, "anthropic/claude-sonnet-4-20250514/latest", []) + + //#then + expect(result?.model).toEqual({ + providerID: "anthropic", + modelID: "claude-sonnet-4-20250514/latest", + }) + }) +}) + +describe("createStaleFallbackHandler", () => { + it("should launch a new task with the next fallback model", async () => { + //#given + const mockLaunch = mock(() => + Promise.resolve(createTask({ id: "task-2", status: "pending" })), + ) + const config = createConfig() + const task = createTask({ + fallbackModels: ["openai/gpt-4o"], + }) + const handler = createStaleFallbackHandler(config, mockLaunch) + + //#when + await handler(task) + + //#then + expect(mockLaunch).toHaveBeenCalledTimes(1) + const launchInput = mockLaunch.mock.calls[0][0] as LaunchInput + expect(launchInput.model).toEqual({ providerID: "openai", modelID: "gpt-4o" }) + expect(launchInput.fallbackModels).toEqual([]) + }) + + it("should not launch when no fallback models are available", async () => { + //#given + const mockLaunch = mock(() => Promise.resolve(createTask())) + const config = { agents: {} } as unknown as OhMyOpenCodeConfig + const task = createTask({ + agent: "unknown", + fallbackModels: undefined, + }) + const handler = createStaleFallbackHandler(config, mockLaunch) + + //#when + await handler(task) + + //#then + expect(mockLaunch).not.toHaveBeenCalled() + }) + + it("should not throw when launch fails", async () => { + //#given + const mockLaunch = mock(() => Promise.reject(new Error("launch failed"))) + const config = createConfig() + const task = createTask({ + fallbackModels: ["openai/gpt-4o"], + }) + const handler = createStaleFallbackHandler(config, mockLaunch) + + //#when + then (should not throw) + await handler(task) + expect(mockLaunch).toHaveBeenCalledTimes(1) + }) +}) diff --git a/src/features/background-agent/stale-fallback-handler.ts b/src/features/background-agent/stale-fallback-handler.ts new file mode 100644 index 0000000000..547516584a --- /dev/null +++ b/src/features/background-agent/stale-fallback-handler.ts @@ -0,0 +1,139 @@ +import type { OhMyOpenCodeConfig } from "../../config" +import type { BackgroundTask, LaunchInput } from "./types" +import { normalizeFallbackModels } from "../../shared/model-resolver" +import { log } from "../../shared" + +function resolveFallbackModelsForTask( + task: BackgroundTask, + pluginConfig: OhMyOpenCodeConfig, +): string[] { + const agentName = task.agent?.toLowerCase() + if (agentName) { + const agentConfig = pluginConfig.agents?.[agentName as keyof typeof pluginConfig.agents] + if (agentConfig?.fallback_models) { + return normalizeFallbackModels(agentConfig.fallback_models) ?? [] + } + } + + if (task.category && pluginConfig.categories?.[task.category]) { + const categoryConfig = pluginConfig.categories[task.category] + if (categoryConfig?.fallback_models) { + return normalizeFallbackModels(categoryConfig.fallback_models) ?? [] + } + } + + return [] +} + +export function resolveNextFallbackModel( + task: BackgroundTask, + pluginConfig: OhMyOpenCodeConfig, +): { nextModel: string; remainingModels: string[] } | undefined { + if (task.fallbackModels && task.fallbackModels.length > 0) { + return { + nextModel: task.fallbackModels[0], + remainingModels: task.fallbackModels.slice(1), + } + } + + const allFallbackModels = resolveFallbackModelsForTask(task, pluginConfig) + if (allFallbackModels.length === 0) return undefined + + const currentModel = task.model + ? `${task.model.providerID}/${task.model.modelID}` + : undefined + + let startIndex = 0 + if (currentModel) { + const idx = allFallbackModels.indexOf(currentModel) + if (idx >= 0) startIndex = idx + 1 + } + + if (startIndex >= allFallbackModels.length) return undefined + + return { + nextModel: allFallbackModels[startIndex], + remainingModels: allFallbackModels.slice(startIndex + 1), + } +} + +function parseModelToProviderAndId( + model: string, +): { providerID: string; modelID: string } | undefined { + const parts = model.split("/") + if (parts.length < 2) return undefined + return { providerID: parts[0], modelID: parts.slice(1).join("/") } +} + +export function buildFallbackLaunchInput( + task: BackgroundTask, + nextModel: string, + remainingModels: string[], +): LaunchInput | undefined { + const parsed = parseModelToProviderAndId(nextModel) + if (!parsed) return undefined + + return { + description: task.description, + prompt: task.prompt, + agent: task.agent, + parentSessionID: task.parentSessionID, + parentMessageID: task.parentMessageID, + parentModel: task.parentModel, + parentAgent: task.parentAgent, + parentTools: task.parentTools, + model: parsed, + isUnstableAgent: task.isUnstableAgent, + category: task.category, + fallbackModels: remainingModels, + } +} + +export function createStaleFallbackHandler( + pluginConfig: OhMyOpenCodeConfig, + launchFn: (input: LaunchInput) => Promise, +): (task: BackgroundTask) => Promise { + return async (task: BackgroundTask) => { + const result = resolveNextFallbackModel(task, pluginConfig) + if (!result) { + log("[background-agent] No fallback models available for stale task", { + taskId: task.id, + agent: task.agent, + category: task.category, + }) + return + } + + const { nextModel, remainingModels } = result + const launchInput = buildFallbackLaunchInput(task, nextModel, remainingModels) + if (!launchInput) { + log("[background-agent] Invalid fallback model format", { + taskId: task.id, + model: nextModel, + }) + return + } + + log("[background-agent] Re-launching stale task with fallback model", { + taskId: task.id, + fromModel: task.model ? `${task.model.providerID}/${task.model.modelID}` : "unknown", + toModel: nextModel, + remainingFallbacks: remainingModels.length, + }) + + try { + const newTask = await launchFn(launchInput) + log("[background-agent] Fallback task launched", { + originalTaskId: task.id, + newTaskId: newTask.id, + model: nextModel, + }) + } catch (error) { + log("[background-agent] Failed to launch fallback task", { + taskId: task.id, + model: nextModel, + error: String(error), + }) + } + } +} diff --git a/src/features/background-agent/task-poller.test.ts b/src/features/background-agent/task-poller.test.ts index d411cb240e..caef61eb1e 100644 --- a/src/features/background-agent/task-poller.test.ts +++ b/src/features/background-agent/task-poller.test.ts @@ -184,14 +184,14 @@ describe("checkAndInterruptStaleTasks", () => { expect(task.status).toBe("running") }) - it("should NOT interrupt busy session even with no progress (undefined lastUpdate)", async () => { - //#given — task has no progress at all, but session is busy + it("should interrupt busy session with no progress when exceeding staleness timeout (possible API hang)", async () => { + //#given — task has no progress at all, session is busy but likely hung const task = createRunningTask({ startedAt: new Date(Date.now() - 15 * 60 * 1000), progress: undefined, }) - //#when — session is busy + //#when — session is busy but has had zero progress for 15min > 10min timeout await checkAndInterruptStaleTasks({ tasks: [task], client: mockClient as never, @@ -201,8 +201,9 @@ describe("checkAndInterruptStaleTasks", () => { sessionStatuses: { "ses-1": { type: "busy" } }, }) - //#then — task should survive because session is actively running - expect(task.status).toBe("running") + //#then — task should be killed as possible API hang (no progress despite being "busy") + expect(task.status).toBe("cancelled") + expect(task.error).toContain("possible API hang") }) it("should interrupt task when session is idle and lastUpdate exceeds stale timeout", async () => { @@ -254,14 +255,14 @@ describe("checkAndInterruptStaleTasks", () => { expect(task.status).toBe("running") }) - it("should NOT interrupt running session even with no progress (undefined lastUpdate)", async () => { - //#given — task has no progress at all, but session is running + it("should interrupt running session with no progress when exceeding staleness timeout (possible API hang)", async () => { + //#given — task has no progress at all, session is running but likely hung const task = createRunningTask({ startedAt: new Date(Date.now() - 15 * 60 * 1000), progress: undefined, }) - //#when — session is running + //#when — session is running but has had zero progress for 15min > 10min timeout await checkAndInterruptStaleTasks({ tasks: [task], client: mockClient as never, @@ -271,8 +272,9 @@ describe("checkAndInterruptStaleTasks", () => { sessionStatuses: { "ses-1": { type: "running" } }, }) - //#then — running sessions are NEVER killed, even without progress - expect(task.status).toBe("running") + //#then — task should be killed as possible API hang (no progress despite being "running") + expect(task.status).toBe("cancelled") + expect(task.error).toContain("possible API hang") }) it("should use default stale timeout when session status is unknown/missing", async () => { @@ -348,14 +350,14 @@ describe("checkAndInterruptStaleTasks", () => { expect(task.status).toBe("running") }) - it("should NOT interrupt busy session even with no progress (undefined lastUpdate)", async () => { - //#given — no progress at all, session is "busy" (thinking model with no streamed tokens yet) + it("should NOT interrupt busy session with no progress when within staleness timeout", async () => { + //#given — no progress, session is busy, but within staleness timeout const task = createRunningTask({ - startedAt: new Date(Date.now() - 15 * 60 * 1000), + startedAt: new Date(Date.now() - 5 * 60 * 1000), progress: undefined, }) - //#when — session is busy + //#when — session is busy and runtime (5min) < messageStalenessTimeoutMs (10min) await checkAndInterruptStaleTasks({ tasks: [task], client: mockClient as never, @@ -365,7 +367,7 @@ describe("checkAndInterruptStaleTasks", () => { sessionStatuses: { "ses-1": { type: "busy" } }, }) - //#then — busy sessions with no progress must survive + //#then — session is within timeout so should survive expect(task.status).toBe("running") }) diff --git a/src/features/background-agent/task-poller.ts b/src/features/background-agent/task-poller.ts index eca83bc661..e6a62a661d 100644 --- a/src/features/background-agent/task-poller.ts +++ b/src/features/background-agent/task-poller.ts @@ -85,12 +85,13 @@ export async function checkAndInterruptStaleTasks(args: { const runtime = now - startedAt.getTime() if (!task.progress?.lastUpdate) { - if (sessionIsRunning) continue if (runtime <= messageStalenessMs) continue const staleMinutes = Math.round(runtime / 60000) task.status = "cancelled" - task.error = `Stale timeout (no activity for ${staleMinutes}min since start)` + task.error = sessionIsRunning + ? `Stale timeout (no activity for ${staleMinutes}min since start — possible API hang)` + : `Stale timeout (no activity for ${staleMinutes}min since start)` task.completedAt = new Date() if (task.concurrencyKey) { diff --git a/src/features/background-agent/types.ts b/src/features/background-agent/types.ts index 6973dd7831..c578cb8807 100644 --- a/src/features/background-agent/types.ts +++ b/src/features/background-agent/types.ts @@ -49,6 +49,8 @@ export interface BackgroundTask { isUnstableAgent?: boolean /** Category used for this task (e.g., 'quick', 'visual-engineering') */ category?: string + /** Remaining fallback models for automatic retry on stale timeout */ + fallbackModels?: string[] /** Last message count for stability detection */ lastMsgCount?: number @@ -72,6 +74,7 @@ export interface LaunchInput { skills?: string[] skillContent?: string category?: string + fallbackModels?: string[] } export interface ResumeInput { diff --git a/src/features/claude-code-session-state/state.test.ts b/src/features/claude-code-session-state/state.test.ts index 82018316cf..c34171c084 100644 --- a/src/features/claude-code-session-state/state.test.ts +++ b/src/features/claude-code-session-state/state.test.ts @@ -4,11 +4,14 @@ import { getSessionAgent, clearSessionAgent, updateSessionAgent, + pinSessionAgent, + unpinSessionAgent, setMainSession, getMainSessionID, _resetForTesting, } from "./state" + describe("claude-code-session-state", () => { beforeEach(() => { // given - clean state before each test @@ -161,4 +164,94 @@ describe("claude-code-session-state", () => { expect(getSessionAgent(sessionID)).toBe(newAgent) }) }) + + describe("pinSessionAgent", () => { + test("should store pinned agent for session", () => { + // given + const sessionID = "test-pin-1" + + // when + pinSessionAgent(sessionID, "atlas") + + // then + expect(getSessionAgent(sessionID)).toBe("atlas") + }) + + test("should take precedence over updateSessionAgent", () => { + // given - pin atlas + const sessionID = "test-pin-priority" + pinSessionAgent(sessionID, "atlas") + + // when - SDK event tries to overwrite via updateSessionAgent + updateSessionAgent(sessionID, "prometheus") + + // then - pinned agent still wins + expect(getSessionAgent(sessionID)).toBe("atlas") + }) + + test("should take precedence over setSessionAgent", () => { + // given - pin atlas + const sessionID = "test-pin-over-set" + pinSessionAgent(sessionID, "atlas") + + // when - setSessionAgent tries to set + setSessionAgent(sessionID, "prometheus") + + // then - pinned agent still wins + expect(getSessionAgent(sessionID)).toBe("atlas") + }) + + test("should allow re-pinning to a different agent", () => { + // given - pin atlas + const sessionID = "test-repin" + pinSessionAgent(sessionID, "atlas") + + // when - re-pin to hephaestus + pinSessionAgent(sessionID, "hephaestus") + + // then + expect(getSessionAgent(sessionID)).toBe("hephaestus") + }) + }) + + describe("unpinSessionAgent", () => { + test("should allow updateSessionAgent to take effect after unpin", () => { + // given - pin atlas, then unpin + const sessionID = "test-unpin-then-update" + pinSessionAgent(sessionID, "atlas") + unpinSessionAgent(sessionID) + + // when - update via SDK event + updateSessionAgent(sessionID, "prometheus") + + // then - update takes effect since no pin exists + expect(getSessionAgent(sessionID)).toBe("prometheus") + }) + + test("should be a no-op when no pin exists", () => { + // given - only regular agent + const sessionID = "test-unpin-noop" + setSessionAgent(sessionID, "prometheus") + + // when - unpin (nothing pinned) + unpinSessionAgent(sessionID) + + // then - regular agent unchanged + expect(getSessionAgent(sessionID)).toBe("prometheus") + }) + }) + + describe("clearSessionAgent with pinned agents", () => { + test("should clear both pinned and regular agents", () => { + // given - pinned agent set + const sessionID = "test-clear-pinned" + pinSessionAgent(sessionID, "atlas") + + // when + clearSessionAgent(sessionID) + + // then - both cleared + expect(getSessionAgent(sessionID)).toBeUndefined() + }) + }) }) diff --git a/src/features/claude-code-session-state/state.ts b/src/features/claude-code-session-state/state.ts index 60a8f8a84d..8c92030970 100644 --- a/src/features/claude-code-session-state/state.ts +++ b/src/features/claude-code-session-state/state.ts @@ -17,24 +17,50 @@ export function _resetForTesting(): void { subagentSessions.clear() syncSubagentSessions.clear() sessionAgentMap.clear() + pinnedSessionAgentMap.clear() } const sessionAgentMap = new Map() +/** + * Pinned session agents — set by explicit commands like /start-work. + * These take precedence over regular sessionAgentMap to prevent SDK events + * from overwriting deliberately set agents. + */ +const pinnedSessionAgentMap = new Map() + export function setSessionAgent(sessionID: string, agent: string): void { if (!sessionAgentMap.has(sessionID)) { sessionAgentMap.set(sessionID, agent) } } +/** + * Pin an agent for a session — takes precedence over SDK-updated agents. + * Use this for explicit agent switches (e.g., /start-work) that should + * not be overwritten by SDK message.updated events. + */ +export function pinSessionAgent(sessionID: string, agent: string): void { + pinnedSessionAgentMap.set(sessionID, agent) + sessionAgentMap.set(sessionID, agent) +} + +/** + * Unpin an agent for a session — reverts to regular sessionAgentMap behavior. + */ +export function unpinSessionAgent(sessionID: string): void { + pinnedSessionAgentMap.delete(sessionID) +} + export function updateSessionAgent(sessionID: string, agent: string): void { sessionAgentMap.set(sessionID, agent) } export function getSessionAgent(sessionID: string): string | undefined { - return sessionAgentMap.get(sessionID) + return pinnedSessionAgentMap.get(sessionID) ?? sessionAgentMap.get(sessionID) } export function clearSessionAgent(sessionID: string): void { sessionAgentMap.delete(sessionID) + pinnedSessionAgentMap.delete(sessionID) } diff --git a/src/hooks/no-sisyphus-gpt/hook.ts b/src/hooks/no-sisyphus-gpt/hook.ts index 2042c7451c..3ef51b11b1 100644 --- a/src/hooks/no-sisyphus-gpt/hook.ts +++ b/src/hooks/no-sisyphus-gpt/hook.ts @@ -1,6 +1,6 @@ import type { PluginInput } from "@opencode-ai/plugin" import { isGptModel } from "../../agents/types" -import { getSessionAgent, updateSessionAgent } from "../../features/claude-code-session-state" +import { getSessionAgent, pinSessionAgent } from "../../features/claude-code-session-state" import { log } from "../../shared" import { getAgentConfigKey, getAgentDisplayName } from "../../shared/agent-display-names" @@ -47,7 +47,7 @@ export function createNoSisyphusGptHook(ctx: PluginInput) { if (output?.message) { output.message.agent = HEPHAESTUS_DISPLAY } - updateSessionAgent(input.sessionID, HEPHAESTUS_DISPLAY) + pinSessionAgent(input.sessionID, HEPHAESTUS_DISPLAY) } }, } diff --git a/src/hooks/runtime-fallback/agent-resolver.test.ts b/src/hooks/runtime-fallback/agent-resolver.test.ts new file mode 100644 index 0000000000..e270dd005f --- /dev/null +++ b/src/hooks/runtime-fallback/agent-resolver.test.ts @@ -0,0 +1,184 @@ +import { describe, test, expect, beforeEach, afterEach } from "bun:test" +import { + detectAgentFromSession, + normalizeAgentName, + resolveAgentForSession, + AGENT_NAMES, +} from "./agent-resolver" +import { + _resetForTesting, + updateSessionAgent, + pinSessionAgent, +} from "../../features/claude-code-session-state" + +describe("agent-resolver", () => { + beforeEach(() => { + _resetForTesting() + }) + + afterEach(() => { + _resetForTesting() + }) + + describe("detectAgentFromSession", () => { + test("should detect agent name embedded in sessionID", () => { + // given + const sessionID = "ses_abc123_atlas_work" + + // when + const result = detectAgentFromSession(sessionID) + + // then + expect(result).toBe("atlas") + }) + + test("should detect sisyphus-junior (hyphenated agent)", () => { + // given + const sessionID = "ses_sisyphus-junior_task" + + // when + const result = detectAgentFromSession(sessionID) + + // then + expect(result).toBe("sisyphus-junior") + }) + + test("should prefer longer match (sisyphus-junior over sisyphus)", () => { + // given + const sessionID = "ses_sisyphus-junior_123" + + // when + const result = detectAgentFromSession(sessionID) + + // then - should match sisyphus-junior, not just sisyphus + expect(result).toBe("sisyphus-junior") + }) + + test("should return undefined when no agent found", () => { + // given + const sessionID = "ses_abc123_random_session" + + // when + const result = detectAgentFromSession(sessionID) + + // then + expect(result).toBeUndefined() + }) + + test("should be case-insensitive", () => { + // given + const sessionID = "ses_ORACLE_query" + + // when + const result = detectAgentFromSession(sessionID) + + // then + expect(result).toBe("oracle") + }) + }) + + describe("normalizeAgentName", () => { + test("should return exact match for known agent", () => { + expect(normalizeAgentName("sisyphus")).toBe("sisyphus") + expect(normalizeAgentName("oracle")).toBe("oracle") + expect(normalizeAgentName("atlas")).toBe("atlas") + }) + + test("should normalize case", () => { + expect(normalizeAgentName("SISYPHUS")).toBe("sisyphus") + expect(normalizeAgentName("Oracle")).toBe("oracle") + }) + + test("should trim whitespace", () => { + expect(normalizeAgentName(" atlas ")).toBe("atlas") + }) + + test("should extract agent from display name", () => { + // given - display names like "Atlas (Work Orchestrator)" + expect(normalizeAgentName("Atlas (Work Orchestrator)")).toBe("atlas") + expect(normalizeAgentName("Prometheus (Planner)")).toBe("prometheus") + expect(normalizeAgentName("Hephaestus (Craftsman)")).toBe("hephaestus") + }) + + test("should return undefined for unknown agent", () => { + expect(normalizeAgentName("unknown-agent")).toBeUndefined() + expect(normalizeAgentName("random")).toBeUndefined() + }) + + test("should return undefined for empty/undefined input", () => { + expect(normalizeAgentName(undefined)).toBeUndefined() + expect(normalizeAgentName("")).toBeUndefined() + }) + }) + + describe("resolveAgentForSession", () => { + test("should prioritize session agent over event agent", () => { + // given - session agent set to atlas, event says prometheus + const sessionID = "test-resolve-priority" + updateSessionAgent(sessionID, "atlas") + + // when + const result = resolveAgentForSession(sessionID, "prometheus") + + // then - session agent wins + expect(result).toBe("atlas") + }) + + test("should prioritize pinned session agent over event agent", () => { + // given - pinned agent set to atlas + const sessionID = "test-resolve-pinned" + pinSessionAgent(sessionID, "atlas") + + // when - event says prometheus, updateSessionAgent also says prometheus + updateSessionAgent(sessionID, "prometheus") + const result = resolveAgentForSession(sessionID, "prometheus") + + // then - pinned agent wins + expect(result).toBe("atlas") + }) + + test("should fall back to event agent when no session agent exists", () => { + // given - no session agent + const sessionID = "test-resolve-event" + + // when + const result = resolveAgentForSession(sessionID, "oracle") + + // then + expect(result).toBe("oracle") + }) + + test("should fall back to sessionID pattern when no agents set", () => { + // given - no session or event agent, but sessionID contains agent name + const sessionID = "ses_librarian_search" + + // when + const result = resolveAgentForSession(sessionID) + + // then + expect(result).toBe("librarian") + }) + + test("should return undefined when nothing resolves", () => { + // given - no agents anywhere + const sessionID = "test-no-agent" + + // when + const result = resolveAgentForSession(sessionID) + + // then + expect(result).toBeUndefined() + }) + + test("should normalize display name from event agent", () => { + // given + const sessionID = "test-normalize-event" + + // when - event agent is a display name + const result = resolveAgentForSession(sessionID, "Atlas (Work Orchestrator)") + + // then - normalized to config key + expect(result).toBe("atlas") + }) + }) +}) diff --git a/src/hooks/runtime-fallback/agent-resolver.ts b/src/hooks/runtime-fallback/agent-resolver.ts index 1310a95bbc..f2956b8e8a 100644 --- a/src/hooks/runtime-fallback/agent-resolver.ts +++ b/src/hooks/runtime-fallback/agent-resolver.ts @@ -17,10 +17,10 @@ export const AGENT_NAMES = [ ] export const agentPattern = new RegExp( - `\\b(${AGENT_NAMES + `(?:^|[\\s_\\-/])(${AGENT_NAMES .sort((a, b) => b.length - a.length) .map((a) => a.replace(/-/g, "\\-")) - .join("|")})\\b`, + .join("|")})(?:$|[\\s_\\-/])`, "i", ) @@ -46,9 +46,12 @@ export function normalizeAgentName(agent: string | undefined): string | undefine } export function resolveAgentForSession(sessionID: string, eventAgent?: string): string | undefined { + // Session agent (set by updateSessionAgent, e.g. /start-work switching to atlas) + // takes priority over event agent (which can be stale — e.g. SDK still reports + // "prometheus" after /start-work already switched the session to "atlas"). return ( - normalizeAgentName(eventAgent) ?? normalizeAgentName(getSessionAgent(sessionID)) ?? + normalizeAgentName(eventAgent) ?? detectAgentFromSession(sessionID) ) } diff --git a/src/hooks/runtime-fallback/auto-retry.ts b/src/hooks/runtime-fallback/auto-retry.ts index dda3a3b6e6..80422b482a 100644 --- a/src/hooks/runtime-fallback/auto-retry.ts +++ b/src/hooks/runtime-fallback/auto-retry.ts @@ -6,6 +6,7 @@ import { getSessionAgent } from "../../features/claude-code-session-state" import { getFallbackModelsForSession } from "./fallback-models" import { prepareFallback } from "./fallback-state" import { SessionCategoryRegistry } from "../../shared/session-category-registry" +import { getAgentDisplayName } from "../../shared/agent-display-names" const SESSION_TTL_MS = 30 * 60 * 1000 @@ -102,6 +103,12 @@ export function createAutoRetryHelpers(deps: HookDeps) { modelID: modelParts.slice(1).join("/"), } + // Abort any in-flight request before sending the fallback prompt. + // The SDK cannot process two concurrent prompts on a single session — + // without this abort, the new promptAsync is silently dropped while the + // previous model's request still occupies the session. + await abortSessionRequest(sessionID, `pre-fallback.${source}`) + sessionRetryInFlight.add(sessionID) let retryDispatched = false try { @@ -135,10 +142,15 @@ export function createAutoRetryHelpers(deps: HookDeps) { sessionAwaitingFallbackResult.add(sessionID) scheduleSessionFallbackTimeout(sessionID, retryAgent) + const agentDisplayName = retryAgent ? getAgentDisplayName(retryAgent) : undefined + log(`[${HOOK_NAME}] Sending fallback prompt (${source})`, { + sessionID, agent: agentDisplayName, model: fallbackModelObj, + partsCount: retryParts.length, firstPart: retryParts[0]?.text?.slice(0, 80), + }) await ctx.client.session.promptAsync({ path: { id: sessionID }, body: { - ...(retryAgent ? { agent: retryAgent } : {}), + ...(agentDisplayName ? { agent: agentDisplayName } : {}), model: fallbackModelObj, parts: retryParts, }, @@ -151,6 +163,8 @@ export function createAutoRetryHelpers(deps: HookDeps) { } } catch (retryError) { log(`[${HOOK_NAME}] Auto-retry failed (${source})`, { sessionID, error: String(retryError) }) + sessionAwaitingFallbackResult.delete(sessionID) + clearSessionFallbackTimeout(sessionID) } finally { sessionRetryInFlight.delete(sessionID) if (!retryDispatched) { @@ -188,7 +202,23 @@ export function createAutoRetryHelpers(deps: HookDeps) { } } } catch { - return undefined + // messages query failed, continue to session.get fallback + } + + // Fallback: query SDK session.get for agent field + // Handles case where model fails before chat.message fires (e.g., model not found) + try { + const sessionInfo = await ctx.client.session.get({ path: { id: sessionID } }) + const sessionData = (sessionInfo?.data ?? sessionInfo) as Record + const sdkAgent = typeof sessionData?.agent === "string" ? sessionData.agent : undefined + const normalized = normalizeAgentName(sdkAgent) + if (normalized) { + log(`[${HOOK_NAME}] Resolved agent from session.get`, { sessionID, agent: normalized }) + return normalized + } + + } catch { + // session.get failed, no agent available } return undefined diff --git a/src/hooks/runtime-fallback/chat-message-handler.ts b/src/hooks/runtime-fallback/chat-message-handler.ts index 9d400f7d2b..34d0255445 100644 --- a/src/hooks/runtime-fallback/chat-message-handler.ts +++ b/src/hooks/runtime-fallback/chat-message-handler.ts @@ -1,10 +1,11 @@ import type { HookDeps } from "./types" +import type { AutoRetryHelpers } from "./auto-retry" import { HOOK_NAME } from "./constants" import { log } from "../../shared/logger" import { createFallbackState } from "./fallback-state" -export function createChatMessageHandler(deps: HookDeps) { - const { config, sessionStates, sessionLastAccess } = deps +export function createChatMessageHandler(deps: HookDeps, helpers: AutoRetryHelpers) { + const { config, sessionStates, sessionLastAccess, sessionRetryInFlight, sessionAwaitingFallbackResult } = deps return async ( input: { sessionID: string; agent?: string; model?: { providerID: string; modelID: string } }, @@ -34,6 +35,15 @@ export function createChatMessageHandler(deps: HookDeps) { from: state.currentModel, to: requestedModel, }) + + helpers.clearSessionFallbackTimeout(sessionID) + sessionAwaitingFallbackResult.delete(sessionID) + + if (sessionRetryInFlight.has(sessionID)) { + await helpers.abortSessionRequest(sessionID, "manual-model-change") + sessionRetryInFlight.delete(sessionID) + } + state = createFallbackState(requestedModel) sessionStates.set(sessionID, state) return diff --git a/src/hooks/runtime-fallback/constants.ts b/src/hooks/runtime-fallback/constants.ts index 60da6fb533..0add6e1d29 100644 --- a/src/hooks/runtime-fallback/constants.ts +++ b/src/hooks/runtime-fallback/constants.ts @@ -26,6 +26,8 @@ export const RETRYABLE_ERROR_PATTERNS = [ /rate.?limit/i, /too.?many.?requests/i, /quota.?exceeded/i, + /quota.?protection/i, + /key.?limit.?exceeded/i, /usage\s+limit\s+has\s+been\s+reached/i, /service.?unavailable/i, /overloaded/i, diff --git a/src/hooks/runtime-fallback/error-classifier.ts b/src/hooks/runtime-fallback/error-classifier.ts index f35819b76c..75b5d4daf8 100644 --- a/src/hooks/runtime-fallback/error-classifier.ts +++ b/src/hooks/runtime-fallback/error-classifier.ts @@ -148,6 +148,43 @@ export function containsErrorContent( return { hasError: false } } +export function detectErrorInTextParts( + parts: Array<{ type?: string; text?: string }> | undefined, +): { hasError: boolean; errorType?: string; errorMessage?: string } { + if (!parts || parts.length === 0) return { hasError: false } + + const textContent = parts + .filter((p) => p.type === "text" && typeof p.text === "string" && p.text.length > 0) + .map((p) => p.text!) + .join("\n") + + if (!textContent) return { hasError: false } + + const errorType = classifyErrorType({ message: textContent, name: "TextContent" }) + if (errorType) { + return { hasError: true, errorType, errorMessage: textContent } + } + + return { hasError: false } +} + +export function extractErrorContentFromParts( + parts: Array<{ type?: string; text?: string }> | undefined, +): { hasError: boolean; errorMessage?: string } { + if (!parts || parts.length === 0) return { hasError: false } + + const errorParts = parts.filter( + (p) => p.type === "error" && typeof p.text === "string" && p.text.length > 0, + ) + + if (errorParts.length > 0) { + const errorMessage = errorParts.map((p) => p.text).join("\n") + return { hasError: true, errorMessage } + } + + return { hasError: false } +} + export function isRetryableError(error: unknown, retryOnErrors: number[]): boolean { const statusCode = extractStatusCode(error, retryOnErrors) const message = getErrorMessage(error) diff --git a/src/hooks/runtime-fallback/event-handler.ts b/src/hooks/runtime-fallback/event-handler.ts index f73e6557f8..acaac3788e 100644 --- a/src/hooks/runtime-fallback/event-handler.ts +++ b/src/hooks/runtime-fallback/event-handler.ts @@ -11,15 +11,13 @@ export function createEventHandler(deps: HookDeps, helpers: AutoRetryHelpers) { const { config, pluginConfig, sessionStates, sessionLastAccess, sessionRetryInFlight, sessionAwaitingFallbackResult, sessionFallbackTimeouts } = deps const handleSessionCreated = (props: Record | undefined) => { - const sessionInfo = props?.info as { id?: string; model?: string } | undefined + const sessionInfo = props?.info as { id?: string } | undefined const sessionID = sessionInfo?.id - const model = sessionInfo?.model + if (!sessionID) return - if (sessionID && model) { - log(`[${HOOK_NAME}] Session created with model`, { sessionID, model }) - sessionStates.set(sessionID, createFallbackState(model)) - sessionLastAccess.set(sessionID, Date.now()) - } + // SDK Session type has no model/agent fields — state is created on-demand + // by handleSessionError or handleSessionStatus when the actual model is known + log(`[${HOOK_NAME}] Session created, state will be created on-demand`, { sessionID }) } const handleSessionDeleted = (props: Record | undefined) => { @@ -57,7 +55,6 @@ export function createEventHandler(deps: HookDeps, helpers: AutoRetryHelpers) { log(`[${HOOK_NAME}] Cleared fallback retry state on session.stop`, { sessionID }) } - const handleSessionIdle = (props: Record | undefined) => { const sessionID = props?.sessionID as string | undefined if (!sessionID) return @@ -81,6 +78,72 @@ export function createEventHandler(deps: HookDeps, helpers: AutoRetryHelpers) { } } + const handleSessionStatus = async (props: Record | undefined) => { + const sessionID = props?.sessionID as string | undefined + const status = props?.status as { type?: string; attempt?: number; message?: string; next?: number } | undefined + if (!sessionID || !status || status.type !== "retry") return + + const resolvedAgent = await helpers.resolveAgentForSessionFromContext(sessionID, undefined) + const fallbackModels = getFallbackModelsForSession(sessionID, resolvedAgent, pluginConfig) + + log(`[${HOOK_NAME}] Provider retry detected`, { + sessionID, attempt: status.attempt, message: status.message, + nextRetryMs: status.next, resolvedAgent, totalFallbackModels: fallbackModels.length, + }) + + if (fallbackModels.length === 0) { + if (config.notify_on_fallback) { + await deps.ctx.client.tui.showToast({ body: { + title: "Provider Retrying", variant: "info", duration: 3000, + message: `${status.message || "retrying..."} (no fallback models configured)`, + } }).catch(() => {}) + } + return + } + + let state = sessionStates.get(sessionID) + if (!state) { + const agentConfig = resolvedAgent + ? pluginConfig?.agents?.[resolvedAgent as keyof typeof pluginConfig.agents] : undefined + const initialModel = (agentConfig?.model as string | undefined) + ?? (pluginConfig?.agents?.sisyphus?.model as string | undefined) + if (!initialModel) { + log(`[${HOOK_NAME}] No model info for session.status fallback`, { sessionID }) + return + } + log(`[${HOOK_NAME}] Creating on-demand state for session.status`, { sessionID, model: initialModel, agent: resolvedAgent }) + state = createFallbackState(initialModel) + sessionStates.set(sessionID, state) + sessionLastAccess.set(sessionID, Date.now()) + } else { + sessionLastAccess.set(sessionID, Date.now()) + } + + sessionAwaitingFallbackResult.delete(sessionID) + helpers.clearSessionFallbackTimeout(sessionID) + + const result = prepareFallback(sessionID, state, fallbackModels, config) + + if (result.success && config.notify_on_fallback) { + const modelName = result.newModel?.split("/").pop() || result.newModel + await deps.ctx.client.tui.showToast({ body: { + title: "Retry Detected — Switching Model", variant: "warning", duration: 5000, + message: `${status.message || "Provider retrying"} → ${modelName} (attempt ${state.attemptCount} of ${fallbackModels.length})`, + } }).catch(() => {}) + } + + if (result.success && result.newModel) { + await helpers.autoRetryWithFallback(sessionID, result.newModel, resolvedAgent, "session.status") + } else if (!result.success) { + log(`[${HOOK_NAME}] session.status fallback failed`, { sessionID, error: result.error }) + if (result.maxAttemptsReached && config.notify_on_fallback) { + await deps.ctx.client.tui.showToast({ body: { + title: "All Fallbacks Exhausted", variant: "error", duration: 8000, + message: `All ${fallbackModels.length} fallback models exhausted after ${state.attemptCount} attempts`, + } }).catch(() => {}) + } + } + } const handleSessionError = async (props: Record | undefined) => { const sessionID = props?.sessionID as string | undefined const error = props?.error @@ -150,8 +213,16 @@ export function createEventHandler(deps: HookDeps, helpers: AutoRetryHelpers) { sessionStates.set(sessionID, state) sessionLastAccess.set(sessionID, Date.now()) } else { - log(`[${HOOK_NAME}] No model info available, cannot fallback`, { sessionID }) - return + const sisyphusModel = pluginConfig?.agents?.sisyphus?.model as string | undefined + if (sisyphusModel) { + log(`[${HOOK_NAME}] Using sisyphus model for state creation (no agent detected)`, { sessionID, model: sisyphusModel }) + state = createFallbackState(sisyphusModel) + sessionStates.set(sessionID, state) + sessionLastAccess.set(sessionID, Date.now()) + } else { + log(`[${HOOK_NAME}] No model info available, cannot fallback`, { sessionID }) + return + } } } } else { @@ -161,11 +232,13 @@ export function createEventHandler(deps: HookDeps, helpers: AutoRetryHelpers) { const result = prepareFallback(sessionID, state, fallbackModels, config) if (result.success && config.notify_on_fallback) { + const modelName = result.newModel?.split("/").pop() || result.newModel + const attemptInfo = `attempt ${state.attemptCount} of ${fallbackModels.length}` await deps.ctx.client.tui .showToast({ body: { title: "Model Fallback", - message: `Switching to ${result.newModel?.split("/").pop() || result.newModel} for next request`, + message: `Switching to ${modelName} (${attemptInfo})`, variant: "warning", duration: 5000, }, @@ -192,5 +265,6 @@ export function createEventHandler(deps: HookDeps, helpers: AutoRetryHelpers) { if (event.type === "session.stop") { await handleSessionStop(props); return } if (event.type === "session.idle") { handleSessionIdle(props); return } if (event.type === "session.error") { await handleSessionError(props); return } + if (event.type === "session.status") { await handleSessionStatus(props); return } } } diff --git a/src/hooks/runtime-fallback/hook.ts b/src/hooks/runtime-fallback/hook.ts index b378879909..69a048b33c 100644 --- a/src/hooks/runtime-fallback/hook.ts +++ b/src/hooks/runtime-fallback/hook.ts @@ -45,7 +45,7 @@ export function createRuntimeFallbackHook( const helpers = createAutoRetryHelpers(deps) const baseEventHandler = createEventHandler(deps, helpers) const messageUpdateHandler = createMessageUpdateHandler(deps, helpers) - const chatMessageHandler = createChatMessageHandler(deps) + const chatMessageHandler = createChatMessageHandler(deps, helpers) const cleanupInterval = setInterval(helpers.cleanupStaleSessions, 5 * 60 * 1000) cleanupInterval.unref() diff --git a/src/hooks/runtime-fallback/index.test.ts b/src/hooks/runtime-fallback/index.test.ts index 7660f19547..170f79afdc 100644 --- a/src/hooks/runtime-fallback/index.test.ts +++ b/src/hooks/runtime-fallback/index.test.ts @@ -28,6 +28,7 @@ describe("runtime-fallback", () => { messages?: (args: unknown) => Promise promptAsync?: (args: unknown) => Promise abort?: (args: unknown) => Promise + get?: (args: { path: { id: string } }) => Promise<{ data?: Record }> } }) { return { @@ -45,6 +46,7 @@ describe("runtime-fallback", () => { messages: overrides?.session?.messages ?? (async () => ({ data: [] })), promptAsync: overrides?.session?.promptAsync ?? (async () => ({})), abort: overrides?.session?.abort ?? (async () => ({})), + get: overrides?.session?.get ?? (async () => ({ data: {} })), }, }, directory: "/test/dir", @@ -62,8 +64,11 @@ describe("runtime-fallback", () => { } } - function createMockPluginConfigWithCategoryFallback(fallbackModels: string[]): OhMyOpenCodeConfig { + function createMockPluginConfigWithCategoryFallback(fallbackModels: string[], initialModel = "google/gemini-2.5-pro"): OhMyOpenCodeConfig { return { + agents: { + sisyphus: { model: initialModel }, + }, categories: { test: { fallback_models: fallbackModels, @@ -472,21 +477,20 @@ describe("runtime-fallback", () => { }) describe("session lifecycle", () => { - test("should create state on session.created", async () => { + test("should log on-demand state creation on session.created", async () => { const hook = createRuntimeFallbackHook(createMockPluginInput(), { config: createMockConfig() }) const sessionID = "test-session-create" - const model = "anthropic/claude-opus-4-5" await hook.event({ event: { type: "session.created", - properties: { info: { id: sessionID, model } }, + properties: { info: { id: sessionID } }, }, }) - const createLog = logCalls.find((c) => c.msg.includes("Session created with model")) + const createLog = logCalls.find((c) => c.msg.includes("state will be created on-demand")) expect(createLog).toBeDefined() - expect(createLog?.data).toMatchObject({ sessionID, model }) + expect(createLog?.data).toMatchObject({ sessionID }) }) test("should cleanup state on session.deleted", async () => { @@ -1902,10 +1906,11 @@ describe("runtime-fallback", () => { }) describe("fallback models configuration", () => { - function createMockPluginConfigWithAgentFallback(agentName: string, fallbackModels: string[]): OhMyOpenCodeConfig { + function createMockPluginConfigWithAgentFallback(agentName: string, fallbackModels: string[], initialModel = "anthropic/claude-opus-4-5"): OhMyOpenCodeConfig { return { agents: { [agentName]: { + model: initialModel, fallback_models: fallbackModels, }, }, @@ -2009,7 +2014,7 @@ describe("runtime-fallback", () => { expect(promptCalls.length).toBe(1) const callBody = promptCalls[0]?.body as Record - expect(callBody?.agent).toBe("prometheus") + expect(callBody?.agent).toBe("Prometheus (Plan Builder)") expect(callBody?.model).toEqual({ providerID: "github-copilot", modelID: "claude-opus-4.6" }) }) }) @@ -2021,7 +2026,7 @@ describe("runtime-fallback", () => { pluginConfig: createMockPluginConfigWithCategoryFallback([ "openai/gpt-5.2", "anthropic/claude-opus-4-5", - ]), + ], "anthropic/claude-opus-4-5"), }) const sessionID = "test-session-cooldown" SessionCategoryRegistry.register(sessionID, "test") @@ -2029,7 +2034,7 @@ describe("runtime-fallback", () => { await hook.event({ event: { type: "session.created", - properties: { info: { id: sessionID, model: "anthropic/claude-opus-4-5" } }, + properties: { info: { id: sessionID } }, }, }) @@ -2083,6 +2088,725 @@ describe("runtime-fallback", () => { expect(maxLog).toBeDefined() }) }) + describe("manual model change cleanup", () => { + test("should clear fallback timeout and abort in-flight request on manual model change", async () => { + const abortCalls: Array<{ path?: { id?: string } }> = [] + const retriedModels: string[] = [] + const pending = new Promise(() => {}) + + const hook = createRuntimeFallbackHook( + createMockPluginInput({ + session: { + messages: async () => ({ + data: [{ info: { role: "user" }, parts: [{ type: "text", text: "hello" }] }], + }), + promptAsync: async (args: unknown) => { + const model = (args as { body?: { model?: { providerID?: string; modelID?: string } } })?.body?.model + if (model?.providerID && model?.modelID) { + retriedModels.push(`${model.providerID}/${model.modelID}`) + } + if (retriedModels.length === 1) { + await pending + } + return {} + }, + abort: async (args: unknown) => { + abortCalls.push(args as { path?: { id?: string } }) + return {} + }, + }, + }), + { + config: createMockConfig({ notify_on_fallback: false, timeout_seconds: 30 }), + pluginConfig: createMockPluginConfigWithCategoryFallback([ + "github-copilot/claude-opus-4.6", + "anthropic/claude-opus-4-6", + ]), + session_timeout_ms: 500, + } + ) + + const sessionID = "test-manual-switch-cleanup" + SessionCategoryRegistry.register(sessionID, "test") + + await hook.event({ + event: { + type: "session.created", + properties: { info: { id: sessionID, model: "google/gemini-2.5-pro" } }, + }, + }) + + const sessionErrorPromise = hook.event({ + event: { + type: "session.error", + properties: { + sessionID, + error: { + name: "ProviderAuthError", + data: { + providerID: "google", + message: "Google Generative AI API key is missing. Pass it using the 'apiKey' parameter or the GOOGLE_GENERATIVE_AI_API_KEY environment variable.", + }, + }, + }, + }, + }) + + await new Promise((resolve) => setTimeout(resolve, 0)) + + const output: { message: { model?: { providerID: string; modelID: string } }; parts: Array<{ type: string; text?: string }> } = { + message: {}, + parts: [], + } + + await hook["chat.message"]?.( + { + sessionID, + model: { providerID: "openai", modelID: "gpt-5.2" }, + }, + output + ) + + const manualChangeLog = logCalls.find((c) => c.msg.includes("Detected manual model change")) + expect(manualChangeLog).toBeDefined() + + const abortLog = logCalls.find((c) => c.msg.includes("Aborted in-flight session request (manual-model-change)")) + expect(abortLog).toBeDefined() + + await new Promise((resolve) => setTimeout(resolve, 600)) + + const timeoutAbort = abortCalls.filter((c) => c.path?.id === sessionID) + const timeoutLog = logCalls.find((c) => c.msg.includes("Session fallback timeout reached")) + expect(timeoutLog).toBeUndefined() + + void sessionErrorPromise + }) + + test("should not kill new prompt after manual model switch", async () => { + const abortCalls: Array<{ path?: { id?: string } }> = [] + const retriedModels: string[] = [] + + const hook = createRuntimeFallbackHook( + createMockPluginInput({ + session: { + messages: async () => ({ + data: [{ info: { role: "user" }, parts: [{ type: "text", text: "hello" }] }], + }), + promptAsync: async (args: unknown) => { + const model = (args as { body?: { model?: { providerID?: string; modelID?: string } } })?.body?.model + if (model?.providerID && model?.modelID) { + retriedModels.push(`${model.providerID}/${model.modelID}`) + } + return {} + }, + abort: async (args: unknown) => { + abortCalls.push(args as { path?: { id?: string } }) + return {} + }, + }, + }), + { + config: createMockConfig({ notify_on_fallback: false, timeout_seconds: 30 }), + pluginConfig: createMockPluginConfigWithCategoryFallback([ + "github-copilot/claude-opus-4.6", + ]), + session_timeout_ms: 20, + } + ) + + const sessionID = "test-no-kill-after-manual-switch" + SessionCategoryRegistry.register(sessionID, "test") + + await hook.event({ + event: { + type: "session.created", + properties: { info: { id: sessionID, model: "google/gemini-2.5-pro" } }, + }, + }) + + await hook.event({ + event: { + type: "session.error", + properties: { + sessionID, + error: { + name: "ProviderAuthError", + data: { + providerID: "google", + message: "Google Generative AI API key is missing. Pass it using the 'apiKey' parameter or the GOOGLE_GENERATIVE_AI_API_KEY environment variable.", + }, + }, + }, + }, + }) + + const output: { message: { model?: { providerID: string; modelID: string } }; parts: Array<{ type: string; text?: string }> } = { + message: {}, + parts: [], + } + + await hook["chat.message"]?.( + { + sessionID, + model: { providerID: "openai", modelID: "gpt-5.2" }, + }, + output + ) + + const abortCountBeforeTimeout = abortCalls.length + + await new Promise((resolve) => setTimeout(resolve, 50)) + + const postSwitchAborts = abortCalls.slice(abortCountBeforeTimeout) + const timeoutAborts = postSwitchAborts.filter((c) => c.path?.id === sessionID) + expect(timeoutAborts).toHaveLength(0) + }) + }) + + describe("auto-retry failure cleanup", () => { + test("should clear awaiting state when promptAsync throws", async () => { + const hook = createRuntimeFallbackHook( + createMockPluginInput({ + session: { + messages: async () => ({ + data: [{ info: { role: "user" }, parts: [{ type: "text", text: "hello" }] }], + }), + promptAsync: async () => { + throw new Error("Network failure") + }, + }, + }), + { + config: createMockConfig({ notify_on_fallback: false, timeout_seconds: 30 }), + pluginConfig: createMockPluginConfigWithCategoryFallback([ + "github-copilot/claude-opus-4.6", + "openai/gpt-5.2", + ]), + session_timeout_ms: 20, + } + ) + + const sessionID = "test-retry-failure-cleanup" + SessionCategoryRegistry.register(sessionID, "test") + + await hook.event({ + event: { + type: "session.created", + properties: { info: { id: sessionID, model: "google/gemini-2.5-pro" } }, + }, + }) + + await hook.event({ + event: { + type: "session.error", + properties: { + sessionID, + error: { + name: "ProviderAuthError", + data: { + providerID: "google", + message: "Google Generative AI API key is missing. Pass it using the 'apiKey' parameter or the GOOGLE_GENERATIVE_AI_API_KEY environment variable.", + }, + }, + }, + }, + }) + + const retryFailedLog = logCalls.find((c) => c.msg.includes("Auto-retry failed")) + expect(retryFailedLog).toBeDefined() + + await hook.event({ + event: { + type: "session.idle", + properties: { sessionID }, + }, + }) + + const stuckLog = logCalls.find((c) => c.msg.includes("session.idle while awaiting fallback result")) + expect(stuckLog).toBeUndefined() + + await new Promise((resolve) => setTimeout(resolve, 50)) + + const timeoutLog = logCalls.find((c) => c.msg.includes("Session fallback timeout reached")) + expect(timeoutLog).toBeUndefined() + }) + }) + + describe("session.status retry handling", () => { + test("should log provider retry events", async () => { + const hook = createRuntimeFallbackHook( + createMockPluginInput(), + { + config: createMockConfig({ notify_on_fallback: false }), + pluginConfig: createMockPluginConfigWithCategoryFallback([ + "anthropic/claude-opus-4-6", + ]), + } + ) + + const sessionID = "test-retry-status" + SessionCategoryRegistry.register(sessionID, "test") + + await hook.event({ + event: { + type: "session.created", + properties: { info: { id: sessionID, model: "google/gemini-2.5-pro" } }, + }, + }) + + await hook.event({ + event: { + type: "session.status", + properties: { + sessionID, + status: { type: "retry", attempt: 1, message: "Rate limited, retrying", next: 5000 }, + }, + }, + }) + + const retryLog = logCalls.find((c) => c.msg.includes("Provider retry detected")) + expect(retryLog).toBeDefined() + expect((retryLog?.data as Record)?.attempt).toBe(1) + }) + + test("should show toast on retry when notify_on_fallback is enabled", async () => { + const hook = createRuntimeFallbackHook( + createMockPluginInput(), + { + config: createMockConfig({ notify_on_fallback: true }), + pluginConfig: createMockPluginConfigWithCategoryFallback([ + "anthropic/claude-opus-4-6", + ]), + } + ) + + const sessionID = "test-retry-toast" + SessionCategoryRegistry.register(sessionID, "test") + + await hook.event({ + event: { + type: "session.created", + properties: { info: { id: sessionID, model: "google/gemini-2.5-pro" } }, + }, + }) + + await hook.event({ + event: { + type: "session.status", + properties: { + sessionID, + status: { type: "retry", attempt: 2, message: "Server overloaded", next: 3000 }, + }, + }, + }) + + const retryToast = toastCalls.find((t) => t.title === "Retry Detected — Switching Model") + expect(retryToast).toBeDefined() + expect(retryToast?.message).toContain("Server overloaded") + expect(retryToast?.variant).toBe("warning") + }) + + test("should ignore non-retry session.status events", async () => { + const hook = createRuntimeFallbackHook( + createMockPluginInput(), + { config: createMockConfig() } + ) + + await hook.event({ + event: { + type: "session.status", + properties: { + sessionID: "test-busy", + status: { type: "busy" }, + }, + }, + }) + + const retryLog = logCalls.find((c) => c.msg.includes("Provider retry detected")) + expect(retryLog).toBeUndefined() + }) + }) + + describe("on-demand state creation in session.created", () => { + test("should always log on-demand state creation on session.created", async () => { + const hook = createRuntimeFallbackHook( + createMockPluginInput(), + { + config: createMockConfig(), + pluginConfig: createMockPluginConfigWithCategoryFallback([ + "anthropic/claude-opus-4-6", + ]), + } + ) + + const sessionID = "test-created-on-demand" + await hook.event({ + event: { + type: "session.created", + properties: { info: { id: sessionID } }, + }, + }) + + const onDemandLog = logCalls.find((c) => c.msg.includes("state will be created on-demand")) + expect(onDemandLog).toBeDefined() + }) + }) + + describe("enhanced toast notifications", () => { + test("should show attempt count in fallback toast", async () => { + const hook = createRuntimeFallbackHook( + createMockPluginInput(), + { + config: createMockConfig({ notify_on_fallback: true }), + pluginConfig: createMockPluginConfigWithCategoryFallback([ + "anthropic/claude-opus-4-6", + "openai/gpt-5.2", + ]), + } + ) + + const sessionID = "test-attempt-toast" + SessionCategoryRegistry.register(sessionID, "test") + + await hook.event({ + event: { + type: "session.created", + properties: { info: { id: sessionID, model: "google/gemini-2.5-pro" } }, + }, + }) + + await hook.event({ + event: { + type: "session.error", + properties: { sessionID, error: { statusCode: 429 } }, + }, + }) + + const fallbackToast = toastCalls.find((t) => t.title === "Model Fallback") + expect(fallbackToast).toBeDefined() + expect(fallbackToast?.message).toContain("attempt 1 of 2") + }) + }) + + describe("agent resolution priority (bug #2 fix)", () => { + function createMockPluginConfigWithAgentFallback2(agentName: string, fallbackModels: string[], initialModel = "anthropic/claude-opus-4-6"): OhMyOpenCodeConfig { + return { + agents: { + [agentName]: { + model: initialModel, + fallback_models: fallbackModels, + }, + }, + } + } + + test("should prefer session agent over stale SDK event agent (/start-work scenario)", async () => { + // given - atlas has its own fallback models, prometheus has different ones + const { updateSessionAgent, clearSessionAgent } = await import("../../features/claude-code-session-state") + const promptCalls: Array> = [] + const sessionID = "test-start-work-agent-resolution" + + const hook = createRuntimeFallbackHook( + createMockPluginInput({ + session: { + messages: async () => ({ + data: [ + { + info: { role: "user" }, + parts: [{ type: "text", text: "start work" }], + }, + ], + }), + promptAsync: async (args: unknown) => { + promptCalls.push(args as Record) + return {} + }, + }, + }), + { + config: createMockConfig({ notify_on_fallback: false }), + pluginConfig: { + agents: { + atlas: { + model: "anthropic/claude-opus-4-6", + fallback_models: ["openai/gpt-5.2"], + }, + prometheus: { + model: "anthropic/claude-opus-4-6", + fallback_models: ["google/gemini-3-pro"], + }, + }, + }, + }, + ) + + // when - /start-work sets session agent to atlas, but SDK event still says prometheus + updateSessionAgent(sessionID, "atlas") + + await hook.event({ + event: { + type: "session.error", + properties: { + sessionID, + model: "anthropic/claude-opus-4-6", + error: { statusCode: 429, message: "Rate limit exceeded" }, + agent: "prometheus", // stale SDK event agent + }, + }, + }) + + // then - should use atlas's fallback (openai/gpt-5.2), not prometheus's (google/gemini-3-pro) + expect(promptCalls.length).toBe(1) + const callBody = promptCalls[0]?.body as Record + expect(callBody?.model).toEqual({ providerID: "openai", modelID: "gpt-5.2" }) + + const fallbackLog = logCalls.find((c) => c.msg.includes("Preparing fallback")) + expect(fallbackLog).toBeDefined() + expect(fallbackLog?.data).toMatchObject({ to: "openai/gpt-5.2" }) + + // cleanup + clearSessionAgent(sessionID) + }) + + test("should fall back to event agent when no session agent is set", async () => { + // given - no session agent set, event agent is the only signal + const promptCalls: Array> = [] + const sessionID = "test-event-agent-fallback" + + const hook = createRuntimeFallbackHook( + createMockPluginInput({ + session: { + messages: async () => ({ + data: [ + { + info: { role: "user" }, + parts: [{ type: "text", text: "hello" }], + }, + ], + }), + promptAsync: async (args: unknown) => { + promptCalls.push(args as Record) + return {} + }, + }, + }), + { + config: createMockConfig({ notify_on_fallback: false }), + pluginConfig: { + agents: { + oracle: { + model: "openai/gpt-5.2", + fallback_models: ["anthropic/claude-opus-4-6"], + }, + }, + }, + }, + ) + + // when - event agent is oracle and no session agent override exists + await hook.event({ + event: { + type: "session.error", + properties: { + sessionID, + model: "openai/gpt-5.2", + error: { statusCode: 503, message: "Service unavailable" }, + agent: "oracle", + }, + }, + }) + + // then - should use oracle's fallback since event agent is the only signal + expect(promptCalls.length).toBe(1) + const callBody = promptCalls[0]?.body as Record + expect(callBody?.model).toEqual({ providerID: "anthropic", modelID: "claude-opus-4-6" }) + }) + }) + + describe("abort before promptAsync (Bug #1)", () => { + test("should call abort before promptAsync on session.error fallback", async () => { + // given - track call order to verify abort fires before promptAsync + const callOrder: string[] = [] + const sessionID = "test-abort-before-prompt" + + const hook = createRuntimeFallbackHook( + createMockPluginInput({ + session: { + messages: async () => ({ + data: [ + { + info: { role: "user" }, + parts: [{ type: "text", text: "hello" }], + }, + ], + }), + abort: async () => { + callOrder.push("abort") + return {} + }, + promptAsync: async (args: unknown) => { + callOrder.push("promptAsync") + return {} + }, + }, + }), + { + config: createMockConfig({ notify_on_fallback: false }), + pluginConfig: { + agents: { + sisyphus: { + model: "anthropic/claude-opus-4-6", + fallback_models: ["openai/gpt-5.2"], + }, + }, + }, + }, + ) + + SessionCategoryRegistry.register(sessionID, "test") + + await hook.event({ + event: { + type: "session.created", + properties: { info: { id: sessionID, model: "anthropic/claude-opus-4-6" } }, + }, + }) + + // when - retryable error triggers fallback + await hook.event({ + event: { + type: "session.error", + properties: { + sessionID, + model: "anthropic/claude-opus-4-6", + error: { statusCode: 429, message: "Rate limit exceeded" }, + }, + }, + }) + + // then - abort must fire before promptAsync + expect(callOrder).toContain("abort") + expect(callOrder).toContain("promptAsync") + const abortIndex = callOrder.indexOf("abort") + const promptIndex = callOrder.indexOf("promptAsync") + expect(abortIndex).toBeLessThan(promptIndex) + }) + }) + + describe("agent resolution fallback when chat.message never fires", () => { + test("should resolve agent from session.get when messages have no agent info", async () => { + const promptCalls: Array> = [] + const hook = createRuntimeFallbackHook( + createMockPluginInput({ + session: { + messages: async () => ({ + data: [{ info: { role: "user" }, parts: [{ type: "text", text: "test" }] }], + }), + get: async () => ({ data: { agent: "Atlas (Work Orchestrator)" } }), + promptAsync: async (args: unknown) => { + promptCalls.push(args as Record) + return {} + }, + }, + }), + { + config: createMockConfig({ notify_on_fallback: false }), + pluginConfig: { + agents: { + atlas: { + model: "anthropic/claude-opus-4-6", + fallback_models: ["openai/gpt-5.2"], + }, + sisyphus: { + model: "anthropic/claude-opus-4-6", + fallback_models: ["google/gemini-3-pro"], + }, + }, + }, + } + ) + const sessionID = "test-session-get-resolution" + SessionCategoryRegistry.register(sessionID, "test") + + await hook.event({ + event: { + type: "session.created", + properties: { info: { id: sessionID, model: "anthropic/claude-opus-4-6" } }, + }, + }) + + await hook.event({ + event: { + type: "session.error", + properties: { sessionID, error: { statusCode: 429 } }, + }, + }) + + const resolveLog = logCalls.find((c) => c.msg.includes("Resolved agent from session.get")) + expect(resolveLog).toBeDefined() + expect(resolveLog?.data).toMatchObject({ sessionID, agent: "atlas" }) + + expect(promptCalls.length).toBeGreaterThanOrEqual(1) + const callBody = promptCalls[0]?.body as Record + expect(callBody?.model).toEqual({ providerID: "openai", modelID: "gpt-5.2" }) + }) + + test("should fall back to sisyphus when all agent resolution methods fail", async () => { + const promptCalls: Array> = [] + const hook = createRuntimeFallbackHook( + createMockPluginInput({ + session: { + messages: async () => ({ + data: [{ info: { role: "user" }, parts: [{ type: "text", text: "test" }] }], + }), + get: async () => ({ data: {} }), // no agent, no modelID + promptAsync: async (args: unknown) => { + promptCalls.push(args as Record) + return {} + }, + }, + }), + { + config: createMockConfig({ notify_on_fallback: false }), + pluginConfig: { + agents: { + sisyphus: { + model: "anthropic/claude-opus-4-6", + fallback_models: ["openai/gpt-5.2"], + }, + }, + }, + } + ) + const sessionID = "test-session-no-agent-info" + SessionCategoryRegistry.register(sessionID, "test") + + await hook.event({ + event: { + type: "session.created", + properties: { info: { id: sessionID, model: "anthropic/claude-opus-4-6" } }, + }, + }) + + await hook.event({ + event: { + type: "session.error", + properties: { sessionID, error: { statusCode: 429 } }, + }, + }) + + const sisyphusLog = logCalls.find( + (c) => + c.msg.includes("Using sisyphus fallback models (no agent detected)") || + c.msg.includes("Using sisyphus model for state creation (no agent detected)") + ) + expect(sisyphusLog).toBeDefined() + + expect(promptCalls.length).toBeGreaterThanOrEqual(1) + const callBody = promptCalls[0]?.body as Record + expect(callBody?.model).toEqual({ providerID: "openai", modelID: "gpt-5.2" }) + }) + }) describe("race condition guards", () => { test("session.error is skipped while retry request is in flight", async () => { diff --git a/src/hooks/runtime-fallback/message-update-handler.ts b/src/hooks/runtime-fallback/message-update-handler.ts index 7e6130955a..087ab50a37 100644 --- a/src/hooks/runtime-fallback/message-update-handler.ts +++ b/src/hooks/runtime-fallback/message-update-handler.ts @@ -2,7 +2,7 @@ import type { HookDeps } from "./types" import type { AutoRetryHelpers } from "./auto-retry" import { HOOK_NAME } from "./constants" import { log } from "../../shared/logger" -import { extractStatusCode, extractErrorName, classifyErrorType, isRetryableError, extractAutoRetrySignal, containsErrorContent } from "./error-classifier" +import { extractStatusCode, extractErrorName, classifyErrorType, isRetryableError, extractAutoRetrySignal, containsErrorContent, extractErrorContentFromParts, detectErrorInTextParts } from "./error-classifier" import { createFallbackState, prepareFallback } from "./fallback-state" import { getFallbackModelsForSession } from "./fallback-models" @@ -50,6 +50,43 @@ export function hasVisibleAssistantResponse(extractAutoRetrySignalFn: typeof ext } } +async function checkLastAssistantForErrorContent( + ctx: HookDeps["ctx"], + sessionID: string, +): Promise { + try { + const messagesResp = await ctx.client.session.messages({ + path: { id: sessionID }, + query: { directory: ctx.directory }, + }) + + const msgs = (messagesResp as { + data?: Array<{ + info?: Record + parts?: Array<{ type?: string; text?: string }> + }> + }).data + + if (!msgs || msgs.length === 0) return undefined + + const lastAssistant = [...msgs].reverse().find((m) => m.info?.role === "assistant") + if (!lastAssistant) return undefined + + const parts = lastAssistant.parts ?? + (lastAssistant.info?.parts as Array<{ type?: string; text?: string }> | undefined) + + const result = extractErrorContentFromParts(parts) + if (result.hasError) return result.errorMessage + + const textResult = detectErrorInTextParts(parts) + if (textResult.hasError) return textResult.errorMessage + + return undefined + } catch { + return undefined + } +} + export function createMessageUpdateHandler(deps: HookDeps, helpers: AutoRetryHelpers) { const { ctx, config, pluginConfig, sessionStates, sessionLastAccess, sessionRetryInFlight, sessionAwaitingFallbackResult } = deps const checkVisibleResponse = hasVisibleAssistantResponse(extractAutoRetrySignal) @@ -62,11 +99,31 @@ export function createMessageUpdateHandler(deps: HookDeps, helpers: AutoRetryHel const timeoutEnabled = config.timeout_seconds > 0 const parts = props?.parts as Array<{ type?: string; text?: string }> | undefined const errorContentResult = containsErrorContent(parts) - const error = info?.error ?? + let error = info?.error ?? (retrySignal && timeoutEnabled ? { name: "ProviderRateLimitError", message: retrySignal } : undefined) ?? (errorContentResult.hasError ? { name: "MessageContentError", message: errorContentResult.errorMessage || "Message contains error content" } : undefined) const role = info?.role as string | undefined - const model = info?.model as string | undefined + const model = (info?.model as string | undefined) + ?? (typeof info?.providerID === "string" && typeof info?.modelID === "string" + ? `${info.providerID}/${info.modelID}` + : undefined) + + if (sessionID && role === "assistant") { + log(`[${HOOK_NAME}] message.updated received`, { + sessionID, + model, + hasInfoError: !!info?.error, + errorType: info?.error ? classifyErrorType(info.error) : undefined, + }) + } + + if (sessionID && role === "assistant" && !error) { + const errorContent = await checkLastAssistantForErrorContent(ctx, sessionID) + if (errorContent) { + log(`[${HOOK_NAME}] Detected error content in message parts`, { sessionID, errorContent: errorContent.slice(0, 200) }) + error = { name: "ContentError", message: errorContent } + } + } if (sessionID && role === "assistant" && !error) { if (!sessionAwaitingFallbackResult.has(sessionID)) { @@ -162,12 +219,18 @@ export function createMessageUpdateHandler(deps: HookDeps, helpers: AutoRetryHel } if (!initialModel) { - log(`[${HOOK_NAME}] message.updated missing model info, cannot fallback`, { - sessionID, - errorName: extractErrorName(error), - errorType: classifyErrorType(error), - }) - return + const sisyphusModel = pluginConfig?.agents?.sisyphus?.model as string | undefined + if (sisyphusModel) { + log(`[${HOOK_NAME}] Using sisyphus model for state creation (no agent detected)`, { sessionID, model: sisyphusModel }) + initialModel = sisyphusModel + } else { + log(`[${HOOK_NAME}] message.updated missing model info, cannot fallback`, { + sessionID, + errorName: extractErrorName(error), + errorType: classifyErrorType(error), + }) + return + } } state = createFallbackState(initialModel) diff --git a/src/hooks/start-work/start-work-hook.ts b/src/hooks/start-work/start-work-hook.ts index 77c76d240a..7a554eadb1 100644 --- a/src/hooks/start-work/start-work-hook.ts +++ b/src/hooks/start-work/start-work-hook.ts @@ -10,7 +10,7 @@ import { clearBoulderState, } from "../../features/boulder-state" import { log } from "../../shared/logger" -import { updateSessionAgent } from "../../features/claude-code-session-state" +import { pinSessionAgent } from "../../features/claude-code-session-state" export const HOOK_NAME = "start-work" as const @@ -71,7 +71,7 @@ export function createStartWorkHook(ctx: PluginInput) { sessionID: input.sessionID, }) - updateSessionAgent(input.sessionID, "atlas") // Always switch: fixes #1298 + pinSessionAgent(input.sessionID, "atlas") // Always switch: fixes #1298 const existingState = readBoulderState(ctx.directory) const sessionId = input.sessionID diff --git a/src/plugin-handlers/config-handler.test.ts b/src/plugin-handlers/config-handler.test.ts index cff6c97e2a..725165fc2f 100644 --- a/src/plugin-handlers/config-handler.test.ts +++ b/src/plugin-handlers/config-handler.test.ts @@ -1161,11 +1161,12 @@ describe("per-agent todowrite/todoread deny when task_system enabled", () => { getAgentDisplayName("sisyphus"), getAgentDisplayName("hephaestus"), getAgentDisplayName("atlas"), - ]) - const AGENTS_WITHOUT_TODO_DENY = new Set([ getAgentDisplayName("prometheus"), getAgentDisplayName("sisyphus-junior"), ]) + const AGENTS_WITHOUT_TODO_DENY = new Set([ + getAgentDisplayName("oracle"), + ]) test("denies todowrite and todoread for primary agents when task_system is enabled", async () => { //#given diff --git a/src/plugin/chat-message.ts b/src/plugin/chat-message.ts index f3c02297fb..23dae59470 100644 --- a/src/plugin/chat-message.ts +++ b/src/plugin/chat-message.ts @@ -3,7 +3,7 @@ import type { PluginContext } from "./types" import { hasConnectedProvidersCache } from "../shared" import { setSessionModel } from "../shared/session-model-state" -import { setSessionAgent } from "../features/claude-code-session-state" +import { updateSessionAgent, pinSessionAgent, unpinSessionAgent } from "../features/claude-code-session-state" import { applyUltraworkModelOverrideOnMessage } from "./ultrawork-model-override" import { parseRalphLoopArguments } from "../hooks/ralph-loop/command-arguments" @@ -71,8 +71,19 @@ export function createChatMessageHandler(args: { output: ChatMessageHandlerOutput ): Promise => { if (input.agent) { - setSessionAgent(input.sessionID, input.agent) + unpinSessionAgent(input.sessionID) + updateSessionAgent(input.sessionID, input.agent) } + const promptText = output.parts + ?.filter((p) => p.type === "text" && p.text) + .map((p) => p.text) + .join("\n") + .trim() || "" + const isStartWorkCommand = promptText.includes("") + if (isStartWorkCommand) { + pinSessionAgent(input.sessionID, "atlas") + } + if (firstMessageVariantGate.shouldOverride(input.sessionID)) { firstMessageVariantGate.markApplied(input.sessionID) diff --git a/src/shared/index.ts b/src/shared/index.ts index 09187602f0..02ac7353e0 100644 --- a/src/shared/index.ts +++ b/src/shared/index.ts @@ -53,6 +53,7 @@ export * from "./port-utils" export * from "./git-worktree" export * from "./safe-create-hook" export * from "./truncate-description" +export * from "./session-category-registry" export * from "./opencode-storage-paths" export * from "./opencode-message-dir" export * from "./normalize-sdk-response" diff --git a/src/shared/model-error-classifier.ts b/src/shared/model-error-classifier.ts index defcef6705..4eaf704998 100644 --- a/src/shared/model-error-classifier.ts +++ b/src/shared/model-error-classifier.ts @@ -36,6 +36,7 @@ const RETRYABLE_MESSAGE_PATTERNS = [ "rate_limit", "rate limit", "quota", + "quota protection", "not found", "unavailable", "insufficient", diff --git a/src/shared/model-resolver.ts b/src/shared/model-resolver.ts index e2e02fce3c..a9e450fb20 100644 --- a/src/shared/model-resolver.ts +++ b/src/shared/model-resolver.ts @@ -7,6 +7,17 @@ export type ModelResolutionInput = { systemDefault?: string } +/** + * Normalizes fallback_models to an array. + * Handles single string or array input, returns undefined for falsy values. + */ +export function normalizeFallbackModels( + fallbackModels: string | string[] | undefined | null +): string[] | undefined { + if (!fallbackModels) return undefined + return Array.isArray(fallbackModels) ? fallbackModels : [fallbackModels] +} + export type ModelSource = | "override" | "category-default" @@ -62,13 +73,3 @@ export function resolveModelWithFallback( variant: resolved.variant, } } - -/** - * Normalizes fallback_models config (which can be string or string[]) to string[] - * Centralized helper to avoid duplicated normalization logic - */ -export function normalizeFallbackModels(models: string | string[] | undefined): string[] | undefined { - if (!models) return undefined - if (typeof models === "string") return [models] - return models -}