1010 ReasoningParserFactory )
1111from ..llmapi .tokenizer import TransformersTokenizer
1212# yapf: disable
13+ from .chat_utils import make_tool_call_id
1314from .harmony_adapter import (handle_non_streaming_response ,
1415 handle_streaming_response )
1516from .openai_protocol import (ChatCompletionLogProbs ,
2324 CompletionRequest , CompletionResponse ,
2425 CompletionResponseChoice ,
2526 CompletionResponseStreamChoice ,
26- CompletionStreamResponse , DeltaMessage ,
27- FunctionCall , PromptTokensDetails , StreamOptions ,
28- ToolCall , UsageInfo , to_disaggregated_params )
27+ CompletionStreamResponse , DeltaFunctionCall ,
28+ DeltaMessage , DeltaToolCall , FunctionCall ,
29+ PromptTokensDetails , StreamOptions , ToolCall ,
30+ UsageInfo , to_disaggregated_params )
31+ from .tool_parser .base_tool_parser import BaseToolParser
32+ from .tool_parser .core_types import ToolCallItem
33+ from .tool_parser .tool_parser_factory import ToolParserFactory
2934
3035# yapf: enable
3136
3237
3338@dataclass (kw_only = True )
3439class ChatPostprocArgs (PostprocArgs ):
3540 echo : bool = False
36- role : str = None
37- model : str = None
41+ role : str
42+ model : str
3843 num_choices : int = 1
3944 tools : Optional [List [ChatCompletionToolsParam ]] = None
4045 tool_choice : Optional [Union [Literal ["none" ],
@@ -44,8 +49,11 @@ class ChatPostprocArgs(PostprocArgs):
4449 stream_options : Optional [StreamOptions ] = None
4550 last_message_content : Optional [str ] = None
4651 reasoning_parser : Optional [str ] = None
52+ tool_parser : Optional [str ] = None
4753 reasoning_parser_dict : dict [int , BaseReasoningParser ] = field (
4854 default_factory = dict )
55+ tool_parser_dict : dict [int , BaseToolParser ] = field (default_factory = dict )
56+ has_tool_call : dict [int , bool ] = field (default_factory = dict )
4957
5058 @classmethod
5159 def from_request (cls , request : ChatCompletionRequest ):
@@ -116,6 +124,31 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
116124 return content , reasoning_content
117125
118126
127+ def apply_tool_parser (args : ChatPostprocArgs , output_index : int , text : str ,
128+ streaming : bool ) -> Tuple [str , List [ToolCallItem ]]:
129+ tool_parser = None
130+ tools = args .tools
131+ if args .tool_parser is not None and tools is not None :
132+ if output_index not in args .tool_parser_dict :
133+ args .tool_parser_dict [
134+ output_index ] = ToolParserFactory .create_tool_parser (
135+ args .tool_parser )
136+ tool_parser = args .tool_parser_dict [output_index ]
137+
138+ if tool_parser is not None and tools is not None :
139+ if not streaming :
140+ result = tool_parser .detect_and_parse (text , tools )
141+ else :
142+ result = tool_parser .parse_streaming_increment (text , tools )
143+ normal_text , calls = result .normal_text , result .calls
144+ if result .calls :
145+ args .has_tool_call [output_index ] = True
146+ else :
147+ normal_text , calls = text , []
148+
149+ return normal_text , calls
150+
151+
119152@nvtx_range_debug ("chat_stream_post_processor" )
120153def chat_stream_post_processor (rsp : GenerationResultBase ,
121154 args : ChatPostprocArgs ) -> List [str ]:
@@ -176,27 +209,63 @@ def yield_first_chat(num_tokens: int,
176209 if args .tool_choice and type (
177210 args .tool_choice ) is ChatCompletionNamedToolChoiceParam :
178211 delta_message = DeltaMessage (tool_calls = [
179- ToolCall (function = FunctionCall (
180- name = args .tool_choice .function .name , arguments = delta_text ))
181- ])
212+ DeltaToolCall (
213+ function = DeltaFunctionCall (
214+ name = args .tool_choice .function .name ,
215+ arguments = delta_text ),
216+ index = i ,
217+ ),
218+ ], )
182219 else :
183- delta_message = DeltaMessage (content = delta_text ,
184- reasoning_content = reasoning_delta_text )
220+ delta_text , calls = apply_tool_parser (args , i , delta_text , True )
221+ tool_calls = []
222+ for call_item in calls :
223+ # Tool call ID should be generated only once per tool call
224+ if call_item .name :
225+ # First chunk: include ID and function name
226+ tool_call_id = make_tool_call_id ()
227+ function_name = call_item .name
228+ else :
229+ # Subsequent chunks: null ID and name for argument deltas
230+ tool_call_id = None
231+ function_name = None
232+
233+ tool_calls .append (
234+ DeltaToolCall (
235+ id = tool_call_id ,
236+ index = call_item .tool_index ,
237+ function = DeltaFunctionCall (
238+ name = function_name ,
239+ arguments = call_item .parameters ,
240+ ),
241+ ))
242+ if tool_calls or delta_text or reasoning_delta_text or output .finish_reason :
243+ delta_message = DeltaMessage (
244+ content = delta_text ,
245+ reasoning_content = reasoning_delta_text ,
246+ tool_calls = tool_calls if tool_calls else None )
247+ else :
248+ continue
185249
186250 choice = ChatCompletionResponseStreamChoice (
187251 index = i ,
188252 delta = delta_message ,
189- finish_reason = None ,
190253 avg_decoded_tokens_per_iter = getattr (rsp ,
191254 'avg_decoded_tokens_per_iter' ,
192- None ))
255+ None ),
256+ stop_reason = output .stop_reason ,
257+ )
193258 if args .return_logprobs :
194259 logprobs = output .logprobs_diff
195260 token_ids = output .token_ids_diff
196261 choice .logprobs = create_logprobs (token_ids , args .tokenizer ,
197262 logprobs , args .top_logprobs )
198263 if output .finish_reason is not None :
199- choice .finish_reason = output .finish_reason
264+ if output .finish_reason == "stop" and args .has_tool_call .get (
265+ i , False ):
266+ choice .finish_reason = "tool_calls"
267+ else :
268+ choice .finish_reason = output .finish_reason
200269 choice .stop_reason = output .stop_reason
201270 finish_reason_sent [i ] = True
202271 chunk = ChatCompletionStreamResponse (choices = [choice ], model = args .model )
@@ -247,21 +316,34 @@ def chat_response_post_processor(
247316 name = args .tool_choice .function .name , arguments = text ))
248317 ])
249318 else :
319+ if text is None :
320+ text = ""
321+ text , calls = apply_tool_parser (args , output .index , text , False )
322+ tool_calls = [
323+ ToolCall (function = FunctionCall (name = call .name or "" ,
324+ arguments = call .parameters ))
325+ for call in calls
326+ ]
250327 message = ChatMessage (role = role ,
251328 content = text ,
252- reasoning_content = reasoning_text )
329+ reasoning_content = reasoning_text ,
330+ tool_calls = tool_calls )
253331 disaggregated_params = to_disaggregated_params (
254332 output .disaggregated_params )
255333 choice = ChatCompletionResponseChoice (
256334 index = output .index ,
257335 message = message ,
258- finish_reason = output .finish_reason ,
259336 stop_reason = output .stop_reason ,
260337 disaggregated_params = disaggregated_params ,
261338 avg_decoded_tokens_per_iter = getattr (rsp ,
262339 'avg_decoded_tokens_per_iter' ,
263340 None ),
264341 )
342+ if output .finish_reason == "stop" and args .has_tool_call .get (
343+ output .index , False ):
344+ choice .finish_reason = "tool_calls"
345+ else :
346+ choice .finish_reason = output .finish_reason
265347
266348 if args .return_logprobs :
267349 choice .logprobs = create_logprobs (output .token_ids , args .tokenizer ,
0 commit comments