diff --git a/src/open_agent_sdk/agent.py b/src/open_agent_sdk/agent.py index 70cc9cf..1757abe 100644 --- a/src/open_agent_sdk/agent.py +++ b/src/open_agent_sdk/agent.py @@ -211,6 +211,7 @@ async def query( include_partial_messages=opts.include_partial_messages, thinking=opts.thinking, json_schema=opts.json_schema, + abort_signal=opts.abort_signal, debug=opts.debug, extra_args=opts.extra_args, betas=opts.betas, @@ -260,9 +261,10 @@ def clear(self) -> None: self._history.clear() async def interrupt(self) -> None: - """Abort current query.""" - # Would need abort controller integration - pass + """Abort current query by setting the abort signal event.""" + signal = self._options.abort_signal + if signal is not None and hasattr(signal, "set"): + signal.set() async def set_model(self, model: str) -> None: """Switch model during session.""" diff --git a/src/open_agent_sdk/engine.py b/src/open_agent_sdk/engine.py index f149ab4..5e19413 100644 --- a/src/open_agent_sdk/engine.py +++ b/src/open_agent_sdk/engine.py @@ -191,6 +191,12 @@ async def submit_message( turns_remaining -= 1 self._turn_count += 1 + if config.abort_signal is not None and hasattr(config.abort_signal, "is_set") and config.abort_signal.is_set(): + yield self._make_result_event( + SDKResultStatus.ERROR_DURING_EXECUTION, start_time + ) + return + # Check budget if config.max_budget_usd and self._total_cost >= config.max_budget_usd: yield self._make_result_event( @@ -395,7 +401,7 @@ async def _call_api( async def _do_call(): return await provider.create_message(params) - response = await with_retry(_do_call) + response = await with_retry(_do_call, abort_signal=config.abort_signal) # Wrap CreateMessageResponse in a duck-typed object compatible with # the rest of the engine (which expects response.content as list of @@ -409,7 +415,7 @@ async def _execute_tools( ) -> list[ToolResult]: """Execute tool calls, concurrent for read-only, serial for mutations.""" config = self._config - context = ToolContext(cwd=config.cwd, env=config.env) + context = ToolContext(cwd=config.cwd, env=config.env, abort_signal=config.abort_signal) # Partition into read-only (concurrent) and mutations (serial) read_only: list[dict[str, Any]] = [] diff --git a/src/open_agent_sdk/utils/retry.py b/src/open_agent_sdk/utils/retry.py index 5956ccd..7ea2846 100644 --- a/src/open_agent_sdk/utils/retry.py +++ b/src/open_agent_sdk/utils/retry.py @@ -75,6 +75,8 @@ async def with_retry( last_error: Exception | None = None for attempt in range(cfg.max_retries + 1): + if abort_signal is not None and hasattr(abort_signal, "is_set") and abort_signal.is_set(): + raise RuntimeError("Aborted") try: return await fn() except Exception as e: