Skip to content

Commit 2aade46

Browse files
authored
[TRTLLM-8214][feat] Support Qwen3 tool parser (#8216)
Signed-off-by: Pengyun Lin <[email protected]>
1 parent 7411839 commit 2aade46

File tree

16 files changed

+1405
-19
lines changed

16 files changed

+1405
-19
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,4 @@ nvidia-cutlass-dsl==4.2.1; python_version >= "3.10"
7878
numba-cuda>=0.19.0 # WAR for nvbugs/5501820
7979
plotly
8080
numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing
81+
partial_json_parser

tensorrt_llm/commands/serve.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory
3434
from tensorrt_llm.logger import logger, severity_map
3535
from tensorrt_llm.serve import OpenAIDisaggServer, OpenAIServer
36+
from tensorrt_llm.serve.tool_parser import ToolParserFactory
3637

3738
# Global variable to store the Popen object of the child process
3839
_child_p_global: Optional[subprocess.Popen] = None
@@ -150,6 +151,7 @@ def launch_server(
150151
host: str,
151152
port: int,
152153
llm_args: dict,
154+
tool_parser: Optional[str] = None,
153155
metadata_server_cfg: Optional[MetadataServerConfig] = None,
154156
server_role: Optional[ServerRole] = None,
155157
disagg_cluster_config: Optional[DisaggClusterConfig] = None,
@@ -173,6 +175,7 @@ def launch_server(
173175

174176
server = OpenAIServer(llm=llm,
175177
model=model,
178+
tool_parser=tool_parser,
176179
server_role=server_role,
177180
metadata_server_cfg=metadata_server_cfg,
178181
disagg_cluster_config=disagg_cluster_config,
@@ -311,6 +314,12 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
311314
default=None,
312315
help="[Experimental] Specify the parser for reasoning models.",
313316
)
317+
@click.option(
318+
"--tool_parser",
319+
type=click.Choice(ToolParserFactory.parsers.keys()),
320+
default=None,
321+
help="[Experimental] Specify the parser for tool models.",
322+
)
314323
@click.option("--metadata_server_config_file",
315324
type=str,
316325
default=None,
@@ -352,7 +361,8 @@ def serve(
352361
gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float,
353362
num_postprocess_workers: int, trust_remote_code: bool,
354363
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
355-
metadata_server_config_file: Optional[str], server_role: Optional[str],
364+
tool_parser: Optional[str], metadata_server_config_file: Optional[str],
365+
server_role: Optional[str],
356366
fail_fast_on_attention_window_too_large: bool,
357367
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
358368
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str]):
@@ -423,8 +433,8 @@ def serve(
423433

424434
multimodal_server_config = MultimodalServerConfig(
425435
media_io_kwargs=parsed_media_io_kwargs)
426-
launch_server(host, port, llm_args, metadata_server_cfg, server_role,
427-
disagg_cluster_config, multimodal_server_config)
436+
launch_server(host, port, llm_args, tool_parser, metadata_server_cfg,
437+
server_role, disagg_cluster_config, multimodal_server_config)
428438

429439

430440
@click.command("mm_embedding_serve")

tensorrt_llm/serve/chat_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import uuid
12
from functools import partial
23
from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal,
34
Optional, Tuple, TypeAlias, TypedDict, Union, cast)
@@ -220,3 +221,11 @@ def check_multiple_response(n: int, backend: Optional[str]):
220221
if n > 1 and backend == "pytorch":
221222
raise ValueError(
222223
"Multiple response is not supported in PyTorch workflow")
224+
225+
226+
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
227+
if id_type == "kimi_k2":
228+
return f"functions.{func_name}:{idx}"
229+
else:
230+
# by default return random
231+
return f"chatcmpl-tool-{uuid.uuid4().hex}"

tensorrt_llm/serve/openai_server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,14 @@ class OpenAIServer:
7878
def __init__(self,
7979
llm: Union[LLM, MultimodalEncoder],
8080
model: str,
81+
tool_parser: Optional[str],
8182
server_role: Optional[ServerRole],
8283
metadata_server_cfg: MetadataServerConfig,
8384
disagg_cluster_config: Optional[DisaggClusterConfig] = None,
8485
multimodal_server_config: Optional[MultimodalServerConfig] = None):
8586
self.llm = llm
8687
self.tokenizer = llm.tokenizer
88+
self.tool_parser = tool_parser
8789
self.metadata_server = create_metadata_server(metadata_server_cfg)
8890
self.disagg_cluster_config = disagg_cluster_config
8991
self.multimodal_server_config = multimodal_server_config
@@ -532,6 +534,7 @@ async def create_chat_response(
532534
prompt["multi_modal_data"] = mm_data
533535

534536
postproc_args.reasoning_parser = self.llm.args.reasoning_parser
537+
postproc_args.tool_parser = self.tool_parser
535538
if conversation and conversation[-1].get(
536539
"content") and conversation[-1].get("role") == get_role():
537540
postproc_args.last_message_content = conversation[-1]["content"]

tensorrt_llm/serve/postprocess_handlers.py

Lines changed: 97 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ReasoningParserFactory)
1111
from ..llmapi.tokenizer import TransformersTokenizer
1212
# yapf: disable
13+
from .chat_utils import make_tool_call_id
1314
from .harmony_adapter import (handle_non_streaming_response,
1415
handle_streaming_response)
1516
from .openai_protocol import (ChatCompletionLogProbs,
@@ -23,18 +24,22 @@
2324
CompletionRequest, CompletionResponse,
2425
CompletionResponseChoice,
2526
CompletionResponseStreamChoice,
26-
CompletionStreamResponse, DeltaMessage,
27-
FunctionCall, PromptTokensDetails, StreamOptions,
28-
ToolCall, UsageInfo, to_disaggregated_params)
27+
CompletionStreamResponse, DeltaFunctionCall,
28+
DeltaMessage, DeltaToolCall, FunctionCall,
29+
PromptTokensDetails, StreamOptions, ToolCall,
30+
UsageInfo, to_disaggregated_params)
31+
from .tool_parser.base_tool_parser import BaseToolParser
32+
from .tool_parser.core_types import ToolCallItem
33+
from .tool_parser.tool_parser_factory import ToolParserFactory
2934

3035
# yapf: enable
3136

3237

3338
@dataclass(kw_only=True)
3439
class ChatPostprocArgs(PostprocArgs):
3540
echo: bool = False
36-
role: str = None
37-
model: str = None
41+
role: str
42+
model: str
3843
num_choices: int = 1
3944
tools: Optional[List[ChatCompletionToolsParam]] = None
4045
tool_choice: Optional[Union[Literal["none"],
@@ -44,8 +49,11 @@ class ChatPostprocArgs(PostprocArgs):
4449
stream_options: Optional[StreamOptions] = None
4550
last_message_content: Optional[str] = None
4651
reasoning_parser: Optional[str] = None
52+
tool_parser: Optional[str] = None
4753
reasoning_parser_dict: dict[int, BaseReasoningParser] = field(
4854
default_factory=dict)
55+
tool_parser_dict: dict[int, BaseToolParser] = field(default_factory=dict)
56+
has_tool_call: dict[int, bool] = field(default_factory=dict)
4957

5058
@classmethod
5159
def from_request(cls, request: ChatCompletionRequest):
@@ -116,6 +124,31 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
116124
return content, reasoning_content
117125

118126

127+
def apply_tool_parser(args: ChatPostprocArgs, output_index: int, text: str,
128+
streaming: bool) -> Tuple[str, List[ToolCallItem]]:
129+
tool_parser = None
130+
tools = args.tools
131+
if args.tool_parser is not None and tools is not None:
132+
if output_index not in args.tool_parser_dict:
133+
args.tool_parser_dict[
134+
output_index] = ToolParserFactory.create_tool_parser(
135+
args.tool_parser)
136+
tool_parser = args.tool_parser_dict[output_index]
137+
138+
if tool_parser is not None and tools is not None:
139+
if not streaming:
140+
result = tool_parser.detect_and_parse(text, tools)
141+
else:
142+
result = tool_parser.parse_streaming_increment(text, tools)
143+
normal_text, calls = result.normal_text, result.calls
144+
if result.calls:
145+
args.has_tool_call[output_index] = True
146+
else:
147+
normal_text, calls = text, []
148+
149+
return normal_text, calls
150+
151+
119152
@nvtx_range_debug("chat_stream_post_processor")
120153
def chat_stream_post_processor(rsp: GenerationResultBase,
121154
args: ChatPostprocArgs) -> List[str]:
@@ -176,27 +209,63 @@ def yield_first_chat(num_tokens: int,
176209
if args.tool_choice and type(
177210
args.tool_choice) is ChatCompletionNamedToolChoiceParam:
178211
delta_message = DeltaMessage(tool_calls=[
179-
ToolCall(function=FunctionCall(
180-
name=args.tool_choice.function.name, arguments=delta_text))
181-
])
212+
DeltaToolCall(
213+
function=DeltaFunctionCall(
214+
name=args.tool_choice.function.name,
215+
arguments=delta_text),
216+
index=i,
217+
),
218+
], )
182219
else:
183-
delta_message = DeltaMessage(content=delta_text,
184-
reasoning_content=reasoning_delta_text)
220+
delta_text, calls = apply_tool_parser(args, i, delta_text, True)
221+
tool_calls = []
222+
for call_item in calls:
223+
# Tool call ID should be generated only once per tool call
224+
if call_item.name:
225+
# First chunk: include ID and function name
226+
tool_call_id = make_tool_call_id()
227+
function_name = call_item.name
228+
else:
229+
# Subsequent chunks: null ID and name for argument deltas
230+
tool_call_id = None
231+
function_name = None
232+
233+
tool_calls.append(
234+
DeltaToolCall(
235+
id=tool_call_id,
236+
index=call_item.tool_index,
237+
function=DeltaFunctionCall(
238+
name=function_name,
239+
arguments=call_item.parameters,
240+
),
241+
))
242+
if tool_calls or delta_text or reasoning_delta_text or output.finish_reason:
243+
delta_message = DeltaMessage(
244+
content=delta_text,
245+
reasoning_content=reasoning_delta_text,
246+
tool_calls=tool_calls if tool_calls else None)
247+
else:
248+
continue
185249

186250
choice = ChatCompletionResponseStreamChoice(
187251
index=i,
188252
delta=delta_message,
189-
finish_reason=None,
190253
avg_decoded_tokens_per_iter=getattr(rsp,
191254
'avg_decoded_tokens_per_iter',
192-
None))
255+
None),
256+
stop_reason=output.stop_reason,
257+
)
193258
if args.return_logprobs:
194259
logprobs = output.logprobs_diff
195260
token_ids = output.token_ids_diff
196261
choice.logprobs = create_logprobs(token_ids, args.tokenizer,
197262
logprobs, args.top_logprobs)
198263
if output.finish_reason is not None:
199-
choice.finish_reason = output.finish_reason
264+
if output.finish_reason == "stop" and args.has_tool_call.get(
265+
i, False):
266+
choice.finish_reason = "tool_calls"
267+
else:
268+
choice.finish_reason = output.finish_reason
200269
choice.stop_reason = output.stop_reason
201270
finish_reason_sent[i] = True
202271
chunk = ChatCompletionStreamResponse(choices=[choice], model=args.model)
@@ -247,21 +316,34 @@ def chat_response_post_processor(
247316
name=args.tool_choice.function.name, arguments=text))
248317
])
249318
else:
319+
if text is None:
320+
text = ""
321+
text, calls = apply_tool_parser(args, output.index, text, False)
322+
tool_calls = [
323+
ToolCall(function=FunctionCall(name=call.name or "",
324+
arguments=call.parameters))
325+
for call in calls
326+
]
250327
message = ChatMessage(role=role,
251328
content=text,
252-
reasoning_content=reasoning_text)
329+
reasoning_content=reasoning_text,
330+
tool_calls=tool_calls)
253331
disaggregated_params = to_disaggregated_params(
254332
output.disaggregated_params)
255333
choice = ChatCompletionResponseChoice(
256334
index=output.index,
257335
message=message,
258-
finish_reason=output.finish_reason,
259336
stop_reason=output.stop_reason,
260337
disaggregated_params=disaggregated_params,
261338
avg_decoded_tokens_per_iter=getattr(rsp,
262339
'avg_decoded_tokens_per_iter',
263340
None),
264341
)
342+
if output.finish_reason == "stop" and args.has_tool_call.get(
343+
output.index, False):
344+
choice.finish_reason = "tool_calls"
345+
else:
346+
choice.finish_reason = output.finish_reason
265347

266348
if args.return_logprobs:
267349
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer,
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .tool_parser_factory import ToolParserFactory
2+
3+
__all__ = ["ToolParserFactory"]

0 commit comments

Comments
 (0)