|
| 1 | +"""LangGraph/LangChain Agent 事件转换器 |
| 2 | +
|
| 3 | +将 LangGraph/LangChain astream_events 的单个事件转换为 AgentRun 事件。 |
| 4 | +
|
| 5 | +支持两种事件格式: |
| 6 | +1. on_chat_model_stream - LangGraph create_react_agent 的流式输出 |
| 7 | +2. on_chain_stream - LangChain create_agent 的输出 |
| 8 | +
|
| 9 | +Example: |
| 10 | + >>> async def invoke_agent(request: AgentRequest): |
| 11 | + ... async for event in agent.astream_events(input_data, version="v2"): |
| 12 | + ... for item in convert(event, request.hooks): |
| 13 | + ... yield item |
| 14 | +""" |
| 15 | + |
| 16 | +import json |
| 17 | +from typing import Any, Dict, Generator, List, Optional, Union |
| 18 | + |
| 19 | +from agentrun.server.model import AgentEvent, AgentLifecycleHooks |
| 20 | + |
| 21 | + |
| 22 | +def convert( |
| 23 | + event: Dict[str, Any], |
| 24 | + hooks: Optional[AgentLifecycleHooks] = None, |
| 25 | +) -> Generator[Union[AgentEvent, str], None, None]: |
| 26 | + """转换单个 astream_events 事件 |
| 27 | +
|
| 28 | + Args: |
| 29 | + event: LangGraph/LangChain astream_events 的单个事件 |
| 30 | + hooks: AgentLifecycleHooks,用于创建工具调用事件 |
| 31 | +
|
| 32 | + Yields: |
| 33 | + str (文本内容) 或 AgentEvent (工具调用事件) |
| 34 | + """ |
| 35 | + event_type = event.get("event", "") |
| 36 | + data = event.get("data", {}) |
| 37 | + |
| 38 | + # 1. LangGraph 格式: on_chat_model_stream |
| 39 | + if event_type == "on_chat_model_stream": |
| 40 | + chunk = data.get("chunk") |
| 41 | + if chunk: |
| 42 | + content = _get_content(chunk) |
| 43 | + if content: |
| 44 | + yield content |
| 45 | + |
| 46 | + # 流式工具调用参数 |
| 47 | + if hooks: |
| 48 | + for tc in _get_tool_chunks(chunk): |
| 49 | + tc_id = tc.get("id") or str(tc.get("index", "")) |
| 50 | + if tc.get("name") and tc_id: |
| 51 | + yield hooks.on_tool_call_start( |
| 52 | + id=tc_id, name=tc["name"] |
| 53 | + ) |
| 54 | + if tc.get("args") and tc_id: |
| 55 | + yield hooks.on_tool_call_args_delta( |
| 56 | + id=tc_id, delta=tc["args"] |
| 57 | + ) |
| 58 | + |
| 59 | + # 2. LangChain 格式: on_chain_stream (来自 create_agent) |
| 60 | + # 只处理 name="model" 的事件,避免重复(LangGraph 会发送 name="model" 和 name="LangGraph" 两个相同内容的事件) |
| 61 | + elif event_type == "on_chain_stream" and event.get("name") == "model": |
| 62 | + chunk_data = data.get("chunk", {}) |
| 63 | + if isinstance(chunk_data, dict): |
| 64 | + # chunk 格式: {"messages": [AIMessage(...)]} |
| 65 | + messages = chunk_data.get("messages", []) |
| 66 | + |
| 67 | + for msg in messages: |
| 68 | + # 提取文本内容 |
| 69 | + content = _get_content(msg) |
| 70 | + if content: |
| 71 | + yield content |
| 72 | + |
| 73 | + # 提取工具调用 |
| 74 | + if hooks: |
| 75 | + tool_calls = _get_tool_calls(msg) |
| 76 | + for tc in tool_calls: |
| 77 | + tc_id = tc.get("id", "") |
| 78 | + tc_name = tc.get("name", "") |
| 79 | + tc_args = tc.get("args", {}) |
| 80 | + if tc_id and tc_name: |
| 81 | + yield hooks.on_tool_call_start( |
| 82 | + id=tc_id, name=tc_name |
| 83 | + ) |
| 84 | + if tc_args: |
| 85 | + yield hooks.on_tool_call_args( |
| 86 | + id=tc_id, args=_to_json(tc_args) |
| 87 | + ) |
| 88 | + |
| 89 | + # 3. 工具开始 (LangGraph) |
| 90 | + elif event_type == "on_tool_start" and hooks: |
| 91 | + run_id = event.get("run_id", "") |
| 92 | + tool_name = event.get("name", "") |
| 93 | + tool_input = data.get("input", {}) |
| 94 | + |
| 95 | + if run_id: |
| 96 | + yield hooks.on_tool_call_start(id=run_id, name=tool_name) |
| 97 | + if tool_input: |
| 98 | + yield hooks.on_tool_call_args( |
| 99 | + id=run_id, args=_to_json(tool_input) |
| 100 | + ) |
| 101 | + |
| 102 | + # 4. 工具结束 (LangGraph) |
| 103 | + elif event_type == "on_tool_end" and hooks: |
| 104 | + run_id = event.get("run_id", "") |
| 105 | + output = data.get("output", "") |
| 106 | + |
| 107 | + if run_id: |
| 108 | + yield hooks.on_tool_call_result( |
| 109 | + id=run_id, result=str(output) if output else "" |
| 110 | + ) |
| 111 | + yield hooks.on_tool_call_end(id=run_id) |
| 112 | + |
| 113 | + |
| 114 | +def _get_content(obj: Any) -> Optional[str]: |
| 115 | + """提取文本内容""" |
| 116 | + if obj is None: |
| 117 | + return None |
| 118 | + |
| 119 | + # 字符串 |
| 120 | + if isinstance(obj, str): |
| 121 | + return obj if obj else None |
| 122 | + |
| 123 | + # 有 content 属性的对象 (AIMessage, AIMessageChunk, etc.) |
| 124 | + if hasattr(obj, "content"): |
| 125 | + c = obj.content |
| 126 | + if isinstance(c, str) and c: |
| 127 | + return c |
| 128 | + if isinstance(c, list): |
| 129 | + parts = [] |
| 130 | + for item in c: |
| 131 | + if isinstance(item, str): |
| 132 | + parts.append(item) |
| 133 | + elif isinstance(item, dict): |
| 134 | + parts.append(item.get("text", "")) |
| 135 | + return "".join(parts) or None |
| 136 | + |
| 137 | + return None |
| 138 | + |
| 139 | + |
| 140 | +def _get_tool_chunks(chunk: Any) -> List[Dict[str, Any]]: |
| 141 | + """提取工具调用增量 (AIMessageChunk.tool_call_chunks)""" |
| 142 | + result: List[Dict[str, Any]] = [] |
| 143 | + if hasattr(chunk, "tool_call_chunks") and chunk.tool_call_chunks: |
| 144 | + for tc in chunk.tool_call_chunks: |
| 145 | + if isinstance(tc, dict): |
| 146 | + result.append(tc) |
| 147 | + else: |
| 148 | + result.append({ |
| 149 | + "id": getattr(tc, "id", None), |
| 150 | + "name": getattr(tc, "name", None), |
| 151 | + "args": getattr(tc, "args", None), |
| 152 | + "index": getattr(tc, "index", None), |
| 153 | + }) |
| 154 | + return result |
| 155 | + |
| 156 | + |
| 157 | +def _get_tool_calls(msg: Any) -> List[Dict[str, Any]]: |
| 158 | + """提取完整工具调用 (AIMessage.tool_calls)""" |
| 159 | + result: List[Dict[str, Any]] = [] |
| 160 | + if hasattr(msg, "tool_calls") and msg.tool_calls: |
| 161 | + for tc in msg.tool_calls: |
| 162 | + if isinstance(tc, dict): |
| 163 | + result.append(tc) |
| 164 | + else: |
| 165 | + result.append({ |
| 166 | + "id": getattr(tc, "id", None), |
| 167 | + "name": getattr(tc, "name", None), |
| 168 | + "args": getattr(tc, "args", None), |
| 169 | + }) |
| 170 | + return result |
| 171 | + |
| 172 | + |
| 173 | +def _to_json(obj: Any) -> str: |
| 174 | + """转 JSON 字符串""" |
| 175 | + if isinstance(obj, str): |
| 176 | + return obj |
| 177 | + return json.dumps(obj, ensure_ascii=False) |
0 commit comments