Skip to content

Commit 3dc7b09

Browse files
committed
Propagate model reasoning through AgentRun protocols
Constraint: Reasoning exposure must follow MODEL_PARAMETER_RULES.thinking from the agent environment. Rejected: Querying model metadata per invoke | The requirement pins behavior to agent env and would widen runtime behavior. Confidence: high Scope-risk: moderate Directive: Keep protocol gating centralized so future protocol handlers do not leak reasoning when thinking is disabled. Tested: uv run --extra server pytest tests/unittests/server -q; uv run --extra server ruff check agentrun/utils/reasoning.py agentrun/server/model.py agentrun/server/invoker.py agentrun/server/openai_protocol.py agentrun/server/agui_protocol.py scripts/smoke_reasoning_protocol.py tests/unittests/server/test_reasoning.py tests/unittests/server/test_openai_protocol.py tests/unittests/server/test_agui_protocol.py; real-model smoke with thinking=true and thinking=false for OpenAI and AG-UI Change-Id: I3b1ae025db4c0d26631cf4b4bb8e322dec77ae18 Not-tested: Remote CI and hosted preprod endpoint Signed-off-by: congxiao.wxx <congxiao.wxx@alibaba-inc.com>
1 parent a26d8b6 commit 3dc7b09

9 files changed

Lines changed: 1011 additions & 8 deletions

File tree

agentrun/server/agui_protocol.py

Lines changed: 131 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
from dataclasses import dataclass, field
11+
import json
1112
from typing import (
1213
Any,
1314
AsyncIterator,
@@ -30,6 +31,10 @@
3031
import pydash
3132

3233
from ..utils.helper import merge, MergeOptions
34+
from ..utils.reasoning import (
35+
get_reasoning_content,
36+
is_thinking_enabled_from_env,
37+
)
3338
from .model import (
3439
AgentEvent,
3540
AgentRequest,
@@ -60,6 +65,14 @@ class TextState:
6065
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
6166

6267

68+
@dataclass
69+
class ReasoningState:
70+
started: bool = False
71+
message_started: bool = False
72+
phase_id: str = field(default_factory=lambda: str(uuid.uuid4()))
73+
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
74+
75+
6376
@dataclass
6477
class ToolCallState:
6578
name: str = ""
@@ -72,6 +85,7 @@ class ToolCallState:
7285
@dataclass
7386
class StreamStateMachine:
7487
text: TextState = field(default_factory=TextState)
88+
reasoning: ReasoningState = field(default_factory=ReasoningState)
7589
tool_call_states: Dict[str, ToolCallState] = field(default_factory=dict)
7690
tool_result_chunks: Dict[str, List[str]] = field(default_factory=dict)
7791
run_errored: bool = False
@@ -121,6 +135,43 @@ def cache_tool_result_chunk(self, tool_id: str, delta: str) -> None:
121135
def pop_tool_result_chunks(self, tool_id: str) -> str:
122136
return "".join(self.tool_result_chunks.pop(tool_id, []))
123137

138+
def ensure_reasoning_started(self) -> Iterator[str]:
139+
if not self.reasoning.started:
140+
yield _encode_reasoning_event(
141+
"REASONING_START",
142+
messageId=self.reasoning.phase_id,
143+
)
144+
self.reasoning.started = True
145+
if not self.reasoning.message_started:
146+
yield _encode_reasoning_event(
147+
"REASONING_MESSAGE_START",
148+
messageId=self.reasoning.message_id,
149+
role="reasoning",
150+
)
151+
self.reasoning.message_started = True
152+
153+
def end_reasoning_if_open(self) -> Iterator[str]:
154+
if self.reasoning.message_started:
155+
yield _encode_reasoning_event(
156+
"REASONING_MESSAGE_END",
157+
messageId=self.reasoning.message_id,
158+
)
159+
self.reasoning.message_started = False
160+
if self.reasoning.started:
161+
yield _encode_reasoning_event(
162+
"REASONING_END",
163+
messageId=self.reasoning.phase_id,
164+
)
165+
self.reasoning = ReasoningState()
166+
167+
168+
def _encode_reasoning_event(event_type: str, **payload: Any) -> str:
169+
return (
170+
"data: "
171+
+ json.dumps({"type": event_type, **payload}, ensure_ascii=False)
172+
+ "\n\n"
173+
)
174+
124175

125176
class AGUIProtocolHandler(BaseProtocolHandler):
126177
"""AG-UI 协议处理器
@@ -376,6 +427,10 @@ async def _format_stream(
376427
if state.run_errored:
377428
return
378429

430+
# 结束未结束的 reasoning 消息
431+
for sse_data in state.end_reasoning_if_open():
432+
yield sse_data
433+
379434
# 结束所有未结束的工具调用
380435
for sse_data in state.end_all_tools(self._encoder):
381436
yield sse_data
@@ -399,8 +454,6 @@ def _process_event_with_boundaries(
399454
state: StreamStateMachine,
400455
) -> Iterator[str]:
401456
"""处理事件并注入边界事件"""
402-
import json
403-
404457
from ag_ui.core import CustomEvent as AguiCustomEvent
405458
from ag_ui.core import (
406459
RunErrorEvent,
@@ -413,6 +466,8 @@ def _process_event_with_boundaries(
413466
ToolCallStartEvent,
414467
)
415468

469+
thinking_enabled = is_thinking_enabled_from_env()
470+
416471
# RAW 事件直接透传
417472
if event.event == EventType.RAW:
418473
raw_data = event.data.get("raw", "")
@@ -422,9 +477,46 @@ def _process_event_with_boundaries(
422477
yield raw_data
423478
return
424479

480+
if event.event == EventType.REASONING:
481+
if thinking_enabled:
482+
reasoning_content = (
483+
event.data.get("delta")
484+
or get_reasoning_content(event.data)
485+
or ""
486+
)
487+
if reasoning_content:
488+
for sse_data in state.end_text_if_open(self._encoder):
489+
yield sse_data
490+
for sse_data in state.end_all_tools(self._encoder):
491+
yield sse_data
492+
for sse_data in state.ensure_reasoning_started():
493+
yield sse_data
494+
yield _encode_reasoning_event(
495+
"REASONING_MESSAGE_CONTENT",
496+
messageId=state.reasoning.message_id,
497+
delta=reasoning_content,
498+
)
499+
return
500+
425501
# TEXT 事件:在首个 TEXT 前注入 TEXT_MESSAGE_START
426502
# AG-UI 协议要求:发送 TEXT_MESSAGE_START 前必须先结束所有未结束的 TOOL_CALL
427503
if event.event == EventType.TEXT:
504+
addition = self._strip_reasoning_from_addition(
505+
event.addition, thinking_enabled
506+
)
507+
addition_reasoning = get_reasoning_content(event.addition or {})
508+
if thinking_enabled and addition_reasoning:
509+
for sse_data in state.ensure_reasoning_started():
510+
yield sse_data
511+
yield _encode_reasoning_event(
512+
"REASONING_MESSAGE_CONTENT",
513+
messageId=state.reasoning.message_id,
514+
delta=addition_reasoning,
515+
)
516+
517+
for sse_data in state.end_reasoning_if_open():
518+
yield sse_data
519+
428520
for sse_data in state.end_all_tools(self._encoder):
429521
yield sse_data
430522

@@ -435,13 +527,13 @@ def _process_event_with_boundaries(
435527
message_id=state.text.message_id,
436528
delta=event.data.get("delta", ""),
437529
)
438-
if event.addition:
530+
if addition:
439531
event_dict = agui_event.model_dump(
440532
by_alias=True, exclude_none=True
441533
)
442534
event_dict = self._apply_addition(
443535
event_dict,
444-
event.addition,
536+
addition,
445537
event.addition_merge_options,
446538
)
447539
json_str = json.dumps(event_dict, ensure_ascii=False)
@@ -455,6 +547,9 @@ def _process_event_with_boundaries(
455547
tool_id = event.data.get("id", "")
456548
tool_name = event.data.get("name", "")
457549

550+
for sse_data in state.end_reasoning_if_open():
551+
yield sse_data
552+
458553
for sse_data in state.end_text_if_open(self._encoder):
459554
yield sse_data
460555

@@ -491,6 +586,9 @@ def _process_event_with_boundaries(
491586
tool_name = event.data.get("name", "")
492587
tool_args = event.data.get("args", "")
493588

589+
for sse_data in state.end_reasoning_if_open():
590+
yield sse_data
591+
494592
for sse_data in state.end_text_if_open(self._encoder):
495593
yield sse_data
496594

@@ -541,6 +639,9 @@ def _process_event_with_boundaries(
541639
timeout = event.data.get("timeout")
542640
schema = event.data.get("schema")
543641

642+
for sse_data in state.end_reasoning_if_open():
643+
yield sse_data
644+
544645
for sse_data in state.end_text_if_open(self._encoder):
545646
yield sse_data
546647

@@ -601,6 +702,9 @@ def _process_event_with_boundaries(
601702
tool_id = event.data.get("id", "")
602703
tool_name = event.data.get("name", "")
603704

705+
for sse_data in state.end_reasoning_if_open():
706+
yield sse_data
707+
604708
for sse_data in state.end_text_if_open(self._encoder):
605709
yield sse_data
606710

@@ -767,6 +871,29 @@ def _apply_addition(
767871

768872
return merge(event_data, addition, **(merge_options or {}))
769873

874+
def _strip_reasoning_from_addition(
875+
self,
876+
addition: Optional[Dict[str, Any]],
877+
thinking_enabled: bool,
878+
) -> Optional[Dict[str, Any]]:
879+
if not addition:
880+
return addition
881+
882+
stripped = dict(addition)
883+
stripped.pop("reasoning_content", None)
884+
additional_kwargs = stripped.get("additional_kwargs")
885+
if isinstance(additional_kwargs, dict):
886+
additional_kwargs = dict(additional_kwargs)
887+
additional_kwargs.pop("reasoning_content", None)
888+
if additional_kwargs:
889+
stripped["additional_kwargs"] = additional_kwargs
890+
else:
891+
stripped.pop("additional_kwargs", None)
892+
893+
if not thinking_enabled:
894+
return stripped
895+
return stripped or None
896+
770897
async def _error_stream(self, message: str) -> AsyncIterator[str]:
771898
"""生成错误事件流
772899

agentrun/server/invoker.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
InvokeAgentHandler,
3030
SyncInvokeAgentHandler,
3131
)
32+
from agentrun.utils.reasoning import get_reasoning_content
3233

3334

3435
class AgentInvoker:
@@ -124,6 +125,9 @@ async def invoke_stream(
124125
# 处理用户返回的事件
125126
for processed_event in self._process_user_event(item):
126127
yield processed_event
128+
else:
129+
for processed_event in self._wrap_model_chunk(item):
130+
yield processed_event
127131
else:
128132
# 非流式结果
129133
results = self._wrap_non_stream(raw_result)
@@ -238,6 +242,11 @@ def _wrap_non_stream(self, result: Any) -> List[AgentEvent]:
238242
data={"delta": item},
239243
)
240244
)
245+
else:
246+
results.extend(self._wrap_model_chunk(item))
247+
248+
else:
249+
results.extend(self._wrap_model_chunk(result))
241250

242251
return results
243252

@@ -267,6 +276,9 @@ async def _wrap_stream(
267276
elif isinstance(item, AgentEvent):
268277
for processed_event in self._process_user_event(item):
269278
yield processed_event
279+
else:
280+
for processed_event in self._wrap_model_chunk(item):
281+
yield processed_event
270282

271283
async def _iterate_async(
272284
self, content: Union[Iterator[Any], AsyncIterator[Any]]
@@ -307,3 +319,31 @@ def _is_iterator(self, obj: Any) -> bool:
307319
if isinstance(obj, (str, bytes, dict, list, AgentEvent)):
308320
return False
309321
return hasattr(obj, "__iter__") or hasattr(obj, "__aiter__")
322+
323+
def _wrap_model_chunk(self, item: Any) -> List[AgentEvent]:
324+
"""Convert common model chunks into AgentEvent objects."""
325+
events: List[AgentEvent] = []
326+
reasoning_content = get_reasoning_content(item)
327+
if reasoning_content:
328+
events.append(
329+
AgentEvent(
330+
event=EventType.REASONING,
331+
data={"delta": reasoning_content},
332+
)
333+
)
334+
335+
content = self._read_attr_or_key(item, "content")
336+
if isinstance(content, str) and content:
337+
events.append(
338+
AgentEvent(
339+
event=EventType.TEXT,
340+
data={"delta": content},
341+
)
342+
)
343+
344+
return events
345+
346+
def _read_attr_or_key(self, obj: Any, key: str) -> Any:
347+
if isinstance(obj, dict):
348+
return obj.get(key)
349+
return getattr(obj, key, None)

agentrun/server/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
Iterator,
1515
List,
1616
Optional,
17-
TYPE_CHECKING,
1817
Union,
1918
)
2019

@@ -91,6 +90,7 @@ class Message(BaseModel):
9190
id: Optional[str] = None
9291
role: MessageRole
9392
content: Optional[Union[str, List[Dict[str, Any]]]] = None
93+
reasoning_content: Optional[str] = None
9494
name: Optional[str] = None
9595
tool_calls: Optional[List[ToolCall]] = None
9696
tool_call_id: Optional[str] = None
@@ -125,6 +125,7 @@ class EventType(str, Enum):
125125
# 核心事件(用户主要使用)
126126
# =========================================================================
127127
TEXT = "TEXT" # 文本内容块
128+
REASONING = "REASONING" # 模型思考内容块
128129
TOOL_CALL = "TOOL_CALL" # 完整工具调用(含 id, name, args)
129130
TOOL_CALL_CHUNK = "TOOL_CALL_CHUNK" # 工具调用参数片段(流式场景)
130131
TOOL_RESULT = "TOOL_RESULT" # 工具执行结果(最终结果,标识流式输出结束)

0 commit comments

Comments
 (0)