diff --git a/src/kimi_cli/approval_runtime/runtime.py b/src/kimi_cli/approval_runtime/runtime.py index 22199b2da..1eff000f2 100644 --- a/src/kimi_cli/approval_runtime/runtime.py +++ b/src/kimi_cli/approval_runtime/runtime.py @@ -1,9 +1,10 @@ from __future__ import annotations import asyncio +import time import uuid from contextvars import ContextVar, Token -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from kimi_cli.utils.logging import logger from kimi_cli.wire.types import ApprovalRequest, ApprovalResponse @@ -45,6 +46,12 @@ def reset_current_approval_source(token: Token[ApprovalSource | None]) -> None: _current_approval_source.reset(token) +_MAX_RETAINED_REQUESTS = 100 +_RESOLVED_GRACE_SECONDS = 30.0 +_MAX_RESOLVED_CACHE = 200 +_ResolvedEntry = tuple[Literal["resolved", "cancelled"], ApprovalResponseKind, str] + + class ApprovalRuntime: def __init__(self) -> None: self._requests: dict[str, ApprovalRequestRecord] = {} @@ -52,6 +59,11 @@ def __init__(self) -> None: self._waiter_counts: dict[str, int] = {} self._subscribers: dict[str, Callable[[ApprovalRuntimeEvent], None]] = {} self._root_wire_hub: RootWireHub | None = None + # Cache recent resolved/cancelled results so wait_for_response() + # can still answer after the request record has been pruned. + # Each entry is (status, response, feedback) where status is + # "resolved" or "cancelled". + self._resolved_cache: dict[str, _ResolvedEntry] = {} def bind_root_wire_hub(self, root_wire_hub: RootWireHub) -> None: if self._root_wire_hub is root_wire_hub: @@ -89,6 +101,13 @@ async def wait_for_response( waiter = self._waiters.get(request_id) request = self._requests.get(request_id) if request is None: + # The request may have been pruned; check the resolved cache. + cached = self._resolved_cache.get(request_id) + if cached is not None: + status, response, feedback = cached + if status == "cancelled": + raise ApprovalCancelledError(request_id) + return response, feedback raise KeyError(f"Approval request not found: {request_id}") if waiter is None: if request.status == "cancelled": @@ -133,20 +152,21 @@ def resolve(self, request_id: str, response: ApprovalResponseKind, feedback: str request.status = "resolved" request.response = response request.feedback = feedback - import time - request.resolved_at = time.time() waiter = self._waiters.pop(request_id, None) if waiter is not None and not waiter.done(): waiter.set_result((response, feedback)) self._publish_event(ApprovalRuntimeEvent(kind="request_resolved", request=request)) self._publish_wire_response(request_id, response, feedback) + self._resolved_cache[request_id] = ("resolved", response, feedback) + self._prune_requests() + # Cap resolved cache independently of request dict + while len(self._resolved_cache) > _MAX_RESOLVED_CACHE: + self._resolved_cache.pop(next(iter(self._resolved_cache))) return True def _cancel_request(self, request_id: str, feedback: str = "") -> None: """Cancel a single pending request by ID.""" - import time - request = self._requests.get(request_id) if request is None or request.status != "pending": return @@ -159,11 +179,14 @@ def _cancel_request(self, request_id: str, feedback: str = "") -> None: waiter.set_exception(ApprovalCancelledError(request_id)) self._publish_event(ApprovalRuntimeEvent(kind="request_resolved", request=request)) self._publish_wire_response(request_id, "reject", feedback) + self._resolved_cache[request_id] = ("cancelled", "reject", feedback) + self._prune_requests() + while len(self._resolved_cache) > _MAX_RESOLVED_CACHE: + self._resolved_cache.pop(next(iter(self._resolved_cache))) def cancel_by_source(self, source_kind: ApprovalSourceKind, source_id: str) -> int: cancelled = 0 - import time - + to_remove: list[str] = [] for request_id, request in self._requests.items(): if request.status != "pending": continue @@ -177,9 +200,57 @@ def cancel_by_source(self, source_kind: ApprovalSourceKind, source_id: str) -> i waiter.set_exception(ApprovalCancelledError(request_id)) self._publish_event(ApprovalRuntimeEvent(kind="request_resolved", request=request)) self._publish_wire_response(request_id, "reject") + to_remove.append(request_id) cancelled += 1 + + self._cache_cancelled_requests(to_remove) + self._prune_requests() + while len(self._resolved_cache) > _MAX_RESOLVED_CACHE: + self._resolved_cache.pop(next(iter(self._resolved_cache))) return cancelled + def _cache_cancelled_requests(self, request_ids: list[str]) -> None: + """Cache cancelled request IDs so wait_for_response() can raise + ApprovalCancelledError even after the original record is pruned.""" + for request_id in request_ids: + self._resolved_cache[request_id] = ("cancelled", "reject", "") + + def _prune_requests(self) -> None: + """Remove resolved/cancelled requests to prevent unbounded growth. + + Only removes requests that have been resolved/cancelled for longer + than _RESOLVED_GRACE_SECONDS so that recent results remain + queryable via wait_for_response(). + """ + if len(self._requests) <= _MAX_RETAINED_REQUESTS: + return + now = time.time() + # Collect resolved/cancelled request IDs that are past the grace period + to_remove = [ + req_id + for req_id, req in self._requests.items() + if req.status in ("resolved", "cancelled") + and req.resolved_at is not None + and (now - req.resolved_at) > _RESOLVED_GRACE_SECONDS + ] + # Remove enough to get back under the limit + excess = len(self._requests) - _MAX_RETAINED_REQUESTS + for req_id in to_remove[:excess]: + self._requests.pop(req_id, None) + # Fallback: if we are still over the limit (e.g. all resolved very + # recently), evict the oldest resolved/cancelled entries regardless + # of grace period so the dict cannot grow unbounded. + if len(self._requests) > _MAX_RETAINED_REQUESTS: + done = [ + (req_id, req) + for req_id, req in self._requests.items() + if req.status in ("resolved", "cancelled") + ] + done.sort(key=lambda item: item[1].resolved_at or 0) + still_excess = len(self._requests) - _MAX_RETAINED_REQUESTS + for req_id, _ in done[:still_excess]: + self._requests.pop(req_id, None) + def list_pending(self) -> list[ApprovalRequestRecord]: pending = [request for request in self._requests.values() if request.status == "pending"] pending.sort(key=lambda request: request.created_at) diff --git a/src/kimi_cli/metadata.py b/src/kimi_cli/metadata.py index 88242d89e..82672a717 100644 --- a/src/kimi_cli/metadata.py +++ b/src/kimi_cli/metadata.py @@ -57,7 +57,12 @@ def get_work_dir_meta(self, path: KaosPath) -> WorkDirMeta | None: def new_work_dir_meta(self, path: KaosPath) -> WorkDirMeta: """Create a new work directory metadata.""" - wd_meta = WorkDirMeta(path=str(path), kaos=get_current_kaos().name) + kaos_name = get_current_kaos().name + wd_meta = WorkDirMeta(path=str(path), kaos=kaos_name) + # Deduplicate: remove existing entry for the same path + kaos + self.work_dirs = [ + wd for wd in self.work_dirs if not (wd.path == str(path) and wd.kaos == kaos_name) + ] self.work_dirs.append(wd_meta) return wd_meta diff --git a/src/kimi_cli/ui/shell/placeholders.py b/src/kimi_cli/ui/shell/placeholders.py index f58c39c84..dd5688b85 100644 --- a/src/kimi_cli/ui/shell/placeholders.py +++ b/src/kimi_cli/ui/shell/placeholders.py @@ -111,6 +111,7 @@ def __init__( self._legacy_roots = tuple(legacy_roots or (_LEGACY_PROMPT_CACHE_ROOT,)) self._dir_map: dict[CachedAttachmentKind, str] = {"image": "images"} self._payload_map: dict[tuple[CachedAttachmentKind, str, str], CachedAttachment] = {} + self._max_payload_map = 1000 def _dir_for(self, kind: CachedAttachmentKind, *, root: Path | None = None) -> Path: return (self._root if root is None else root) / self._dir_map[kind] @@ -164,6 +165,11 @@ def store_bytes( cached = CachedAttachment(kind=kind, attachment_id=attachment_id, path=path) self._payload_map[cache_key] = cached + # Prevent unbounded in-memory growth. We intentionally do NOT + # delete files on disk because image tokens in prompt history + # may still reference them. + while len(self._payload_map) > self._max_payload_map: + self._payload_map.pop(next(iter(self._payload_map))) return cached def store_image(self, image: Image.Image) -> CachedAttachment | None: diff --git a/src/kimi_cli/utils/aioqueue.py b/src/kimi_cli/utils/aioqueue.py index 92756f662..9108eff08 100644 --- a/src/kimi_cli/utils/aioqueue.py +++ b/src/kimi_cli/utils/aioqueue.py @@ -1,29 +1,28 @@ from __future__ import annotations import asyncio +import contextlib import sys if sys.version_info >= (3, 13): QueueShutDown = asyncio.QueueShutDown # type: ignore[assignment] class Queue[T](asyncio.Queue[T]): - """Asyncio Queue with shutdown support.""" + """Asyncio Queue with shutdown support (Python 3.13+ native).""" + + def __init__(self, *, maxsize: int = 0) -> None: + super().__init__(maxsize=maxsize) else: class QueueShutDown(Exception): """Raised when operating on a shut down queue.""" - class _Shutdown: - """Sentinel for queue shutdown.""" - - _SHUTDOWN = _Shutdown() - - class Queue[T](asyncio.Queue[T | _Shutdown]): + class Queue[T](asyncio.Queue[T]): """Asyncio Queue with shutdown support for Python < 3.13.""" - def __init__(self) -> None: - super().__init__() + def __init__(self, *, maxsize: int = 0) -> None: + super().__init__(maxsize=maxsize) self._shutdown = False def shutdown(self, immediate: bool = False) -> None: @@ -31,40 +30,57 @@ def shutdown(self, immediate: bool = False) -> None: return self._shutdown = True if immediate: - self._queue.clear() - - getters = list(getattr(self, "_getters", [])) - count = max(1, len(getters)) - self._enqueue_shutdown(count) - - def _enqueue_shutdown(self, count: int) -> None: - for _ in range(count): - try: - super().put_nowait(_SHUTDOWN) - except asyncio.QueueFull: - self._queue.clear() - super().put_nowait(_SHUTDOWN) + self._queue.clear() # type: ignore[attr-defined] + + # Wake all getters so they can check the shutdown flag and + # raise QueueShutDown instead of re-blocking forever. + # NOTE: _wakeup_next is a private asyncio.Queue method that has + # been stable since Python 3.7. We use it because there is no + # public API to wake a specific waiter. + while getattr(self, "_getters", []): + with contextlib.suppress(IndexError): + self._wakeup_next(self._getters) # type: ignore[attr-defined] + + # Wake all putters so they re-check shutdown instead of + # hanging on a full bounded queue. + while getattr(self, "_putters", []): + with contextlib.suppress(IndexError): + self._wakeup_next(self._putters) # type: ignore[attr-defined] async def get(self) -> T: if self._shutdown and self.empty(): raise QueueShutDown - item = await super().get() - if isinstance(item, _Shutdown): - raise QueueShutDown - return item + while self.empty(): + getter = asyncio.get_running_loop().create_future() + self._getters.append(getter) + try: + await getter + finally: + with contextlib.suppress(ValueError): + self._getters.remove(getter) + if self._shutdown and self.empty(): + raise QueueShutDown + return super().get_nowait() def get_nowait(self) -> T: if self._shutdown and self.empty(): raise QueueShutDown - item = super().get_nowait() - if isinstance(item, _Shutdown): - raise QueueShutDown - return item + return super().get_nowait() async def put(self, item: T) -> None: if self._shutdown: raise QueueShutDown - await super().put(item) + while self.full(): + putter = asyncio.get_running_loop().create_future() + self._putters.append(putter) + try: + await putter + finally: + with contextlib.suppress(ValueError): + self._putters.remove(putter) + if self._shutdown: + raise QueueShutDown + super().put_nowait(item) def put_nowait(self, item: T) -> None: if self._shutdown: diff --git a/src/kimi_cli/utils/broadcast.py b/src/kimi_cli/utils/broadcast.py index 296ddfd0e..b9c10593f 100644 --- a/src/kimi_cli/utils/broadcast.py +++ b/src/kimi_cli/utils/broadcast.py @@ -1,19 +1,33 @@ import asyncio +import contextlib from kimi_cli.utils.aioqueue import Queue class BroadcastQueue[T]: - """ - A broadcast queue that allows multiple subscribers to receive published items. + """A broadcast queue that allows multiple subscribers to receive published items. + + Each subscriber gets its own queue. By default queues are bounded + (``maxsize=1000``) to prevent unbounded memory growth; an unbounded + queue can be requested with ``maxsize=0``. Critical consumers + (e.g. wire recorders and waitable request paths) should use an + unbounded queue. """ - def __init__(self) -> None: + def __init__(self, *, maxsize: int = 1000) -> None: self._queues: set[Queue[T]] = set() - - def subscribe(self) -> Queue[T]: - """Create a new subscription queue.""" - queue: Queue[T] = Queue() + self._maxsize = maxsize + + def subscribe(self, *, maxsize: int | None = None) -> Queue[T]: + """Create a new subscription queue. + + Args: + maxsize: Maximum queue size. ``None`` uses the broadcast + queue's default (``1000``). ``0`` means unbounded. + Pass a positive value for lossy consumers that may fall + behind and can tolerate dropped messages. + """ + queue: Queue[T] = Queue(maxsize=maxsize if maxsize is not None else self._maxsize) self._queues.add(queue) return queue @@ -22,16 +36,28 @@ def unsubscribe(self, queue: Queue[T]) -> None: self._queues.discard(queue) async def publish(self, item: T) -> None: - """Publish an item to all subscription queues.""" - await asyncio.gather(*(queue.put(item) for queue in self._queues)) + """Publish an item to all subscription queues, awaiting space. + + This blocks until every subscriber has room for the item. + """ + for queue in list(self._queues): + await queue.put(item) def publish_nowait(self, item: T) -> None: - """Publish an item to all subscription queues without waiting.""" - for queue in self._queues: - queue.put_nowait(item) + """Publish an item to all subscription queues without waiting. + + If a single subscriber's queue is full, that subscriber is + skipped so that later subscribers still receive the item. + Callers that require guaranteed delivery (e.g. waitable + requests) should use an unbounded queue so no subscriber is + ever skipped. + """ + for queue in list(self._queues): + with contextlib.suppress(asyncio.QueueFull): + queue.put_nowait(item) def shutdown(self, immediate: bool = False) -> None: """Close all subscription queues.""" - for queue in self._queues: + for queue in list(self._queues): queue.shutdown(immediate=immediate) self._queues.clear() diff --git a/src/kimi_cli/web/store/sessions.py b/src/kimi_cli/web/store/sessions.py index fb6f43bc1..6630f1f24 100644 --- a/src/kimi_cli/web/store/sessions.py +++ b/src/kimi_cli/web/store/sessions.py @@ -34,6 +34,7 @@ # Cache configuration CACHE_TTL = 5.0 # seconds - balance between freshness and performance +MAX_CACHED_SESSIONS = 100 # hard limit to prevent unbounded memory growth # Auto-archive configuration AUTO_ARCHIVE_DAYS = 15 # Sessions older than this will be auto-archived @@ -288,11 +289,23 @@ def _load_sessions_index_cached() -> list[SessionIndexEntry]: return _sessions_index_cache -def load_all_sessions() -> list[JointSession]: - """Load all sessions from all work directories.""" +def load_all_sessions(limit: int | None = None) -> list[JointSession]: + """Load all sessions from all work directories. + + Args: + limit: If given, only the most-recently-updated *limit* sessions + are fully materialised. This avoids building expensive + :class:`JointSession` objects for sessions that will be + discarded anyway. + """ entries = _load_sessions_index_cached() - sessions: list[JointSession] = [] + # Sort by mtime (most recent first) before materialising so we only + # build JointSession objects for the sessions we actually need. + entries.sort(key=lambda e: e.context_file.stat().st_mtime, reverse=True) + if limit is not None: + entries = entries[:limit] + sessions: list[JointSession] = [] for entry in entries: _ensure_title(entry, refresh=False) sessions.append(_build_joint_session(entry)) @@ -316,7 +329,9 @@ def load_all_sessions_cached() -> list[JointSession]: if _sessions_cache is not None and (now - _cache_timestamp) < CACHE_TTL: return _sessions_cache - _sessions_cache = load_all_sessions() + sessions = load_all_sessions(limit=MAX_CACHED_SESSIONS) + + _sessions_cache = sessions _cache_timestamp = now return _sessions_cache diff --git a/src/kimi_cli/wire/__init__.py b/src/kimi_cli/wire/__init__.py index 9480f539b..512cf07bb 100644 --- a/src/kimi_cli/wire/__init__.py +++ b/src/kimi_cli/wire/__init__.py @@ -12,7 +12,12 @@ from kimi_cli.wire.file import WireFile from kimi_cli.wire.types import ContentPart, ToolCallPart, WireMessage, is_wire_message -WireMessageQueue = BroadcastQueue[WireMessage] + +def _WireMessageQueue() -> BroadcastQueue[WireMessage]: + """Wire message queue. Defaults to bounded (1000); callers that need + guaranteed delivery (recorder, waitable requests) should pass + ``maxsize=0`` when subscribing.""" + return BroadcastQueue[WireMessage]() class Wire: @@ -21,14 +26,15 @@ class Wire: """ def __init__(self, *, file_backend: WireFile | None = None): - self._raw_queue = WireMessageQueue() - self._merged_queue = WireMessageQueue() + self._raw_queue = _WireMessageQueue() + self._merged_queue = _WireMessageQueue() self._soul_side = WireSoulSide(self._raw_queue, self._merged_queue) if file_backend is not None: # record all complete Wire messages to the file backend - self._recorder = _WireRecorder(file_backend, self._merged_queue.subscribe()) + # recorder uses an unbounded queue so it never drops events + self._recorder = _WireRecorder(file_backend, self._merged_queue.subscribe(maxsize=0)) else: self._recorder = None @@ -68,7 +74,11 @@ class WireSoulSide: The soul side of a `Wire`. """ - def __init__(self, raw_queue: WireMessageQueue, merged_queue: WireMessageQueue): + def __init__( + self, + raw_queue: BroadcastQueue[WireMessage], + merged_queue: BroadcastQueue[WireMessage], + ) -> None: self._raw_queue = raw_queue self._merged_queue = merged_queue self._merge_buffer: MergeableMixin | None = None @@ -81,7 +91,10 @@ def send(self, msg: WireMessage) -> None: try: self._raw_queue.publish_nowait(msg) except QueueShutDown: - logger.info("Failed to send raw wire message, queue is shut down: {msg}", msg=msg) + logger.info( + "Failed to send raw wire message, queue is shut down: {msg}", + msg=msg, + ) # merge and send merged message match msg: @@ -109,7 +122,10 @@ def _send_merged(self, msg: WireMessage) -> None: try: self._merged_queue.publish_nowait(msg) except QueueShutDown: - logger.info("Failed to send merged wire message, queue is shut down: {msg}", msg=msg) + logger.info( + "Failed to send merged wire message, queue is shut down: {msg}", + msg=msg, + ) class WireUISide: diff --git a/src/kimi_cli/wire/root_hub.py b/src/kimi_cli/wire/root_hub.py index f2fa1734c..a566ae2d1 100644 --- a/src/kimi_cli/wire/root_hub.py +++ b/src/kimi_cli/wire/root_hub.py @@ -9,10 +9,17 @@ class RootWireHub: """Session-level broadcast hub for out-of-turn wire messages.""" def __init__(self) -> None: - self._queue = BroadcastQueue[WireMessage]() - - def subscribe(self) -> Queue[WireMessage]: - return self._queue.subscribe() + # Unbounded so that waitable requests (QuestionRequest, + # ToolCallRequest) are never dropped. + self._queue = BroadcastQueue[WireMessage](maxsize=0) + + def subscribe(self, *, maxsize: int | None = None) -> Queue[WireMessage]: + # Default to a bounded queue for UI consumers so slow subscribers + # do not cause unbounded memory growth. Critical paths (e.g. + # the wire recorder) should pass maxsize=0 for an unbounded queue. + if maxsize is None: + maxsize = 1000 + return self._queue.subscribe(maxsize=maxsize) def unsubscribe(self, queue: Queue[WireMessage]) -> None: self._queue.unsubscribe(queue) diff --git a/tests/core/test_approval_runtime.py b/tests/core/test_approval_runtime.py index fdb2eba86..66ea99f06 100644 --- a/tests/core/test_approval_runtime.py +++ b/tests/core/test_approval_runtime.py @@ -407,3 +407,63 @@ async def fake_ensure_fresh(_runtime): assert background is not None assert background.status == "pending" assert runtime.approval_runtime.list_pending() == [background] + + +@pytest.mark.asyncio +async def test_approval_runtime_prunes_resolved_requests() -> None: + """Resolved/cancelled requests are pruned when the dict grows too large.""" + runtime = ApprovalRuntime() + # Create and resolve 150 requests to trigger pruning + for i in range(120): + req = runtime.create_request( + request_id=f"req-{i}", + tool_call_id=f"call-{i}", + sender="Shell", + action="run", + description="test", + display=[], + source=ApprovalSource(kind="foreground_turn", id="turn-1"), + ) + runtime.resolve(req.id, "approve") + + # After pruning, should have <= 100 resolved requests + assert len(runtime._requests) <= 100 + # All remaining should still be resolvable + assert all(req.status == "resolved" for req in runtime._requests.values()) + + +@pytest.mark.asyncio +async def test_wait_for_response_still_works_after_prune() -> None: + """wait_for_response() returns cached result even after the request is pruned.""" + runtime = ApprovalRuntime() + req = runtime.create_request( + request_id="req-target", + tool_call_id="call-target", + sender="Shell", + action="run", + description="target", + display=[], + source=ApprovalSource(kind="foreground_turn", id="turn-1"), + ) + runtime.resolve(req.id, "approve") + + # Flood with 120 more requests to force prune + for i in range(120): + r = runtime.create_request( + request_id=f"req-{i}", + tool_call_id=f"call-{i}", + sender="Shell", + action="run", + description="flood", + display=[], + source=ApprovalSource(kind="foreground_turn", id="turn-1"), + ) + runtime.resolve(r.id, "approve") + + # The target request should have been pruned (oldest resolved) + assert req.id not in runtime._requests + + # But wait_for_response should still return the result without error + response, feedback = await runtime.wait_for_response(req.id) + assert response == "approve" + assert feedback == "" diff --git a/tests/utils/test_broadcast_queue.py b/tests/utils/test_broadcast_queue.py index f0b52ba5d..605b65500 100644 --- a/tests/utils/test_broadcast_queue.py +++ b/tests/utils/test_broadcast_queue.py @@ -75,3 +75,104 @@ async def test_publish_to_empty_queue(): # Should not raise any exception await broadcast.publish("no_subscribers") broadcast.publish_nowait("no_subscribers") + + +async def test_publish_nowait_skips_full_subscriber(): + """publish_nowait skips a full subscriber but continues to others.""" + broadcast = BroadcastQueue(maxsize=2) + bounded = broadcast.subscribe() + unbounded = broadcast.subscribe(maxsize=0) + + broadcast.publish_nowait("msg_1") + broadcast.publish_nowait("msg_2") + assert bounded.qsize() == 2 + assert unbounded.qsize() == 2 + + # bounded is full, but unbounded should still receive msg_3. + broadcast.publish_nowait("msg_3") + assert bounded.qsize() == 2 + assert unbounded.qsize() == 3 + + +async def test_default_maxsize_is_bounded(): + """Default BroadcastQueue is bounded (maxsize=1000) for production safety.""" + broadcast = BroadcastQueue() + queue = broadcast.subscribe() + assert queue.maxsize == 1000 + + +async def test_graceful_shutdown_preserves_items(): + """shutdown(immediate=False) must not drop pending items from a full queue.""" + broadcast = BroadcastQueue(maxsize=3) + queue = broadcast.subscribe() + + # Fill the queue via async publish (put) since put_nowait would skip. + await broadcast.publish("keep_1") + await broadcast.publish("keep_2") + await broadcast.publish("keep_3") + assert queue.qsize() == 3 + + # Graceful shutdown should preserve the three items. + broadcast.shutdown(immediate=False) + + assert await queue.get() == "keep_1" + assert await queue.get() == "keep_2" + assert await queue.get() == "keep_3" + + with pytest.raises(QueueShutDown): + queue.get_nowait() + + +async def test_immediate_shutdown_clears_items(): + """shutdown(immediate=True) may drop pending items to unblock immediately.""" + broadcast = BroadcastQueue(maxsize=3) + queue = broadcast.subscribe() + + await broadcast.publish("drop_1") + await broadcast.publish("drop_2") + await broadcast.publish("drop_3") + + broadcast.shutdown(immediate=True) + + with pytest.raises(QueueShutDown): + queue.get_nowait() + + +async def test_publish_blocks_until_space(): + """async publish() blocks on a full queue until space is available.""" + broadcast = BroadcastQueue(maxsize=2) + queue = broadcast.subscribe() + + broadcast.publish_nowait("old_1") + broadcast.publish_nowait("old_2") + assert queue.qsize() == 2 + + # publish() should block until a consumer frees space. + task = asyncio.create_task(broadcast.publish("new_msg")) + await asyncio.sleep(0) # let the task start and block + + assert queue.qsize() == 2 # still full + assert await queue.get() == "old_1" + + # Now the blocked publish can proceed. + await asyncio.wait_for(task, timeout=1.0) + assert queue.qsize() == 2 + assert await queue.get() == "old_2" + assert await queue.get() == "new_msg" + + +async def test_subscribe_with_custom_maxsize(): + """subscribe() accepts a per-subscriber maxsize.""" + broadcast = BroadcastQueue(maxsize=10) + bounded = broadcast.subscribe() + unbounded = broadcast.subscribe(maxsize=0) + + assert bounded.maxsize == 10 + assert unbounded.maxsize == 0 + + +async def test_subscribe_defaults_to_broadcast_maxsize(): + """subscribe() without args uses the broadcast queue's maxsize.""" + broadcast = BroadcastQueue(maxsize=42) + queue = broadcast.subscribe() + assert queue.maxsize == 42 diff --git a/tests/web/test_sessions_api.py b/tests/web/test_sessions_api.py index c7bbe984b..5dda1fdd3 100644 --- a/tests/web/test_sessions_api.py +++ b/tests/web/test_sessions_api.py @@ -114,3 +114,32 @@ async def fake_generate(*, chat_provider, system_prompt, tools, history): assert state.custom_title == "Manual Title" assert state.title_generated is True assert state.title_generate_attempts == 0 + + +@pytest.mark.anyio +async def test_load_all_sessions_cached_respects_max_limit( + isolated_share_dir: Path, + work_dir: KaosPath, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """The sessions cache must not grow unbounded; it is truncated to MAX_CACHED_SESSIONS.""" + from kimi_cli.web.store import sessions as store + + # Create more sessions than MAX_CACHED_SESSIONS. + original_limit = store.MAX_CACHED_SESSIONS + monkeypatch.setattr(store, "MAX_CACHED_SESSIONS", 5) + + sessions = [] + for _ in range(10): + s = await Session.create(work_dir) + sessions.append(s) + + # Force cache refresh. + store.invalidate_sessions_cache() + cached = store.load_all_sessions_cached() + + assert len(cached) == 5, f"Expected 5 cached sessions, got {len(cached)}" + + # Restore limit. + monkeypatch.setattr(store, "MAX_CACHED_SESSIONS", original_limit) + store.invalidate_sessions_cache()