Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from ..flashinfer_utils import ENABLE_PDL, IS_FLASHINFER_AVAILABLE
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE, get_env_enable_pdl

if IS_FLASHINFER_AVAILABLE:
from flashinfer.activation import silu_and_mul
Expand All @@ -11,7 +11,7 @@
# Warp this into custom op since flashinfer didn't warp it properly and we want to avoid graph break between mlp layer for user buffer optimization
@torch.library.custom_op("trtllm::flashinfer_silu_and_mul", mutates_args=())
def flashinfer_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
return silu_and_mul(x, enable_pdl=ENABLE_PDL)
return silu_and_mul(x, enable_pdl=get_env_enable_pdl())

@flashinfer_silu_and_mul.register_fake
def _(x: torch.Tensor) -> torch.Tensor:
Expand All @@ -21,7 +21,7 @@ def _(x: torch.Tensor) -> torch.Tensor:
@torch.library.custom_op("trtllm::flashinfer_rmsnorm", mutates_args=())
def flashinfer_rmsnorm(input: torch.Tensor, weight: torch.Tensor,
eps: float) -> torch.Tensor:
return rmsnorm(input, weight, eps, enable_pdl=ENABLE_PDL)
return rmsnorm(input, weight, eps, enable_pdl=get_env_enable_pdl())

@flashinfer_rmsnorm.register_fake
def _(input: torch.Tensor, weight: torch.Tensor,
Expand All @@ -32,7 +32,10 @@ def _(input: torch.Tensor, weight: torch.Tensor,
mutates_args=())
def flashinfer_gemma_rmsnorm(input: torch.Tensor, weight: torch.Tensor,
eps: float) -> torch.Tensor:
return gemma_rmsnorm(input, weight, eps, enable_pdl=ENABLE_PDL)
return gemma_rmsnorm(input,
weight,
eps,
enable_pdl=get_env_enable_pdl())

@flashinfer_gemma_rmsnorm.register_fake
def _(input: torch.Tensor, weight: torch.Tensor,
Expand All @@ -44,7 +47,11 @@ def _(input: torch.Tensor, weight: torch.Tensor,
def flashinfer_fused_add_rmsnorm(input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor, eps: float) -> None:
fused_add_rmsnorm(input, residual, weight, eps, enable_pdl=ENABLE_PDL)
fused_add_rmsnorm(input,
residual,
weight,
eps,
enable_pdl=get_env_enable_pdl())

@torch.library.custom_op("trtllm::flashinfer_gemma_fused_add_rmsnorm",
mutates_args=("input", "residual"))
Expand All @@ -56,7 +63,7 @@ def flashinfer_gemma_fused_add_rmsnorm(input: torch.Tensor,
residual,
weight,
eps,
enable_pdl=ENABLE_PDL)
enable_pdl=get_env_enable_pdl())

@torch.library.custom_op(
"trtllm::flashinfer_apply_rope_with_cos_sin_cache_inplace",
Expand Down
10 changes: 5 additions & 5 deletions tensorrt_llm/_torch/flashinfer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@


def get_env_enable_pdl():
return os.environ.get("TRTLLM_ENABLE_PDL", "0") == "1"
enabled = os.environ.get("TRTLLM_ENABLE_PDL", "0") == "1"
if enabled and not getattr(get_env_enable_pdl, "_printed", False):
logger.info("PDL enabled")
setattr(get_env_enable_pdl, "_printed", True)
return enabled


ENABLE_PDL = get_env_enable_pdl()
if ENABLE_PDL:
logger.info("PDL is enabled")

if platform.system() != "Windows":
try:
import flashinfer
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
else:
from typing_extensions import override

from ..flashinfer_utils import ENABLE_PDL
from ..flashinfer_utils import get_env_enable_pdl
from .sampling_utils import (
GREEDY,
GroupedStrategySampler,
Expand Down Expand Up @@ -112,7 +112,7 @@ def _prepare_probs_with_temperature(
probs = flashinfer.sampling.softmax(
logits,
temperature,
enable_pdl=ENABLE_PDL,
enable_pdl=get_env_enable_pdl(),
)
return probs

Expand Down
27 changes: 16 additions & 11 deletions tensorrt_llm/bench/benchmark/low_latency.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
import os
from functools import partial
from pathlib import Path

Expand Down Expand Up @@ -46,12 +45,14 @@
help="Path to a serialized TRT-LLM engine.",
)
@optgroup.option(
"--config",
"--extra_llm_api_options",
"extra_llm_api_options",
type=str,
default=None,
help=
"Path to a YAML file that overwrites the parameters specified by trtllm-bench."
)
"Path to a YAML file that overwrites the parameters specified by trtllm-bench. "
"Can be specified as either --config or --extra_llm_api_options.")
@optgroup.option(
"--backend",
type=click.Choice(ALL_SUPPORTED_BACKENDS),
Expand Down Expand Up @@ -192,6 +193,7 @@ def latency_command(
) -> None:
"""Run a latency test on a TRT-LLM engine."""
logger.info("Preparing to run latency benchmark...")

# Parameters from CLI
# Model, experiment, and engine params
options = get_general_cli_options(params, bench_env)
Expand Down Expand Up @@ -262,14 +264,6 @@ def latency_command(
exec_settings["settings_config"][
"scheduler_policy"] = CapacitySchedulerPolicy.GUARANTEED_NO_EVICT

# Set environment variables for setting runtime options.
# TODO: Once passing of variables is fixed, these should work
# when using MPI in C++ runtime.
os.environ["TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG"] = "1"
os.environ["TRTLLM_MMHA_KERNEL_BLOCK_SIZE"] = "256"
os.environ["FORCE_MULTI_BLOCK_MODE"] = "1"
os.environ["TRTLLM_ENABLE_PDL"] = "1"

# Performance options
exec_settings["performance_options"]["cuda_graphs"] = True
exec_settings["performance_options"]["multi_block_mode"] = True
Expand All @@ -289,6 +283,17 @@ def latency_command(
kwargs = kwargs | runtime_config.get_llm_args()
kwargs['backend'] = options.backend

# Set environment variables for setting runtime options.
default_env_overrides = {
"TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG": "1",
"TRTLLM_MMHA_KERNEL_BLOCK_SIZE": "256",
"FORCE_MULTI_BLOCK_MODE": "1",
"TRTLLM_ENABLE_PDL": "1",
}
# Update defaults with existing overrides (user preference takes priority)
default_env_overrides.update(kwargs.get("env_overrides", {}))
kwargs["env_overrides"] = default_env_overrides

try:
logger.info("Setting up latency benchmark.")

Expand Down
7 changes: 5 additions & 2 deletions tensorrt_llm/bench/benchmark/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@
help="Paths to custom module directories to import.",
)
@optgroup.option(
"--config",
"--extra_llm_api_options",
"extra_llm_api_options",
type=str,
default=None,
help=
"Path to a YAML file that overwrites the parameters specified by trtllm-bench."
)
"Path to a YAML file that overwrites the parameters specified by trtllm-bench. "
"Can be specified as either --config or --extra_llm_api_options.")
@optgroup.option("--sampler_options",
type=click.Path(exists=True,
readable=True,
Expand Down Expand Up @@ -293,6 +295,7 @@ def throughput_command(
) -> None:
"""Run a throughput test on a TRT-LLM engine."""
logger.info("Preparing to run throughput benchmark...")

# Parameters from CLI
image_data_format: str = params.get("image_data_format", "pt")
data_device: str = params.get("data_device", "cpu")
Expand Down
7 changes: 5 additions & 2 deletions tensorrt_llm/commands/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,13 @@
is_flag=True,
default=False,
help="Flag for HF transformers.")
@click.option("--extra_llm_api_options",
@click.option("--config",
"--extra_llm_api_options",
"extra_llm_api_options",
type=str,
default=None,
help="Path to a YAML file that overwrites the parameters")
help="Path to a YAML file that overwrites the parameters. "
"Can be specified as either --config or --extra_llm_api_options.")
@click.option("--disable_kv_cache_reuse",
is_flag=True,
default=False,
Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,14 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
default=False,
help="Flag for HF transformers.")
@click.option(
"--config",
"--extra_llm_api_options",
"extra_llm_api_options",
type=str,
default=None,
help=
"Path to a YAML file that overwrites the parameters specified by trtllm-serve."
)
"Path to a YAML file that overwrites the parameters specified by trtllm-serve. "
"Can be specified as either --config or --extra_llm_api_options.")
@click.option(
"--reasoning_parser",
type=click.Choice(ReasoningParserFactory.parsers.keys()),
Expand Down
9 changes: 9 additions & 0 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,17 @@ def worker_main(
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[BaseLlmArgs] = None,
) -> None:

mpi_comm().barrier()

if llm_args is not None and llm_args.env_overrides:
# this is needed because MPI_Init seems to cache the env at import time.
# The cached env snapshot is used to spawn workers.
# Any env overrides to the main process after tensorrt_llm import
# may not get reflected in the spawned worker process, no matter how early,
# unless we update it explicitly here.
os.environ.update(llm_args.env_overrides)

if llm_args is not None and llm_args.trust_remote_code:
_init_hf_modules()

Expand Down
22 changes: 22 additions & 0 deletions tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def __init__(self,
logger.set_level("info") # force display the backend

try:
env_overrides = kwargs.get("env_overrides", None)
self._process_env_overrides(env_overrides)

backend = kwargs.get('backend', None)
if backend == "pytorch":
logger.info("Using LLM with PyTorch backend")
Expand Down Expand Up @@ -587,6 +590,25 @@ def get_kv_cache_events_async(self,
'''
return self._executor.aget_kv_events(timeout=timeout)

def _process_env_overrides(self,
env_overrides: Optional[dict[str, str]]) -> None:
if env_overrides is None:
return
logger.info("Processing LLM API environment variable overrides")
# TODO: If an env var is cached at import-time in code, overriding os.environ will
# unfortunately not update wherever the var is used.
# This is a known issue and only way to fix it is at every such usage to access it
# from os.environ on-demand.
for key, value in env_overrides.items():
str_value = str(value)
if key in os.environ:
old_value = os.environ[key]
os.environ[key] = str_value
logger.info(f"Overriding {key}: '{old_value}' -> '{str_value}'")
else:
os.environ[key] = str_value
logger.info(f"Setting {key}='{str_value}'")

def _prepare_sampling_params(
self,
sampling_params: Optional[SamplingParams] = None) -> SamplingParams:
Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1922,6 +1922,12 @@ class BaseLlmArgs(StrictBaseModel):
status="prototype",
)

env_overrides: Optional[Dict[str, str]] = Field(
default=None,
description=
"[EXPERIMENTAL] Environment variable overrides. NOTE: import-time-cached env vars in the code won’t update unless the code fetches them from os.environ on demand.",
status="prototype")

_parallel_config: Optional[_ParallelConfig] = PrivateAttr(default=None)
_model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None)
_speculative_model: Optional[str] = PrivateAttr(default=None)
Expand Down
Loading