diff --git a/packages/kosong/src/kosong/__init__.py b/packages/kosong/src/kosong/__init__.py index 7dfdd9414..12ff9c5ef 100644 --- a/packages/kosong/src/kosong/__init__.py +++ b/packages/kosong/src/kosong/__init__.py @@ -75,7 +75,7 @@ async def main() -> None: from loguru import logger -from kosong._generate import GenerateResult, generate +from kosong._generate import GenerateCancelled, GenerateResult, generate from kosong.chat_provider import ChatProvider, ChatProviderError, StreamedMessagePart, TokenUsage from kosong.message import Message, ToolCall from kosong.tooling import ToolResult, ToolResultFuture, Toolset @@ -98,9 +98,34 @@ async def main() -> None: "GenerateResult", "step", "StepResult", + "StepCancelled", ] +class StepCancelled(asyncio.CancelledError): + """CancelledError carrying a partial :class:`StepResult`. + + Raised by :func:`step` when the underlying generation is cancelled. The + partial result records: + + - ``message``: everything streamed so far (including a best-effort flush + of the pending part, so its ``tool_calls`` mirrors what the UI saw). + - ``tool_calls``: aligned with ``message.tool_calls`` so callers can + iterate them as the canonical "what UI saw" list. + - ``_tool_result_futures``: only the tools whose ``on_tool_call`` had + already fired before cancellation — these may still be running. The + caller decides whether to await or cancel them. + + Subclassing :class:`asyncio.CancelledError` keeps the cancellation + contract intact: callers that don't care about the partial result still + see a plain cancel. + """ + + def __init__(self, partial: "StepResult"): + super().__init__() + self.partial = partial + + async def step( chat_provider: ChatProvider, system_prompt: str, @@ -127,7 +152,9 @@ async def step( APIStatusError: If the API returns a status code of 4xx or 5xx. APIEmptyResponseError: If the API returns an empty response. ChatProviderError: If any other recognized chat provider error occurs. - asyncio.CancelledError: If the step is cancelled. + StepCancelled: When the step is cancelled by the caller. Carries a + partial :class:`StepResult` so callers can persist a well-formed + history (every TUI-visible tool_call paired with a tool_result). """ tool_calls: list[ToolCall] = [] @@ -163,7 +190,22 @@ async def on_tool_call(tool_call: ToolCall): on_message_part=on_message_part, on_tool_call=on_tool_call, ) - except (ChatProviderError, asyncio.CancelledError): + except GenerateCancelled as cancelled: + # Streaming was interrupted. The partial message reflects everything + # the UI saw. Align tool_calls with message.tool_calls so callers + # have a single canonical iteration list — message.tool_calls may + # include a flushed pending ToolCall that never fired on_tool_call + # (and hence has no future in tool_result_futures). + merged_tool_calls = list(cancelled.message.tool_calls or []) + partial = StepResult( + id=None, + message=cancelled.message, + usage=None, + tool_calls=merged_tool_calls, + _tool_result_futures=tool_result_futures, + ) + raise StepCancelled(partial) from None + except ChatProviderError: # cancel all the futures to avoid hanging tasks for future in tool_result_futures.values(): future.remove_done_callback(future_done_callback) diff --git a/packages/kosong/src/kosong/_generate.py b/packages/kosong/src/kosong/_generate.py index 1eb45013e..6c3d3d224 100644 --- a/packages/kosong/src/kosong/_generate.py +++ b/packages/kosong/src/kosong/_generate.py @@ -1,3 +1,4 @@ +import asyncio from collections.abc import Sequence from dataclasses import dataclass @@ -14,6 +15,23 @@ from kosong.utils.aio import Callback, callback +class GenerateCancelled(asyncio.CancelledError): + """CancelledError carrying the partial message accumulated up to the cancel point. + + Raised by :func:`generate` when its streaming loop is cancelled. The + ``message`` field reflects everything observed via the stream so far, with + any pending part flushed in best-effort fashion — including incomplete + ToolCall arguments. No ``on_tool_call`` callback is fired for the flushed + pending part (we don't want to start a tool the user just cancelled), so + callers must treat any tool_call in ``message`` without a paired result + future as "interrupted". + """ + + def __init__(self, message: Message): + super().__init__() + self.message = message + + async def generate( chat_provider: ChatProvider, system_prompt: str, @@ -45,31 +63,58 @@ async def generate( APIStatusError: If the API returns a status code of 4xx or 5xx. APIEmptyResponseError: If the API returns an empty response. ChatProviderError: If any other recognized chat provider error occurs. + GenerateCancelled: When the streaming loop is cancelled. Carries the + partial message observed so far so callers can decide whether to + persist it. """ message = Message(role="assistant", content=[]) pending_part: StreamedMessagePart | None = None # message part that is currently incomplete logger.trace("Generating with history: {history}", history=history) stream = await chat_provider.generate(system_prompt, tools, history) - async for part in stream: - logger.trace("Received part: {part}", part=part) - if on_message_part: - await callback(on_message_part, part.model_copy(deep=True)) - - if pending_part is None: - pending_part = part - elif not pending_part.merge_in_place(part): # try merge into the pending part - # unmergeable part must push the pending part to the buffer + try: + async for part in stream: + logger.trace("Received part: {part}", part=part) + if on_message_part: + await callback(on_message_part, part.model_copy(deep=True)) + + if pending_part is None: + pending_part = part + elif not pending_part.merge_in_place(part): # try merge into the pending part + # Unmergeable part: flush the previously pending one. + # + # Invariant: ``pending_part`` is non-None iff that part has + # NOT yet been appended to ``message``. Clearing it BEFORE + # the await means a CancelledError raised inside + # ``on_tool_call`` can't trick the except-handler below into + # re-appending the same part — a duplicate tool_call.id + # would propagate into the partial StepResult and produce + # duplicate paired tool_results downstream. + flushing = pending_part + pending_part = part + _message_append(message, flushing) + if isinstance(flushing, ToolCall) and on_tool_call: + await callback(on_tool_call, flushing) + + # end of message + if pending_part is not None: + # Same race protection as the mid-loop flush: drop the reference + # before awaiting on_tool_call so a cancel during the callback + # leaves ``pending_part`` as None in the except handler. + flushing = pending_part + pending_part = None + _message_append(message, flushing) + if isinstance(flushing, ToolCall) and on_tool_call: + await callback(on_tool_call, flushing) + except asyncio.CancelledError: + # Best-effort flush of the pending part so the caller sees everything + # the UI has seen. Crucially do NOT fire on_tool_call here — that + # callback starts the tool, and the user just asked us to stop. + # Thanks to the "clear before await" pattern above, pending_part is + # non-None here only if it was never appended. + if pending_part is not None: _message_append(message, pending_part) - if isinstance(pending_part, ToolCall) and on_tool_call: - await callback(on_tool_call, pending_part) - pending_part = part - - # end of message - if pending_part is not None: - _message_append(message, pending_part) - if isinstance(pending_part, ToolCall) and on_tool_call: - await callback(on_tool_call, pending_part) + raise GenerateCancelled(message) from None if not message.content and not message.tool_calls: raise APIEmptyResponseError("The API returned an empty response.") diff --git a/packages/kosong/tests/test_step.py b/packages/kosong/tests/test_step.py index 55f33633f..c572bf142 100644 --- a/packages/kosong/tests/test_step.py +++ b/packages/kosong/tests/test_step.py @@ -1,11 +1,22 @@ import asyncio -from typing import override +import copy +from collections.abc import AsyncIterator, Sequence +from typing import Self, override -from kosong import step -from kosong.chat_provider import StreamedMessagePart +import pytest + +from kosong import StepCancelled, step +from kosong._generate import GenerateCancelled, generate +from kosong.chat_provider import ( + ChatProvider, + StreamedMessage, + StreamedMessagePart, + ThinkingEffort, + TokenUsage, +) from kosong.chat_provider.mock import MockChatProvider -from kosong.message import TextPart, ToolCall -from kosong.tooling import CallableTool, ParametersType, ToolOk, ToolResult, ToolReturnValue +from kosong.message import Message, TextPart, ThinkPart, ToolCall +from kosong.tooling import CallableTool, ParametersType, Tool, ToolOk, ToolResult, ToolReturnValue from kosong.tooling.simple import SimpleToolset @@ -63,3 +74,246 @@ async def run(): assert output_parts == input_parts assert tool_results == [ToolResult(tool_call_id="plus#123", return_value=ToolOk(output="3"))] assert collected_tool_results == tool_results + + +class _BlockingStreamedMessage(StreamedMessage): + """Streams predetermined parts then blocks forever on the next __anext__. + + Used to deterministically reproduce the "user pressed ESC after the model + streamed some content but before the stream ended" scenario. The blocking + point can be released by setting ``parts_emitted``. + """ + + def __init__( + self, + parts: list[StreamedMessagePart], + parts_emitted: asyncio.Event, + ): + self._parts = parts + self._idx = 0 + self._parts_emitted = parts_emitted + + def __aiter__(self) -> AsyncIterator[StreamedMessagePart]: + return self + + async def __anext__(self) -> StreamedMessagePart: + if self._idx < len(self._parts): + part = self._parts[self._idx] + self._idx += 1 + if self._idx == len(self._parts): + self._parts_emitted.set() + return part + # parts exhausted — block until the test cancels us + await asyncio.Event().wait() + raise StopAsyncIteration # pragma: no cover + + @property + def id(self) -> str: + return "blocking" + + @property + def usage(self) -> TokenUsage | None: + return None + + +class _BlockingChatProvider(ChatProvider): + """ChatProvider that streams predetermined parts then blocks indefinitely. + + Implements the Protocol directly so the ``generate`` return type doesn't + have to match ``MockChatProvider``'s narrower ``MockStreamedMessage``. + """ + + name = "blocking-mock" + + def __init__( + self, + parts: list[StreamedMessagePart], + parts_emitted: asyncio.Event, + ): + self._stream_parts = parts + self._parts_emitted = parts_emitted + + @property + def model_name(self) -> str: + return "blocking-mock" + + @property + def thinking_effort(self) -> ThinkingEffort | None: + return None + + async def generate( + self, + system_prompt: str, + tools: Sequence[Tool], + history: Sequence[Message], + ) -> _BlockingStreamedMessage: + return _BlockingStreamedMessage(self._stream_parts, self._parts_emitted) + + def with_thinking(self, effort: ThinkingEffort) -> Self: + return copy.copy(self) + + +@pytest.mark.asyncio +async def test_step_streaming_cancel_raises_step_cancelled_with_partial_result(): + """ESC during streaming after a complete tool_call must surface a partial + StepResult so the caller can pair it with a synthetic tool_result. + """ + tool_call = ToolCall( + id="plus#abc", + function=ToolCall.FunctionBody(name="plus", arguments='{"a": 1, "b": 2}'), + ) + parts: list[StreamedMessagePart] = [ + ThinkPart(think="planning to add..."), + TextPart(text="Let me add them."), + tool_call, + ] + parts_emitted = asyncio.Event() + + class PlusTool(CallableTool): + name: str = "plus" + description: str = "Add two integers." + parameters: ParametersType = { + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"}, + }, + } + + @override + async def __call__(self, a: int, b: int) -> ToolReturnValue: + return ToolOk(output=str(a + b)) + + toolset = SimpleToolset([PlusTool()]) + chat_provider = _BlockingChatProvider(parts, parts_emitted) + + step_task = asyncio.create_task( + step( + chat_provider, + system_prompt="", + toolset=toolset, + history=[], + ) + ) + + # wait until the provider has streamed every part (including the tool_call) + await asyncio.wait_for(parts_emitted.wait(), timeout=1.0) + # one event-loop tick so step() can advance past the last yielded part + await asyncio.sleep(0) + step_task.cancel() + + with pytest.raises(StepCancelled) as exc_info: + await step_task + + partial = exc_info.value.partial + assert partial.id is None + assert partial.usage is None + # tool_calls in both the message and the partial result mirror what the + # callbacks saw — the on_tool_call for `tool_call` did fire (it was the + # last pending part and the stream was still emitting), so the future + # exists in _tool_result_futures. + assert partial.message.tool_calls is not None + visible_ids = {tc.id for tc in partial.message.tool_calls} + assert tool_call.id in visible_ids + assert {tc.id for tc in partial.tool_calls} == visible_ids + # The completed-before-cancel tool may still have a future entry; the + # caller decides how to handle it (await, cancel, synthesize). + futures = partial._tool_result_futures # pyright: ignore[reportPrivateUsage] + assert set(futures.keys()) <= visible_ids + + +@pytest.mark.asyncio +async def test_step_streaming_cancel_preserves_thinking_only_partial(): + """ESC mid-thinking (no tool_call yet, no text) should still yield a + partial result. Whether the caller persists it is its decision. + """ + parts: list[StreamedMessagePart] = [ + ThinkPart(think="hmm, considering..."), + ] + parts_emitted = asyncio.Event() + + chat_provider = _BlockingChatProvider(parts, parts_emitted) + toolset = SimpleToolset() + + step_task = asyncio.create_task( + step( + chat_provider, + system_prompt="", + toolset=toolset, + history=[], + ) + ) + await asyncio.wait_for(parts_emitted.wait(), timeout=1.0) + await asyncio.sleep(0) + step_task.cancel() + + with pytest.raises(StepCancelled) as exc_info: + await step_task + + partial = exc_info.value.partial + assert partial.tool_calls == [] + # ThinkPart was the pending part — best-effort flush on cancel + assert any(isinstance(p, ThinkPart) for p in partial.message.content) + + +@pytest.mark.asyncio +async def test_generate_cancel_during_on_tool_call_does_not_duplicate_pending(): + """Cancellation arriving while awaiting ``on_tool_call`` must NOT leave + duplicate tool_calls in the partial message. + + Reproduces the race: between ``_message_append(pending_part)`` and the + next-iteration's ``pending_part = part`` reassignment, the only thing + keeping the previous pending alive is an ``await callback(on_tool_call, + pending_part)``. If a CancelledError is delivered at that yield point, + the except-handler sees the same ``pending_part`` and — without the fix — + appends it again, producing two assistant tool_calls with the same id. + + Driven through ``generate()`` directly so we can install a custom + ``on_tool_call`` that genuinely awaits and yields control to the event + loop. Through ``step()`` the on_tool_call closure is synchronous and the + race is harder to provoke deterministically. + """ + tool_call_a = ToolCall( + id="tc#a", + function=ToolCall.FunctionBody(name="t", arguments="{}"), + ) + # Force tool_call_a out of pending by following it with an unmergeable part. + text_after = TextPart(text="follow-up") + parts: list[StreamedMessagePart] = [tool_call_a, text_after] + parts_emitted = asyncio.Event() + + chat_provider = _BlockingChatProvider(parts, parts_emitted) + on_tool_call_entered = asyncio.Event() + + async def slow_on_tool_call(_tc: ToolCall) -> None: + on_tool_call_entered.set() + # block forever — the test cancels us mid-await + await asyncio.Event().wait() + + gen_task = asyncio.create_task( + generate( + chat_provider, + system_prompt="", + tools=[], + history=[], + on_tool_call=slow_on_tool_call, + ) + ) + + # wait until generate() is awaiting on_tool_call(tool_call_a) + await asyncio.wait_for(on_tool_call_entered.wait(), timeout=1.0) + gen_task.cancel() + + with pytest.raises(GenerateCancelled) as exc_info: + await gen_task + + msg = exc_info.value.message + assert msg.tool_calls is not None + ids = [tc.id for tc in msg.tool_calls] + # Without the fix this would be ["tc#a", "tc#a"]. + assert ids == [tool_call_a.id], ( + f"expected exactly one tool_call entry, got {ids} (duplicate would corrupt history)" + ) + # The unmergeable follow-up text was the new pending part when cancel + # hit; the except-handler's best-effort flush should include it. + assert any(isinstance(p, TextPart) and p.text == "follow-up" for p in msg.content) diff --git a/src/kimi_cli/soul/kimisoul.py b/src/kimi_cli/soul/kimisoul.py index 6f27b4233..e6ebb8f17 100644 --- a/src/kimi_cli/soul/kimisoul.py +++ b/src/kimi_cli/soul/kimisoul.py @@ -19,6 +19,8 @@ RetryableChatProvider, ) from kosong.message import Message +from kosong.tooling import ToolError +from kosong.tooling.error import ToolRuntimeError from tenacity import RetryCallState, retry_if_exception, stop_after_attempt, wait_exponential_jitter from kimi_cli.approval_runtime import ( @@ -94,6 +96,7 @@ def type_check(soul: KimiSoul): SKILL_COMMAND_PREFIX = "skill:" FLOW_COMMAND_PREFIX = "flow:" DEFAULT_MAX_FLOW_MOVES = 1000 +INTERRUPTED_TOOL_MESSAGE = "[Interrupted By User]" def classify_api_error(e: Exception) -> tuple[str, int | None]: @@ -1117,6 +1120,12 @@ async def _kosong_step_with_retry() -> StepResult: t0 = time.monotonic() try: result = await _kosong_step_with_retry() + except kosong.StepCancelled as cancelled: + # Streaming was interrupted by user. The partial StepResult + # carries the tool_calls that were already visible in the TUI; + # we owe each of them a paired tool_result in history. + await self._finalize_interrupted_step(cancelled.partial) + raise asyncio.CancelledError() from None except Exception as _step_exc: # Attach known context so the outer loop can enrich api_error telemetry _ctx: dict[str, Any] = { @@ -1156,7 +1165,11 @@ async def _kosong_step_with_retry() -> StepResult: # ═══════════════════════════════════════════════════════════════════════ # wait for all tool results (may be interrupted) plan_mode_before_tools = self._plan_mode - results = await result.tool_results() + try: + results = await result.tool_results() + except asyncio.CancelledError: + await self._finalize_interrupted_step(result) + raise logger.debug("Got tool results: {results}", results=results) # Update dedup tracking for the next step @@ -1219,6 +1232,111 @@ async def _kosong_step_with_retry() -> StepResult: return None return StepOutcome(stop_reason="no_tool_calls", assistant_message=result.message) + async def _finalize_interrupted_step(self, step_result: StepResult) -> None: + """Unified cleanup when a step is interrupted by user cancel. + + Used by both interruption paths: + + - **Phase B (streaming)**: ``kosong.StepCancelled`` raised before + ``kosong.step()`` could return a complete StepResult. Any tools + whose ``on_tool_call`` already fired may still be running. + - **Phase C (tool wait)**: ``CancelledError`` from + ``result.tool_results()``. ``kosong``'s ``finally`` has already + cancelled the futures, so the work below is mostly a no-op for + this path. + + Invariant: any tool_call that became visible in the TUI (= present in + ``step_result.message.tool_calls``) ends up with a paired tool_result + in history — either the real one (if the tool actually completed) or + a synthetic ``[Interrupted]`` error. + + Persistence is shielded so the cancel signal can't tear it down + mid-write. + """ + futures: dict[str, Any] = getattr(step_result, "_tool_result_futures", {}) + + # Cancel any in-flight futures and wait for them to settle. After + # this point every future is either done-with-result or cancelled. + for future in futures.values(): + if not future.done(): + future.cancel() + if futures: + await asyncio.gather(*futures.values(), return_exceptions=True) + + message = step_result.message + tool_calls = list(message.tool_calls or []) + if not tool_calls and not message.content: + # Nothing observed — skip persisting an empty assistant message. + return + + results, interrupted_results = self._build_interrupted_tool_results(tool_calls, futures) + for tool_result in interrupted_results: + wire_send(tool_result) + + try: + await asyncio.shield(self._grow_context(step_result, results)) + except Exception: + logger.exception("Failed to record interrupted step in context.") + + @staticmethod + def _build_interrupted_tool_results( + tool_calls: list[Any], + futures: dict[str, Any], + ) -> tuple[list[ToolResult], list[ToolResult]]: + """Build model-facing tool results for an interrupted step. + + Returns ``(all_results, synthesized_interrupted)``: ``all_results`` + has one entry per tool_call in order; ``synthesized_interrupted`` is + the subset that we fabricated because the tool never completed. The + caller wire-sends only the synthesized subset (real completions were + already emitted via ``on_tool_result``). + """ + results: list[ToolResult] = [] + interrupted_results: list[ToolResult] = [] + + for tool_call in tool_calls: + tool_result = KimiSoul._completed_tool_result_or_none(tool_call.id, futures) + if tool_result is None: + # Leave brief empty so the TUI doesn't render a second line + # under "Used (...)" — the red bullet from is_error + # already communicates the interruption. The model still sees + # the [Interrupted By User] marker via `message`. + tool_result = ToolResult( + tool_call_id=tool_call.id, + return_value=ToolError( + message=INTERRUPTED_TOOL_MESSAGE, + brief="", + ), + ) + interrupted_results.append(tool_result) + results.append(tool_result) + + return results, interrupted_results + + @staticmethod + def _completed_tool_result_or_none( + tool_call_id: str, + futures: dict[str, Any], + ) -> ToolResult | None: + future = futures.get(tool_call_id) + if future is None or not future.done() or future.cancelled(): + return None + try: + tool_result = future.result() + except asyncio.CancelledError: + return None + except Exception as exc: + return ToolResult( + tool_call_id=tool_call_id, + return_value=ToolRuntimeError(str(exc)), + ) + if isinstance(tool_result, ToolResult): + return tool_result + return ToolResult( + tool_call_id=tool_call_id, + return_value=ToolRuntimeError(f"Invalid tool result: {type(tool_result).__name__}"), + ) + async def _grow_context(self, result: StepResult, tool_results: list[ToolResult]): logger.debug("Growing context with result: {result}", result=result) diff --git a/src/kimi_cli/soul/toolset.py b/src/kimi_cli/soul/toolset.py index 3868d3bdb..14fb85cea 100644 --- a/src/kimi_cli/soul/toolset.py +++ b/src/kimi_cli/soul/toolset.py @@ -99,6 +99,8 @@ def type_check(kimi_toolset: KimiToolset): "\n" ) +_CROSS_STEP_DEDUP_TRIGGER_COUNT = 7 + def _append_reminder_to_return_value(return_value: Any) -> Any: """Append dedup reminder text to a ToolReturnValue output.""" @@ -133,6 +135,8 @@ def __init__(self) -> None: # Deduplication state self._previous_step_calls: list[tuple[str, str]] = [] + self._consecutive_call_key: tuple[str, str] | None = None + self._consecutive_call_count: int = 0 self._current_step_calls: list[tuple[str, str]] = [] self._current_step_tasks: dict[tuple[str, str], asyncio.Task[ToolResult]] = {} self._dedup_triggered: bool = False @@ -184,6 +188,11 @@ def begin_step( ) -> None: """Called before each step to set up deduplication state.""" self._previous_step_calls = previous_calls + if previous_calls: + self._sync_consecutive_state_from_previous_calls(previous_calls) + else: + self._consecutive_call_key = None + self._consecutive_call_count = 0 self._current_step_calls = [] self._current_step_tasks = {} self._dedup_triggered = False @@ -192,11 +201,44 @@ def begin_step( def end_step(self) -> list[tuple[str, str]]: """Called after each step to capture the calls made in this step.""" - return list(self._current_step_calls) + current_calls = list(self._current_step_calls) + for call_key in current_calls: + if call_key == self._consecutive_call_key: + self._consecutive_call_count += 1 + else: + self._consecutive_call_key = call_key + self._consecutive_call_count = 1 + return current_calls + + def _sync_consecutive_state_from_previous_calls( + self, previous_calls: list[tuple[str, str]] + ) -> None: + last_call = previous_calls[-1] + if self._consecutive_call_key == last_call and self._consecutive_call_count > 0: + return + + trailing_count = 0 + for call_key in reversed(previous_calls): + if call_key != last_call: + break + trailing_count += 1 + self._consecutive_call_key = last_call + self._consecutive_call_count = trailing_count + + def _projected_consecutive_state(self) -> tuple[tuple[str, str] | None, int]: + call_key = self._consecutive_call_key + count = self._consecutive_call_count + for current_call in self._current_step_calls: + if current_call == call_key: + count += 1 + else: + call_key = current_call + count = 1 + return call_key, count @property def dedup_triggered(self) -> bool: - """Whether a cross-step duplicate was blocked in the current step.""" + """Whether a cross-step repeat reminder was triggered in the current step.""" return self._dedup_triggered def handle(self, tool_call: ToolCall) -> HandleResult: @@ -228,7 +270,11 @@ async def _await_dup() -> ToolResult: return asyncio.create_task(_await_dup()) - is_cross_step_dup = call_key in self._previous_step_calls + previous_call_key, previous_occurrences = self._projected_consecutive_state() + is_cross_step_dup = ( + call_key == previous_call_key + and previous_occurrences >= _CROSS_STEP_DEDUP_TRIGGER_COUNT - 1 + ) if is_cross_step_dup: from kimi_cli.telemetry import track diff --git a/src/kimi_cli/ui/shell/prompt.py b/src/kimi_cli/ui/shell/prompt.py index b820fe8a0..1aef3116f 100644 --- a/src/kimi_cli/ui/shell/prompt.py +++ b/src/kimi_cli/ui/shell/prompt.py @@ -845,6 +845,13 @@ def _load_history_entries(history_file: Path) -> list[_HistoryEntry]: return entries +def _is_browsing_history_entry(buffer: Buffer) -> bool: + working_lines = getattr(buffer, "_working_lines", None) + if working_lines is None: + return False + return buffer.working_index < len(working_lines) - 1 + + class PromptMode(Enum): AGENT = "agent" SHELL = "shell" @@ -1537,7 +1544,11 @@ def _(event: KeyPressEvent) -> None: def _(buffer: Buffer) -> None: self._last_input_activity_time = time.monotonic() self._input_activity_event.set() - if buffer.complete_while_typing() and not self._suppress_auto_completion: + if ( + buffer.complete_while_typing() + and not self._suppress_auto_completion + and not _is_browsing_history_entry(buffer) + ): buffer.start_completion() self._status_refresh_task: asyncio.Task[None] | None = None diff --git a/tests/core/test_kimisoul_tool_interrupt.py b/tests/core/test_kimisoul_tool_interrupt.py new file mode 100644 index 000000000..770d7b5dd --- /dev/null +++ b/tests/core/test_kimisoul_tool_interrupt.py @@ -0,0 +1,277 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Any + +import pytest +from kosong import StepCancelled, StepResult +from kosong.message import Message +from kosong.tooling.empty import EmptyToolset + +import kimi_cli.soul.kimisoul as kimisoul_module +from kimi_cli.soul.agent import Agent, Runtime +from kimi_cli.soul.context import Context +from kimi_cli.soul.kimisoul import INTERRUPTED_TOOL_MESSAGE, KimiSoul +from kimi_cli.wire.types import TextPart, ToolCall, ToolResult + + +@pytest.mark.asyncio +async def test_cancel_during_tool_results_records_interrupted_tool_result( + runtime: Runtime, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent( + name="Tool Interrupt Agent", + system_prompt="Test prompt.", + toolset=EmptyToolset(), + runtime=runtime, + ) + soul = KimiSoul(agent, context=Context(file_backend=tmp_path / "history.jsonl")) + sent: list[Any] = [] + pending_future_ready = asyncio.Event() + + async def fake_step( + _chat_provider: Any, + _system_prompt: str, + _toolset: Any, + _history: Any, + *, + on_message_part: Any = None, + on_tool_result: Any = None, + ) -> StepResult: + del on_tool_result + tool_call = ToolCall( + id="tc-interrupted", + function=ToolCall.FunctionBody(name="SlowTool", arguments='{"seconds": 30}'), + ) + if on_message_part is not None: + on_message_part(TextPart(text="I will run the slow tool.")) + on_message_part(tool_call) + pending_future: asyncio.Future[ToolResult] = asyncio.get_running_loop().create_future() + pending_future_ready.set() + return StepResult( + "msg-interrupted", + Message( + role="assistant", + content=[TextPart(text="I will run the slow tool.")], + tool_calls=[tool_call], + ), + None, + [tool_call], + {tool_call.id: pending_future}, + ) + + monkeypatch.setattr(kimisoul_module.kosong, "step", fake_step) + monkeypatch.setattr(kimisoul_module, "wire_send", lambda msg: sent.append(msg)) + + task = asyncio.create_task(soul.run("run a slow tool")) + await asyncio.wait_for(pending_future_ready.wait(), timeout=1.0) + await asyncio.sleep(0) + + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + history = list(soul.context.history) + assistant_messages = [message for message in history if message.role == "assistant"] + tool_messages = [message for message in history if message.role == "tool"] + + assert assistant_messages + assert assistant_messages[-1].tool_calls + assert assistant_messages[-1].tool_calls[0].id == "tc-interrupted" + assert tool_messages + assert tool_messages[-1].tool_call_id == "tc-interrupted" + assert INTERRUPTED_TOOL_MESSAGE in tool_messages[-1].extract_text() + + interrupted_wire_results = [ + msg for msg in sent if isinstance(msg, ToolResult) and msg.tool_call_id == "tc-interrupted" + ] + assert interrupted_wire_results + assert interrupted_wire_results[-1].return_value.is_error is True + assert interrupted_wire_results[-1].return_value.brief == "" + + +@pytest.mark.asyncio +async def test_interrupted_tool_result_is_sent_to_next_model_step( + runtime: Runtime, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent( + name="Tool Interrupt Agent", + system_prompt="Test prompt.", + toolset=EmptyToolset(), + runtime=runtime, + ) + soul = KimiSoul(agent, context=Context(file_backend=tmp_path / "history.jsonl")) + pending_future_ready = asyncio.Event() + captured_next_history: list[Message] = [] + step_calls = 0 + + async def fake_step( + _chat_provider: Any, + _system_prompt: str, + _toolset: Any, + history: list[Message], + *, + on_message_part: Any = None, + on_tool_result: Any = None, + ) -> StepResult: + nonlocal step_calls, captured_next_history + del on_tool_result + step_calls += 1 + if step_calls == 2: + captured_next_history = list(history) + return StepResult( + "msg-next", + Message(role="assistant", content="I can see the interrupt marker."), + None, + [], + {}, + ) + + tool_call = ToolCall( + id="tc-interrupted", + function=ToolCall.FunctionBody(name="SlowTool", arguments='{"seconds": 30}'), + ) + if on_message_part is not None: + on_message_part(TextPart(text="I will run the slow tool.")) + on_message_part(tool_call) + pending_future: asyncio.Future[ToolResult] = asyncio.get_running_loop().create_future() + pending_future_ready.set() + return StepResult( + "msg-interrupted", + Message( + role="assistant", + content=[TextPart(text="I will run the slow tool.")], + tool_calls=[tool_call], + ), + None, + [tool_call], + {tool_call.id: pending_future}, + ) + + monkeypatch.setattr(kimisoul_module.kosong, "step", fake_step) + monkeypatch.setattr(kimisoul_module, "wire_send", lambda _msg: None) + + interrupted_task = asyncio.create_task(soul.run("run a slow tool")) + await asyncio.wait_for(pending_future_ready.wait(), timeout=1.0) + await asyncio.sleep(0) + + interrupted_task.cancel() + with pytest.raises(asyncio.CancelledError): + await interrupted_task + + await soul.run("can you see my cancel?") + + assert any( + message.role == "assistant" + and message.tool_calls + and message.tool_calls[0].id == "tc-interrupted" + for message in captured_next_history + ) + assert any( + message.role == "tool" + and message.tool_call_id == "tc-interrupted" + and INTERRUPTED_TOOL_MESSAGE in message.extract_text() + for message in captured_next_history + ) + + +@pytest.mark.asyncio +async def test_streaming_phase_cancel_persists_assistant_with_interrupted_tool_result( + runtime: Runtime, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ESC during LLM streaming (after a tool_call became visible in the TUI) + must still leave history well-formed: the assistant message is appended + AND every visible tool_call is paired with a synthetic [Interrupted] + tool_result. + + This is the "thinking + tool-call" case from the bug report: the model + streams a tool_call to the UI, the user presses ESC before the stream + finishes, kosong.step() never returns a complete StepResult — and the + previous implementation lost the entire assistant message. + """ + agent = Agent( + name="Tool Interrupt Agent", + system_prompt="Test prompt.", + toolset=EmptyToolset(), + runtime=runtime, + ) + soul = KimiSoul(agent, context=Context(file_backend=tmp_path / "history.jsonl")) + sent: list[Any] = [] + streaming_started = asyncio.Event() + + async def fake_step( + _chat_provider: Any, + _system_prompt: str, + _toolset: Any, + _history: Any, + *, + on_message_part: Any = None, + on_tool_result: Any = None, + ) -> StepResult: + del on_tool_result + # emit a tool_call to the wire — UI now sees it + tool_call = ToolCall( + id="tc-streamed", + function=ToolCall.FunctionBody(name="SlowTool", arguments="{}"), + ) + if on_message_part is not None: + on_message_part(TextPart(text="thinking about it...")) + on_message_part(tool_call) + streaming_started.set() + # Block until cancelled — mirroring real kosong behaviour, transform + # the incoming CancelledError into a StepCancelled carrying the + # partial StepResult so the caller can pair tool_calls with synthetic + # results. No futures because on_tool_call would have started the + # tool, which we explicitly skip in the cancel path. + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + partial = StepResult( + id=None, + message=Message( + role="assistant", + content=[TextPart(text="thinking about it...")], + tool_calls=[tool_call], + ), + usage=None, + tool_calls=[tool_call], + _tool_result_futures={}, + ) + raise StepCancelled(partial) from None + raise AssertionError("unreachable") # pragma: no cover + + monkeypatch.setattr(kimisoul_module.kosong, "step", fake_step) + monkeypatch.setattr(kimisoul_module, "wire_send", lambda msg: sent.append(msg)) + + task = asyncio.create_task(soul.run("call the slow tool")) + await asyncio.wait_for(streaming_started.wait(), timeout=1.0) + await asyncio.sleep(0) + + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + history = list(soul.context.history) + assistant_messages = [m for m in history if m.role == "assistant"] + tool_messages = [m for m in history if m.role == "tool"] + + assert assistant_messages, "assistant message must be persisted even on streaming cancel" + assert assistant_messages[-1].tool_calls + assert assistant_messages[-1].tool_calls[0].id == "tc-streamed" + assert tool_messages, "interrupted tool_result must be persisted to pair the tool_call" + assert tool_messages[-1].tool_call_id == "tc-streamed" + assert INTERRUPTED_TOOL_MESSAGE in tool_messages[-1].extract_text() + + interrupted_wire_results = [ + msg for msg in sent if isinstance(msg, ToolResult) and msg.tool_call_id == "tc-streamed" + ] + assert interrupted_wire_results + assert interrupted_wire_results[-1].return_value.is_error is True + assert interrupted_wire_results[-1].return_value.brief == "" diff --git a/tests/core/test_toolset.py b/tests/core/test_toolset.py index fbf618993..2bcb02191 100644 --- a/tests/core/test_toolset.py +++ b/tests/core/test_toolset.py @@ -230,14 +230,32 @@ async def test_same_step_dedup(): assert ts.end_step() == [("ToolA", args)] -async def test_cross_step_duplicate_appends_reminder(): - """A tool call identical to one in the previous step should execute and append reminder in output.""" +async def test_cross_step_duplicate_appends_reminder_on_seventh_occurrence(): + """A repeated tool call should execute normally six times and append reminder on the seventh.""" ts = _make_toolset() args = json.dumps({"value": "x"}) - ts.begin_step([("ToolA", args)]) + last_calls: list[tuple[str, str]] = [] + for occurrence in range(1, 7): + ts.begin_step(last_calls) + tool_call = ToolCall( + id=f"tc-repeat-{occurrence}", + function=ToolCall.FunctionBody( + name="ToolA", + arguments=args, + ), + ) + result = ts.handle(tool_call) + assert isinstance(result, asyncio.Task) + tr = await result + output = tr.return_value.output + assert output == "a" + assert ts.dedup_triggered is False + last_calls = ts.end_step() + + ts.begin_step(last_calls) tool_call = ToolCall( - id="tc-dedup-reminder", + id="tc-repeat-7", function=ToolCall.FunctionBody( name="ToolA", arguments=args, @@ -255,6 +273,111 @@ async def test_cross_step_duplicate_appends_reminder(): assert ts.end_step() == [("ToolA", args)] +async def test_cross_step_repeat_counter_resets_after_other_tool_call(): + """Only continuous repeats count; an intervening tool call resets the repeat counter.""" + ts = _make_toolset() + repeated_args = json.dumps({"value": "x"}) + other_args = json.dumps({"value": "y"}) + + last_calls: list[tuple[str, str]] = [] + for occurrence in range(1, 7): + ts.begin_step(last_calls) + tool_call = ToolCall( + id=f"tc-repeat-before-reset-{occurrence}", + function=ToolCall.FunctionBody( + name="ToolA", + arguments=repeated_args, + ), + ) + result = ts.handle(tool_call) + assert isinstance(result, asyncio.Task) + tr = await result + assert tr.return_value.output == "a" + assert ts.dedup_triggered is False + last_calls = ts.end_step() + + ts.begin_step(last_calls) + other_call = ToolCall( + id="tc-resetting-call", + function=ToolCall.FunctionBody( + name="ToolB", + arguments=other_args, + ), + ) + result = ts.handle(other_call) + assert isinstance(result, asyncio.Task) + tr = await result + assert tr.return_value.output == "b" + assert ts.dedup_triggered is False + last_calls = ts.end_step() + + ts.begin_step(last_calls) + tool_call = ToolCall( + id="tc-repeat-after-reset", + function=ToolCall.FunctionBody( + name="ToolA", + arguments=repeated_args, + ), + ) + result = ts.handle(tool_call) + assert isinstance(result, asyncio.Task) + tr = await result + assert tr.return_value.output == "a" + assert ts.dedup_triggered is False + assert ts.end_step() == [("ToolA", repeated_args)] + + +async def test_cross_step_repeat_counter_resets_with_intervening_call_in_same_step(): + """A different call earlier in the current step should reset the projected repeat count.""" + ts = _make_toolset() + repeated_args = json.dumps({"value": "x"}) + other_args = json.dumps({"value": "y"}) + + last_calls: list[tuple[str, str]] = [] + for occurrence in range(1, 7): + ts.begin_step(last_calls) + tool_call = ToolCall( + id=f"tc-repeat-before-same-step-reset-{occurrence}", + function=ToolCall.FunctionBody( + name="ToolA", + arguments=repeated_args, + ), + ) + result = ts.handle(tool_call) + assert isinstance(result, asyncio.Task) + tr = await result + assert tr.return_value.output == "a" + last_calls = ts.end_step() + + ts.begin_step(last_calls) + other_call = ToolCall( + id="tc-same-step-resetting-call", + function=ToolCall.FunctionBody( + name="ToolB", + arguments=other_args, + ), + ) + repeated_call = ToolCall( + id="tc-same-step-repeat-after-reset", + function=ToolCall.FunctionBody( + name="ToolA", + arguments=repeated_args, + ), + ) + + other_result = ts.handle(other_call) + repeated_result = ts.handle(repeated_call) + assert isinstance(other_result, asyncio.Task) + assert isinstance(repeated_result, asyncio.Task) + + other_tr = await other_result + repeated_tr = await repeated_result + assert other_tr.return_value.output == "b" + assert repeated_tr.return_value.output == "a" + assert ts.dedup_triggered is False + assert ts.end_step() == [("ToolB", other_args), ("ToolA", repeated_args)] + + async def test_non_duplicate_allowed(): """A tool call with different arguments should be allowed even if the tool name matches.""" ts = _make_toolset() diff --git a/tests/e2e/test_shell_pty_e2e.py b/tests/e2e/test_shell_pty_e2e.py index 503ad10b2..40501b843 100644 --- a/tests/e2e/test_shell_pty_e2e.py +++ b/tests/e2e/test_shell_pty_e2e.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hashlib import json import sys import textwrap @@ -86,6 +87,17 @@ def _exit_shell(shell) -> None: raise last_error +def _write_user_history(home_dir: Path, work_dir: Path, entries: list[str]) -> None: + work_dir_id = hashlib.md5(str(work_dir.resolve()).encode("utf-8")).hexdigest() + history_dir = home_dir / ".kimi" / "user-history" + history_dir.mkdir(parents=True, exist_ok=True) + history_file = (history_dir / work_dir_id).with_suffix(".jsonl") + history_file.write_text( + "".join(json.dumps({"content": entry}, ensure_ascii=False) + "\n" for entry in entries), + encoding="utf-8", + ) + + def test_shell_smoke_multiturn_scripted_echo(tmp_path: Path) -> None: config_path = write_scripted_config( tmp_path, @@ -137,6 +149,33 @@ def test_shell_smoke_multiturn_scripted_echo(tmp_path: Path) -> None: shell.close() +def test_shell_up_from_slash_history_continues_history_navigation(tmp_path: Path) -> None: + config_path = write_scripted_config(tmp_path, []) + work_dir = make_work_dir(tmp_path) + home_dir = make_home_dir(tmp_path) + _write_user_history(home_dir, work_dir, ["older-history-sentinel", "/help"]) + shell = start_shell_pty( + config_path=config_path, + work_dir=work_dir, + home_dir=home_dir, + yolo=True, + ) + + try: + shell.read_until_contains("Welcome to Kimi Code CLI!") + _read_until_prompt(shell, after=shell.mark()) + + first_up_mark = shell.mark() + shell.send_key("up") + shell.read_until_contains("/help", after=first_up_mark, timeout=5.0) + + second_up_mark = shell.mark() + shell.send_key("up") + shell.read_until_contains("older-history-sentinel", after=second_up_mark, timeout=5.0) + finally: + shell.close() + + def test_shell_running_prompt_preserves_unsubmitted_draft(tmp_path: Path) -> None: scripts = [ "\n".join( @@ -928,6 +967,12 @@ def test_shell_cancel_running_command_kills_process_and_recovers(tmp_path: Path) cancel_prompt_mark = shell.mark() _read_until_prompt(shell, after=cancel_prompt_mark) + session_dir = find_session_dir(home_dir, work_dir) + context_text = (session_dir / "context.jsonl").read_text(encoding="utf-8") + assert '"role":"assistant"' in context_text + assert '"tool_call_id":"tc-c1"' in context_text + assert "[Interrupted By User]" in context_text + time.sleep(2.3) assert not (work_dir / "cancel_output.txt").exists() diff --git a/tests/ui_and_conv/test_prompt_history.py b/tests/ui_and_conv/test_prompt_history.py index 27cbf88c5..f650658d7 100644 --- a/tests/ui_and_conv/test_prompt_history.py +++ b/tests/ui_and_conv/test_prompt_history.py @@ -1,8 +1,12 @@ from __future__ import annotations import json +from collections import deque +from types import SimpleNamespace +from typing import cast from PIL import Image +from prompt_toolkit.buffer import Buffer from kimi_cli.ui.shell import prompt as shell_prompt from kimi_cli.ui.shell.placeholders import AttachmentCache, PromptPlaceholderManager @@ -23,6 +27,16 @@ def _read_history_lines(path) -> list[dict[str, str]]: return [json.loads(line) for line in path.read_text(encoding="utf-8").splitlines()] +def _fake_history_buffer(*, working_index: int) -> Buffer: + return cast( + Buffer, + SimpleNamespace( + working_index=working_index, + _working_lines=deque(["previous prompt", "/help", ""]), + ), + ) + + def test_append_history_entry_expands_text_placeholders_but_preserves_images(tmp_path) -> None: manager = PromptPlaceholderManager(attachment_cache=AttachmentCache(root=tmp_path / "cache")) pasted_text = "\n".join([f"line{i}" for i in range(1, 16)]) @@ -66,3 +80,11 @@ def test_append_history_entry_writes_sanitized_surrogate_text(tmp_path) -> None: assert "\ud83d" not in lines[0]["content"] assert "\ufffd" in lines[0]["content"] assert lines[0]["content"].startswith("A" * 1000) + + +def test_current_history_working_line_allows_auto_completion() -> None: + assert not shell_prompt._is_browsing_history_entry(_fake_history_buffer(working_index=2)) + + +def test_recalled_history_entry_suppresses_auto_completion() -> None: + assert shell_prompt._is_browsing_history_entry(_fake_history_buffer(working_index=1))