Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions tensorrt_llm/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def launch_server(
port: int,
llm_args: dict,
tool_parser: Optional[str] = None,
chat_template: Optional[str] = None,
metadata_server_cfg: Optional[MetadataServerConfig] = None,
server_role: Optional[ServerRole] = None,
disagg_cluster_config: Optional[DisaggClusterConfig] = None,
Expand Down Expand Up @@ -180,7 +181,8 @@ def launch_server(
server_role=server_role,
metadata_server_cfg=metadata_server_cfg,
disagg_cluster_config=disagg_cluster_config,
multimodal_server_config=multimodal_server_config)
multimodal_server_config=multimodal_server_config,
chat_template=chat_template)

# Optionally disable GC (default: not disabled)
if os.getenv("TRTLLM_SERVER_DISABLE_GC", "0") == "1":
Expand Down Expand Up @@ -354,6 +356,10 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
type=str,
default=None,
help="Keyword arguments for media I/O.")
@click.option("--chat_template",
type=str,
default=None,
help="[Experimental] Specify the chat template.")
def serve(
model: str, tokenizer: Optional[str], host: str, port: int,
log_level: str, backend: str, max_beam_width: int, max_batch_size: int,
Expand All @@ -366,7 +372,8 @@ def serve(
server_role: Optional[str],
fail_fast_on_attention_window_too_large: bool,
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str]):
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str],
chat_template: Optional[str]):
"""Running an OpenAI API compatible server

MODEL: model name | HF checkpoint path | TensorRT engine path
Expand Down Expand Up @@ -434,8 +441,9 @@ def serve(

multimodal_server_config = MultimodalServerConfig(
media_io_kwargs=parsed_media_io_kwargs)
launch_server(host, port, llm_args, tool_parser, metadata_server_cfg,
server_role, disagg_cluster_config, multimodal_server_config)
launch_server(host, port, llm_args, tool_parser, chat_template,
metadata_server_cfg, server_role, disagg_cluster_config,
multimodal_server_config)


@click.command("mm_embedding_serve")
Expand Down
50 changes: 49 additions & 1 deletion tensorrt_llm/serve/chat_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import uuid
from functools import partial
from functools import lru_cache, partial
from pathlib import Path
from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal,
Optional, Tuple, TypeAlias, TypedDict, Union, cast)

Expand Down Expand Up @@ -254,3 +255,50 @@ def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
else:
# by default return random
return f"chatcmpl-tool-{uuid.uuid4().hex}"


# Adapted from
# https://github.com/vllm-project/vllm/blob/44b5ce956d3cf28841615a58c1c0873af87bcfe2/vllm/entrypoints/chat_utils.py
def _load_chat_template(
chat_template: Path | str | None,
*,
is_literal: bool = False,
) -> str | None:
if chat_template is None:
return None

if is_literal:
if isinstance(chat_template, Path):
raise TypeError(
"chat_template is expected to be read directly from its value")

return chat_template

try:
with open(chat_template) as f:
return f.read()
except OSError as e:
if isinstance(chat_template, Path):
raise

JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e

# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
chat_template: Path | str | None,
*,
is_literal: bool = False,
) -> str | None:
return _cached_load_chat_template(chat_template, is_literal=is_literal)
9 changes: 6 additions & 3 deletions tensorrt_llm/serve/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.logger import logger
from tensorrt_llm.metrics.collector import MetricsCollector
from tensorrt_llm.serve.chat_utils import parse_chat_messages_coroutines
from tensorrt_llm.serve.chat_utils import (load_chat_template,
parse_chat_messages_coroutines)
from tensorrt_llm.serve.cluster_storage import create_cluster_storage_client
from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterWorker
from tensorrt_llm.serve.metadata_server import create_metadata_server
Expand Down Expand Up @@ -81,13 +82,15 @@ def __init__(self,
server_role: Optional[ServerRole],
metadata_server_cfg: MetadataServerConfig,
disagg_cluster_config: Optional[DisaggClusterConfig] = None,
multimodal_server_config: Optional[MultimodalServerConfig] = None):
multimodal_server_config: Optional[MultimodalServerConfig] = None,
chat_template: Optional[str] = None):
self.llm = llm
self.tokenizer = llm.tokenizer
self.tool_parser = tool_parser
self.metadata_server = create_metadata_server(metadata_server_cfg)
self.disagg_cluster_config = disagg_cluster_config
self.multimodal_server_config = multimodal_server_config
self.chat_template = load_chat_template(chat_template)
self.server_role = server_role
# Will be set in __call__
self.binding_addr = None
Expand Down Expand Up @@ -510,7 +513,7 @@ async def create_chat_response(
mm_placeholder_counts=mm_placeholder_counts,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
chat_template=request.chat_template or self.chat_template,
chat_template_kwargs=request.chat_template_kwargs or {},
)
prompt = prompt_inputs(prompt)
Expand Down
54 changes: 53 additions & 1 deletion tests/unittest/llmapi/apps/test_chat_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import tempfile
from unittest.mock import MagicMock

import pytest

from tensorrt_llm.serve.chat_utils import parse_chat_message_content
from tensorrt_llm.serve.chat_utils import load_chat_template, parse_chat_message_content


@pytest.fixture
Expand Down Expand Up @@ -177,3 +179,53 @@ def test_tool_message_without_tool_call_id(self, mock_mm_data_tracker):

expected = {**message, "media": []}
assert result == expected


# ruff: noqa: E501
TEMPLATE_CHATML = """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""


@pytest.fixture
def chat_template_path():
"""Return the path to the chat template."""
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, "chat_template.jinja")
try:
with open(temp_file_path, "w") as f:
f.write(TEMPLATE_CHATML)
yield temp_file_path
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)


class TestLoadChatTemplate:
"""Test suite for loading chat templates."""

def test_load_chat_template_from_path(self, chat_template_path):
"""Test loading a chat template from a path."""
template = load_chat_template(chat_template_path)
assert template == TEMPLATE_CHATML

def test_load_chat_template_from_string(self):
"""Test loading a chat template from a string."""
text = "Hello, how can I help you?"
template = load_chat_template(text, is_literal=True)
assert template == text

def test_load_chat_template_from_none(self):
"""Test loading a chat template from None."""
template = load_chat_template(None)
assert template is None

def test_load_chat_template_from_path_with_invalid_path(self):
"""Test loading a chat template from a path with an invalid path."""
with pytest.raises(ValueError, match="looks like a file path"):
load_chat_template("invalid/path/to/chat_template.jinja")

def test_jinjalike_literal(self):
"""Test loading a chat template from a jinja-like string."""
template = "{{ messages }}"
template_content = load_chat_template(template)
assert template_content == template
Loading