diff --git a/src/kimi_cli/soul/kimisoul.py b/src/kimi_cli/soul/kimisoul.py index 6f27b4233..6c83928e1 100644 --- a/src/kimi_cli/soul/kimisoul.py +++ b/src/kimi_cli/soul/kimisoul.py @@ -22,6 +22,7 @@ from tenacity import RetryCallState, retry_if_exception, stop_after_attempt, wait_exponential_jitter from kimi_cli.approval_runtime import ( + ApprovalRuntimeEvent, ApprovalSource, get_current_approval_source_or_none, reset_current_approval_source, @@ -209,8 +210,13 @@ def __init__( ] self._hook_engine: HookEngine = HookEngine() self._stop_hook_active: bool = False + self._approval_hook_subscription: str | None = None if self.is_root: self._runtime.notifications.ack_ids("llm", extract_notification_ids(context.history)) + if self._runtime.approval_runtime is not None: + self._approval_hook_subscription = self._runtime.approval_runtime.subscribe( + self._on_approval_runtime_event + ) # Bind plan mode state to tools that support it self._bind_plan_mode_tools() @@ -276,6 +282,40 @@ def set_hook_engine(self, engine: HookEngine) -> None: if isinstance(self._agent.toolset, KimiToolset): self._agent.toolset.set_hook_engine(engine) + def _on_approval_runtime_event(self, event: ApprovalRuntimeEvent) -> None: + if event.kind != "request_created": + return + + from kimi_cli.hooks import events + + request = event.request + input_data = { + **events.notification( + session_id=self._runtime.session.id, + cwd=str(Path.cwd()), + sink="user", + notification_type="permission_prompt", + title=f"Approval required: {request.sender}", + body=request.description, + severity="warning", + ), + "request_id": request.id, + "tool_call_id": request.tool_call_id, + "sender": request.sender, + "action": request.action, + "description": request.description, + "source_kind": request.source.kind, + "source_id": request.source.id, + } + try: + self._hook_engine.fire_and_forget_trigger( + "Notification", + matcher_value="permission_prompt", + input_data=input_data, + ) + except Exception: + logger.exception("Failed to trigger approval notification hook") + def add_injection_provider(self, provider: DynamicInjectionProvider) -> None: """Register an additional dynamic injection provider.""" self._injection_providers.append(provider) @@ -507,6 +547,13 @@ def runtime(self) -> Runtime: def context(self) -> Context: return self._context + def close(self) -> None: + if self._approval_hook_subscription is None: + return + if self._runtime.approval_runtime is not None: + self._runtime.approval_runtime.unsubscribe(self._approval_hook_subscription) + self._approval_hook_subscription = None + @property def _context_usage(self) -> float: if self._runtime.llm is not None: diff --git a/src/kimi_cli/soul/slash.py b/src/kimi_cli/soul/slash.py index 7155271fc..5261e41e3 100644 --- a/src/kimi_cli/soul/slash.py +++ b/src/kimi_cli/soul/slash.py @@ -42,7 +42,10 @@ async def init(soul: KimiSoul, args: str): with tempfile.TemporaryDirectory() as temp_dir: tmp_context = Context(file_backend=Path(temp_dir) / "context.jsonl") tmp_soul = KimiSoul(soul.agent, context=tmp_context) - await tmp_soul.run(prompts.INIT) + try: + await tmp_soul.run(prompts.INIT) + finally: + tmp_soul.close() agents_md = await load_agents_md(soul.runtime.builtin_args.KIMI_WORK_DIR) system_message = system( diff --git a/tests/core/test_approval_runtime.py b/tests/core/test_approval_runtime.py index fdb2eba86..692bfcff4 100644 --- a/tests/core/test_approval_runtime.py +++ b/tests/core/test_approval_runtime.py @@ -298,6 +298,95 @@ async def _drain_ui_messages(wire: Wire) -> None: return +@pytest.mark.asyncio +async def test_kimisoul_triggers_notification_hook_for_approval_requests( + runtime, tmp_path, monkeypatch: pytest.MonkeyPatch +) -> None: + assert runtime.approval_runtime is not None + + soul = KimiSoul( + SoulAgent( + name="test", + system_prompt="test prompt", + toolset=EmptyToolset(), + runtime=runtime, + ), + context=Context(file_backend=tmp_path / "history.jsonl"), + ) + + calls: list[tuple[str, str, dict]] = [] + + async def _done() -> list: + return [] + + def fake_fire_and_forget(event: str, *, matcher_value: str, input_data: dict): + calls.append((event, matcher_value, input_data)) + return asyncio.create_task(_done()) + + monkeypatch.setattr(soul.hook_engine, "fire_and_forget_trigger", fake_fire_and_forget) + + runtime.approval_runtime.create_request( + request_id="req-notification-hook", + tool_call_id="call-notification-hook", + sender="WriteFile", + action="edit file", + description="Write file /tmp/test.txt", + display=[], + source=ApprovalSource(kind="foreground_turn", id="turn-notification-hook"), + ) + + assert len(calls) == 1 + event, matcher_value, payload = calls[0] + assert event == "Notification" + assert matcher_value == "permission_prompt" + assert payload["notification_type"] == "permission_prompt" + assert payload["sender"] == "WriteFile" + assert payload["description"] == "Write file /tmp/test.txt" + + +@pytest.mark.asyncio +async def test_kimisoul_close_unsubscribes_approval_hook( + runtime, tmp_path, monkeypatch: pytest.MonkeyPatch +) -> None: + assert runtime.approval_runtime is not None + + soul = KimiSoul( + SoulAgent( + name="test", + system_prompt="test prompt", + toolset=EmptyToolset(), + runtime=runtime, + ), + context=Context(file_backend=tmp_path / "history.jsonl"), + ) + + calls: list[tuple[str, str, dict]] = [] + + async def _done() -> list: + return [] + + def fake_fire_and_forget(event: str, *, matcher_value: str, input_data: dict): + calls.append((event, matcher_value, input_data)) + return asyncio.create_task(_done()) + + monkeypatch.setattr(soul.hook_engine, "fire_and_forget_trigger", fake_fire_and_forget) + + soul.close() + soul.close() + + runtime.approval_runtime.create_request( + request_id="req-closed-notification-hook", + tool_call_id="call-closed-notification-hook", + sender="WriteFile", + action="edit file", + description="Write file /tmp/test.txt", + display=[], + source=ApprovalSource(kind="foreground_turn", id="turn-closed-notification-hook"), + ) + + assert calls == [] + + @pytest.mark.asyncio async def test_kimisoul_run_preserves_existing_approval_source( runtime, tmp_path, monkeypatch