diff --git a/orchestrator/agent_runner.py b/orchestrator/agent_runner.py index b547325..ae2a0c1 100644 --- a/orchestrator/agent_runner.py +++ b/orchestrator/agent_runner.py @@ -8,7 +8,14 @@ import time from collections.abc import AsyncIterator from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import ( + TYPE_CHECKING, + Any, + Final, + Literal, + NamedTuple, + cast, +) from claude_agent_sdk import ( AssistantMessage, @@ -36,6 +43,44 @@ # Max retries when SDK raises for unknown message types (e.g. rate_limit_event) _MAX_UNKNOWN_MSG_RETRIES = 10 +_BASE_ALLOWED_TOOLS: Final[tuple[str, ...]] = ( + "Read", + "Write", + "Edit", + "Bash", + "Glob", + "Grep", + "Task", + "WebSearch", + "WebFetch", + "mcp__tracker__tracker_get_issue", + "mcp__tracker__tracker_add_comment", + "mcp__tracker__tracker_get_comments", + "mcp__tracker__tracker_get_checklist", + "mcp__tracker__tracker_request_info", + "mcp__tracker__tracker_signal_blocked", + "mcp__tracker__tracker_create_subtask", + "mcp__tracker__propose_improvement", + "mcp__tracker__tracker_get_attachments", + "mcp__tracker__tracker_download_attachment", + "mcp__tracker__tracker_create_workpad", + "mcp__tracker__tracker_update_workpad", + "mcp__tracker__tracker_mark_complete", +) + +_WORKSPACE_TOOLS: Final[tuple[str, ...]] = ( + "mcp__workspace__list_available_repos", + "mcp__workspace__request_worktree", +) + +_COMM_TOOLS: Final[tuple[str, ...]] = ( + "mcp__comm__list_running_agents", + "mcp__comm__send_message_to_agent", + "mcp__comm__send_request_to_agent", + "mcp__comm__reply_to_message", + "mcp__comm__check_messages", +) + async def receive_response_safe( client: ClaudeSDKClient, @@ -95,6 +140,21 @@ def total_tokens(self) -> int: return self.input_tokens + self.output_tokens +class _ResultData(NamedTuple): + """Data extracted from a ResultMessage.""" + + cost: float | None + input_tokens: int + output_tokens: int + + +class _ToolStateSnapshot(NamedTuple): + """Snapshot of ToolState side-channel data.""" + + needs_info: bool + proposals: list[dict[str, str]] + + def merge_results(base: AgentResult, update: AgentResult) -> AgentResult: """Merge drain result into base: accumulate costs, prefer latest data. @@ -179,6 +239,122 @@ def get_pending_message(self) -> str | None: except asyncio.QueueEmpty: return None + async def _process_assistant_message( + self, + message: AssistantMessage, + output_parts: list[str], + ) -> bool: + """Process an AssistantMessage: collect text, detect rate limit. + + Returns True if rate_limit error detected. + """ + is_rate_limited = message.error == "rate_limit" + for block in message.content: + if isinstance(block, TextBlock): + output_parts.append(block.text) + if self._event_bus: + await self._event_bus.publish( + Event( + type=EventType.AGENT_OUTPUT, + task_key=self._issue_key, + data={"text": block.text}, + ) + ) + return is_rate_limited + + async def _apply_result_message( + self, + message: ResultMessage, + start: float, + ) -> _ResultData: + """Extract cost/tokens from ResultMessage, publish event. + + Side effects: updates ``_session_id`` and cumulative token + counts on the session instance. + """ + cost = getattr(message, "total_cost_usd", None) + self._session_id = getattr(message, "session_id", None) + usage = getattr(message, "usage", None) + input_tokens = usage.get("input_tokens", 0) if usage else 0 + output_tokens = usage.get("output_tokens", 0) if usage else 0 + # Tokens: latest-wins (SDK includes prior context) + self.cumulative_input_tokens = input_tokens + self.cumulative_output_tokens = output_tokens + if self._event_bus: + elapsed_ms = (time.monotonic() - start) * 1000 + await self._event_bus.publish( + Event( + type=EventType.AGENT_RESULT, + task_key=self._issue_key, + data={ + "cost": cost, + "duration_ms": elapsed_ms, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + }, + ) + ) + return _ResultData( + cost=cost, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + + def _read_tool_state(self) -> _ToolStateSnapshot: + """Read and reset tool_state side-channel flags. + + Returns a snapshot; resets ``needs_info_requested`` and + ``proposals`` on the underlying ToolState so they are not + consumed twice. + """ + if not self._tool_state: + return _ToolStateSnapshot( + needs_info=False, + proposals=[], + ) + needs_info = self._tool_state.needs_info_requested + if needs_info: + self._tool_state.needs_info_requested = False + proposals = list(self._tool_state.proposals) + if proposals: + self._tool_state.proposals.clear() + return _ToolStateSnapshot( + needs_info=needs_info, + proposals=proposals, + ) + + def _build_success_result( + self, + output_parts: list[str], + result_data: _ResultData, + is_rate_limited: bool, + duration: float, + ) -> AgentResult: + """Build successful AgentResult from collected data.""" + output = "\n".join(output_parts) if output_parts else "No output" + pr_match = PR_URL_PATTERN.search(output) + logger.info( + "Agent session for %s completed in %.0fs (cost: %s, tokens: %d in/%d out)", + self._issue_key, + duration, + result_data.cost, + result_data.input_tokens, + result_data.output_tokens, + ) + ts = self._read_tool_state() + return AgentResult( + success=True, + output=output, + cost_usd=result_data.cost, + duration_seconds=duration, + pr_url=(pr_match.group(0) if pr_match else None), + needs_info=ts.needs_info, + is_rate_limited=is_rate_limited, + proposals=ts.proposals, + input_tokens=result_data.input_tokens, + output_tokens=result_data.output_tokens, + ) + async def send(self, prompt: str) -> AgentResult: """Send a prompt and collect response (re-usable).""" if self._closed: @@ -188,53 +364,24 @@ async def send(self, prompt: str) -> AgentResult: ) start = time.monotonic() output_parts: list[str] = [] - cost = None - input_tokens = 0 - output_tokens = 0 + result_data = _ResultData(None, 0, 0) is_rate_limited = False try: await self._client.query(prompt) - async for message in receive_response_safe(self._client): - if isinstance(message, AssistantMessage): - # Check for rate limit error from SDK - if message.error == "rate_limit": + async for msg in receive_response_safe( + self._client, + ): + if isinstance(msg, AssistantMessage): + if await self._process_assistant_message( + msg, + output_parts, + ): is_rate_limited = True - for block in message.content: - if isinstance(block, TextBlock): - output_parts.append(block.text) - if self._event_bus: - await self._event_bus.publish( - Event( - type=EventType.AGENT_OUTPUT, - task_key=self._issue_key, - data={"text": block.text}, - ) - ) - elif isinstance(message, ResultMessage): - cost = getattr(message, "total_cost_usd", None) - self._session_id = getattr(message, "session_id", None) - usage = getattr(message, "usage", None) - if usage: - input_tokens = usage.get("input_tokens", 0) - output_tokens = usage.get("output_tokens", 0) - # Tokens: latest-wins (SDK includes prior - # context) - self.cumulative_input_tokens = input_tokens - self.cumulative_output_tokens = output_tokens - if self._event_bus: - elapsed_ms = (time.monotonic() - start) * 1000 - await self._event_bus.publish( - Event( - type=EventType.AGENT_RESULT, - task_key=self._issue_key, - data={ - "cost": cost, - "duration_ms": elapsed_ms, - "input_tokens": (input_tokens), - "output_tokens": (output_tokens), - }, - ) - ) + elif isinstance(msg, ResultMessage): + result_data = await self._apply_result_message( + msg, + start, + ) except Exception as e: logger.error( "Agent session failed for %s: %s", @@ -246,40 +393,11 @@ async def send(self, prompt: str) -> AgentResult: output=str(e), duration_seconds=time.monotonic() - start, ) - - output = "\n".join(output_parts) if output_parts else "No output" - pr_match = PR_URL_PATTERN.search(output) - duration = time.monotonic() - start - logger.info( - "Agent session for %s completed in %.0fs (cost: %s, tokens: %d in/%d out)", - self._issue_key, - duration, - cost, - input_tokens, - output_tokens, - ) - - needs_info = False - if self._tool_state and self._tool_state.needs_info_requested: - needs_info = True - self._tool_state.needs_info_requested = False - - proposals: list[dict[str, str]] = [] - if self._tool_state and self._tool_state.proposals: - proposals = list(self._tool_state.proposals) - self._tool_state.proposals.clear() - - return AgentResult( - success=True, - output=output, - cost_usd=cost, - duration_seconds=duration, - pr_url=pr_match.group(0) if pr_match else None, - needs_info=needs_info, - is_rate_limited=is_rate_limited, - proposals=proposals, - input_tokens=input_tokens, - output_tokens=output_tokens, + return self._build_success_result( + output_parts, + result_data, + is_rate_limited, + duration=time.monotonic() - start, ) async def drain_pending_messages( @@ -356,84 +474,82 @@ def __init__( self._tracker = tracker self._storage = storage - def _build_options( + def _build_mcp_servers( self, issue: TrackerIssue, - tool_state: ToolState | None = None, - model: str | None = None, - workspace_server: object | None = None, - cwd: str | None = None, - resume_session_id: str | None = None, - mailbox: AgentMailbox | None = None, - ) -> ClaudeAgentOptions: - """Build ClaudeAgentOptions for a task.""" - cfg = self._config - workflow_content = build_system_prompt_append(cfg.workflow_prompt_path) + tool_state: ToolState | None, + workspace_server: object | None, + mailbox: AgentMailbox | None, + ) -> dict[str, Any]: + """Build MCP server dict for agent options.""" tracker_server = build_tracker_server( self._tracker, issue.key, tool_state=tool_state, - config=cfg, + config=self._config, issue_components=issue.components, storage=self._storage, ) - - mcp_servers: dict[str, Any] = {"tracker": tracker_server} + servers: dict[str, Any] = {"tracker": tracker_server} if workspace_server: - mcp_servers["workspace"] = workspace_server + servers["workspace"] = workspace_server if mailbox is not None: + # Lazy import: avoids SDK resolution issues + # under autouse mock_sdk fixture in tests. from orchestrator.comm_tools import build_comm_server - comm_server = build_comm_server(mailbox, issue.key, issue.summary) - mcp_servers["comm"] = comm_server - - allowed_tools = [ - "Read", - "Write", - "Edit", - "Bash", - "Glob", - "Grep", - "Task", - "WebSearch", - "WebFetch", - "mcp__tracker__tracker_get_issue", - "mcp__tracker__tracker_add_comment", - "mcp__tracker__tracker_get_comments", - "mcp__tracker__tracker_get_checklist", - "mcp__tracker__tracker_request_info", - "mcp__tracker__tracker_signal_blocked", - "mcp__tracker__tracker_create_subtask", - "mcp__tracker__propose_improvement", - "mcp__tracker__tracker_get_attachments", - "mcp__tracker__tracker_download_attachment", - "mcp__tracker__tracker_create_workpad", - "mcp__tracker__tracker_update_workpad", - "mcp__tracker__tracker_mark_complete", - ] - if workspace_server: - allowed_tools.extend( - [ - "mcp__workspace__list_available_repos", - "mcp__workspace__request_worktree", - ] + servers["comm"] = build_comm_server( + mailbox, + issue.key, + issue.summary, ) + return servers + + def _build_allowed_tools( + self, + workspace_server: object | None, + mailbox: AgentMailbox | None, + ) -> list[str]: + """Build allowed tools list based on available servers.""" + tools = list(_BASE_ALLOWED_TOOLS) + if workspace_server: + tools.extend(_WORKSPACE_TOOLS) if mailbox is not None: - allowed_tools.extend( - [ - "mcp__comm__list_running_agents", - "mcp__comm__send_message_to_agent", - "mcp__comm__send_request_to_agent", - "mcp__comm__reply_to_message", - "mcp__comm__check_messages", - ] - ) + tools.extend(_COMM_TOOLS) + return tools + + def _build_options( + self, + issue: TrackerIssue, + tool_state: ToolState | None = None, + model: str | None = None, + workspace_server: object | None = None, + cwd: str | None = None, + resume_session_id: str | None = None, + mailbox: AgentMailbox | None = None, + ) -> ClaudeAgentOptions: + """Build ClaudeAgentOptions for a task.""" + cfg = self._config + workflow_content = build_system_prompt_append( + cfg.workflow_prompt_path, + ) + mcp_servers = self._build_mcp_servers( + issue, + tool_state, + workspace_server, + mailbox, + ) + allowed_tools = self._build_allowed_tools( + workspace_server, + mailbox, + ) resume_kwargs: dict[str, Any] = {} if resume_session_id: resume_kwargs["resume"] = resume_session_id resume_kwargs["fork_session"] = True + max_budget = float(cfg.agent_max_budget_usd) if cfg.agent_max_budget_usd is not None else None return ClaudeAgentOptions( model=model or cfg.agent_model, system_prompt={ @@ -444,11 +560,17 @@ def _build_options( mcp_servers=mcp_servers, allowed_tools=allowed_tools, permission_mode=cast( - Literal["default", "acceptEdits", "plan", "bypassPermissions"] | None, + Literal[ + "default", + "acceptEdits", + "plan", + "bypassPermissions", + ] + | None, cfg.agent_permission_mode, ), cwd=cwd or "/tmp", - max_budget_usd=float(cfg.agent_max_budget_usd) if cfg.agent_max_budget_usd is not None else None, + max_budget_usd=max_budget, hooks={}, env=cfg.agent_env, setting_sources=["project"], diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index 9cb285a..e749996 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -1260,3 +1260,166 @@ def test_merge_false_false_stays_false(self) -> None: ) merged = merge_results(base, update) assert merged.externally_resolved is False + + +class TestBuildOptionsWithMailbox: + """Characterization: mailbox/comm branch in _build_options.""" + + def test_mailbox_adds_comm_server_and_tools( + self, + mock_sdk, + tmp_path, + ) -> None: + """When mailbox is provided, comm MCP server and tools are added.""" + from orchestrator.agent_runner import AgentRunner + + wf = tmp_path / "workflow.md" + wf.write_text("# Workflow") + + config = _make_config(workflow_prompt_path=str(wf)) + tracker = MagicMock(spec=TrackerClient) + runner = AgentRunner(config, tracker) + + mailbox = MagicMock() + runner._build_options( + _make_issue(), + mailbox=mailbox, + cwd="/tmp", + ) + + call_kwargs = mock_sdk.ClaudeAgentOptions.call_args.kwargs + assert "comm" in call_kwargs["mcp_servers"] + for tool_name in ( + "mcp__comm__list_running_agents", + "mcp__comm__send_message_to_agent", + "mcp__comm__send_request_to_agent", + "mcp__comm__reply_to_message", + "mcp__comm__check_messages", + ): + assert tool_name in call_kwargs["allowed_tools"] + + def test_no_mailbox_no_comm_tools( + self, + mock_sdk, + tmp_path, + ) -> None: + """Without mailbox, no comm server or tools.""" + from orchestrator.agent_runner import AgentRunner + + wf = tmp_path / "workflow.md" + wf.write_text("# Workflow") + + config = _make_config(workflow_prompt_path=str(wf)) + tracker = MagicMock(spec=TrackerClient) + runner = AgentRunner(config, tracker) + + runner._build_options(_make_issue(), cwd="/tmp") + + call_kwargs = mock_sdk.ClaudeAgentOptions.call_args.kwargs + assert "comm" not in call_kwargs["mcp_servers"] + assert "mcp__comm__list_running_agents" not in (call_kwargs["allowed_tools"]) + + +class TestSendToolStateConsumption: + """Characterization: send() reads and resets ToolState.""" + + async def test_send_reads_needs_info(self, mock_sdk) -> None: + """send() returns needs_info=True and resets the flag.""" + from orchestrator.agent_runner import AgentSession + from orchestrator.tracker_tools import ToolState + + mock_client = AsyncMock() + + async def mock_receive(): + return + yield + + mock_client.receive_response = mock_receive + mock_client.query = AsyncMock() + + ts = ToolState(needs_info_requested=True) + session = AgentSession( + mock_client, + "QR-1", + tool_state=ts, + ) + result = await session.send("do something") + + assert result.needs_info is True + assert ts.needs_info_requested is False + + async def test_send_reads_proposals(self, mock_sdk) -> None: + """send() returns proposals and clears the list.""" + from orchestrator.agent_runner import AgentSession + from orchestrator.tracker_tools import ToolState + + mock_client = AsyncMock() + + async def mock_receive(): + return + yield + + mock_client.receive_response = mock_receive + mock_client.query = AsyncMock() + + proposal = {"title": "Idea", "description": "Details"} + ts = ToolState(proposals=[proposal]) + session = AgentSession( + mock_client, + "QR-1", + tool_state=ts, + ) + result = await session.send("do something") + + assert result.proposals == [proposal] + assert ts.proposals == [] + + +class TestAgentOutputEvent: + """Characterization: send() publishes AGENT_OUTPUT per TextBlock.""" + + async def test_send_publishes_agent_output_event( + self, + mock_sdk, + ) -> None: + """Each TextBlock in AssistantMessage publishes AGENT_OUTPUT.""" + from orchestrator.agent_runner import AgentSession + from orchestrator.event_bus import EventBus + + AssistantMessage = mock_sdk.AssistantMessage + TextBlock = mock_sdk.TextBlock + + mock_client = AsyncMock() + + block1 = MagicMock(spec=TextBlock) + block1.__class__ = TextBlock + block1.text = "Hello" + + block2 = MagicMock(spec=TextBlock) + block2.__class__ = TextBlock + block2.text = " world" + + msg = MagicMock(spec=AssistantMessage) + msg.__class__ = AssistantMessage + msg.error = None + msg.content = [block1, block2] + + async def mock_receive(): + yield msg + + mock_client.receive_response = mock_receive + mock_client.query = AsyncMock() + + event_bus = EventBus() + session = AgentSession( + mock_client, + "QR-1", + event_bus=event_bus, + ) + await session.send("do something") + + history = event_bus.get_task_history("QR-1") + output_events = [e for e in history if e.type == "agent_output"] + assert len(output_events) == 2 + assert output_events[0].data["text"] == "Hello" + assert output_events[1].data["text"] == " world"