Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions packages/kosong/src/kosong/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def main() -> None:

from loguru import logger

from kosong._generate import GenerateResult, generate
from kosong._generate import GenerateCancelled, GenerateResult, generate
from kosong.chat_provider import ChatProvider, ChatProviderError, StreamedMessagePart, TokenUsage
from kosong.message import Message, ToolCall
from kosong.tooling import ToolResult, ToolResultFuture, Toolset
Expand All @@ -98,9 +98,34 @@ async def main() -> None:
"GenerateResult",
"step",
"StepResult",
"StepCancelled",
]


class StepCancelled(asyncio.CancelledError):
"""CancelledError carrying a partial :class:`StepResult`.

Raised by :func:`step` when the underlying generation is cancelled. The
partial result records:

- ``message``: everything streamed so far (including a best-effort flush
of the pending part, so its ``tool_calls`` mirrors what the UI saw).
- ``tool_calls``: aligned with ``message.tool_calls`` so callers can
iterate them as the canonical "what UI saw" list.
- ``_tool_result_futures``: only the tools whose ``on_tool_call`` had
already fired before cancellation — these may still be running. The
caller decides whether to await or cancel them.

Subclassing :class:`asyncio.CancelledError` keeps the cancellation
contract intact: callers that don't care about the partial result still
see a plain cancel.
"""

def __init__(self, partial: "StepResult"):
super().__init__()
self.partial = partial


async def step(
chat_provider: ChatProvider,
system_prompt: str,
Expand All @@ -127,7 +152,9 @@ async def step(
APIStatusError: If the API returns a status code of 4xx or 5xx.
APIEmptyResponseError: If the API returns an empty response.
ChatProviderError: If any other recognized chat provider error occurs.
asyncio.CancelledError: If the step is cancelled.
StepCancelled: When the step is cancelled by the caller. Carries a
partial :class:`StepResult` so callers can persist a well-formed
history (every TUI-visible tool_call paired with a tool_result).
"""

tool_calls: list[ToolCall] = []
Expand Down Expand Up @@ -163,7 +190,22 @@ async def on_tool_call(tool_call: ToolCall):
on_message_part=on_message_part,
on_tool_call=on_tool_call,
)
except (ChatProviderError, asyncio.CancelledError):
except GenerateCancelled as cancelled:
# Streaming was interrupted. The partial message reflects everything
# the UI saw. Align tool_calls with message.tool_calls so callers
# have a single canonical iteration list — message.tool_calls may
# include a flushed pending ToolCall that never fired on_tool_call
# (and hence has no future in tool_result_futures).
merged_tool_calls = list(cancelled.message.tool_calls or [])
partial = StepResult(
id=None,
message=cancelled.message,
usage=None,
tool_calls=merged_tool_calls,
_tool_result_futures=tool_result_futures,
)
raise StepCancelled(partial) from None
except ChatProviderError:
# cancel all the futures to avoid hanging tasks
for future in tool_result_futures.values():
future.remove_done_callback(future_done_callback)
Expand Down
81 changes: 63 additions & 18 deletions packages/kosong/src/kosong/_generate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from collections.abc import Sequence
from dataclasses import dataclass

Expand All @@ -14,6 +15,23 @@
from kosong.utils.aio import Callback, callback


class GenerateCancelled(asyncio.CancelledError):
"""CancelledError carrying the partial message accumulated up to the cancel point.

Raised by :func:`generate` when its streaming loop is cancelled. The
``message`` field reflects everything observed via the stream so far, with
any pending part flushed in best-effort fashion — including incomplete
ToolCall arguments. No ``on_tool_call`` callback is fired for the flushed
pending part (we don't want to start a tool the user just cancelled), so
callers must treat any tool_call in ``message`` without a paired result
future as "interrupted".
"""

def __init__(self, message: Message):
super().__init__()
self.message = message


async def generate(
chat_provider: ChatProvider,
system_prompt: str,
Expand Down Expand Up @@ -45,31 +63,58 @@ async def generate(
APIStatusError: If the API returns a status code of 4xx or 5xx.
APIEmptyResponseError: If the API returns an empty response.
ChatProviderError: If any other recognized chat provider error occurs.
GenerateCancelled: When the streaming loop is cancelled. Carries the
partial message observed so far so callers can decide whether to
persist it.
"""
message = Message(role="assistant", content=[])
pending_part: StreamedMessagePart | None = None # message part that is currently incomplete

logger.trace("Generating with history: {history}", history=history)
stream = await chat_provider.generate(system_prompt, tools, history)
async for part in stream:
logger.trace("Received part: {part}", part=part)
if on_message_part:
await callback(on_message_part, part.model_copy(deep=True))

if pending_part is None:
pending_part = part
elif not pending_part.merge_in_place(part): # try merge into the pending part
# unmergeable part must push the pending part to the buffer
try:
async for part in stream:
logger.trace("Received part: {part}", part=part)
if on_message_part:
await callback(on_message_part, part.model_copy(deep=True))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Record streamed parts before awaiting callbacks

When on_message_part is an async callback, a user cancellation delivered while this await is in progress raises GenerateCancelled before the newly streamed part has been merged into pending_part or message. In that scenario the callback may already have rendered a text/tool-call chunk to the UI, but the partial result used to repair history omits it, so a visible tool call can still be persisted without the required synthetic tool result (or lost entirely).

Useful? React with 👍 / 👎.


if pending_part is None:
pending_part = part
elif not pending_part.merge_in_place(part): # try merge into the pending part
# Unmergeable part: flush the previously pending one.
#
# Invariant: ``pending_part`` is non-None iff that part has
# NOT yet been appended to ``message``. Clearing it BEFORE
# the await means a CancelledError raised inside
# ``on_tool_call`` can't trick the except-handler below into
# re-appending the same part — a duplicate tool_call.id
# would propagate into the partial StepResult and produce
# duplicate paired tool_results downstream.
flushing = pending_part
pending_part = part
_message_append(message, flushing)
if isinstance(flushing, ToolCall) and on_tool_call:
await callback(on_tool_call, flushing)

# end of message
if pending_part is not None:
# Same race protection as the mid-loop flush: drop the reference
# before awaiting on_tool_call so a cancel during the callback
# leaves ``pending_part`` as None in the except handler.
flushing = pending_part
pending_part = None
_message_append(message, flushing)
if isinstance(flushing, ToolCall) and on_tool_call:
await callback(on_tool_call, flushing)
except asyncio.CancelledError:
# Best-effort flush of the pending part so the caller sees everything
# the UI has seen. Crucially do NOT fire on_tool_call here — that
# callback starts the tool, and the user just asked us to stop.
# Thanks to the "clear before await" pattern above, pending_part is
# non-None here only if it was never appended.
if pending_part is not None:
_message_append(message, pending_part)
if isinstance(pending_part, ToolCall) and on_tool_call:
await callback(on_tool_call, pending_part)
pending_part = part

# end of message
if pending_part is not None:
_message_append(message, pending_part)
if isinstance(pending_part, ToolCall) and on_tool_call:
await callback(on_tool_call, pending_part)
raise GenerateCancelled(message) from None

if not message.content and not message.tool_calls:
raise APIEmptyResponseError("The API returned an empty response.")
Expand Down
Loading
Loading