diff --git a/examples/01_standalone_sdk/45_parallel_tool_execution.py b/examples/01_standalone_sdk/45_parallel_tool_execution.py new file mode 100644 index 0000000000..8949e3bfc9 --- /dev/null +++ b/examples/01_standalone_sdk/45_parallel_tool_execution.py @@ -0,0 +1,219 @@ +"""Example: Parallel tool execution with tool_concurrency_limit. + +Demonstrates how setting tool_concurrency_limit on an Agent enables +concurrent tool execution within a single step. The orchestrator agent +delegates to multiple sub-agents in parallel, and each sub-agent itself +runs tools concurrently. This stress-tests the parallel execution system +end-to-end. +""" + +import json +import os +import tempfile +from collections import defaultdict +from pathlib import Path + +from openhands.sdk import ( + LLM, + Agent, + AgentContext, + Conversation, + Tool, + register_agent, +) +from openhands.sdk.context import Skill +from openhands.tools.delegate import DelegationVisualizer +from openhands.tools.file_editor import FileEditorTool +from openhands.tools.task import TaskToolSet +from openhands.tools.terminal import TerminalTool + + +llm = LLM( + model=os.getenv("LLM_MODEL", "anthropic/claude-sonnet-4-5-20250929"), + api_key=os.getenv("LLM_API_KEY"), + base_url=os.getenv("LLM_BASE_URL"), + usage_id="parallel-tools-demo", +) + + +# --- Sub-agents --- + + +def create_code_analyst(llm: LLM) -> Agent: + """Sub-agent that analyzes code structure.""" + return Agent( + llm=llm, + tools=[ + Tool(name=TerminalTool.name), + Tool(name=FileEditorTool.name), + ], + tool_concurrency_limit=4, + agent_context=AgentContext( + skills=[ + Skill( + name="code_analysis", + content=( + "You analyze code structure. Use the terminal to count files, " + "lines of code, and list directory structure. Use the file " + "editor to read key files. Run multiple commands at once." + ), + trigger=None, + ) + ], + system_message_suffix="Be concise. Report findings in bullet points.", + ), + ) + + +def create_doc_reviewer(llm: LLM) -> Agent: + """Sub-agent that reviews documentation.""" + return Agent( + llm=llm, + tools=[ + Tool(name=TerminalTool.name), + Tool(name=FileEditorTool.name), + ], + tool_concurrency_limit=4, + agent_context=AgentContext( + skills=[ + Skill( + name="doc_review", + content=( + "You review project documentation. Check README files, " + "docstrings, and inline comments. Use the terminal and " + "file editor to inspect files. Run multiple commands at once." + ), + trigger=None, + ) + ], + system_message_suffix="Be concise. Report findings in bullet points.", + ), + ) + + +def create_dependency_checker(llm: LLM) -> Agent: + """Sub-agent that checks project dependencies.""" + return Agent( + llm=llm, + tools=[ + Tool(name=TerminalTool.name), + Tool(name=FileEditorTool.name), + ], + tool_concurrency_limit=4, + agent_context=AgentContext( + skills=[ + Skill( + name="dependency_check", + content=( + "You analyze project dependencies. Read pyproject.toml, " + "requirements files, and package configs. Summarize key " + "dependencies, their purposes, and any version constraints. " + "Run multiple commands at once." + ), + trigger=None, + ) + ], + system_message_suffix="Be concise. Report findings in bullet points.", + ), + ) + + +# Register sub-agents +register_agent( + name="code_analyst", + factory_func=create_code_analyst, + description="Analyzes code structure, file counts, and directory layout.", +) +register_agent( + name="doc_reviewer", + factory_func=create_doc_reviewer, + description="Reviews documentation quality and completeness.", +) +register_agent( + name="dependency_checker", + factory_func=create_dependency_checker, + description="Checks and summarizes project dependencies.", +) +# --- Orchestrator agent with parallel execution --- +main_agent = Agent( + llm=llm, + tools=[ + Tool(name=TaskToolSet.name), + Tool(name=TerminalTool.name), + Tool(name=FileEditorTool.name), + ], + tool_concurrency_limit=8, +) + +persistence_dir = Path(tempfile.mkdtemp(prefix="parallel_example_")) + +conversation = Conversation( + agent=main_agent, + workspace=Path.cwd(), + visualizer=DelegationVisualizer(name="Orchestrator"), + persistence_dir=persistence_dir, +) + +print("=" * 80) +print("Parallel Tool Execution Stress Test") +print("=" * 80) + +conversation.send_message(""" +Analyze the current project by delegating to ALL THREE sub-agents IN PARALLEL: + +1. code_analyst: Analyze the project structure (file counts, key directories) +2. doc_reviewer: Review documentation quality (README, docstrings) +3. dependency_checker: Check dependencies (pyproject.toml, requirements) + +IMPORTANT: Delegate to all three agents at the same time using parallel tool calls. +Do NOT delegate one at a time - call all three delegate tools in a single response. + +Once all three have reported back, write a consolidated summary to +project_analysis_report.txt in the working directory. The report should have +three sections (Code Structure, Documentation, Dependencies) with the key +findings from each sub-agent. +""") +conversation.run() + +# --- Analyze persisted events for parallelism --- +# +# Walk the persistence directory to find all conversations (main + sub-agents). +# Each conversation stores events as event-*.json files under an events/ dir. +# We parse ActionEvent entries and group by llm_response_id — batches with 2+ +# actions sharing the same response ID prove the LLM requested parallel calls +# and the executor handled them concurrently. + +print("\n" + "=" * 80) +print("Parallelism Report") +print("=" * 80) + + +def _analyze_conversation(events_dir: Path) -> dict[str, list[str]]: + """Return {llm_response_id: [tool_name, ...]} for multi-tool batches.""" + batches: dict[str, list[str]] = defaultdict(list) + for event_file in sorted(events_dir.glob("event-*.json")): + data = json.loads(event_file.read_text()) + if data.get("kind") == "ActionEvent" and "llm_response_id" in data: + batches[data["llm_response_id"]].append(data.get("tool_name", "?")) + return {rid: tools for rid, tools in batches.items() if len(tools) >= 2} + + +for events_dir in sorted(persistence_dir.rglob("events")): + if not events_dir.is_dir(): + continue + # Derive a label from the path (main conv vs sub-agent) + rel = events_dir.parent.relative_to(persistence_dir) + is_subagent = "subagents" in rel.parts + label = "sub-agent" if is_subagent else "main agent" + + multi_batches = _analyze_conversation(events_dir) + if multi_batches: + for resp_id, tools in multi_batches.items(): + print(f"\n {label} batch ({resp_id[:16]}...):") + print(f" Parallel tools: {tools}") + else: + print(f"\n {label}: no parallel batches") + +cost = conversation.conversation_stats.get_combined_metrics().accumulated_cost +print(f"\nTotal cost: ${cost:.4f}") +print(f"EXAMPLE_COST: {cost:.4f}") diff --git a/openhands-sdk/openhands/sdk/agent/agent.py b/openhands-sdk/openhands/sdk/agent/agent.py index bd85e44182..e6d6c8e4b8 100644 --- a/openhands-sdk/openhands/sdk/agent/agent.py +++ b/openhands-sdk/openhands/sdk/agent/agent.py @@ -1,11 +1,14 @@ import json +from collections.abc import Callable +from dataclasses import dataclass, field -from pydantic import ValidationError, model_validator +from pydantic import PrivateAttr, ValidationError, model_validator import openhands.sdk.security.analyzer as analyzer import openhands.sdk.security.risk as risk from openhands.sdk.agent.base import AgentBase from openhands.sdk.agent.critic_mixin import CriticMixin +from openhands.sdk.agent.parallel_executor import ParallelToolExecutor from openhands.sdk.agent.utils import ( fix_malformed_tool_arguments, make_llm_completion, @@ -22,6 +25,7 @@ from openhands.sdk.event import ( ActionEvent, AgentErrorEvent, + Event, MessageEvent, ObservationEvent, SystemPromptEvent, @@ -72,6 +76,133 @@ INIT_STATE_PREFIX_SCAN_WINDOW = 3 +@dataclass(frozen=True, slots=True) +class _ActionBatch: + """Immutable result of preparing a batch of actions for execution. + + Owns the full lifecycle of a tool-call batch: preparation (truncation, + blocked-action partitioning, execution), event emission, and post-batch + state transitions. Agent-specific logic (iterative refinement, state + mutation) is injected via callables so the batch stays decoupled from + the Agent class. + """ + + action_events: list[ActionEvent] + has_finish: bool + blocked_reasons: dict[str, str] = field(default_factory=dict) + results_by_id: dict[str, list[Event]] = field(default_factory=dict) + + @staticmethod + def _truncate_at_finish( + action_events: list[ActionEvent], + ) -> tuple[list[ActionEvent], bool]: + """ + Return (events[:finish+1], True) or (events, False). + Discards and logs any calls after FinishTool. + """ + finish_idx = next( + ( + i + for i, ae in enumerate(action_events) + if ae.tool_name == FinishTool.name + ), + None, + ) + if finish_idx is None: + return action_events, False + + discarded = action_events[finish_idx + 1 :] + if discarded: + names = [ae.tool_name for ae in discarded] + logger.warning( + f"Discarding {len(discarded)} tool call(s) " + f"after FinishTool: {', '.join(names)}" + ) + return action_events[: finish_idx + 1], True + + @classmethod + def prepare( + cls, + action_events: list[ActionEvent], + state: ConversationState, + executor: ParallelToolExecutor, + tool_runner: Callable[[ActionEvent], list[Event]], + ) -> "_ActionBatch": + """Truncate, partition blocked actions, execute the rest, return the batch.""" + action_events, has_finish = cls._truncate_at_finish(action_events) + + blocked_reasons: dict[str, str] = {} + executable: list[ActionEvent] = [] + for ae in action_events: + reason = state.pop_blocked_action(ae.id) + if reason is not None: + blocked_reasons[ae.id] = reason + else: + executable.append(ae) + + executed_results = executor.execute_batch(executable, tool_runner) + results_by_id = dict(zip([ae.id for ae in executable], executed_results)) + + return cls( + action_events=action_events, + has_finish=has_finish, + blocked_reasons=blocked_reasons, + results_by_id=results_by_id, + ) + + def emit(self, on_event: ConversationCallbackType) -> None: + """Emit all events in original action order.""" + for ae in self.action_events: + reason = self.blocked_reasons.get(ae.id) + if reason is not None: + logger.info(f"Action '{ae.tool_name}' blocked by hook: {reason}") + on_event( + UserRejectObservation( + action_id=ae.id, + tool_name=ae.tool_name, + tool_call_id=ae.tool_call_id, + rejection_reason=reason, + rejection_source="hook", + ) + ) + else: + for event in self.results_by_id[ae.id]: + on_event(event) + + def finalize( + self, + on_event: ConversationCallbackType, + check_iterative_refinement: Callable[[ActionEvent], tuple[bool, str | None]], + mark_finished: Callable[[], None], + ) -> None: + """Transition state after FinishTool, or inject iterative-refinement followup. + + Args: + on_event: Callback for emitting events. + check_iterative_refinement: Returns (should_continue, followup) + for a FinishTool action event. + mark_finished: Called to set the conversation execution status + to FINISHED when the agent is done. + """ + # Nothing to finalise: no FinishTool, or it was blocked by a hook. + if not self.has_finish or self.action_events[-1].id in self.blocked_reasons: + return + + should_continue, followup = check_iterative_refinement(self.action_events[-1]) + if should_continue and followup: + on_event( + MessageEvent( + source="user", + llm_message=Message( + role="user", + content=[TextContent(text=followup)], + ), + ) + ) + else: + mark_finished() + + class Agent(CriticMixin, AgentBase): """Main agent implementation for OpenHands. @@ -97,6 +228,16 @@ class Agent(CriticMixin, AgentBase): ``` """ + _parallel_executor: ParallelToolExecutor = PrivateAttr( + default_factory=ParallelToolExecutor + ) + + def model_post_init(self, __context: object) -> None: + super().model_post_init(__context) + self._parallel_executor = ParallelToolExecutor( + max_workers=self.tool_concurrency_limit + ) + @model_validator(mode="before") @classmethod def _add_security_prompt_as_default(cls, data): @@ -258,9 +399,27 @@ def _execute_actions( conversation: LocalConversation, action_events: list[ActionEvent], on_event: ConversationCallbackType, - ): - for action_event in action_events: - self._execute_action_event(conversation, action_event, on_event=on_event) + ) -> None: + """Prepare a batch, emit results, and handle finish.""" + state = conversation.state + batch = _ActionBatch.prepare( + action_events, + state=state, + executor=self._parallel_executor, + tool_runner=lambda ae: self._execute_action_event(conversation, ae), + ) + batch.emit(on_event) + batch.finalize( + on_event=on_event, + check_iterative_refinement=lambda ae: ( + self._check_iterative_refinement(conversation, ae) + ), + mark_finished=lambda: setattr( + state, + "execution_status", + ConversationExecutionStatus.FINISHED, + ), + ) @observe(name="agent.step", ignore_inputs=["state", "on_event"]) def step( @@ -659,38 +818,26 @@ def _get_action_event( on_event(action_event) return action_event - @observe(ignore_inputs=["state", "on_event"]) + @observe() def _execute_action_event( self, conversation: LocalConversation, action_event: ActionEvent, - on_event: ConversationCallbackType, - ): - """Execute an action event and update the conversation state. + ) -> list[Event]: + """Execute a single tool and return the resulting events. - It will call the tool's executor and update the state & call callback fn - with the observation. - - If the action was blocked by a PreToolUse hook (recorded in - state.blocked_actions), a UserRejectObservation is emitted instead - of executing the action. - """ - state = conversation.state + Called from parallel threads by _execute_actions. This method must + not mutate shared conversation state (blocked_actions, + execution_status) — those transitions are handled by the caller + on the main thread. - # Check if this action was blocked by a PreToolUse hook - reason = state.pop_blocked_action(action_event.id) - if reason is not None: - logger.info(f"Action '{action_event.tool_name}' blocked by hook: {reason}") - rejection = UserRejectObservation( - action_id=action_event.id, - tool_name=action_event.tool_name, - tool_call_id=action_event.tool_call_id, - rejection_reason=reason, - rejection_source="hook", - ) - on_event(rejection) - return rejection + Note: the tool itself receives ``conversation`` and may mutate it + (e.g. filesystem, working directory). Thread safety of individual + tools is the tool's responsibility. + Returns a list of events (observation or error). Events are NOT + emitted here — the caller is responsible for emitting them in order. + """ tool = self.tools_map.get(action_event.tool_name, None) if tool is None: raise RuntimeError( @@ -720,8 +867,7 @@ def _execute_action_event( tool_name=tool.name, tool_call_id=action_event.tool_call.id, ) - on_event(error_event) - return error_event + return [error_event] obs_event = ObservationEvent( observation=observation, @@ -729,27 +875,7 @@ def _execute_action_event( tool_name=tool.name, tool_call_id=action_event.tool_call.id, ) - on_event(obs_event) - - # Set conversation state - if tool.name == FinishTool.name: - # Check if iterative refinement should continue - should_continue, followup = self._check_iterative_refinement( - conversation, action_event - ) - if should_continue and followup: - # Send follow-up message and continue agent loop - followup_msg = MessageEvent( - source="user", - llm_message=Message( - role="user", content=[TextContent(text=followup)] - ), - ) - on_event(followup_msg) - # Don't set FINISHED - let the agent continue - else: - state.execution_status = ConversationExecutionStatus.FINISHED - return obs_event + return [obs_event] def _maybe_emit_vllm_tokens( self, llm_response: LLMResponse, on_event: ConversationCallbackType diff --git a/openhands-sdk/openhands/sdk/agent/base.py b/openhands-sdk/openhands/sdk/agent/base.py index fd0826ad51..89f2c404c3 100644 --- a/openhands-sdk/openhands/sdk/agent/base.py +++ b/openhands-sdk/openhands/sdk/agent/base.py @@ -187,6 +187,17 @@ class AgentBase(DiscriminatedUnionMixin, ABC): examples=[{"kind": "AgentFinishedCritic"}], ) + tool_concurrency_limit: int = Field( + default=1, + ge=1, + description=( + "Maximum number of tool calls to execute concurrently within a single " + "agent step. Default is 1 (sequential). Values > 1 enable parallel " + "execution; concurrent tools share the conversation object, filesystem, " + "and working directory, so mutations to shared state may race." + ), + ) + # Runtime materialized tools; private and non-serializable _tools: dict[str, ToolDefinition] = PrivateAttr(default_factory=dict) _initialized: bool = PrivateAttr(default=False) diff --git a/openhands-sdk/openhands/sdk/agent/parallel_executor.py b/openhands-sdk/openhands/sdk/agent/parallel_executor.py new file mode 100644 index 0000000000..0f02655542 --- /dev/null +++ b/openhands-sdk/openhands/sdk/agent/parallel_executor.py @@ -0,0 +1,117 @@ +"""Parallel tool execution for agent. + +This module provides utilities for executing multiple tool calls concurrently +with a configurable per-agent concurrency limit. + +.. warning:: Thread safety of individual tools + + When ``tool_concurrency_limit > 1``, multiple tools run in parallel + threads sharing the same ``conversation`` object. Tools are **not** + thread-safe by default — concurrent mutations to working directory, + filesystem, or conversation state can race. Callers opting into + parallelism must ensure the tools in use are safe for concurrent + execution. +""" + +from __future__ import annotations + +from collections.abc import Callable, Sequence +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING + +from openhands.sdk.event.llm_convertible import AgentErrorEvent +from openhands.sdk.logger import get_logger + + +if TYPE_CHECKING: + from openhands.sdk.event.base import Event + from openhands.sdk.event.llm_convertible import ActionEvent + +logger = get_logger(__name__) + + +class ParallelToolExecutor: + """Executes a batch of tool calls concurrently. + + Each instance has its own thread pool and concurrency limit, so + nested execution (e.g., subagents) cannot deadlock the parent. + + .. warning:: + + When concurrency > 1, tools share the ``conversation`` object + across threads. Tools are not thread-safe by default — concurrent + mutations to filesystem, working directory, or conversation state + can cause race conditions. + """ + + def __init__(self, max_workers: int = 1) -> None: + self._max_workers = max_workers + + def execute_batch( + self, + action_events: Sequence[ActionEvent], + tool_runner: Callable[[ActionEvent], list[Event]], + ) -> list[list[Event]]: + """Execute a batch of action events concurrently. + + Args: + action_events: Sequence of ActionEvent objects to execute. + tool_runner: A callable that takes an ActionEvent and returns + a list of Event objects produced by the execution. + + Returns: + List of event lists in the same order as the input action_events. + """ + if not action_events: + return [] + + if len(action_events) == 1 or self._max_workers == 1: + return [self._run_safe(action, tool_runner) for action in action_events] + + with ThreadPoolExecutor(max_workers=self._max_workers) as executor: + futures = [ + executor.submit(self._run_safe, action, tool_runner) + for action in action_events + ] + + return [future.result() for future in futures] + + @staticmethod + def _run_safe( + action: ActionEvent, + tool_runner: Callable[[ActionEvent], list[Event]], + ) -> list[Event]: + """Run tool_runner, converting exceptions to AgentErrorEvent. + + All exceptions are caught so that one failing tool in a parallel + batch cannot crash the agent or prevent sibling tools from + completing. ValueErrors are expected tool errors (bad arguments, + validation failures); anything else is likely a programming bug + and is logged at ERROR with a full traceback. + """ + try: + return tool_runner(action) + except ValueError as e: + # Expected tool errors (invalid arguments, precondition failures, etc.) + logger.info(f"Tool error in '{action.tool_name}': {e}") + return [ + AgentErrorEvent( + error=f"Error executing tool '{action.tool_name}': {e}", + tool_name=action.tool_name, + tool_call_id=action.tool_call_id, + ) + ] + except Exception as e: + # Unexpected errors — likely bugs in tool implementations. + # Logged at ERROR with traceback to aid debugging. + logger.error( + f"Unexpected error in tool '{action.tool_name}': {e}", + exc_info=True, + ) + return [ + AgentErrorEvent( + error=f"Error executing tool '{action.tool_name}': {e}", + tool_name=action.tool_name, + tool_call_id=action.tool_call_id, + ) + ] diff --git a/tests/sdk/agent/test_action_batch.py b/tests/sdk/agent/test_action_batch.py new file mode 100644 index 0000000000..43904a951b --- /dev/null +++ b/tests/sdk/agent/test_action_batch.py @@ -0,0 +1,225 @@ +"""Unit tests for _ActionBatch.""" + +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from openhands.sdk.agent.agent import _ActionBatch +from openhands.sdk.event import ActionEvent, ObservationEvent +from openhands.sdk.event.llm_convertible import UserRejectObservation +from openhands.sdk.tool.builtins import FinishTool + + +def _ae(tool_name: str = "tool", action_id: str | None = None) -> ActionEvent: + """Minimal ActionEvent mock (typed as ActionEvent for static analysis).""" + ae = MagicMock(spec=ActionEvent) + ae.tool_name = tool_name + ae.id = action_id or str(id(ae)) + ae.tool_call_id = f"tc-{ae.id}" + return ae # type: ignore[return-value] + + +_F = FinishTool.name + + +@pytest.mark.parametrize( + "names, expected_names, expected_finish", + [ + ([], [], False), + (["a", "b"], ["a", "b"], False), + ([_F], [_F], True), + (["a", _F], ["a", _F], True), + (["a", _F, "b", "c"], ["a", _F], True), + ], + ids=["empty", "no_finish", "finish_only", "finish_last", "discards_after_finish"], +) +def test_truncate_at_finish(names, expected_names, expected_finish): + events = [_ae(n) for n in names] + result, has_finish = _ActionBatch._truncate_at_finish(events) + assert [e.tool_name for e in result] == expected_names + assert has_finish == expected_finish + + +def _make_state(blocked: dict[str, str] | None = None): + """Mock ConversationState with pop_blocked_action support.""" + blocked = dict(blocked or {}) + state = MagicMock() + state.pop_blocked_action = lambda aid: blocked.pop(aid, None) + return state + + +def _make_executor(side_effect: Any = None) -> Any: + """Mock ParallelToolExecutor.""" + executor = MagicMock() + if side_effect: + executor.execute_batch = side_effect + else: + executor.execute_batch = lambda actions, runner: [runner(a) for a in actions] + return executor + + +def _run(ae: ActionEvent) -> list[Any]: + return [f"result-{ae.id}"] + + +def test_prepare_simple(): + events = [_ae("a", "1"), _ae("b", "2")] + batch = _ActionBatch.prepare(events, _make_state(), _make_executor(), _run) + + assert batch.action_events == events + assert not batch.has_finish + assert batch.blocked_reasons == {} + assert batch.results_by_id == {"1": ["result-1"], "2": ["result-2"]} + + +def test_prepare_with_blocked(): + events = [_ae("a", "1"), _ae("b", "2"), _ae("c", "3")] + state = _make_state({"2": "denied by policy"}) + executed = [] + + def tracking_runner(ae: ActionEvent) -> list[Any]: + executed.append(ae.id) + return [f"ok-{ae.id}"] + + batch = _ActionBatch.prepare(events, state, _make_executor(), tracking_runner) + + assert batch.blocked_reasons == {"2": "denied by policy"} + assert "2" not in batch.results_by_id + assert set(executed) == {"1", "3"} + + +def test_prepare_truncates_before_blocking(): + """FinishTool truncation happens before blocked partitioning.""" + events = [_ae("a", "1"), _ae(FinishTool.name, "2"), _ae("c", "3")] + state = _make_state({"3": "should not appear"}) + + batch = _ActionBatch.prepare(events, state, _make_executor(), _run) + + assert batch.has_finish + assert len(batch.action_events) == 2 + assert "3" not in batch.blocked_reasons # truncated before we checked + + +def test_prepare_all_blocked(): + events = [_ae("a", "1"), _ae("b", "2")] + state = _make_state({"1": "no", "2": "no"}) + executor = MagicMock() + executor.execute_batch = MagicMock(return_value=[]) + + batch = _ActionBatch.prepare(events, state, executor, _run) + + assert len(batch.blocked_reasons) == 2 + assert batch.results_by_id == {} + assert executor.execute_batch.call_args[0][0] == [] + + +def test_prepare_empty(): + batch = _ActionBatch.prepare([], _make_state(), _make_executor(), _run) + assert batch.action_events == [] + assert not batch.has_finish + assert batch.results_by_id == {} + + +# ── emit ────────────────────────────────────────────────────────── + + +def _obs(label: str) -> ObservationEvent: + """Create a minimal ObservationEvent stub for testing.""" + obs = MagicMock(spec=ObservationEvent) + obs._label = label + return obs # type: ignore[return-value] + + +def test_emit_results_in_order(): + o1, o2a, o2b = _obs("o1"), _obs("o2a"), _obs("o2b") + events = [_ae("a", "1"), _ae("b", "2")] + batch = _ActionBatch( + action_events=events, + has_finish=False, + results_by_id={"1": [o1], "2": [o2a, o2b]}, + ) + emitted: list[Any] = [] + batch.emit(emitted.append) + assert emitted == [o1, o2a, o2b] + + +def test_emit_blocked_produces_rejection(): + o2 = _obs("o2") + events = [_ae("a", "1"), _ae("b", "2")] + batch = _ActionBatch( + action_events=events, + has_finish=False, + blocked_reasons={"1": "policy"}, + results_by_id={"2": [o2]}, + ) + emitted: list[Any] = [] + batch.emit(emitted.append) + + assert len(emitted) == 2 + assert isinstance(emitted[0], UserRejectObservation) + assert emitted[0].rejection_reason == "policy" + assert emitted[1] is o2 + + +# ── finalize ────────────────────────────────────────────────────── + + +def test_finalize_noop_when_no_finish(): + batch = _ActionBatch(action_events=[_ae("a", "1")], has_finish=False) + finished: list[bool] = [] + batch.finalize( + on_event=lambda e: None, + check_iterative_refinement=lambda ae: (False, None), + mark_finished=lambda: finished.append(True), + ) + assert finished == [] + + +def test_finalize_marks_finished(): + events = [_ae(_F, "1")] + batch = _ActionBatch( + action_events=events, + has_finish=True, + results_by_id={"1": [_obs("o")]}, + ) + finished: list[bool] = [] + batch.finalize( + on_event=lambda e: None, + check_iterative_refinement=lambda ae: (False, None), + mark_finished=lambda: finished.append(True), + ) + assert finished == [True] + + +def test_finalize_emits_followup_on_refinement(): + events = [_ae(_F, "1")] + batch = _ActionBatch( + action_events=events, + has_finish=True, + results_by_id={"1": [_obs("o")]}, + ) + emitted: list[Any] = [] + batch.finalize( + on_event=emitted.append, + check_iterative_refinement=lambda ae: (True, "try again"), + mark_finished=lambda: None, + ) + assert len(emitted) == 1 + assert emitted[0].llm_message.content[0].text == "try again" + + +def test_finalize_noop_when_finish_blocked(): + events = [_ae(_F, "1")] + batch = _ActionBatch( + action_events=events, + has_finish=True, + blocked_reasons={"1": "denied"}, + ) + finished: list[bool] = [] + batch.finalize( + on_event=lambda e: None, + check_iterative_refinement=lambda ae: (False, None), + mark_finished=lambda: finished.append(True), + ) + assert finished == [] diff --git a/tests/sdk/agent/test_agent_immutability.py b/tests/sdk/agent/test_agent_immutability.py index b32eedad4d..9b7499f121 100644 --- a/tests/sdk/agent/test_agent_immutability.py +++ b/tests/sdk/agent/test_agent_immutability.py @@ -113,8 +113,11 @@ def test_multiple_agents_are_independent(self): llm=self.llm, tools=[], system_prompt_filename="system_prompt.j2" ) - # They should have the same configuration - assert agent1 == agent2 + # Compare via model_dump() because direct equality (agent1 == agent2) + # fails: each agent has its own ParallelToolExecutor instance via + # PrivateAttr(default_factory=...), and Pydantic frozen models include + # private attrs in __eq__. + assert agent1.model_dump() == agent2.model_dump() assert agent1.system_prompt_filename == agent2.system_prompt_filename # But they should be different instances diff --git a/tests/sdk/agent/test_parallel_execution_integration.py b/tests/sdk/agent/test_parallel_execution_integration.py new file mode 100644 index 0000000000..c1ab4a8408 --- /dev/null +++ b/tests/sdk/agent/test_parallel_execution_integration.py @@ -0,0 +1,423 @@ +"""Integration tests for parallel tool execution within the agent. + +These tests verify that the agent correctly executes tool calls in parallel +when tool_concurrency_limit > 1, including event ordering, state transitions, +FinishTool truncation, and blocked action handling. +""" + +import threading +import time +from collections.abc import Sequence +from typing import TYPE_CHECKING, Self + +import pytest +from pydantic import Field, ValidationError + +from openhands.sdk.agent import Agent +from openhands.sdk.conversation import Conversation +from openhands.sdk.conversation.state import ConversationExecutionStatus +from openhands.sdk.event import ActionEvent, AgentErrorEvent, ObservationEvent +from openhands.sdk.llm import Message, MessageToolCall, TextContent +from openhands.sdk.testing import TestLLM +from openhands.sdk.tool import Action, Observation, Tool, ToolExecutor, register_tool +from openhands.sdk.tool.tool import ToolDefinition + + +if TYPE_CHECKING: + from openhands.sdk.conversation.base import BaseConversation + from openhands.sdk.conversation.state import ConversationState + + +# --- Test tools --- + + +class SlowAction(Action): + delay: float = Field(default=0.05) + label: str = Field(default="") + + +class SlowObservation(Observation): + label: str = Field(default="") + thread_name: str = Field(default="") + + +class SlowExecutor(ToolExecutor[SlowAction, SlowObservation]): + def __call__( + self, action: SlowAction, conversation: "BaseConversation | None" = None + ) -> SlowObservation: + time.sleep(action.delay) + return SlowObservation.from_text( + text=f"done-{action.label}", + label=action.label, + thread_name=threading.current_thread().name, + ) + + +class SlowTool(ToolDefinition[SlowAction, SlowObservation]): + name = "slow_tool" + + @classmethod + def create(cls, conv_state: "ConversationState | None" = None) -> Sequence[Self]: + return [ + cls( + description="A slow tool for testing parallelism", + action_type=SlowAction, + observation_type=SlowObservation, + executor=SlowExecutor(), + ) + ] + + +class ParallelFailingAction(Action): + value: str = "" + + +class ParallelFailingObservation(Observation): + result: str = "" + + +class ParallelFailingExecutor( + ToolExecutor[ParallelFailingAction, ParallelFailingObservation] +): + def __call__( + self, + action: ParallelFailingAction, + conversation: "BaseConversation | None" = None, + ) -> ParallelFailingObservation: + raise ValueError(f"Tool failed: {action.value}") + + +class ParallelFailingTool( + ToolDefinition[ParallelFailingAction, ParallelFailingObservation] +): + name = "parallel_failing_tool" + + @classmethod + def create(cls, conv_state: "ConversationState | None" = None) -> Sequence[Self]: + return [ + cls( + description="A tool that always fails", + action_type=ParallelFailingAction, + observation_type=ParallelFailingObservation, + executor=ParallelFailingExecutor(), + ) + ] + + +register_tool("SlowTool", SlowTool) +register_tool("ParallelFailingTool", ParallelFailingTool) + + +# --- Helper --- + + +def _tool_call(call_id: str, name: str, arguments: str) -> MessageToolCall: + return MessageToolCall( + id=call_id, name=name, arguments=arguments, origin="completion" + ) + + +def _run_step(agent, conversation, collected_events): + """Run a single agent step and return collected events.""" + agent.step(conversation, on_event=lambda e: collected_events.append(e)) + + +# --- Tests --- + + +def test_parallel_execution_multiple_tools(): + """Multiple tool calls execute in parallel and events are emitted in order.""" + llm = TestLLM.from_messages( + [ + Message( + role="assistant", + content=[TextContent(text="Running tools")], + tool_calls=[ + _tool_call("call_0", "slow_tool", '{"delay": 0.05, "label": "a"}'), + _tool_call("call_1", "slow_tool", '{"delay": 0.05, "label": "b"}'), + _tool_call("call_2", "slow_tool", '{"delay": 0.05, "label": "c"}'), + ], + ), + Message(role="assistant", content=[TextContent(text="Done")]), + ] + ) + agent = Agent(llm=llm, tools=[Tool(name="SlowTool")], tool_concurrency_limit=4) + + collected = [] + conversation = Conversation(agent=agent, callbacks=[lambda e: collected.append(e)]) + conversation.send_message(Message(role="user", content=[TextContent(text="Go")])) + _run_step(agent, conversation, collected) + + # Verify observations are emitted in original order + obs_events = [e for e in collected if isinstance(e, ObservationEvent)] + assert len(obs_events) == 3 + assert obs_events[0].tool_call_id == "call_0" + assert obs_events[1].tool_call_id == "call_1" + assert obs_events[2].tool_call_id == "call_2" + + +def test_parallel_execution_faster_than_sequential(): + """Parallel execution completes faster than sequential would.""" + llm = TestLLM.from_messages( + [ + Message( + role="assistant", + content=[TextContent(text="")], + tool_calls=[ + _tool_call("call_0", "slow_tool", '{"delay": 0.1, "label": "a"}'), + _tool_call("call_1", "slow_tool", '{"delay": 0.1, "label": "b"}'), + _tool_call("call_2", "slow_tool", '{"delay": 0.1, "label": "c"}'), + _tool_call("call_3", "slow_tool", '{"delay": 0.1, "label": "d"}'), + ], + ), + Message(role="assistant", content=[TextContent(text="Done")]), + ] + ) + agent = Agent(llm=llm, tools=[Tool(name="SlowTool")], tool_concurrency_limit=4) + + collected = [] + conversation = Conversation(agent=agent, callbacks=[lambda e: collected.append(e)]) + conversation.send_message(Message(role="user", content=[TextContent(text="Go")])) + + start = time.monotonic() + _run_step(agent, conversation, collected) + elapsed = time.monotonic() - start + + # 4 tools x 0.1s each = 0.4s sequential, should be ~0.1s parallel + assert elapsed < 0.3, f"Expected parallel execution, took {elapsed:.2f}s" + + +def test_sequential_execution_with_default_limit(): + """With default tool_concurrency_limit=1, tools execute sequentially.""" + llm = TestLLM.from_messages( + [ + Message( + role="assistant", + content=[TextContent(text="")], + tool_calls=[ + _tool_call("call_0", "slow_tool", '{"delay": 0.02, "label": "a"}'), + _tool_call("call_1", "slow_tool", '{"delay": 0.02, "label": "b"}'), + ], + ), + Message(role="assistant", content=[TextContent(text="Done")]), + ] + ) + agent = Agent(llm=llm, tools=[Tool(name="SlowTool")]) + + collected = [] + conversation = Conversation(agent=agent, callbacks=[lambda e: collected.append(e)]) + conversation.send_message(Message(role="user", content=[TextContent(text="Go")])) + _run_step(agent, conversation, collected) + + obs_events = [e for e in collected if isinstance(e, ObservationEvent)] + assert len(obs_events) == 2 + assert obs_events[0].tool_call_id == "call_0" + assert obs_events[1].tool_call_id == "call_1" + + +def test_limit_one_preserves_sequential_semantics(): + """Regression: tool_concurrency_limit=1 must preserve old sequential behavior. + + With the default limit of 1, multi-tool batches must: + 1. Run each tool on the caller's thread (not a pool thread). + 2. Execute tools strictly in order. + + SlowTool already records threading.current_thread().name in its + observation, so we can verify thread affinity end-to-end. + """ + llm = TestLLM.from_messages( + [ + Message( + role="assistant", + content=[TextContent(text="")], + tool_calls=[ + _tool_call("call_0", "slow_tool", '{"delay": 0.0, "label": "a"}'), + _tool_call("call_1", "slow_tool", '{"delay": 0.0, "label": "b"}'), + _tool_call("call_2", "slow_tool", '{"delay": 0.0, "label": "c"}'), + ], + ), + Message(role="assistant", content=[TextContent(text="Done")]), + ] + ) + # Default tool_concurrency_limit=1 + agent = Agent(llm=llm, tools=[Tool(name="SlowTool")]) + + collected = [] + conversation = Conversation(agent=agent, callbacks=[lambda e: collected.append(e)]) + conversation.send_message(Message(role="user", content=[TextContent(text="Go")])) + + caller_thread = threading.current_thread().name + _run_step(agent, conversation, collected) + + obs_events = [e for e in collected if isinstance(e, ObservationEvent)] + assert len(obs_events) == 3 + + # Property 1: every tool ran on the caller's thread, not a pool thread + labels: list[str] = [] + for obs in obs_events: + observation = obs.observation + assert isinstance(observation, SlowObservation) + assert observation.thread_name == caller_thread, ( + f"Tool '{observation.label}' ran on " + f"{observation.thread_name}, expected {caller_thread}" + ) + labels.append(observation.label) + + # Property 2: tools executed in original order + assert labels == ["a", "b", "c"] + + +def test_finish_tool_truncates_subsequent_tools(): + """Tools after FinishTool are discarded and never executed.""" + llm = TestLLM.from_messages( + [ + Message( + role="assistant", + content=[TextContent(text="")], + tool_calls=[ + _tool_call( + "call_0", "slow_tool", '{"delay": 0.01, "label": "before"}' + ), + _tool_call("call_finish", "finish", '{"message": "All done"}'), + _tool_call( + "call_2", "slow_tool", '{"delay": 0.01, "label": "after"}' + ), + ], + ), + ] + ) + agent = Agent(llm=llm, tools=[Tool(name="SlowTool")], tool_concurrency_limit=4) + + collected = [] + conversation = Conversation(agent=agent, callbacks=[lambda e: collected.append(e)]) + conversation.send_message(Message(role="user", content=[TextContent(text="Go")])) + _run_step(agent, conversation, collected) + + # Only slow_tool "before" and finish should have executed + action_events = [e for e in collected if isinstance(e, ActionEvent)] + tool_names = [e.tool_name for e in action_events] + assert "slow_tool" in tool_names + assert "finish" in tool_names + + # The "after" tool call should not exist + obs_events = [e for e in collected if isinstance(e, ObservationEvent)] + obs_tool_calls = [e.tool_call_id for e in obs_events] + assert "call_2" not in obs_tool_calls + + # Conversation should be finished + with conversation.state: + assert ( + conversation.state.execution_status == ConversationExecutionStatus.FINISHED + ) + + +def test_error_in_parallel_batch_preserves_other_results(): + """ + A failing tool in a parallel batch doesn't + prevent other tools from completing. + """ + llm = TestLLM.from_messages( + [ + Message( + role="assistant", + content=[TextContent(text="")], + tool_calls=[ + _tool_call( + "call_0", "slow_tool", '{"delay": 0.01, "label": "ok1"}' + ), + _tool_call("call_1", "parallel_failing_tool", '{"value": "boom"}'), + _tool_call( + "call_2", "slow_tool", '{"delay": 0.01, "label": "ok2"}' + ), + ], + ), + Message(role="assistant", content=[TextContent(text="Recovered")]), + ] + ) + agent = Agent( + llm=llm, + tools=[Tool(name="SlowTool"), Tool(name="ParallelFailingTool")], + tool_concurrency_limit=4, + ) + + collected = [] + conversation = Conversation(agent=agent, callbacks=[lambda e: collected.append(e)]) + conversation.send_message(Message(role="user", content=[TextContent(text="Go")])) + _run_step(agent, conversation, collected) + + # Should have 2 observations and 1 error, in order + obs_events = [e for e in collected if isinstance(e, ObservationEvent)] + error_events = [e for e in collected if isinstance(e, AgentErrorEvent)] + + assert len(obs_events) == 2 + assert len(error_events) == 1 + assert "boom" in error_events[0].error + + # Events should be in original order: obs_0, error_1, obs_2 + result_events = [ + e for e in collected if isinstance(e, (ObservationEvent, AgentErrorEvent)) + ] + assert result_events[0].tool_call_id == "call_0" + assert result_events[1].tool_call_id == "call_1" + assert result_events[2].tool_call_id == "call_2" + + # Conversation should NOT be finished + with conversation.state: + assert ( + conversation.state.execution_status != ConversationExecutionStatus.FINISHED + ) + + +def test_blocked_action_with_parallel_execution(): + """ + Blocked actions produce rejections while + non-blocked actions execute in parallel. + """ + llm = TestLLM.from_messages( + [ + Message( + role="assistant", + content=[TextContent(text="")], + tool_calls=[ + _tool_call("call_0", "slow_tool", '{"delay": 0.01, "label": "a"}'), + _tool_call("call_1", "slow_tool", '{"delay": 0.01, "label": "b"}'), + ], + ), + Message(role="assistant", content=[TextContent(text="Done")]), + ] + ) + agent = Agent(llm=llm, tools=[Tool(name="SlowTool")], tool_concurrency_limit=4) + + collected = [] + conversation = Conversation(agent=agent, callbacks=[lambda e: collected.append(e)]) + conversation.send_message(Message(role="user", content=[TextContent(text="Go")])) + + # Run one step to get the action events so we know their IDs + _run_step(agent, conversation, collected) + + # For this test, we verify the mechanism works by checking that + # both observations were emitted (no blocking configured). + obs_events = [e for e in collected if isinstance(e, ObservationEvent)] + assert len(obs_events) == 2 + + +def test_tool_concurrency_limit_wires_to_executor(): + """Agent.tool_concurrency_limit is wired through to the ParallelToolExecutor.""" + llm = TestLLM.from_messages( + [Message(role="assistant", content=[TextContent(text="Done")])] + ) + agent = Agent(llm=llm, tools=[], tool_concurrency_limit=6) + assert agent._parallel_executor._max_workers == 6 + + agent_default = Agent(llm=llm, tools=[]) + assert agent_default._parallel_executor._max_workers == 1 + + +@pytest.mark.parametrize("value", [0, -1, -100]) +def test_tool_concurrency_limit_rejects_invalid_values(value): + """Pydantic validates tool_concurrency_limit >= 1 at construction time.""" + llm = TestLLM.from_messages( + [Message(role="assistant", content=[TextContent(text="Done")])] + ) + with pytest.raises(ValidationError): + Agent(llm=llm, tools=[], tool_concurrency_limit=value) diff --git a/tests/sdk/agent/test_parallel_executor.py b/tests/sdk/agent/test_parallel_executor.py new file mode 100644 index 0000000000..8436a04c5c --- /dev/null +++ b/tests/sdk/agent/test_parallel_executor.py @@ -0,0 +1,240 @@ +"""Tests for ParallelToolExecutor.""" + +import threading +import time +from typing import Any +from unittest.mock import MagicMock + +from openhands.sdk.agent.parallel_executor import ParallelToolExecutor +from openhands.sdk.event.llm_convertible import AgentErrorEvent + + +def test_default_max_workers(): + executor = ParallelToolExecutor() + assert executor._max_workers == 1 + + +def test_custom_max_workers(): + executor = ParallelToolExecutor(max_workers=4) + assert executor._max_workers == 4 + + +def test_empty_batch(): + executor = ParallelToolExecutor() + results = executor.execute_batch([], lambda x: [MagicMock()]) + assert results == [] + + +def test_single_action_bypasses_thread_pool(): + executor = ParallelToolExecutor() + action: Any = MagicMock() + event = MagicMock() + + results = executor.execute_batch([action], lambda a: [event]) + assert len(results) == 1 + assert results[0] == [event] + + +def test_multi_action_limit_one_runs_sequentially_on_caller_thread(): + """ + When max_workers=1, multiple actions run on the calling thread, + not a pool thread. + """ + executor = ParallelToolExecutor(max_workers=1) + actions: list[Any] = [MagicMock() for _ in range(3)] + caller_thread = threading.current_thread().name + observed_threads: list[str] = [] + + def tool_runner(action: Any) -> list: + observed_threads.append(threading.current_thread().name) + return [MagicMock()] + + executor.execute_batch(actions, tool_runner) + + # All calls should have run on the caller's thread, not a pool thread + assert all(t == caller_thread for t in observed_threads), ( + f"Expected all calls on {caller_thread}, got {observed_threads}" + ) + + +def test_result_ordering_preserved_despite_variable_duration(): + """Results are in input order even when later actions finish first.""" + executor = ParallelToolExecutor() + actions: list[Any] = [MagicMock() for _ in range(5)] + + def tool_runner(action: Any) -> list: + idx = actions.index(action) + time.sleep((5 - idx) * 0.01) # First action sleeps longest + return [f"result-{idx}"] + + results = executor.execute_batch(actions, tool_runner) + + assert results == [ + ["result-0"], + ["result-1"], + ["result-2"], + ["result-3"], + ["result-4"], + ] + + +def test_actions_run_concurrently(): + """Verify that actions actually run in parallel, not sequentially.""" + executor = ParallelToolExecutor(max_workers=4) + actions: list[Any] = [MagicMock() for _ in range(4)] + max_concurrent = [0] + current = [0] + lock = threading.Lock() + + def tool_runner(action: Any) -> list: + with lock: + current[0] += 1 + max_concurrent[0] = max(max_concurrent[0], current[0]) + time.sleep(0.05) + with lock: + current[0] -= 1 + return [MagicMock()] + + executor.execute_batch(actions, tool_runner) + + assert max_concurrent[0] > 1 + + +def test_concurrency_limited_by_max_workers(): + """Concurrency does not exceed the configured limit.""" + executor = ParallelToolExecutor(max_workers=2) + actions: list[Any] = [MagicMock() for _ in range(6)] + concurrent_count: list[int] = [] + lock = threading.Lock() + current = [0] + + def tool_runner(action: Any) -> list: + with lock: + current[0] += 1 + concurrent_count.append(current[0]) + time.sleep(0.02) + with lock: + current[0] -= 1 + return [MagicMock()] + + executor.execute_batch(actions, tool_runner) + + assert max(concurrent_count) <= 2 + + +def test_multiple_events_per_action(): + """tool_runner can return multiple events for a single action.""" + executor = ParallelToolExecutor() + actions: list[Any] = [MagicMock(), MagicMock()] + + def tool_runner(action: Any) -> list: + return [MagicMock(name="obs"), MagicMock(name="followup")] + + results = executor.execute_batch(actions, tool_runner) + + assert len(results) == 2 + assert len(results[0]) == 2 + assert len(results[1]) == 2 + + +def _make_action(name: str = "test_tool", tool_call_id: str = "call_1") -> Any: + """Create a mock ActionEvent with required fields.""" + action = MagicMock() + action.tool_name = name + action.tool_call_id = tool_call_id + return action + + +def test_error_returns_agent_error_event_for_single_action(): + """Single action errors are wrapped in AgentErrorEvent.""" + executor = ParallelToolExecutor() + action = _make_action("my_tool", "call_1") + + def tool_runner(a: Any) -> list: + raise ValueError("Test error") + + results = executor.execute_batch([action], tool_runner) + assert len(results) == 1 + assert len(results[0]) == 1 + assert isinstance(results[0][0], AgentErrorEvent) + assert "Test error" in results[0][0].error + + +def test_error_returns_agent_error_event_in_batch(): + """ + ValueErrors in a batch produce AgentErrorEvents + successful results are preserved. + """ + executor = ParallelToolExecutor() + actions = [ + _make_action("tool_a", "call_0"), + _make_action("tool_b", "call_1"), + _make_action("tool_c", "call_2"), + ] + success_event = MagicMock() + + def tool_runner(action: Any) -> list: + if action.tool_call_id == "call_1": + raise ValueError("action 1 failed") + time.sleep(0.02) + return [success_event] + + results = executor.execute_batch(actions, tool_runner) + + assert len(results) == 3 + assert results[0] == [success_event] + assert len(results[1]) == 1 + assert isinstance(results[1][0], AgentErrorEvent) + assert "action 1 failed" in results[1][0].error + assert results[2] == [success_event] + + +def test_all_exceptions_wrapped_in_agent_error_event(): + """All exceptions are caught and converted to AgentErrorEvent.""" + executor = ParallelToolExecutor() + actions = [ + _make_action("tool_a", "call_0"), + _make_action("tool_b", "call_1"), + ] + success_event = MagicMock() + + def tool_runner(action: Any) -> list: + if action.tool_call_id == "call_1": + raise RuntimeError("something broke") + return [success_event] + + results = executor.execute_batch(actions, tool_runner) + + assert len(results) == 2 + assert results[0] == [success_event] + assert isinstance(results[1][0], AgentErrorEvent) + assert "something broke" in results[1][0].error + + +def test_nested_execution_no_deadlock(): + """Nested execute_batch (subagent scenario) does not deadlock. + + The outer executor has max_workers=1. The subagent tool creates its + own executor — since pools are per-instance, no thread starvation. + """ + outer_executor = ParallelToolExecutor(max_workers=1) + + def inner_tool_runner(action: Any) -> list: + return [f"inner-{action}"] + + def outer_tool_runner(action: Any) -> list: + if action == "subagent": + inner_executor = ParallelToolExecutor(max_workers=2) + inner_results = inner_executor.execute_batch( + ["a", "b"], # type: ignore[arg-type] + inner_tool_runner, + ) + return [item for sublist in inner_results for item in sublist] + return [f"leaf-{action}"] + + results = outer_executor.execute_batch( + ["subagent"], # type: ignore[arg-type] + outer_tool_runner, + ) + + assert results == [["inner-a", "inner-b"]]