diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 775ef628d2d..c9b2acff836 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -29,6 +29,8 @@ class CacheConfig: """A dataclass to hold information how to configure the cache.""" dtype: Optional[torch.dtype] = None + # mamba_dtype: Optional[torch.dtype] = None + mamba_dtype: Optional[torch.dtype] = torch.float32 class SequenceInfo: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py index ccd24e7ec00..0908e7c9fb1 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py @@ -325,7 +325,7 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: - # Returns (seq_len, seq_start, slot_idx) + # Returns (seq_len, seq_start, slot_idx, use_initial_states) return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4 @classmethod @@ -339,6 +339,9 @@ def get_cache_initializers( num_heads = hs_fake.shape[-2] head_dim = hs_fake.shape[-1] + # dtype from node itself + dtype = source_attn_node.meta["val"].dtype + # Infer state size by assuming B has shape [b, s, n_groups * ssm_state_size] # During runtime we pass [b, s, n_groups, ssm_state_size]; both give the same last dim product. if B_fake.ndim >= 4: @@ -354,7 +357,7 @@ def _get_ssm_cache(si: SequenceInfo): head_dim, ssm_state_size, device=si.device, - dtype=cache_config.dtype or hs_fake.dtype, + dtype=cache_config.mamba_dtype or dtype, ) return {"ssm_state_cache": _get_ssm_cache} diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index 64b62419162..2f016267fd9 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -1,27 +1,14 @@ -from typing import List, Tuple +from typing import List import torch -from torch._ops import OpOverloadPacket -from torch.fx import Node # Triton kernels from tensorrt_llm._torch.modules.mamba.mamba2_metadata import cu_seqlens_to_chunk_indices_offsets from tensorrt_llm._torch.modules.mamba.selective_state_update import selective_state_update from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined -from ...utils.node_utils import extract_op_args -from ..attention_interface import ( - AttentionDescriptor, - AttentionLayout, - AttentionRegistry, - BufferInitializerDict, - CacheConfig, - CacheInitializerDict, - Constant, - MHACallable, - PrepareMetadataCallable, - SequenceInfo, -) +from ..attention_interface import AttentionRegistry, MHACallable +from .torch_backend_mamba import TorchBackendSSM @torch.library.custom_op("auto_deploy::triton_cached_ssm", mutates_args={}) @@ -51,6 +38,14 @@ def _triton_cached_ssm( - Prefill: run one varlen combined scan over concatenated prefill tokens and update final states per slot. - Decode: batch single-token updates with selective_state_update and update states per slot. """ + hidden_states = hidden_states.to(torch.float32) + A = A.to(torch.float32) + B = B.to(torch.float32) + C = C.to(torch.float32) + D = D.to(torch.float32) + dt = dt.to(torch.float32) + dt_bias = dt_bias.to(torch.float32) + b, s = hidden_states.shape[:2] num_seq = seq_len.shape[0] # Flatten tokens for indexing/scatter @@ -202,70 +197,7 @@ def _triton_cached_ssm_fake( @AttentionRegistry.register("triton_ssm") -class TritonBackendSSM(AttentionDescriptor): - @classmethod - def is_paged(cls) -> bool: - return True - - @classmethod - def get_attention_layout(cls) -> AttentionLayout: - # Hidden states follow [b, s, n, d] - return "bsnd" - - @classmethod - def get_num_qkv_args(cls) -> int: - # torch_ssm_transform signature has 7 node/state arguments - return 7 - - @classmethod - def get_source_attention_op(cls) -> OpOverloadPacket: - # Keep source op unchanged (used for uncached pre-export) - return torch.ops.auto_deploy.torch_ssm - +class TritonBackendSSM(TorchBackendSSM): @classmethod def get_cached_attention_op(cls) -> MHACallable: return torch.ops.auto_deploy.triton_cached_ssm - - @classmethod - def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: - # Returns (seq_len, seq_start, slot_idx, use_initial_states) - return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4 - - @classmethod - def get_cache_initializers( - cls, source_attn_node: Node, cache_config: CacheConfig - ) -> CacheInitializerDict: - # Shapes from fake tensors - hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"] - B_fake: torch.Tensor = source_attn_node.args[2].meta["val"] - - num_heads = hs_fake.shape[-2] - head_dim = hs_fake.shape[-1] - - if B_fake.ndim >= 4: - ssm_state_size = B_fake.shape[-1] - else: - ssm_state_size = max(1, B_fake.shape[-1]) - - def _get_ssm_cache(si: SequenceInfo): - return torch.empty( - si.max_batch_size, - num_heads, - head_dim, - ssm_state_size, - device=si.device, - dtype=cache_config.dtype or hs_fake.dtype, - ) - - return {"ssm_state_cache": _get_ssm_cache} - - @classmethod - def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict: - return {} - - @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: - time_step_limit, chunk_size = extract_op_args( - source_attn_node, "time_step_limit", "chunk_size" - ) - return [time_step_limit, chunk_size] diff --git a/tensorrt_llm/serve/chat_utils.py b/tensorrt_llm/serve/chat_utils.py index acda26b511c..82dfe185832 100644 --- a/tensorrt_llm/serve/chat_utils.py +++ b/tensorrt_llm/serve/chat_utils.py @@ -1,3 +1,4 @@ +import json import uuid from functools import partial from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal, @@ -185,6 +186,36 @@ def parse_chat_message_content( content, mm_data_tracker, ) + if role == "assistant": + result.update(**_parse_assistant_message_content(message)) + elif role == "tool": + result.update(**_parse_tool_message_content(message)) + return result + + +# Adapted from: https://github.com/vllm-project/vllm/blob/4574d48bab9c4e38b7c0a830eeefc8f0980e8c58/vllm/entrypoints/chat_utils.py#L1406 +def _parse_assistant_message_content(message: Dict[str, Any]) -> Dict[str, Any]: + result = {} + tool_calls = message.get("tool_calls") + if tool_calls is not None: + result["tool_calls"] = [] + for item in tool_calls: + if content := item["function"].get("arguments"): + if isinstance(content, str): + item["function"]["arguments"] = json.loads(content) + else: + item["function"]["arguments"] = content + else: + item["function"]["arguments"] = {} + result["tool_calls"].append(item) + + return result + + +def _parse_tool_message_content(message: Dict[str, Any]) -> Dict[str, Any]: + result = {} + if "tool_call_id" in message: + result["tool_call_id"] = message["tool_call_id"] return result diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index af8111d1f07..80695c7366f 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -396,6 +396,12 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): class CustomChatCompletionMessageParam(TypedDict, total=False): """Enables custom roles in the Chat Completion API.""" + + # This is so custom fields not in any of the `ChatCompletionMessageParam` defined by OpenAI + # are still allowed. + # Examples include: assistant messages with `reasoning` / `reasoning_content`. + __pydantic_config__ = ConfigDict(extra="allow") # type: ignore + role: Required[str] """The role of the message's author.""" diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index aad50e9c7d8..e735d229b7c 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -515,6 +515,10 @@ async def create_chat_response( chat_template=request.chat_template, chat_template_kwargs=request.chat_template_kwargs or {}, ) + logger.debug( + "Rendered chat template:\n" + f"{prompt!r}" + ) prompt = prompt_inputs(prompt) mm_data = await mm_coroutines diff --git a/tensorrt_llm/serve/tool_parser/qwen3_coder_parser.py b/tensorrt_llm/serve/tool_parser/qwen3_coder_parser.py new file mode 100644 index 00000000000..7011768c070 --- /dev/null +++ b/tensorrt_llm/serve/tool_parser/qwen3_coder_parser.py @@ -0,0 +1,346 @@ +# Adapted from: https://raw.githubusercontent.com/sgl-project/sglang/d8fcbaa38da95201914a1277971044ee66837b26/python/sglang/srt/function_call/qwen3_coder_detector.py + +import ast +import html +import json +import re +from typing import Any, Dict, List, Tuple + +from tensorrt_llm.logger import logger +from tensorrt_llm.serve.openai_protocol import ChatCompletionToolsParam as Tool +from tensorrt_llm.serve.tool_parser.base_tool_parser import BaseToolParser +from tensorrt_llm.serve.tool_parser.core_types import (StreamingParseResult, + ToolCallItem, + _GetInfoFunc) + + +def _safe_val(raw: str) -> Any: + raw = html.unescape(raw.strip()) + try: + return json.loads(raw) + except Exception: + try: + return ast.literal_eval(raw) + except Exception: + return raw + + +class Qwen3CoderToolParser(BaseToolParser): + """Tool parser for Qwen 3 models. + + Assumes function call format: + + + + pwd && ls + + + + """ + + def __init__(self): + super().__init__() + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + self.tool_call_prefix: str = "(.*?)|(.*?)$", re.DOTALL) + self.tool_call_function_regex = re.compile( + r"|| bool: + return self.tool_call_start_token in text + + def detect_and_parse(self, text: str, + tools: List[Tool]) -> StreamingParseResult: + normal, calls = self._extract(text, tools) + return StreamingParseResult(normal_text=normal, calls=calls) + + def parse_streaming_increment(self, new_text: str, + tools: List[Tool]) -> StreamingParseResult: + self._buf += new_text + normal = "" + calls: List[ToolCallItem] = [] + + # Build tool indices for validation + if not hasattr(self, "_tool_indices"): + self._tool_indices = self._get_tool_indices(tools) + + while True: + # If we're not in a tool call and don't see a start token, return normal text + if not self._in_tool_call and self.tool_call_start_token not in self._buf: + normal += self._buf + self._buf = "" + break + + # Look for tool call start + if not self._in_tool_call: + s = self._buf.find(self.tool_call_start_token) + if s == -1: + normal += self._buf + self._buf = "" + break + + normal += self._buf[:s] + self._buf = self._buf[s:] + + self._in_tool_call = True + self._function_name_sent = False + self._current_function_name = "" + self._current_parameters = {} + self._streamed_parameters = {} + + # Remove the start token + self._buf = self._buf[len(self.tool_call_start_token):] + continue + + # We're in a tool call, try to parse function name if not sent yet + if not self._function_name_sent: + # Look for function name pattern: + function_match = re.search(r"]+)>", self._buf) + if function_match: + function_name = function_match.group(1).strip() + + # Validate function name + if function_name in self._tool_indices: + self._current_function_name = function_name + self._function_name_sent = True + + # Initialize tool call tracking + if self.current_tool_id == -1: + self.current_tool_id = 0 + + # Ensure tracking arrays are large enough + while len(self.prev_tool_call_arr + ) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool + ) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + # Store tool call info + self.prev_tool_call_arr[self.current_tool_id] = { + "name": function_name, + "arguments": {}, + } + + # Send tool name with empty parameters + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=function_name, + parameters="", + )) + + # Remove the processed function declaration + self._buf = self._buf[function_match.end():] + continue + else: + # Invalid function name, reset state + logger.warning( + f"Invalid function name: {function_name}") + self._reset_streaming_state() + normal += self._buf + self._buf = "" + break + else: + # Function name not complete yet, wait for more text + break + + # Parse parameters incrementally + if self._function_name_sent: + # Process parameters and get any calls to emit + parameter_calls = self._parse_and_stream_parameters(self._buf) + calls.extend(parameter_calls) + + # Check if tool call is complete + if self.tool_call_end_token in self._buf: + end_pos = self._buf.find(self.tool_call_end_token) + + # Add closing brace to complete the JSON object + current_streamed = self.streamed_args_for_tool[ + self.current_tool_id] + if current_streamed: + # Count opening and closing braces to check if JSON is complete + open_braces = current_streamed.count("{") + close_braces = current_streamed.count("}") + if open_braces > close_braces: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters="}", + )) + self.streamed_args_for_tool[ + self.current_tool_id] = (current_streamed + "}") + + # Complete the tool call + self._buf = self._buf[end_pos + + len(self.tool_call_end_token):] + self._reset_streaming_state() + self.current_tool_id += 1 + continue + else: + # Tool call not complete yet, wait for more text + break + + return StreamingParseResult(normal_text=normal, calls=calls) + + def _parse_and_stream_parameters(self, + text_to_parse: str) -> List[ToolCallItem]: + """ + Parse complete parameter blocks from text and return any tool call items to emit. + + This method: + 1. Finds all complete blocks + 2. Parses them into a dictionary + 3. Compares with current parameters and generates diff if needed + 4. Updates internal state + + Args: + text_to_parse: The text to search for parameter blocks + + Returns: + List of ToolCallItem objects to emit (may be empty) + """ + calls: List[ToolCallItem] = [] + + # Find all complete parameter patterns + param_matches = list( + re.finditer(r"]+)>(.*?)", text_to_parse, + re.DOTALL)) + + # Build new parameters dictionary + new_params = {} + for match in param_matches: + param_name = match.group(1).strip() + param_value = match.group(2) + new_params[param_name] = _safe_val(param_value) + + # Calculate parameter diff to stream with proper incremental JSON building + if new_params != self._current_parameters: + previous_args_json = self.streamed_args_for_tool[ + self.current_tool_id] + + # Build incremental JSON properly + if not self._current_parameters: + # First parameter(s) - start JSON object but don't close it yet + items = [] + for key, value in new_params.items(): + items.append( + f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}" + ) + json_fragment = "{" + ", ".join(items) + + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=json_fragment, + )) + self.streamed_args_for_tool[ + self.current_tool_id] = json_fragment + + else: + # Additional parameters - add them incrementally + new_keys = set(new_params.keys()) - set( + self._current_parameters.keys()) + if new_keys: + # Build the continuation part (no closing brace yet) + continuation_parts = [] + for key in new_keys: + value = new_params[key] + continuation_parts.append( + f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}" + ) + + json_fragment = ", " + ", ".join(continuation_parts) + + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=json_fragment, + )) + self.streamed_args_for_tool[self.current_tool_id] = ( + previous_args_json + json_fragment) + + # Update current state + self._current_parameters = new_params + self.prev_tool_call_arr[ + self.current_tool_id]["arguments"] = new_params + + return calls + + def _reset_streaming_state(self): + """Reset streaming state for the next tool call""" + self._in_tool_call = False + self._function_name_sent = False + self._current_function_name = "" + self._current_parameters = {} + self._streamed_parameters = {} + self.current_tool_name_sent = False + + def _extract(self, text: str, + tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]: + normal_parts: List[str] = [] + calls: List[ToolCallItem] = [] + cursor = 0 + while True: + s = text.find(self.tool_call_start_token, cursor) + if s == -1: + normal_parts.append(text[cursor:]) + break + normal_parts.append(text[cursor:s]) + e = text.find(self.tool_call_end_token, s) + if e == -1: + normal_parts.append(text[s:]) + break + block = text[s:e + len(self.tool_call_end_token)] + cursor = e + len(self.tool_call_end_token) + calls.extend(self._parse_block(block, tools)) + return "".join(normal_parts), calls + + def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]: + res: List[ToolCallItem] = [] + for m in self.tool_call_function_regex.findall(block): + txt = m[0] if m[0] else m[1] + if ">" not in txt: + continue + idx = txt.index(">") + fname = txt[:idx].strip() + body = txt[idx + 1:] + params: Dict[str, Any] = {} + for pm in self.tool_call_parameter_regex.findall(body): + ptxt = pm[0] if pm[0] else pm[1] + if ">" not in ptxt: + continue + pidx = ptxt.index(">") + pname = ptxt[:pidx].strip() + pval = ptxt[pidx + 1:].lstrip("\n").rstrip("\n") + params[pname] = _safe_val(pval) + raw = {"name": fname, "arguments": params} + try: + # TODO: fix idx in function call, the index for a function + # call will always be -1 in parse_base_json + res.extend(self.parse_base_json(raw, tools)) + except Exception: + logger.warning("invalid tool call for %s dropped", fname) + return res + + def supports_structural_tag(self) -> bool: + return False + + def structure_info(self) -> _GetInfoFunc: + raise NotImplementedError diff --git a/tensorrt_llm/serve/tool_parser/tool_parser_factory.py b/tensorrt_llm/serve/tool_parser/tool_parser_factory.py index 73b02510a67..8a9bbe298c1 100644 --- a/tensorrt_llm/serve/tool_parser/tool_parser_factory.py +++ b/tensorrt_llm/serve/tool_parser/tool_parser_factory.py @@ -1,12 +1,14 @@ from typing import Type from .base_tool_parser import BaseToolParser +from .qwen3_coder_parser import Qwen3CoderToolParser from .qwen3_tool_parser import Qwen3ToolParser class ToolParserFactory: parsers: dict[str, Type[BaseToolParser]] = { "qwen3": Qwen3ToolParser, + "qwen3_coder": Qwen3CoderToolParser, } @staticmethod