Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,6 @@ compile_commands.json

# Enroot sqsh files
enroot/tensorrt_llm.devel.sqsh

# MacOSX Files
.DS_Store
2 changes: 1 addition & 1 deletion benchmarks/cpp/prepare_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def validate_tokenizer(self):
return self


@click.group()
@click.group(deprecated=True)
@click.option(
"--tokenizer",
required=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,13 @@ To do the benchmark, run the following command:

```bash
# generate synthetic dataset
python ${YOUR_WORK_PATH}/benchmarks/cpp/prepare_dataset.py \
--stdout \
--tokenizer nvidia/DeepSeek-R1-FP4 \
trtllm-bench --model nvidia/DeepSeek-R1-FP4 \
dataset \
--output dataset.txt \
token-norm-dist \
--input-mean 1024 --output-mean 2048 \
--input-stdev 0 --output-stdev 0 \
--num-requests 49152 > dataset.txt
--num-requests 49152

YOUR_DATA_PATH=./dataset.txt

Expand Down Expand Up @@ -350,13 +350,14 @@ To do the benchmark, run the following command:

```bash
# generate synthetic dataset
python ${YOUR_WORK_PATH}/benchmarks/cpp/prepare_dataset.py \
--stdout \
--tokenizer deepseek-ai/DeepSeek-R1 \
trtllm-bench --model nvidia/DeepSeek-R1-FP4 \
dataset \
--output dataset.txt \
token-norm-dist \
--input-mean 1024 --output-mean 2048 \
--input-stdev 0 --output-stdev 0 \
--num-requests 5120 > dataset.txt
--num-requests 5120

YOUR_DATA_PATH=./dataset.txt

cat >./extra-llm-api-config.yml<<EOF
Expand Down Expand Up @@ -401,7 +402,7 @@ Average request latency (ms): 181540.5739

## Exploring more ISL/OSL combinations

To benchmark TensorRT LLM on DeepSeek models with more ISL/OSL combinations, you can use `prepare_dataset.py` to generate the dataset and use similar commands mentioned in the previous section. TensorRT LLM is working on enhancements that can make the benchmark process smoother.
To benchmark TensorRT LLM on DeepSeek models with more ISL/OSL combinations, you can use the `trtllm-bench dataset` subcommand to generate the dataset and use similar commands mentioned in the previous section. TensorRT LLM is working on enhancements that can make the benchmark process smoother.
### WIP: Enable more features by default

Currently, there are some features that need to be enabled through a user-defined file `extra-llm-api-config.yml`, such as CUDA graph, overlap scheduler and attention dp. We're working on to enable those features by default, so that users can get good out-of-the-box performance on DeepSeek models.
Expand All @@ -414,7 +415,7 @@ For more details on `max_batch_size` and `max_num_tokens`, refer to [Tuning Max

### MLA chunked context

MLA currently supports the chunked context feature on both Hopper and Blackwell GPUs. You can use `--enable_chunked_context` to enable it. This feature is primarily designed to reduce TPOT (Time Per Output Token). The default chunk size is set to `max_num_tokens`. If you want to achieve a lower TPOT, you can appropriately reduce the chunk size. However, please note that this will also decrease overall throughput. Therefore, a trade-off needs to be considered.
MLA currently supports the chunked context feature on both Hopper and Blackwell GPUs. You can use `--enable_chunked_context` to enable it. This feature is primarily designed to reduce TPOT (Time Per Output Token). The default chunk size is set to `max_num_tokens`. If you want to achieve a lower TPOT, you can appropriately reduce the chunk size. However, please note that this will also decrease overall throughput. Therefore, a trade-off needs to be considered.

For more details on `max_num_tokens`, refer to [Tuning Max Batch Size and Max Num Tokens](../performance/performance-tuning-guide/tuning-max-batch-size-and-max-num-tokens.md).

Expand Down
10 changes: 6 additions & 4 deletions docs/source/developer-guide/perf-analysis.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,12 @@ Say we want to profile iterations 100 to 150 on a `trtllm-bench`/`trtllm-serve`
#!/bin/bash

# Prepare dataset for the benchmark
python3 benchmarks/cpp/prepare_dataset.py \
--tokenizer=${MODEL_PATH} \
--stdout token-norm-dist --num-requests=${NUM_SAMPLES} \
--input-mean=1000 --output-mean=1000 --input-stdev=0 --output-stdev=0 > /tmp/dataset.txt
trtllm-bench --model ${MODEL_PATH} \
dataset \
--output dataset.txt \
token-norm-dist \
--num-requests=${NUM_SAMPLES} \
--input-mean=1000 --output-mean=1000 --input-stdev=0 --output-stdev=0

# Benchmark and profile
TLLM_PROFILE_START_STOP=100-150 nsys profile \
Expand Down
21 changes: 11 additions & 10 deletions docs/source/developer-guide/perf-benchmarking.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ directory. For example, to generate a synthetic dataset of 1000 requests with a
128/128 for [meta-llama/Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B), run:

```shell
python benchmarks/cpp/prepare_dataset.py --stdout --tokenizer meta-llama/Llama-3.1-8B token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1000 > /tmp/synthetic_128_128.txt
trtllm-bench --model meta-llama/Llama-3.1-8B dataset --output /tmp/synthetic_128_128.txt token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1000
```

### Running with the PyTorch Workflow
Expand Down Expand Up @@ -231,13 +231,13 @@ The PyTorch workflow supports benchmarking with LoRA (Low-Rank Adaptation) adapt

**Preparing LoRA Dataset**

Use `prepare_dataset.py` with LoRA-specific options to generate requests with LoRA metadata:
Use `trtllm-bench dataset` with LoRA-specific options to generate requests with LoRA metadata:

```shell
python3 benchmarks/cpp/prepare_dataset.py \
--stdout \
trtllm-bench \
--model /path/to/tokenizer \
dataset \
--rand-task-id 0 1 \
--tokenizer /path/to/tokenizer \
--lora-dir /path/to/loras \
token-norm-dist \
--num-requests 100 \
Expand Down Expand Up @@ -308,17 +308,18 @@ Each subdirectory should contain the LoRA adapter files for that specific task.
To benchmark multi-modal models with PyTorch workflow, you can follow the similar approach as above.

First, prepare the dataset:
```python
python ./benchmarks/cpp/prepare_dataset.py \
--tokenizer Qwen/Qwen2-VL-2B-Instruct \
--stdout \
```bash
trtllm-bench \
--model Qwen/Qwen2-VL-2B-Instruct \
dataset \
--output mm_data.jsonl
real-dataset
--dataset-name lmms-lab/MMMU \
--dataset-split test \
--dataset-image-key image \
--dataset-prompt-key question \
--num-requests 10 \
--output-len-dist 128,5 > mm_data.jsonl
--output-len-dist 128,5
```
It will download the media files to `/tmp` directory and prepare the dataset with their paths. Note that the `prompt` fields are texts and not tokenized ids. This is due to the fact that
the `prompt` and the media (image/video) are processed by a preprocessor for multimodal files.
Expand Down
8 changes: 4 additions & 4 deletions docs/source/legacy/performance/perf-analysis.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ Say we want to profile iterations 100 to 150 on a trtllm-bench/trtllm-serve run,
#!/bin/bash

# Prepare dataset for the benchmark
python3 benchmarks/cpp/prepare_dataset.py \
--tokenizer=${MODEL_PATH} \
--stdout token-norm-dist --num-requests=${NUM_SAMPLES} \
--input-mean=1000 --output-mean=1000 --input-stdev=0 --output-stdev=0 > /tmp/dataset.txt
trtllm-bench \
--model=${MODEL_PATH} dataset \
--output /tmp/dataset.txt token-norm-dist --num-requests=${NUM_SAMPLES} \
--input-mean=1000 --output-mean=1000 --input-stdev=0 --output-stdev=0

# Benchmark and profile
TLLM_PROFILE_START_STOP=100-150 nsys profile \
Expand Down
4 changes: 2 additions & 2 deletions docs/source/legacy/performance/perf-benchmarking.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ of 128:128.
To run the benchmark from start to finish, run the following commands:

```shell
python benchmarks/cpp/prepare_dataset.py --stdout --tokenizer meta-llama/Llama-3.1-8B token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 3000 > /tmp/synthetic_128_128.txt
trtllm-bench --tokenizer meta-llama/Llama-3.1-8B dataset --output /tmp/synthetic_128_128.txt token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 3000
trtllm-bench --model meta-llama/Llama-3.1-8B build --dataset /tmp/synthetic_128_128.txt --quantization FP8
trtllm-bench --model meta-llama/Llama-3.1-8B throughput --dataset /tmp/synthetic_128_128.txt --engine_dir /tmp/meta-llama/Llama-3.1-8B/tp_1_pp_1
```
Expand Down Expand Up @@ -207,7 +207,7 @@ directory. For example, to generate a synthetic dataset of 1000 requests with a
128/128 for [meta-llama/Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B), run:

```shell
benchmarks/cpp/prepare_dataset.py --stdout --tokenizer meta-llama/Llama-3.1-8B token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1000 > /tmp/synthetic_128_128.txt
trtllm-bench --tokenizer meta-llama/Llama-3.1-8B dataset --output /tmp/synthetic_128_128.txt token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 3000
```

### Building a Benchmark Engine
Expand Down
9 changes: 4 additions & 5 deletions examples/llm-api/llm_mgmn_trtllm_bench.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
# not supported in Slurm mode, you need to download the model and put it in
# the LOCAL_MODEL directory.

export prepare_dataset="$SOURCE_ROOT/benchmarks/cpp/prepare_dataset.py"
export data_path="$WORKDIR/token-norm-dist.txt"

echo "Preparing dataset..."
Expand All @@ -50,14 +49,14 @@ srun -l \
--mpi=pmix \
bash -c "
$PROLOGUE
python3 $prepare_dataset \
--tokenizer=$LOCAL_MODEL \
--stdout token-norm-dist \
trtllm-bench --model=$LOCAL_MODEL dataset \
--output $data_path \
token-norm-dist \
--num-requests=100 \
--input-mean=128 \
--output-mean=128 \
--input-stdev=0 \
--output-stdev=0 > $data_path
--output-stdev=0
"

echo "Running benchmark..."
Expand Down
12 changes: 11 additions & 1 deletion examples/llm-api/out_of_tree_example/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,17 @@ Similar to the quickstart example, you can use the same CLI argument with `trtll

Prepare the dataset:
```
python ./benchmarks/cpp/prepare_dataset.py --tokenizer ./model_ckpt --stdout dataset --dataset-name lmms-lab/MMMU --dataset-split test --dataset-image-key image --dataset-prompt-key "question" --num-requests 100 --output-len-dist 128,5 > mm_data.jsonl
trtllm-bench \
--model ./model_ckpt \
dataset \
--output mm_data.jsonl
real-dataset
--dataset-name lmms-lab/MMMU \
--dataset-split test \
--dataset-image-key image \
--dataset-prompt-key question \
--num-requests 10 \
--output-len-dist 128,5
```


Expand Down
16 changes: 9 additions & 7 deletions examples/models/core/deepseek_v3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,13 @@ To avoid OOM (out of memory) error, you need to adjust the values of "--max_batc
#### ISL-64k-OSL-1024
```bash
DS_R1_NVFP4_MODEL_PATH=/path/to/DeepSeek-R1
python /app/tensorrt_llm/benchmarks/cpp/prepare_dataset.py \
--stdout --tokenizer ${DS_R1_NVFP4_MODEL_PATH} \
trtllm-bench --model ${DS_R1_NVFP4_MODEL_PATH} \
dataset \
--output /tmp/benchmarking_64k.txt \
token-norm-dist \
--input-mean 65536 --output-mean 1024 \
--input-stdev 0 --output-stdev 0 \
--num-requests 24 > /tmp/benchmarking_64k.txt
--num-requests 24

cat <<EOF > /tmp/extra-llm-api-config.yml
cuda_graph_config:
Expand All @@ -164,12 +165,13 @@ trtllm-bench -m deepseek-ai/DeepSeek-R1 --model_path ${DS_R1_NVFP4_MODEL_PATH} t
#### ISL-128k-OSL-1024
```bash
DS_R1_NVFP4_MODEL_PATH=/path/to/DeepSeek-R1
python /app/tensorrt_llm/benchmarks/cpp/prepare_dataset.py \
--stdout --tokenizer ${DS_R1_NVFP4_MODEL_PATH} \
trtllm-bench --model ${DS_R1_NVFP4_MODEL_PATH} \
dataset \
--output /tmp/benchmarking_128k.txt \
token-norm-dist \
--input-mean 131072 --output-mean 1024 \
--input-stdev 0 --output-stdev 0 \
--num-requests 4 > /tmp/benchmarking_128k.txt
--num-requests 4

cat <<EOF > /tmp/extra-llm-api-config.yml
cuda_graph_config:
Expand Down Expand Up @@ -336,7 +338,7 @@ curl http://localhost:8000/v1/completions \
}'
```

For DeepSeek-R1 FP4, use the model name `nvidia/DeepSeek-R1-FP4-v2`.
For DeepSeek-R1 FP4, use the model name `nvidia/DeepSeek-R1-FP4-v2`.
For DeepSeek-V3, use the model name `deepseek-ai/DeepSeek-V3`.

### Disaggregated Serving
Expand Down
Empty file.
105 changes: 105 additions & 0 deletions tensorrt_llm/bench/dataset/prepare_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Optional, Tuple

import click
from pydantic import BaseModel, model_validator
from transformers import AutoTokenizer

from tensorrt_llm.bench.dataset.prepare_real_data import real_dataset
from tensorrt_llm.bench.dataset.prepare_synthetic_data import (token_norm_dist,
token_unif_dist)


class RootArgs(BaseModel):
tokenizer: str
output: str
random_seed: int
task_id: int
trust_remote_code: bool = False
rand_task_id: Optional[Tuple[int, int]]
lora_dir: Optional[str] = None

@model_validator(mode="after")
def validate_tokenizer(self):
try:
tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer,
padding_side="left",
trust_remote_code=self.trust_remote_code)
except EnvironmentError as e:
raise ValueError(
"Cannot find a tokenizer from the given string because of "
f"{e}\nPlease set tokenizer to the directory that contains "
"the tokenizer, or set to a model name in HuggingFace.")
tokenizer.pad_token = tokenizer.eos_token
self.tokenizer = tokenizer

return self


@click.group(name="dataset")
@click.option("--output",
type=str,
help="Output json filename.",
default="preprocessed_dataset.json")
@click.option("--random-seed",
required=False,
type=int,
help="random seed for token_ids",
default=420)
@click.option("--task-id", type=int, default=-1, help="LoRA task id")
@click.option("--rand-task-id",
type=int,
default=None,
nargs=2,
help="Random LoRA Tasks")
@click.option("--lora-dir",
type=str,
default=None,
help="Directory containing LoRA adapters")
@click.option("--log-level",
default="info",
type=click.Choice(["info", "debug"]),
help="Logging level.")
@click.option(
"--trust-remote-code",
is_flag=True,
default=False,
envvar="TRUST_REMOTE_CODE",
help="Trust remote code.",
)
@click.pass_context
def prepare_dataset(ctx, **kwargs):
"""Prepare dataset for benchmarking with trtllm-bench."""
model = ctx.obj.model or ctx.obj.checkpoint_path
output_path = Path(kwargs["output"])
output_path.parent.mkdir(parents=True, exist_ok=True)

ctx.obj = RootArgs(
tokenizer=model,
output=kwargs["output"],
random_seed=kwargs["random_seed"],
task_id=kwargs["task_id"],
rand_task_id=kwargs["rand_task_id"],
lora_dir=kwargs["lora_dir"],
trust_remote_code=kwargs["trust_remote_code"],
)


prepare_dataset.add_command(real_dataset)
prepare_dataset.add_command(token_norm_dist)
prepare_dataset.add_command(token_unif_dist)
Loading