Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 78 additions & 7 deletions src/kimi_cli/approval_runtime/runtime.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import asyncio
import time
import uuid
from contextvars import ContextVar, Token
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

from kimi_cli.utils.logging import logger
from kimi_cli.wire.types import ApprovalRequest, ApprovalResponse
Expand Down Expand Up @@ -45,13 +46,24 @@ def reset_current_approval_source(token: Token[ApprovalSource | None]) -> None:
_current_approval_source.reset(token)


_MAX_RETAINED_REQUESTS = 100
_RESOLVED_GRACE_SECONDS = 30.0
_MAX_RESOLVED_CACHE = 200
_ResolvedEntry = tuple[Literal["resolved", "cancelled"], ApprovalResponseKind, str]


class ApprovalRuntime:
def __init__(self) -> None:
self._requests: dict[str, ApprovalRequestRecord] = {}
self._waiters: dict[str, asyncio.Future[tuple[ApprovalResponseKind, str]]] = {}
self._waiter_counts: dict[str, int] = {}
self._subscribers: dict[str, Callable[[ApprovalRuntimeEvent], None]] = {}
self._root_wire_hub: RootWireHub | None = None
# Cache recent resolved/cancelled results so wait_for_response()
# can still answer after the request record has been pruned.
# Each entry is (status, response, feedback) where status is
# "resolved" or "cancelled".
self._resolved_cache: dict[str, _ResolvedEntry] = {}

def bind_root_wire_hub(self, root_wire_hub: RootWireHub) -> None:
if self._root_wire_hub is root_wire_hub:
Expand Down Expand Up @@ -89,6 +101,13 @@ async def wait_for_response(
waiter = self._waiters.get(request_id)
request = self._requests.get(request_id)
if request is None:
# The request may have been pruned; check the resolved cache.
cached = self._resolved_cache.get(request_id)
if cached is not None:
status, response, feedback = cached
if status == "cancelled":
raise ApprovalCancelledError(request_id)
return response, feedback
raise KeyError(f"Approval request not found: {request_id}")
if waiter is None:
if request.status == "cancelled":
Expand Down Expand Up @@ -133,20 +152,21 @@ def resolve(self, request_id: str, response: ApprovalResponseKind, feedback: str
request.status = "resolved"
request.response = response
request.feedback = feedback
import time

request.resolved_at = time.time()
waiter = self._waiters.pop(request_id, None)
if waiter is not None and not waiter.done():
waiter.set_result((response, feedback))
self._publish_event(ApprovalRuntimeEvent(kind="request_resolved", request=request))
self._publish_wire_response(request_id, response, feedback)
self._resolved_cache[request_id] = ("resolved", response, feedback)
self._prune_requests()
# Cap resolved cache independently of request dict
while len(self._resolved_cache) > _MAX_RESOLVED_CACHE:
self._resolved_cache.pop(next(iter(self._resolved_cache)))
return True

def _cancel_request(self, request_id: str, feedback: str = "") -> None:
"""Cancel a single pending request by ID."""
import time

request = self._requests.get(request_id)
if request is None or request.status != "pending":
return
Expand All @@ -159,11 +179,14 @@ def _cancel_request(self, request_id: str, feedback: str = "") -> None:
waiter.set_exception(ApprovalCancelledError(request_id))
self._publish_event(ApprovalRuntimeEvent(kind="request_resolved", request=request))
self._publish_wire_response(request_id, "reject", feedback)
self._resolved_cache[request_id] = ("cancelled", "reject", feedback)
self._prune_requests()
while len(self._resolved_cache) > _MAX_RESOLVED_CACHE:
self._resolved_cache.pop(next(iter(self._resolved_cache)))

def cancel_by_source(self, source_kind: ApprovalSourceKind, source_id: str) -> int:
cancelled = 0
import time

to_remove: list[str] = []
for request_id, request in self._requests.items():
if request.status != "pending":
continue
Expand All @@ -177,9 +200,57 @@ def cancel_by_source(self, source_kind: ApprovalSourceKind, source_id: str) -> i
waiter.set_exception(ApprovalCancelledError(request_id))
self._publish_event(ApprovalRuntimeEvent(kind="request_resolved", request=request))
self._publish_wire_response(request_id, "reject")
to_remove.append(request_id)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Cache source-cancelled approvals before pruning

When cancel_by_source() cancels requests it now calls _prune_requests(), but this path never inserts those cancellations into _resolved_cache the way resolve() and _cancel_request() do. If a source cleanup cancels more than the retained-request window before a consumer calls wait_for_response() for one of the older request IDs, pruning can remove the request and the later wait raises KeyError instead of the expected ApprovalCancelledError, breaking callers that treat source lifecycle cancellation as a rejected approval.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in this commit. cancel_by_source() calls self._cache_cancelled_requests(to_remove) before _prune_requests(). Newly cancelled requests are cached with resolved_at=time.time(), so they survive the grace-period check.

cancelled += 1

self._cache_cancelled_requests(to_remove)
self._prune_requests()
while len(self._resolved_cache) > _MAX_RESOLVED_CACHE:
self._resolved_cache.pop(next(iter(self._resolved_cache)))
return cancelled

def _cache_cancelled_requests(self, request_ids: list[str]) -> None:
"""Cache cancelled request IDs so wait_for_response() can raise
ApprovalCancelledError even after the original record is pruned."""
for request_id in request_ids:
self._resolved_cache[request_id] = ("cancelled", "reject", "")

def _prune_requests(self) -> None:
"""Remove resolved/cancelled requests to prevent unbounded growth.

Only removes requests that have been resolved/cancelled for longer
than _RESOLVED_GRACE_SECONDS so that recent results remain
queryable via wait_for_response().
"""
if len(self._requests) <= _MAX_RETAINED_REQUESTS:
return
now = time.time()
# Collect resolved/cancelled request IDs that are past the grace period
to_remove = [
req_id
for req_id, req in self._requests.items()
if req.status in ("resolved", "cancelled")
and req.resolved_at is not None
and (now - req.resolved_at) > _RESOLVED_GRACE_SECONDS
]
# Remove enough to get back under the limit
excess = len(self._requests) - _MAX_RETAINED_REQUESTS
for req_id in to_remove[:excess]:
self._requests.pop(req_id, None)
# Fallback: if we are still over the limit (e.g. all resolved very
# recently), evict the oldest resolved/cancelled entries regardless
# of grace period so the dict cannot grow unbounded.
if len(self._requests) > _MAX_RETAINED_REQUESTS:
done = [
(req_id, req)
for req_id, req in self._requests.items()
if req.status in ("resolved", "cancelled")
]
done.sort(key=lambda item: item[1].resolved_at or 0)
still_excess = len(self._requests) - _MAX_RETAINED_REQUESTS
for req_id, _ in done[:still_excess]:
self._requests.pop(req_id, None)

def list_pending(self) -> list[ApprovalRequestRecord]:
pending = [request for request in self._requests.values() if request.status == "pending"]
pending.sort(key=lambda request: request.created_at)
Expand Down
7 changes: 6 additions & 1 deletion src/kimi_cli/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ def get_work_dir_meta(self, path: KaosPath) -> WorkDirMeta | None:

def new_work_dir_meta(self, path: KaosPath) -> WorkDirMeta:
"""Create a new work directory metadata."""
wd_meta = WorkDirMeta(path=str(path), kaos=get_current_kaos().name)
kaos_name = get_current_kaos().name
wd_meta = WorkDirMeta(path=str(path), kaos=kaos_name)
# Deduplicate: remove existing entry for the same path + kaos
self.work_dirs = [
wd for wd in self.work_dirs if not (wd.path == str(path) and wd.kaos == kaos_name)
]
self.work_dirs.append(wd_meta)
return wd_meta

Expand Down
6 changes: 6 additions & 0 deletions src/kimi_cli/ui/shell/placeholders.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
self._legacy_roots = tuple(legacy_roots or (_LEGACY_PROMPT_CACHE_ROOT,))
self._dir_map: dict[CachedAttachmentKind, str] = {"image": "images"}
self._payload_map: dict[tuple[CachedAttachmentKind, str, str], CachedAttachment] = {}
self._max_payload_map = 1000

def _dir_for(self, kind: CachedAttachmentKind, *, root: Path | None = None) -> Path:
return (self._root if root is None else root) / self._dir_map[kind]
Expand Down Expand Up @@ -164,6 +165,11 @@ def store_bytes(

cached = CachedAttachment(kind=kind, attachment_id=attachment_id, path=path)
self._payload_map[cache_key] = cached
# Prevent unbounded in-memory growth. We intentionally do NOT
# delete files on disk because image tokens in prompt history
# may still reference them.
while len(self._payload_map) > self._max_payload_map:
self._payload_map.pop(next(iter(self._payload_map)))
return cached

def store_image(self, image: Image.Image) -> CachedAttachment | None:
Expand Down
78 changes: 47 additions & 31 deletions src/kimi_cli/utils/aioqueue.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,86 @@
from __future__ import annotations

import asyncio
import contextlib
import sys

if sys.version_info >= (3, 13):
QueueShutDown = asyncio.QueueShutDown # type: ignore[assignment]

class Queue[T](asyncio.Queue[T]):
"""Asyncio Queue with shutdown support."""
"""Asyncio Queue with shutdown support (Python 3.13+ native)."""

def __init__(self, *, maxsize: int = 0) -> None:
super().__init__(maxsize=maxsize)

else:

class QueueShutDown(Exception):
"""Raised when operating on a shut down queue."""

class _Shutdown:
"""Sentinel for queue shutdown."""

_SHUTDOWN = _Shutdown()

class Queue[T](asyncio.Queue[T | _Shutdown]):
class Queue[T](asyncio.Queue[T]):
"""Asyncio Queue with shutdown support for Python < 3.13."""

def __init__(self) -> None:
super().__init__()
def __init__(self, *, maxsize: int = 0) -> None:
super().__init__(maxsize=maxsize)
self._shutdown = False

def shutdown(self, immediate: bool = False) -> None:
if self._shutdown:
return
self._shutdown = True
if immediate:
self._queue.clear()

getters = list(getattr(self, "_getters", []))
count = max(1, len(getters))
self._enqueue_shutdown(count)

def _enqueue_shutdown(self, count: int) -> None:
for _ in range(count):
try:
super().put_nowait(_SHUTDOWN)
except asyncio.QueueFull:
self._queue.clear()
super().put_nowait(_SHUTDOWN)
self._queue.clear() # type: ignore[attr-defined]

# Wake all getters so they can check the shutdown flag and
# raise QueueShutDown instead of re-blocking forever.
# NOTE: _wakeup_next is a private asyncio.Queue method that has
# been stable since Python 3.7. We use it because there is no
# public API to wake a specific waiter.
while getattr(self, "_getters", []):
with contextlib.suppress(IndexError):
self._wakeup_next(self._getters) # type: ignore[attr-defined]
Comment on lines +40 to +42
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Make shutdown fail already-blocked getters

On supported Python 3.12, a consumer that is already suspended inside await queue.get() when shutdown() is called is only woken here; asyncio.Queue.get() then re-enters its empty-queue wait without returning to this wrapper's pre-check, so it never raises QueueShutDown. Fresh evidence in this revision is that the sentinel path was replaced with _wakeup_next, but get() still has no post-wake shutdown check; this can hang Wire.shutdown()/wire.join() when the recorder is waiting on an empty queue at session shutdown.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shutdown() enqueues _Shutdown sentinel(s) and wakes all blocked getters. A getter blocked on await queue.get() receives the sentinel, raises QueueShutDown, and exits cleanly.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in bd9fdc9. Removed the sentinel approach entirely and instead overrode get() to add the post-wake shutdown check, matching Python 3.13's native Queue behavior. After shutdown() wakes blocked getters, get() now checks self._shutdown before re-blocking and raises QueueShutDown immediately.


# Wake all putters so they re-check shutdown instead of
# hanging on a full bounded queue.
while getattr(self, "_putters", []):
with contextlib.suppress(IndexError):
self._wakeup_next(self._putters) # type: ignore[attr-defined]

async def get(self) -> T:
if self._shutdown and self.empty():
raise QueueShutDown
item = await super().get()
if isinstance(item, _Shutdown):
raise QueueShutDown
return item
while self.empty():
getter = asyncio.get_running_loop().create_future()
self._getters.append(getter)
try:
await getter
finally:
with contextlib.suppress(ValueError):
self._getters.remove(getter)
if self._shutdown and self.empty():
raise QueueShutDown
return super().get_nowait()

def get_nowait(self) -> T:
if self._shutdown and self.empty():
raise QueueShutDown
item = super().get_nowait()
if isinstance(item, _Shutdown):
raise QueueShutDown
return item
return super().get_nowait()

async def put(self, item: T) -> None:
if self._shutdown:
raise QueueShutDown
await super().put(item)
while self.full():
putter = asyncio.get_running_loop().create_future()
self._putters.append(putter)
try:
await putter
finally:
with contextlib.suppress(ValueError):
self._putters.remove(putter)
if self._shutdown:
raise QueueShutDown
super().put_nowait(item)

def put_nowait(self, item: T) -> None:
if self._shutdown:
Expand Down
52 changes: 39 additions & 13 deletions src/kimi_cli/utils/broadcast.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
import asyncio
import contextlib

from kimi_cli.utils.aioqueue import Queue


class BroadcastQueue[T]:
"""
A broadcast queue that allows multiple subscribers to receive published items.
"""A broadcast queue that allows multiple subscribers to receive published items.

Each subscriber gets its own queue. By default queues are bounded
(``maxsize=1000``) to prevent unbounded memory growth; an unbounded
queue can be requested with ``maxsize=0``. Critical consumers
(e.g. wire recorders and waitable request paths) should use an
unbounded queue.
"""

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Bound actual production broadcast queues

With the default still set to 0, every production caller I found (WireMessageQueue() and RootWireHub both construct BroadcastQueue() without a maxsize, and their subscribers also omit one) continues to allocate unbounded per-subscriber queues. That means the slow UI/root-hub subscriber memory leak this change is meant to prevent is not fixed outside tests that explicitly pass a positive maxsize; either the default or the production call sites need a real bound while keeping recorder-critical paths explicitly unbounded.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Production queues are bounded. RootWireHub.subscribe() defaults to maxsize=1000. Wire.ui_side() explicitly passes maxsize=1000 for UI consumers. Only the internal recorder uses maxsize=0 (unbounded) by design.

def __init__(self) -> None:
def __init__(self, *, maxsize: int = 1000) -> None:
self._queues: set[Queue[T]] = set()

def subscribe(self) -> Queue[T]:
"""Create a new subscription queue."""
queue: Queue[T] = Queue()
self._maxsize = maxsize

def subscribe(self, *, maxsize: int | None = None) -> Queue[T]:
"""Create a new subscription queue.

Args:
maxsize: Maximum queue size. ``None`` uses the broadcast
queue's default (``1000``). ``0`` means unbounded.
Pass a positive value for lossy consumers that may fall
behind and can tolerate dropped messages.
"""
queue: Queue[T] = Queue(maxsize=maxsize if maxsize is not None else self._maxsize)
self._queues.add(queue)
return queue

Expand All @@ -22,16 +36,28 @@ def unsubscribe(self, queue: Queue[T]) -> None:
self._queues.discard(queue)

async def publish(self, item: T) -> None:
"""Publish an item to all subscription queues."""
await asyncio.gather(*(queue.put(item) for queue in self._queues))
"""Publish an item to all subscription queues, awaiting space.

This blocks until every subscriber has room for the item.
"""
for queue in list(self._queues):
await queue.put(item)

def publish_nowait(self, item: T) -> None:
"""Publish an item to all subscription queues without waiting."""
for queue in self._queues:
queue.put_nowait(item)
"""Publish an item to all subscription queues without waiting.

If a single subscriber's queue is full, that subscriber is
skipped so that later subscribers still receive the item.
Callers that require guaranteed delivery (e.g. waitable
requests) should use an unbounded queue so no subscriber is
ever skipped.
"""
for queue in list(self._queues):
with contextlib.suppress(asyncio.QueueFull):
queue.put_nowait(item)

def shutdown(self, immediate: bool = False) -> None:
"""Close all subscription queues."""
for queue in self._queues:
for queue in list(self._queues):
queue.shutdown(immediate=immediate)
self._queues.clear()
Loading