From 5c41e2450e1675e5bb6fbc21010b17d3ff4e382a Mon Sep 17 00:00:00 2001 From: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Date: Mon, 17 Nov 2025 08:43:49 -0800 Subject: [PATCH 1/2] Port prepare dataset to trtllm-bench. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Add MacOSX DS_Store to gitignore. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Update imports. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Update click group. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Updates to CLI. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Rename. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Add name. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Renamed real dataset command. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Change to group. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Add docstring. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Remove pass_obj. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Fix context subscription. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Updates to output. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Updates to remove stdout. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Add deprecation flag. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Code clean up. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Fix generator call. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Update prepare_dataset in docs. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Update examples. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Update testing for trtllm-bench dataset. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Remove trtllm-bench dataset from run_ex. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Add missed __init__.py Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Re-add check for dataset subcommand. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Fix execution of trtllm-bench dataset. Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> --- .gitignore | 3 + benchmarks/cpp/prepare_dataset.py | 2 +- ...practice_on_DeepSeek-R1_in_TensorRT-LLM.md | 21 +- docs/source/developer-guide/perf-analysis.md | 10 +- .../developer-guide/perf-benchmarking.md | 21 +- .../legacy/performance/perf-analysis.md | 8 +- .../legacy/performance/perf-benchmarking.md | 4 +- examples/llm-api/llm_mgmn_trtllm_bench.sh | 9 +- .../llm-api/out_of_tree_example/readme.md | 12 +- examples/models/core/deepseek_v3/README.md | 16 +- tensorrt_llm/bench/dataset/__init__.py | 0 tensorrt_llm/bench/dataset/prepare_dataset.py | 93 ++++++ .../bench/dataset/prepare_real_data.py | 305 ++++++++++++++++++ .../bench/dataset/prepare_synthetic_data.py | 104 ++++++ tensorrt_llm/bench/dataset/utils.py | 96 ++++++ tensorrt_llm/commands/bench.py | 2 + .../defs/perf/README_release_test.md | 20 +- tests/integration/defs/perf/test_perf.py | 22 +- tests/integration/defs/perf/utils.py | 4 +- tests/integration/defs/test_e2e.py | 62 ++-- .../unit/singlegpu/test_ad_trtllm_bench.py | 14 +- tests/unittest/tools/test_prepare_dataset.py | 36 ++- 22 files changed, 738 insertions(+), 126 deletions(-) create mode 100644 tensorrt_llm/bench/dataset/__init__.py create mode 100644 tensorrt_llm/bench/dataset/prepare_dataset.py create mode 100644 tensorrt_llm/bench/dataset/prepare_real_data.py create mode 100644 tensorrt_llm/bench/dataset/prepare_synthetic_data.py create mode 100644 tensorrt_llm/bench/dataset/utils.py diff --git a/.gitignore b/.gitignore index 568d4c712d6..78117704297 100644 --- a/.gitignore +++ b/.gitignore @@ -79,3 +79,6 @@ compile_commands.json # Enroot sqsh files enroot/tensorrt_llm.devel.sqsh + +# MacOSX Files +.DS_Store diff --git a/benchmarks/cpp/prepare_dataset.py b/benchmarks/cpp/prepare_dataset.py index 2f7b5516b62..3b9665fd290 100644 --- a/benchmarks/cpp/prepare_dataset.py +++ b/benchmarks/cpp/prepare_dataset.py @@ -49,7 +49,7 @@ def validate_tokenizer(self): return self -@click.group() +@click.group(deprecated=True) @click.option( "--tokenizer", required=True, diff --git a/docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md b/docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md index da72ee54649..5a63b0b944b 100644 --- a/docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md +++ b/docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md @@ -248,13 +248,13 @@ To do the benchmark, run the following command: ```bash # generate synthetic dataset -python ${YOUR_WORK_PATH}/benchmarks/cpp/prepare_dataset.py \ - --stdout \ - --tokenizer nvidia/DeepSeek-R1-FP4 \ +trtllm-bench --model nvidia/DeepSeek-R1-FP4 \ + dataset \ + --output dataset.txt \ token-norm-dist \ --input-mean 1024 --output-mean 2048 \ --input-stdev 0 --output-stdev 0 \ - --num-requests 49152 > dataset.txt + --num-requests 49152 YOUR_DATA_PATH=./dataset.txt @@ -350,13 +350,14 @@ To do the benchmark, run the following command: ```bash # generate synthetic dataset -python ${YOUR_WORK_PATH}/benchmarks/cpp/prepare_dataset.py \ - --stdout \ - --tokenizer deepseek-ai/DeepSeek-R1 \ +trtllm-bench --model nvidia/DeepSeek-R1-FP4 \ + dataset \ + --output dataset.txt \ token-norm-dist \ --input-mean 1024 --output-mean 2048 \ --input-stdev 0 --output-stdev 0 \ - --num-requests 5120 > dataset.txt + --num-requests 5120 + YOUR_DATA_PATH=./dataset.txt cat >./extra-llm-api-config.yml< /tmp/dataset.txt +trtllm-bench --model ${MODEL_PATH} \ + dataset \ + --output dataset.txt \ + token-norm-dist \ + --num-requests=${NUM_SAMPLES} \ + --input-mean=1000 --output-mean=1000 --input-stdev=0 --output-stdev=0 # Benchmark and profile TLLM_PROFILE_START_STOP=100-150 nsys profile \ diff --git a/docs/source/developer-guide/perf-benchmarking.md b/docs/source/developer-guide/perf-benchmarking.md index 6fcf8b64fed..6d89c3981c1 100644 --- a/docs/source/developer-guide/perf-benchmarking.md +++ b/docs/source/developer-guide/perf-benchmarking.md @@ -150,7 +150,7 @@ directory. For example, to generate a synthetic dataset of 1000 requests with a 128/128 for [meta-llama/Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B), run: ```shell -python benchmarks/cpp/prepare_dataset.py --stdout --tokenizer meta-llama/Llama-3.1-8B token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1000 > /tmp/synthetic_128_128.txt +trtllm-bench --model meta-llama/Llama-3.1-8B dataset --output /tmp/synthetic_128_128.txt token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1000 ``` ### Running with the PyTorch Workflow @@ -231,13 +231,13 @@ The PyTorch workflow supports benchmarking with LoRA (Low-Rank Adaptation) adapt **Preparing LoRA Dataset** -Use `prepare_dataset.py` with LoRA-specific options to generate requests with LoRA metadata: +Use `trtllm-bench dataset` with LoRA-specific options to generate requests with LoRA metadata: ```shell -python3 benchmarks/cpp/prepare_dataset.py \ - --stdout \ +trtllm-bench \ + --model /path/to/tokenizer \ + dataset \ --rand-task-id 0 1 \ - --tokenizer /path/to/tokenizer \ --lora-dir /path/to/loras \ token-norm-dist \ --num-requests 100 \ @@ -308,17 +308,18 @@ Each subdirectory should contain the LoRA adapter files for that specific task. To benchmark multi-modal models with PyTorch workflow, you can follow the similar approach as above. First, prepare the dataset: -```python -python ./benchmarks/cpp/prepare_dataset.py \ - --tokenizer Qwen/Qwen2-VL-2B-Instruct \ - --stdout \ +```bash +trtllm-bench \ + --model Qwen/Qwen2-VL-2B-Instruct \ dataset \ + --output mm_data.jsonl + real-dataset --dataset-name lmms-lab/MMMU \ --dataset-split test \ --dataset-image-key image \ --dataset-prompt-key question \ --num-requests 10 \ - --output-len-dist 128,5 > mm_data.jsonl + --output-len-dist 128,5 ``` It will download the media files to `/tmp` directory and prepare the dataset with their paths. Note that the `prompt` fields are texts and not tokenized ids. This is due to the fact that the `prompt` and the media (image/video) are processed by a preprocessor for multimodal files. diff --git a/docs/source/legacy/performance/perf-analysis.md b/docs/source/legacy/performance/perf-analysis.md index f72437f4e9f..0c50d37aa5b 100644 --- a/docs/source/legacy/performance/perf-analysis.md +++ b/docs/source/legacy/performance/perf-analysis.md @@ -66,10 +66,10 @@ Say we want to profile iterations 100 to 150 on a trtllm-bench/trtllm-serve run, #!/bin/bash # Prepare dataset for the benchmark -python3 benchmarks/cpp/prepare_dataset.py \ - --tokenizer=${MODEL_PATH} \ - --stdout token-norm-dist --num-requests=${NUM_SAMPLES} \ - --input-mean=1000 --output-mean=1000 --input-stdev=0 --output-stdev=0 > /tmp/dataset.txt +trtllm-bench \ + --model=${MODEL_PATH} dataset \ + --output /tmp/dataset.txt token-norm-dist --num-requests=${NUM_SAMPLES} \ + --input-mean=1000 --output-mean=1000 --input-stdev=0 --output-stdev=0 # Benchmark and profile TLLM_PROFILE_START_STOP=100-150 nsys profile \ diff --git a/docs/source/legacy/performance/perf-benchmarking.md b/docs/source/legacy/performance/perf-benchmarking.md index 55caef07bab..2d504a1d450 100644 --- a/docs/source/legacy/performance/perf-benchmarking.md +++ b/docs/source/legacy/performance/perf-benchmarking.md @@ -110,7 +110,7 @@ of 128:128. To run the benchmark from start to finish, run the following commands: ```shell -python benchmarks/cpp/prepare_dataset.py --stdout --tokenizer meta-llama/Llama-3.1-8B token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 3000 > /tmp/synthetic_128_128.txt +trtllm-bench --tokenizer meta-llama/Llama-3.1-8B dataset --output /tmp/synthetic_128_128.txt token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 3000 trtllm-bench --model meta-llama/Llama-3.1-8B build --dataset /tmp/synthetic_128_128.txt --quantization FP8 trtllm-bench --model meta-llama/Llama-3.1-8B throughput --dataset /tmp/synthetic_128_128.txt --engine_dir /tmp/meta-llama/Llama-3.1-8B/tp_1_pp_1 ``` @@ -207,7 +207,7 @@ directory. For example, to generate a synthetic dataset of 1000 requests with a 128/128 for [meta-llama/Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B), run: ```shell -benchmarks/cpp/prepare_dataset.py --stdout --tokenizer meta-llama/Llama-3.1-8B token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1000 > /tmp/synthetic_128_128.txt +trtllm-bench --tokenizer meta-llama/Llama-3.1-8B dataset --output /tmp/synthetic_128_128.txt token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 3000 ``` ### Building a Benchmark Engine diff --git a/examples/llm-api/llm_mgmn_trtllm_bench.sh b/examples/llm-api/llm_mgmn_trtllm_bench.sh index 43c126368dd..150618f76af 100644 --- a/examples/llm-api/llm_mgmn_trtllm_bench.sh +++ b/examples/llm-api/llm_mgmn_trtllm_bench.sh @@ -35,7 +35,6 @@ # not supported in Slurm mode, you need to download the model and put it in # the LOCAL_MODEL directory. -export prepare_dataset="$SOURCE_ROOT/benchmarks/cpp/prepare_dataset.py" export data_path="$WORKDIR/token-norm-dist.txt" echo "Preparing dataset..." @@ -50,14 +49,14 @@ srun -l \ --mpi=pmix \ bash -c " $PROLOGUE - python3 $prepare_dataset \ - --tokenizer=$LOCAL_MODEL \ - --stdout token-norm-dist \ + trtllm-bench --model=$LOCAL_MODEL dataset \ + --output $data_path \ + token-norm-dist \ --num-requests=100 \ --input-mean=128 \ --output-mean=128 \ --input-stdev=0 \ - --output-stdev=0 > $data_path + --output-stdev=0 " echo "Running benchmark..." diff --git a/examples/llm-api/out_of_tree_example/readme.md b/examples/llm-api/out_of_tree_example/readme.md index 1b26ea3cd67..e8ec7ffcc8d 100644 --- a/examples/llm-api/out_of_tree_example/readme.md +++ b/examples/llm-api/out_of_tree_example/readme.md @@ -42,7 +42,17 @@ Similar to the quickstart example, you can use the same CLI argument with `trtll Prepare the dataset: ``` -python ./benchmarks/cpp/prepare_dataset.py --tokenizer ./model_ckpt --stdout dataset --dataset-name lmms-lab/MMMU --dataset-split test --dataset-image-key image --dataset-prompt-key "question" --num-requests 100 --output-len-dist 128,5 > mm_data.jsonl +trtllm-bench \ + --model ./model_ckpt \ + dataset \ + --output mm_data.jsonl + real-dataset + --dataset-name lmms-lab/MMMU \ + --dataset-split test \ + --dataset-image-key image \ + --dataset-prompt-key question \ + --num-requests 10 \ + --output-len-dist 128,5 ``` diff --git a/examples/models/core/deepseek_v3/README.md b/examples/models/core/deepseek_v3/README.md index db88ec6ee2e..7c44bd84840 100644 --- a/examples/models/core/deepseek_v3/README.md +++ b/examples/models/core/deepseek_v3/README.md @@ -138,12 +138,13 @@ To avoid OOM (out of memory) error, you need to adjust the values of "--max_batc #### ISL-64k-OSL-1024 ```bash DS_R1_NVFP4_MODEL_PATH=/path/to/DeepSeek-R1 -python /app/tensorrt_llm/benchmarks/cpp/prepare_dataset.py \ - --stdout --tokenizer ${DS_R1_NVFP4_MODEL_PATH} \ +trtllm-bench --model ${DS_R1_NVFP4_MODEL_PATH} \ + dataset \ + --output /tmp/benchmarking_64k.txt \ token-norm-dist \ --input-mean 65536 --output-mean 1024 \ --input-stdev 0 --output-stdev 0 \ - --num-requests 24 > /tmp/benchmarking_64k.txt + --num-requests 24 cat < /tmp/extra-llm-api-config.yml cuda_graph_config: @@ -164,12 +165,13 @@ trtllm-bench -m deepseek-ai/DeepSeek-R1 --model_path ${DS_R1_NVFP4_MODEL_PATH} t #### ISL-128k-OSL-1024 ```bash DS_R1_NVFP4_MODEL_PATH=/path/to/DeepSeek-R1 -python /app/tensorrt_llm/benchmarks/cpp/prepare_dataset.py \ - --stdout --tokenizer ${DS_R1_NVFP4_MODEL_PATH} \ +trtllm-bench --model ${DS_R1_NVFP4_MODEL_PATH} \ + dataset \ + --output /tmp/benchmarking_128k.txt \ token-norm-dist \ --input-mean 131072 --output-mean 1024 \ --input-stdev 0 --output-stdev 0 \ - --num-requests 4 > /tmp/benchmarking_128k.txt + --num-requests 4 cat < /tmp/extra-llm-api-config.yml cuda_graph_config: @@ -336,7 +338,7 @@ curl http://localhost:8000/v1/completions \ }' ``` -For DeepSeek-R1 FP4, use the model name `nvidia/DeepSeek-R1-FP4-v2`. +For DeepSeek-R1 FP4, use the model name `nvidia/DeepSeek-R1-FP4-v2`. For DeepSeek-V3, use the model name `deepseek-ai/DeepSeek-V3`. ### Disaggregated Serving diff --git a/tensorrt_llm/bench/dataset/__init__.py b/tensorrt_llm/bench/dataset/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorrt_llm/bench/dataset/prepare_dataset.py b/tensorrt_llm/bench/dataset/prepare_dataset.py new file mode 100644 index 00000000000..6f024fb8f1a --- /dev/null +++ b/tensorrt_llm/bench/dataset/prepare_dataset.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path +from typing import Optional, Tuple + +import click +from pydantic import BaseModel, model_validator +from transformers import AutoTokenizer + +from tensorrt_llm.bench.dataset.prepare_real_data import real_dataset +from tensorrt_llm.bench.dataset.prepare_synthetic_data import token_norm_dist, token_unif_dist + + +class RootArgs(BaseModel): + tokenizer: str + output: str + random_seed: int + task_id: int + trust_remote_code: bool = False + rand_task_id: Optional[Tuple[int, int]] + lora_dir: Optional[str] = None + + @model_validator(mode="after") + def validate_tokenizer(self): + try: + tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer, padding_side="left", trust_remote_code=self.trust_remote_code + ) + except EnvironmentError as e: + raise ValueError( + "Cannot find a tokenizer from the given string because of " + f"{e}\nPlease set tokenizer to the directory that contains " + "the tokenizer, or set to a model name in HuggingFace." + ) + tokenizer.pad_token = tokenizer.eos_token + self.tokenizer = tokenizer + + return self + + +@click.group(name="dataset") +@click.option( + "--output", type=str, help="Output json filename.", default="preprocessed_dataset.json" +) +@click.option( + "--random-seed", required=False, type=int, help="random seed for token_ids", default=420 +) +@click.option("--task-id", type=int, default=-1, help="LoRA task id") +@click.option("--rand-task-id", type=int, default=None, nargs=2, help="Random LoRA Tasks") +@click.option("--lora-dir", type=str, default=None, help="Directory containing LoRA adapters") +@click.option( + "--log-level", default="info", type=click.Choice(["info", "debug"]), help="Logging level." +) +@click.option( + "--trust-remote-code", + is_flag=True, + default=False, + envvar="TRUST_REMOTE_CODE", + help="Trust remote code.", +) +@click.pass_context +def prepare_dataset(ctx, **kwargs): + """Prepare dataset for benchmarking with trtllm-bench.""" + model = ctx.obj.model or ctx.obj.checkpoint_path + output_path = Path(kwargs["output"]) + output_path.parent.mkdir(parents=True, exist_ok=True) + + ctx.obj = RootArgs( + tokenizer=model, + output=kwargs["output"], + random_seed=kwargs["random_seed"], + task_id=kwargs["task_id"], + rand_task_id=kwargs["rand_task_id"], + lora_dir=kwargs["lora_dir"], + trust_remote_code=kwargs["trust_remote_code"], + ) + + +prepare_dataset.add_command(real_dataset) +prepare_dataset.add_command(token_norm_dist) +prepare_dataset.add_command(token_unif_dist) diff --git a/tensorrt_llm/bench/dataset/prepare_real_data.py b/tensorrt_llm/bench/dataset/prepare_real_data.py new file mode 100644 index 00000000000..063650c926f --- /dev/null +++ b/tensorrt_llm/bench/dataset/prepare_real_data.py @@ -0,0 +1,305 @@ +import logging +import random +import re +import tempfile +from functools import partial +from typing import Optional + +import click +from datasets import load_dataset +from PIL import Image +from pydantic import BaseModel, model_validator + +from tensorrt_llm.bench.dataset.utils import ( + generate_multimodal_dataset, + generate_text_dataset, + get_norm_dist_lengths, + write_dataset_to_file, +) + + +def validate_output_len_dist(ctx, param, value): + """Validate the --output-len-dist option.""" + if value is None: + return value + m = re.match(r"(\d+),(\d+)", value) + if m: + return int(m.group(1)), int(m.group(2)) + else: + raise AssertionError( + "Incorrect specification for --output-len-dist. Correct format: " + "--output-len-dist ," + ) + + +class DatasetConfig(BaseModel): + """Dataset configurations.""" + + """Name of the dataset on HuggingFace.""" + name: str + """Config name of the dataset if existing.""" + config_name: Optional[str] = None + """Split of the dataset. Typical values: train, validation, test. Setting to None will include all splits.""" + split: Optional[str] + """The dataset dictionary used for the input sentence.""" + input_key: Optional[str] = None + """The dataset dictionary key used for the prompt of the input sentence. Must not be set when prompt is set.""" + image_key: Optional[str] = None + """The dataset dictionary key used for the images.""" + prompt_key: Optional[str] = None + """The prompt sentence to be added to the input sentence. Must not be set when prompt_key is set.""" + prompt: Optional[str] = None + """The dataset dictionary key used to derive the output sequence length. Set to None if no output key.""" + output_key: Optional[str] + + @model_validator(mode="after") + def check_prompt(self) -> "DatasetConfig": + if self.prompt_key and self.prompt: + raise AssertionError("--prompt-key and --prompt cannot be set at the same time.") + if (not self.prompt_key) and (not self.prompt): + raise AssertionError("Either --prompt-key or --prompt must be set.") + return self + + @property + def query(self): + """Generate the query for HuggingFace `datasets.load_dataset()`.""" + if self.config_name: + return [self.name, self.config_name] + else: + return [self.name] + + def get_prompt(self, req): + """Get the prompt sentence from the given request.""" + if self.prompt_key: + assert self.prompt_key in req, ( + f"Dataset {self.name} does not have key '{self.prompt_key}'. " + "Please set --prompt-key to one of the available keys: " + f"{req.keys()}" + ) + return req[self.prompt_key] + else: + return self.prompt + + def get_input(self, req): + """Get the input sentence from the given request.""" + assert self.input_key in req, ( + f"Dataset {self.name} does not have key '{self.input_key}'. " + "Please set --input-key to one of the available keys: " + f"{req.keys()}" + ) + return req[self.input_key] + + def get_images(self, req): + """Get the images from the given request.""" + image_keys = [self.image_key] + [f"{self.image_key}_{i}" for i in range(1, 8)] + assert any(key in req for key in image_keys), ( + f"Dataset {self.name} does not have key '{self.image_key}'. " + "Please set --dataset-image-key to one of the available keys: " + f"{req.keys()}" + ) + images = [] + for key in image_keys: + if key in req and req[key] is not None: + images.append(req[key]) + return images + + def get_output(self, req): + """Get the output sentence from the given request.""" + if self.output_key is None: + raise RuntimeError( + "--output-key is not set. Please either:\n" + "1. Define output length through --output-len-dist.\n" + f"2. If the dataset {self.name} has key for golden output and " + "you wish to set output length to the length of the golden " + "output, set --output-key." + ) + assert self.output_key in req, ( + f"Dataset {self.name} does not have key '{self.output_key}'. " + "Please set --output-key to one of the available keys: " + f"{req.keys()}" + ) + return req[self.output_key] + + +def load_dataset_from_hf(dataset_config: DatasetConfig): + """Load dataset from HuggingFace. + + Args: + dataset_config: A `DatasetConfig` object that defines the dataset to load. + + Returns: + Dataset iterator. + + Raises: + ValueError: When dataset loading fails due to incorrect dataset config setting. + """ + try: + dataset = iter( + load_dataset( + *dataset_config.query, + split=dataset_config.split, + streaming=True, + trust_remote_code=True, + ) + ) + except ValueError as e: + if "Config" in e: + e += "\n Please add the config name to the dataset config yaml." + elif "split" in e: + e += "\n Please specify supported split in the dataset config yaml." + raise ValueError(e) + + return dataset + + +@click.command(name="real-dataset") +@click.option("--dataset-name", required=True, type=str, help="Dataset name in HuggingFace.") +@click.option( + "--dataset-config-name", + type=str, + default=None, + help="Dataset config name in HuggingFace (if exists).", +) +@click.option("--dataset-split", type=str, required=True, help="Split of the dataset to use.") +@click.option("--dataset-input-key", type=str, help="The dataset dictionary key for input.") +@click.option( + "--dataset-image-key", type=str, default="image", help="The dataset dictionary key for images." +) +@click.option( + "--dataset-prompt-key", + type=str, + default=None, + help="The dataset dictionary key for prompt (if exists).", +) +@click.option( + "--dataset-prompt", + type=str, + default=None, + help="The prompt string when there is no prompt key for the dataset.", +) +@click.option( + "--dataset-output-key", + type=str, + default=None, + help="The dataset dictionary key for output (if exists).", +) +@click.option( + "--num-requests", + type=int, + default=None, + help="Number of requests to be generated. Will be capped to min(dataset.num_rows, num_requests).", +) +@click.option( + "--max-input-len", + type=int, + default=None, + help="Maximum input sequence length for a given request. This will be used to filter out the " + "requests with long input sequence length. Default will include all the requests.", +) +@click.option( + "--output-len-dist", + type=str, + default=None, + callback=validate_output_len_dist, + help="Output length distribution. Default will be the length of the golden output from " + "the dataset. Format: ,. E.g. 100,10 will randomize " + "the output length with mean=100 and variance=10.", +) +@click.pass_obj +def real_dataset(root_args, **kwargs): + """Prepare dataset from real dataset.""" + dataset_config = DatasetConfig( + **{k[8:]: v for k, v in kwargs.items() if k.startswith("dataset_")} + ) + + input_ids = [] + input_lens = [] + output_lens = [] + task_ids = [] + req_cnt = 0 + modality = None + multimodal_texts = [] + multimodal_image_paths = [] + for req in load_dataset_from_hf(dataset_config): + if any(key in req for key in ["image", "image_1", "video"]): + # multimodal input + if "video" in req and req["video"] is not None: + assert "Not supported yet" + assert kwargs["output_len_dist"] is not None, ( + "Output length distribution must be set for multimodal requests." + ) + modality = "image" + text = dataset_config.get_prompt(req) + images = dataset_config.get_images(req) + image_paths = [] + for image in images: + if image is not None: + if isinstance(image, str): + image_paths.append(image) + elif isinstance(image, Image.Image): + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file: + logging.debug(f"Saving image to {tmp_file.name}") + image = image.convert("RGB") + image.save(tmp_file, "JPEG") + filepath = tmp_file.name + image_paths.append(filepath) + else: + raise ValueError(f"Invalid image path: {image}") + multimodal_texts.append(text) + multimodal_image_paths.append(image_paths) + else: + # text input + prompt = dataset_config.get_prompt(req) + " " + dataset_config.get_input(req) + logging.debug(f"Input sequence: {prompt}") + line = root_args.tokenizer.encode(prompt) + if kwargs["max_input_len"] and len(line) > kwargs["max_input_len"]: + continue + input_ids.append(line) + input_lens.append(len(line)) + + # output if fetch from golden + if kwargs["output_len_dist"] is None: + output_lens.append(len(root_args.tokenizer.encode(dataset_config.get_output(req)))) + + # lora task id + task_id = root_args.task_id + if root_args.rand_task_id is not None: + min_id, max_id = root_args.rand_task_id + task_id = random.randint(min_id, max_id) + task_ids.append(task_id) + + req_cnt += 1 + if kwargs["num_requests"] and req_cnt >= kwargs["num_requests"]: + break + + if ( + kwargs["num_requests"] + and (len(input_ids) if modality is None else len(multimodal_texts)) < kwargs["num_requests"] + ): + logging.warning( + f"Number of requests={len(input_ids) if modality is None else len(multimodal_texts)} is" + f" smaller than the num-requests user set={kwargs['num_requests']}." + ) + + # output if randomized + if kwargs["output_len_dist"] is not None: + osl_mean, osl_stdev = kwargs["output_len_dist"] + output_lens = get_norm_dist_lengths( + osl_mean, + osl_stdev, + len(input_ids) if modality is None else len(multimodal_texts), + root_args.random_seed, + ) + logging.debug(f"Input lengths: {[len(i) for i in input_ids]}") + logging.debug(f"Output lengths: {output_lens}") + if modality is not None: + logging.debug(f"Modality: {modality}") + + dataset_generator = None + if modality is not None: + dataset_generator = partial( + generate_multimodal_dataset, multimodal_texts, multimodal_image_paths + ) + else: + dataset_generator = partial(generate_text_dataset, input_ids) + write_dataset_to_file(dataset_generator(output_lens), root_args.output) diff --git a/tensorrt_llm/bench/dataset/prepare_synthetic_data.py b/tensorrt_llm/bench/dataset/prepare_synthetic_data.py new file mode 100644 index 00000000000..342aa514381 --- /dev/null +++ b/tensorrt_llm/bench/dataset/prepare_synthetic_data.py @@ -0,0 +1,104 @@ +import random +import warnings + +import click + +from tensorrt_llm.bench.dataset.utils import ( + gen_random_tokens, + generate_text_dataset, + get_norm_dist_lengths, + get_unif_dist_lengths, + write_dataset_to_file, +) + + +def _generate_task_ids_and_lora_config(root_args, num_reqs): + """Generate task IDs and determine LoRA configuration based on root_args.""" + if root_args.rand_task_id is None: + task_ids = [root_args.task_id for _ in range(num_reqs)] + else: + min_id, max_id = root_args.rand_task_id + task_ids = [random.randint(min_id, max_id) for _ in range(num_reqs)] + + use_task_ids = root_args.task_id != -1 or root_args.rand_task_id is not None + + # Determine if LoRA should be used (requires both task IDs and lora_dir) + use_lora = use_task_ids and root_args.lora_dir is not None + + # Warn if task IDs are specified but no LoRA directory is provided + if use_task_ids and not use_lora: + warnings.warn( + "Task IDs require LoRA directory. Use --lora-dir or omit task IDs.", UserWarning + ) + + return ( + task_ids, + task_ids if use_task_ids else None, + {"lora_dir": root_args.lora_dir} if use_lora else None, + ) + + +@click.command() +@click.option("--num-requests", required=True, type=int, help="Number of requests to be generated") +@click.option("--input-mean", required=True, type=int, help="normal dist mean for input tokens") +@click.option("--input-stdev", required=True, type=int, help="normal dist stdev for input tokens") +@click.option("--output-mean", required=True, type=int, help="normal dist mean for output tokens") +@click.option("--output-stdev", required=True, type=int, help="normal dist stdev for output tokens") +@click.pass_obj +def token_norm_dist(root_args, **kwargs): + """Prepare synthetic dataset by generating random tokens with normal dist lengths.""" + input_ids = [] + input_lens = [] + output_lens = [] + + input_lens = get_norm_dist_lengths( + kwargs["input_mean"], kwargs["input_stdev"], kwargs["num_requests"], root_args.random_seed + ) + + num_reqs = len(input_lens) + output_lens = get_norm_dist_lengths( + kwargs["output_mean"], kwargs["output_stdev"], num_reqs, root_args.random_seed + ) + input_ids = gen_random_tokens(input_lens, root_args.tokenizer, root_args.random_seed) + _, print_task_ids, lora_config = _generate_task_ids_and_lora_config(root_args, num_reqs) + dataset_generator = generate_text_dataset( + input_ids, output_lens, task_ids=print_task_ids, lora_config=lora_config + ) + write_dataset_to_file(dataset_generator, root_args.output) + + +@click.command() +@click.option("--num-requests", required=True, type=int, help="Number of requests to be generated") +@click.option( + "--input-min", required=True, type=int, help="uniform dist (inclusive) min for input tokens" +) +@click.option( + "--input-max", required=True, type=int, help="normal dist (inclusive) max for input tokens" +) +@click.option( + "--output-min", required=True, type=int, help="normal dist (inclusive) min for output tokens" +) +@click.option( + "--output-max", required=True, type=int, help="normal dist (inclusive) max for output tokens" +) +@click.pass_obj +def token_unif_dist(root_args, **kwargs): + """Prepare synthetic dataset by generating random tokens with normal uniformly lengths.""" + input_ids = [] + input_lens = [] + output_lens = [] + + input_lens = get_unif_dist_lengths( + kwargs["input_min"], kwargs["input_max"], kwargs["num_requests"], root_args.random_seed + ) + + num_reqs = len(input_lens) + output_lens = get_unif_dist_lengths( + kwargs["output_min"], kwargs["output_max"], num_reqs, root_args.random_seed + ) + input_ids = gen_random_tokens(input_lens, root_args.tokenizer, root_args.random_seed) + _, print_task_ids, lora_config = _generate_task_ids_and_lora_config(root_args, num_reqs) + dataset_generator = generate_text_dataset( + input_ids, output_lens, task_ids=print_task_ids, lora_config=lora_config + ) + write_dataset_to_file(dataset_generator, root_args.output) diff --git a/tensorrt_llm/bench/dataset/utils.py b/tensorrt_llm/bench/dataset/utils.py new file mode 100644 index 00000000000..15c91701953 --- /dev/null +++ b/tensorrt_llm/bench/dataset/utils.py @@ -0,0 +1,96 @@ +import json +import math +import os +import random +from pathlib import Path + +import numpy as np + + +def generate_text_dataset(input_ids, output_lens, task_ids=None, lora_config=None): + for i, input_tokens in enumerate(input_ids): + d = {"task_id": i, "input_ids": input_tokens, "output_tokens": output_lens[i]} + + # Add LoRA request if task_ids indicate LoRA usage + if task_ids is not None and lora_config is not None: + task_id = task_ids[i] + if task_id != -1: # -1 means no LoRA + d["lora_request"] = { + "lora_name": f"lora_{task_id}", + "lora_int_id": task_id, + "lora_path": os.path.join(lora_config.get("lora_dir", "loras"), str(task_id)), + } + + yield json.dumps(d, separators=(",", ":"), ensure_ascii=False) + + +def generate_multimodal_dataset(multimodal_texts, multimodal_image_paths, output_lens): + for i, (text, image_paths) in enumerate(zip(multimodal_texts, multimodal_image_paths)): + d = { + "task_id": i, + "prompt": text, + "media_paths": image_paths, + "output_tokens": output_lens[i], + } + yield json.dumps(d, separators=(",", ":"), ensure_ascii=False) + + +def get_list_of_delays(delay_dist, mean_time_bet_reqs, num_reqs, random_seed): + if delay_dist == "constant": + delays = [mean_time_bet_reqs] * num_reqs + elif delay_dist == "exponential_dist": + delays = get_exponential_dist_delays(mean_time_bet_reqs, num_reqs, random_seed) + + return delays + + +def get_exponential_dist_delays(mean_time_bet_reqs, num_reqs, random_seed): + # set seed for determinism + np.random.seed(random_seed) + return np.random.exponential(mean_time_bet_reqs, num_reqs).tolist() + + +def get_norm_dist_lengths(mean, stdev, num_reqs, random_seed): + # set seed for determinism + np.random.seed(random_seed) + numbers_list = np.random.normal(loc=mean, scale=stdev, size=num_reqs).tolist() + return [max(1, math.ceil(x)) for x in numbers_list] + + +def get_unif_dist_lengths(min_len, max_len, num_reqs, random_seed): + # set seed for determinism + rng = np.random.default_rng(random_seed) + numbers = rng.integers(low=min_len, high=max_len + 1, size=num_reqs) + return numbers.tolist() + + +def gen_random_tokens(ip_lens, tokenizer, random_seed): + def get_sample_from_population(population_range, sample_size): + # random.sample can not sample a value more than once. hence the check + if sample_size < len(population_range): + sample = random.sample(population_range, sample_size) + else: + sample = random.choices(population_range, k=sample_size) + + return sample + + input_ids = [] + random.seed(random_seed) + for ip_len in ip_lens: + start_ids = get_sample_from_population(range(0, tokenizer.vocab_size), ip_len) + # Make sure it does not contain EOS token + eos_id = tokenizer.encode(tokenizer.eos_token, add_special_tokens=False) + while set(eos_id).issubset(start_ids): + tmp_id = (eos_id[0] + 1) % tokenizer.vocab_size + start_ids = [tmp_id if element == eos_id[0] else element for element in start_ids] + input_ids.append(start_ids) + + return input_ids + + +def write_dataset_to_file(dataset_generator, output_file): + output_file = Path(output_file) + os.makedirs(output_file.parent, exist_ok=True) + with open(output_file, "w") as f: + for item in dataset_generator: + f.write(item + "\n") diff --git a/tensorrt_llm/commands/bench.py b/tensorrt_llm/commands/bench.py index 4323438a4cf..2c79d21686e 100644 --- a/tensorrt_llm/commands/bench.py +++ b/tensorrt_llm/commands/bench.py @@ -6,6 +6,7 @@ from tensorrt_llm.bench.benchmark.throughput import throughput_command from tensorrt_llm.bench.build.build import build_command from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment +from tensorrt_llm.bench.dataset.prepare_dataset import prepare_dataset from tensorrt_llm.logger import logger, severity_map @@ -57,6 +58,7 @@ def main( main.add_command(build_command) main.add_command(throughput_command) main.add_command(latency_command) +main.add_command(prepare_dataset) if __name__ == "__main__": main() diff --git a/tests/integration/defs/perf/README_release_test.md b/tests/integration/defs/perf/README_release_test.md index 0fdf4eaa855..1ec6ff1cd40 100644 --- a/tests/integration/defs/perf/README_release_test.md +++ b/tests/integration/defs/perf/README_release_test.md @@ -24,27 +24,25 @@ For trtllm-bench, the test extracts the following key performance metrics from l #### Without LoRA ```python -prepare_data_script = os.path.join(self._llm_root, "benchmarks", "cpp", "prepare_dataset.py") data_cmd += [ - "python3", prepare_data_script, "--stdout", - f"--tokenizer={tokenizer_dir}", f"token-norm-dist", - f"--num-requests={self._config.num_reqs}", - f"--input-mean={input_len}", f"--output-mean={output_len}", - f"--input-stdev={istdev}", f"--output-stdev={ostdev}", - f" > {dataset_path}" + "trtllm-bench", f"--model={tokenizer_dir}", + "dataset", "--output", dataset_path, "token-norm-dist", + f"--num-requests={self._config.num_reqs}", + f"--input-mean={input_len}", f"--output-mean={output_len}", + f"--input-stdev={istdev}", f"--output-stdev={ostdev}" ] ``` #### With LoRA ```python -"python3", prepare_data_script, f"--stdout", +"trtllm-bench", f"--model={tokenizer_dir}", + "dataset", "--output", dataset_path, f"--rand-task-id 0 {nloras-1}", - f"--tokenizer={tokenizer_dir}", f"--lora-dir={lora_dir}", + f"--lora-dir={lora_dir}", f"token-norm-dist", f"--num-requests={self._config.num_reqs}", f"--input-mean={input_len}", f"--output-mean={output_len}", - f"--input-stdev={istdev}", f"--output-stdev={ostdev}", - f" > {dataset_path}" + f"--input-stdev={istdev}", f"--output-stdev={ostdev}" ``` ### 2.2 PyTorch Configuration Generation diff --git a/tests/integration/defs/perf/test_perf.py b/tests/integration/defs/perf/test_perf.py index 9b121205c6f..db6a98aa740 100644 --- a/tests/integration/defs/perf/test_perf.py +++ b/tests/integration/defs/perf/test_perf.py @@ -1176,6 +1176,12 @@ def get_prepare_data_command(self, engine_dir, input_len, "llama-7b-hf") if not os.path.exists(engine_dir): os.makedirs(engine_dir, exist_ok=True) + if self._config.num_loras > 0: + istdev = 16 + ostdev = 24 + nloras = self._config.num_loras + dataset_path = os.path.join(engine_dir, "synthetic_data.json") + if self._config.num_loras > 0: istdev = 16 ostdev = 24 @@ -1201,14 +1207,13 @@ def get_prepare_data_command(self, engine_dir, input_len, self.lora_dirs.append(f"{lora_dir}/{i}") data_cmd += [f"ln -sf {lora_path} {lora_dir}/{i}", ";"] data_cmd += [ - "python3", prepare_data_script, f"--stdout", - f"--rand-task-id 0 {nloras-1}", - f"--tokenizer={tokenizer_dir}", f"--lora-dir={lora_dir}", + "trtllm-bench", f"--model={tokenizer_dir}", "dataset", + "--output", f"{dataset_path}", + f"--rand-task-id 0 {nloras-1}", f"--lora-dir={lora_dir}", f"token-norm-dist", f"--num-requests={self._config.num_reqs}", f"--input-mean={input_len}", f"--output-mean={output_len}", - f"--input-stdev={istdev}", f"--output-stdev={ostdev}", - f" > {dataset_path}" + f"--input-stdev={istdev}", f"--output-stdev={ostdev}" ] elif self._config.backend == "cppmanager": data_cmd += [ @@ -1243,12 +1248,11 @@ def get_prepare_data_command(self, engine_dir, input_len, dataset_path = os.path.join(engine_dir, "synthetic_data.json") if self._build_script == 'trtllm-bench': data_cmd += [ - "python3", prepare_data_script, "--stdout", - f"--tokenizer={tokenizer_dir}", f"token-norm-dist", + "trtllm-bench", f"--model={tokenizer_dir}", "dataset", + "--output", f"{dataset_path}", "token-norm-dist", f"--num-requests={self._config.num_reqs}", f"--input-mean={input_len}", f"--output-mean={output_len}", - f"--input-stdev={istdev}", f"--output-stdev={ostdev}", - f" > {dataset_path}" + f"--input-stdev={istdev}", f"--output-stdev={ostdev}" ] else: data_cmd += [ diff --git a/tests/integration/defs/perf/utils.py b/tests/integration/defs/perf/utils.py index f5cfb391e37..ebc0b75f22f 100644 --- a/tests/integration/defs/perf/utils.py +++ b/tests/integration/defs/perf/utils.py @@ -475,8 +475,8 @@ def run_ex(self, self._gpu_clock_lock = gpu_clock_lock tmpDir = temp_wd(self.get_working_dir()) - is_prepare_dataset_cmd = 'prepare_dataset' in commands.get_cmd_str( - cmd_idx) + cmd_str = commands.get_cmd_str(cmd_idx) + is_prepare_dataset_cmd = 'prepare_dataset' in cmd_str or "dataset --output" in cmd_str # Start the timer. self._start_timestamp = datetime.utcnow() diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index cd7b3aa755d..673dfed25e7 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -489,16 +489,15 @@ def __call__(self): return self.run_bench() def prepare_dataset(self): - dataset_tool = Path(self.llm_root, "benchmarks", "cpp", - "prepare_dataset.py") - # Generate a small dataset to run a test. self.work_dir.mkdir(parents=True) command = [ - f"{dataset_tool.resolve()}", - "--stdout", - "--tokenizer", + "trtllm-bench", + "--model", f"{self.model_path}", + "dataset", + "--output", + f"{self.dataset_path}", "token-norm-dist", "--input-mean", "128", @@ -512,13 +511,6 @@ def prepare_dataset(self): str(self.num_requests), ] print(f"Running command: {' '.join(command)}") - dataset_output = self.llm_venv.run_cmd( - command, - caller=check_output, - ) - # Grab the stdout and write it to a dataset file for passing to suite. - with open(self.dataset_path, "w") as dataset: - dataset.write(dataset_output) def build_engine(self): if self.skip_engine_build: @@ -769,7 +761,6 @@ def trtllm_bench_prolog( stream_mode = "streaming" if streaming else "non-streaming" benchmark_name = f"trtllm-bench-sanity-{quant_name}-{stream_mode}" benchmark_name += "-pytorch-backend" if skip_engine_build else benchmark_name - dataset_tool = Path(llm_root, "benchmarks", "cpp", "prepare_dataset.py") work_dir = Path(tempfile.TemporaryDirectory().name ) if skip_engine_build else Path(engine_dir) @@ -778,29 +769,26 @@ def trtllm_bench_prolog( shutil.rmtree(work_dir, ignore_errors=True) # Generate a small dataset to run a test. work_dir.mkdir(parents=True) - dataset_output = llm_venv.run_cmd( - [ - f"{dataset_tool.resolve()}", - "--stdout", - "--tokenizer", - f"{model_path}", - "token-norm-dist", - "--input-mean", - "128", - "--output-mean", - "128", - "--input-stdev", - "0", - "--output-stdev", - "0", - "--num-requests", - "10", - ], - caller=check_output, - ) - # Grab the stdout and write it to a dataset file for passing to suite. - with open(dataset_path, "w") as dataset: - dataset.write(dataset_output) + dataset_cmd = [ + "trtllm-bench", + "--model", + f"{model_path}", + "dataset", + "--output", + f"{dataset_path}", + "token-norm-dist", + "--input-mean", + "128", + "--output-mean", + "128", + "--input-stdev", + "0", + "--output-stdev", + "0", + "--num-requests", + "10", + ] + check_output(" ".join(dataset_cmd), shell=True) if not skip_engine_build: build_cmd = \ diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py index 7c4da257bfa..f5f055c3eb2 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py @@ -43,16 +43,16 @@ def run_benchmark( def prepare_dataset(root_dir: str, temp_dir: str, model_path_or_name: str): _DATASET_NAME = "synthetic_128_128.txt" dataset_path = Path(temp_dir, _DATASET_NAME) - dataset_tool = Path(root_dir, "benchmarks", "cpp", "prepare_dataset.py") script_dir = Path(root_dir, "benchmarks", "cpp") # Generate a small dataset to run a test - matching workload configuration command = [ - "python3", - f"{dataset_tool}", - "--stdout", - "--tokenizer", + "trtllm-bench", + "--model", model_path_or_name, + "dataset", + "--output", + f"{dataset_path}", "token-norm-dist", "--input-mean", "128", @@ -71,9 +71,7 @@ def prepare_dataset(root_dir: str, temp_dir: str, model_path_or_name: str): ) if result.returncode != 0: raise RuntimeError(f"Failed to prepare dataset: {result.stderr}") - # Grab the stdout and write it to a dataset file for passing to suite. - with open(dataset_path, "w") as dataset: - dataset.write(result.stdout) + return dataset_path diff --git a/tests/unittest/tools/test_prepare_dataset.py b/tests/unittest/tools/test_prepare_dataset.py index 05da19a5957..17458ea3213 100644 --- a/tests/unittest/tools/test_prepare_dataset.py +++ b/tests/unittest/tools/test_prepare_dataset.py @@ -49,12 +49,12 @@ def temp_lora_dir(self) -> str: task_dir.mkdir(parents=True, exist_ok=True) yield str(lora_dir) - def _build_base_command(self, llm_root: Path) -> List[str]: + def _build_base_command(self, output_path: Path) -> List[str]: """ Build the base command for running prepare_dataset.py. Args: - llm_root: Path to the TensorRT LLM root directory + output_path: Path to the output dataset file Returns: List[str]: Base command components @@ -62,8 +62,7 @@ def _build_base_command(self, llm_root: Path) -> List[str]: Raises: pytest.skip: If LLM_MODELS_ROOT is not available """ - script_path = llm_root / _PREPARE_DATASET_SCRIPT_PATH - cmd = ["python3", str(script_path)] + cmd = ["trtllm-bench"] # Add required tokenizer argument model_cache = llm_models_root() @@ -71,10 +70,10 @@ def _build_base_command(self, llm_root: Path) -> List[str]: pytest.skip("LLM_MODELS_ROOT not available") tokenizer_dir = model_cache / _TOKENIZER_SUBPATH - cmd.extend(["--tokenizer", str(tokenizer_dir)]) + cmd.extend(["--model", str(tokenizer_dir)]) # Always add --stdout flag since we parse stdout output - cmd.extend(["--stdout"]) + cmd.extend(["dataset", "--output", f"{output_path}"]) return cmd @@ -110,7 +109,7 @@ def _add_synthetic_data_arguments(self, cmd: List[str]) -> None: str(_DEFAULT_OUTPUT_STDEV) ]) - def _run_prepare_dataset(self, llm_root: Path, **kwargs) -> str: + def _run_prepare_dataset(self, **kwargs) -> str: """ Execute prepare_dataset.py with specified parameters and capture output. @@ -125,13 +124,20 @@ def _run_prepare_dataset(self, llm_root: Path, **kwargs) -> str: Raises: subprocess.CalledProcessError: If the command execution fails """ - cmd = self._build_base_command(llm_root) - self._add_lora_arguments(cmd, **kwargs) - self._add_synthetic_data_arguments(cmd) + with tempfile.TemporaryDirectory() as temp_dir: + output_path = Path(temp_dir) / "dataset.jsonl" + cmd = self._build_base_command(output_path) + self._add_lora_arguments(cmd, **kwargs) + self._add_synthetic_data_arguments(cmd) + + # Execute command and capture output + subprocess.run(cmd, check=True, cwd=temp_dir) + + data = "" + with open(output_path, "r") as f: + data = f.read() - # Execute command and capture output - result = subprocess.run(cmd, capture_output=True, text=True, check=True) - return result.stdout + return data def _parse_json_output(self, output: str) -> List[Dict[str, Any]]: """ @@ -199,7 +205,7 @@ def _validate_lora_request(self, }, id="random_task_id") ]) - def test_lora_metadata_generation(self, llm_root: Path, temp_lora_dir: str, + def test_lora_metadata_generation(self, temp_lora_dir: str, test_params: Dict) -> None: """Test LoRA metadata generation with various configurations.""" # Extract test parameters @@ -214,7 +220,7 @@ def test_lora_metadata_generation(self, llm_root: Path, temp_lora_dir: str, if rand_task_id is not None: kwargs["rand_task_id"] = rand_task_id - output = self._run_prepare_dataset(llm_root, **kwargs) + output = self._run_prepare_dataset(**kwargs) json_data = self._parse_json_output(output) assert len(json_data) > 0, f"No JSON data generated for {description}" From 24a1c3429f73cec517a7e1f1fa044d3a0598fd3b Mon Sep 17 00:00:00 2001 From: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Date: Thu, 20 Nov 2025 23:00:17 -0800 Subject: [PATCH 2/2] Run pre-commit Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> --- tensorrt_llm/bench/dataset/prepare_dataset.py | 44 +++--- .../bench/dataset/prepare_real_data.py | 98 +++++++------ .../bench/dataset/prepare_synthetic_data.py | 129 +++++++++++------- tensorrt_llm/bench/dataset/utils.py | 42 ++++-- 4 files changed, 191 insertions(+), 122 deletions(-) diff --git a/tensorrt_llm/bench/dataset/prepare_dataset.py b/tensorrt_llm/bench/dataset/prepare_dataset.py index 6f024fb8f1a..b0605c52068 100644 --- a/tensorrt_llm/bench/dataset/prepare_dataset.py +++ b/tensorrt_llm/bench/dataset/prepare_dataset.py @@ -20,7 +20,8 @@ from transformers import AutoTokenizer from tensorrt_llm.bench.dataset.prepare_real_data import real_dataset -from tensorrt_llm.bench.dataset.prepare_synthetic_data import token_norm_dist, token_unif_dist +from tensorrt_llm.bench.dataset.prepare_synthetic_data import (token_norm_dist, + token_unif_dist) class RootArgs(BaseModel): @@ -36,14 +37,14 @@ class RootArgs(BaseModel): def validate_tokenizer(self): try: tokenizer = AutoTokenizer.from_pretrained( - self.tokenizer, padding_side="left", trust_remote_code=self.trust_remote_code - ) + self.tokenizer, + padding_side="left", + trust_remote_code=self.trust_remote_code) except EnvironmentError as e: raise ValueError( "Cannot find a tokenizer from the given string because of " f"{e}\nPlease set tokenizer to the directory that contains " - "the tokenizer, or set to a model name in HuggingFace." - ) + "the tokenizer, or set to a model name in HuggingFace.") tokenizer.pad_token = tokenizer.eos_token self.tokenizer = tokenizer @@ -51,18 +52,29 @@ def validate_tokenizer(self): @click.group(name="dataset") -@click.option( - "--output", type=str, help="Output json filename.", default="preprocessed_dataset.json" -) -@click.option( - "--random-seed", required=False, type=int, help="random seed for token_ids", default=420 -) +@click.option("--output", + type=str, + help="Output json filename.", + default="preprocessed_dataset.json") +@click.option("--random-seed", + required=False, + type=int, + help="random seed for token_ids", + default=420) @click.option("--task-id", type=int, default=-1, help="LoRA task id") -@click.option("--rand-task-id", type=int, default=None, nargs=2, help="Random LoRA Tasks") -@click.option("--lora-dir", type=str, default=None, help="Directory containing LoRA adapters") -@click.option( - "--log-level", default="info", type=click.Choice(["info", "debug"]), help="Logging level." -) +@click.option("--rand-task-id", + type=int, + default=None, + nargs=2, + help="Random LoRA Tasks") +@click.option("--lora-dir", + type=str, + default=None, + help="Directory containing LoRA adapters") +@click.option("--log-level", + default="info", + type=click.Choice(["info", "debug"]), + help="Logging level.") @click.option( "--trust-remote-code", is_flag=True, diff --git a/tensorrt_llm/bench/dataset/prepare_real_data.py b/tensorrt_llm/bench/dataset/prepare_real_data.py index 063650c926f..6459a5d503e 100644 --- a/tensorrt_llm/bench/dataset/prepare_real_data.py +++ b/tensorrt_llm/bench/dataset/prepare_real_data.py @@ -10,12 +10,10 @@ from PIL import Image from pydantic import BaseModel, model_validator -from tensorrt_llm.bench.dataset.utils import ( - generate_multimodal_dataset, - generate_text_dataset, - get_norm_dist_lengths, - write_dataset_to_file, -) +from tensorrt_llm.bench.dataset.utils import (generate_multimodal_dataset, + generate_text_dataset, + get_norm_dist_lengths, + write_dataset_to_file) def validate_output_len_dist(ctx, param, value): @@ -28,13 +26,11 @@ def validate_output_len_dist(ctx, param, value): else: raise AssertionError( "Incorrect specification for --output-len-dist. Correct format: " - "--output-len-dist ," - ) + "--output-len-dist ,") class DatasetConfig(BaseModel): """Dataset configurations.""" - """Name of the dataset on HuggingFace.""" name: str """Config name of the dataset if existing.""" @@ -55,7 +51,8 @@ class DatasetConfig(BaseModel): @model_validator(mode="after") def check_prompt(self) -> "DatasetConfig": if self.prompt_key and self.prompt: - raise AssertionError("--prompt-key and --prompt cannot be set at the same time.") + raise AssertionError( + "--prompt-key and --prompt cannot be set at the same time.") if (not self.prompt_key) and (not self.prompt): raise AssertionError("Either --prompt-key or --prompt must be set.") return self @@ -74,8 +71,7 @@ def get_prompt(self, req): assert self.prompt_key in req, ( f"Dataset {self.name} does not have key '{self.prompt_key}'. " "Please set --prompt-key to one of the available keys: " - f"{req.keys()}" - ) + f"{req.keys()}") return req[self.prompt_key] else: return self.prompt @@ -85,18 +81,17 @@ def get_input(self, req): assert self.input_key in req, ( f"Dataset {self.name} does not have key '{self.input_key}'. " "Please set --input-key to one of the available keys: " - f"{req.keys()}" - ) + f"{req.keys()}") return req[self.input_key] def get_images(self, req): """Get the images from the given request.""" - image_keys = [self.image_key] + [f"{self.image_key}_{i}" for i in range(1, 8)] + image_keys = [self.image_key + ] + [f"{self.image_key}_{i}" for i in range(1, 8)] assert any(key in req for key in image_keys), ( f"Dataset {self.name} does not have key '{self.image_key}'. " "Please set --dataset-image-key to one of the available keys: " - f"{req.keys()}" - ) + f"{req.keys()}") images = [] for key in image_keys: if key in req and req[key] is not None: @@ -111,13 +106,11 @@ def get_output(self, req): "1. Define output length through --output-len-dist.\n" f"2. If the dataset {self.name} has key for golden output and " "you wish to set output length to the length of the golden " - "output, set --output-key." - ) + "output, set --output-key.") assert self.output_key in req, ( f"Dataset {self.name} does not have key '{self.output_key}'. " "Please set --output-key to one of the available keys: " - f"{req.keys()}" - ) + f"{req.keys()}") return req[self.output_key] @@ -140,8 +133,7 @@ def load_dataset_from_hf(dataset_config: DatasetConfig): split=dataset_config.split, streaming=True, trust_remote_code=True, - ) - ) + )) except ValueError as e: if "Config" in e: e += "\n Please add the config name to the dataset config yaml." @@ -153,18 +145,27 @@ def load_dataset_from_hf(dataset_config: DatasetConfig): @click.command(name="real-dataset") -@click.option("--dataset-name", required=True, type=str, help="Dataset name in HuggingFace.") +@click.option("--dataset-name", + required=True, + type=str, + help="Dataset name in HuggingFace.") @click.option( "--dataset-config-name", type=str, default=None, help="Dataset config name in HuggingFace (if exists).", ) -@click.option("--dataset-split", type=str, required=True, help="Split of the dataset to use.") -@click.option("--dataset-input-key", type=str, help="The dataset dictionary key for input.") -@click.option( - "--dataset-image-key", type=str, default="image", help="The dataset dictionary key for images." -) +@click.option("--dataset-split", + type=str, + required=True, + help="Split of the dataset to use.") +@click.option("--dataset-input-key", + type=str, + help="The dataset dictionary key for input.") +@click.option("--dataset-image-key", + type=str, + default="image", + help="The dataset dictionary key for images.") @click.option( "--dataset-prompt-key", type=str, @@ -187,13 +188,15 @@ def load_dataset_from_hf(dataset_config: DatasetConfig): "--num-requests", type=int, default=None, - help="Number of requests to be generated. Will be capped to min(dataset.num_rows, num_requests).", + help= + "Number of requests to be generated. Will be capped to min(dataset.num_rows, num_requests).", ) @click.option( "--max-input-len", type=int, default=None, - help="Maximum input sequence length for a given request. This will be used to filter out the " + help= + "Maximum input sequence length for a given request. This will be used to filter out the " "requests with long input sequence length. Default will include all the requests.", ) @click.option( @@ -201,16 +204,18 @@ def load_dataset_from_hf(dataset_config: DatasetConfig): type=str, default=None, callback=validate_output_len_dist, - help="Output length distribution. Default will be the length of the golden output from " + help= + "Output length distribution. Default will be the length of the golden output from " "the dataset. Format: ,. E.g. 100,10 will randomize " "the output length with mean=100 and variance=10.", ) @click.pass_obj def real_dataset(root_args, **kwargs): """Prepare dataset from real dataset.""" - dataset_config = DatasetConfig( - **{k[8:]: v for k, v in kwargs.items() if k.startswith("dataset_")} - ) + dataset_config = DatasetConfig(**{ + k[8:]: v + for k, v in kwargs.items() if k.startswith("dataset_") + }) input_ids = [] input_lens = [] @@ -237,7 +242,8 @@ def real_dataset(root_args, **kwargs): if isinstance(image, str): image_paths.append(image) elif isinstance(image, Image.Image): - with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file: + with tempfile.NamedTemporaryFile( + suffix=".jpg", delete=False) as tmp_file: logging.debug(f"Saving image to {tmp_file.name}") image = image.convert("RGB") image.save(tmp_file, "JPEG") @@ -249,7 +255,8 @@ def real_dataset(root_args, **kwargs): multimodal_image_paths.append(image_paths) else: # text input - prompt = dataset_config.get_prompt(req) + " " + dataset_config.get_input(req) + prompt = dataset_config.get_prompt( + req) + " " + dataset_config.get_input(req) logging.debug(f"Input sequence: {prompt}") line = root_args.tokenizer.encode(prompt) if kwargs["max_input_len"] and len(line) > kwargs["max_input_len"]: @@ -259,7 +266,10 @@ def real_dataset(root_args, **kwargs): # output if fetch from golden if kwargs["output_len_dist"] is None: - output_lens.append(len(root_args.tokenizer.encode(dataset_config.get_output(req)))) + output_lens.append( + len( + root_args.tokenizer.encode( + dataset_config.get_output(req)))) # lora task id task_id = root_args.task_id @@ -272,10 +282,9 @@ def real_dataset(root_args, **kwargs): if kwargs["num_requests"] and req_cnt >= kwargs["num_requests"]: break - if ( - kwargs["num_requests"] - and (len(input_ids) if modality is None else len(multimodal_texts)) < kwargs["num_requests"] - ): + if (kwargs["num_requests"] + and (len(input_ids) if modality is None else len(multimodal_texts)) + < kwargs["num_requests"]): logging.warning( f"Number of requests={len(input_ids) if modality is None else len(multimodal_texts)} is" f" smaller than the num-requests user set={kwargs['num_requests']}." @@ -297,9 +306,8 @@ def real_dataset(root_args, **kwargs): dataset_generator = None if modality is not None: - dataset_generator = partial( - generate_multimodal_dataset, multimodal_texts, multimodal_image_paths - ) + dataset_generator = partial(generate_multimodal_dataset, + multimodal_texts, multimodal_image_paths) else: dataset_generator = partial(generate_text_dataset, input_ids) write_dataset_to_file(dataset_generator(output_lens), root_args.output) diff --git a/tensorrt_llm/bench/dataset/prepare_synthetic_data.py b/tensorrt_llm/bench/dataset/prepare_synthetic_data.py index 342aa514381..dc8de4aab10 100644 --- a/tensorrt_llm/bench/dataset/prepare_synthetic_data.py +++ b/tensorrt_llm/bench/dataset/prepare_synthetic_data.py @@ -3,13 +3,11 @@ import click -from tensorrt_llm.bench.dataset.utils import ( - gen_random_tokens, - generate_text_dataset, - get_norm_dist_lengths, - get_unif_dist_lengths, - write_dataset_to_file, -) +from tensorrt_llm.bench.dataset.utils import (gen_random_tokens, + generate_text_dataset, + get_norm_dist_lengths, + get_unif_dist_lengths, + write_dataset_to_file) def _generate_task_ids_and_lora_config(root_args, num_reqs): @@ -28,22 +26,39 @@ def _generate_task_ids_and_lora_config(root_args, num_reqs): # Warn if task IDs are specified but no LoRA directory is provided if use_task_ids and not use_lora: warnings.warn( - "Task IDs require LoRA directory. Use --lora-dir or omit task IDs.", UserWarning - ) + "Task IDs require LoRA directory. Use --lora-dir or omit task IDs.", + UserWarning) return ( task_ids, task_ids if use_task_ids else None, - {"lora_dir": root_args.lora_dir} if use_lora else None, + { + "lora_dir": root_args.lora_dir + } if use_lora else None, ) @click.command() -@click.option("--num-requests", required=True, type=int, help="Number of requests to be generated") -@click.option("--input-mean", required=True, type=int, help="normal dist mean for input tokens") -@click.option("--input-stdev", required=True, type=int, help="normal dist stdev for input tokens") -@click.option("--output-mean", required=True, type=int, help="normal dist mean for output tokens") -@click.option("--output-stdev", required=True, type=int, help="normal dist stdev for output tokens") +@click.option("--num-requests", + required=True, + type=int, + help="Number of requests to be generated") +@click.option("--input-mean", + required=True, + type=int, + help="normal dist mean for input tokens") +@click.option("--input-stdev", + required=True, + type=int, + help="normal dist stdev for input tokens") +@click.option("--output-mean", + required=True, + type=int, + help="normal dist mean for output tokens") +@click.option("--output-stdev", + required=True, + type=int, + help="normal dist stdev for output tokens") @click.pass_obj def token_norm_dist(root_args, **kwargs): """Prepare synthetic dataset by generating random tokens with normal dist lengths.""" @@ -51,36 +66,47 @@ def token_norm_dist(root_args, **kwargs): input_lens = [] output_lens = [] - input_lens = get_norm_dist_lengths( - kwargs["input_mean"], kwargs["input_stdev"], kwargs["num_requests"], root_args.random_seed - ) + input_lens = get_norm_dist_lengths(kwargs["input_mean"], + kwargs["input_stdev"], + kwargs["num_requests"], + root_args.random_seed) num_reqs = len(input_lens) - output_lens = get_norm_dist_lengths( - kwargs["output_mean"], kwargs["output_stdev"], num_reqs, root_args.random_seed - ) - input_ids = gen_random_tokens(input_lens, root_args.tokenizer, root_args.random_seed) - _, print_task_ids, lora_config = _generate_task_ids_and_lora_config(root_args, num_reqs) - dataset_generator = generate_text_dataset( - input_ids, output_lens, task_ids=print_task_ids, lora_config=lora_config - ) + output_lens = get_norm_dist_lengths(kwargs["output_mean"], + kwargs["output_stdev"], num_reqs, + root_args.random_seed) + input_ids = gen_random_tokens(input_lens, root_args.tokenizer, + root_args.random_seed) + _, print_task_ids, lora_config = _generate_task_ids_and_lora_config( + root_args, num_reqs) + dataset_generator = generate_text_dataset(input_ids, + output_lens, + task_ids=print_task_ids, + lora_config=lora_config) write_dataset_to_file(dataset_generator, root_args.output) @click.command() -@click.option("--num-requests", required=True, type=int, help="Number of requests to be generated") -@click.option( - "--input-min", required=True, type=int, help="uniform dist (inclusive) min for input tokens" -) -@click.option( - "--input-max", required=True, type=int, help="normal dist (inclusive) max for input tokens" -) -@click.option( - "--output-min", required=True, type=int, help="normal dist (inclusive) min for output tokens" -) -@click.option( - "--output-max", required=True, type=int, help="normal dist (inclusive) max for output tokens" -) +@click.option("--num-requests", + required=True, + type=int, + help="Number of requests to be generated") +@click.option("--input-min", + required=True, + type=int, + help="uniform dist (inclusive) min for input tokens") +@click.option("--input-max", + required=True, + type=int, + help="normal dist (inclusive) max for input tokens") +@click.option("--output-min", + required=True, + type=int, + help="normal dist (inclusive) min for output tokens") +@click.option("--output-max", + required=True, + type=int, + help="normal dist (inclusive) max for output tokens") @click.pass_obj def token_unif_dist(root_args, **kwargs): """Prepare synthetic dataset by generating random tokens with normal uniformly lengths.""" @@ -88,17 +114,20 @@ def token_unif_dist(root_args, **kwargs): input_lens = [] output_lens = [] - input_lens = get_unif_dist_lengths( - kwargs["input_min"], kwargs["input_max"], kwargs["num_requests"], root_args.random_seed - ) + input_lens = get_unif_dist_lengths(kwargs["input_min"], kwargs["input_max"], + kwargs["num_requests"], + root_args.random_seed) num_reqs = len(input_lens) - output_lens = get_unif_dist_lengths( - kwargs["output_min"], kwargs["output_max"], num_reqs, root_args.random_seed - ) - input_ids = gen_random_tokens(input_lens, root_args.tokenizer, root_args.random_seed) - _, print_task_ids, lora_config = _generate_task_ids_and_lora_config(root_args, num_reqs) - dataset_generator = generate_text_dataset( - input_ids, output_lens, task_ids=print_task_ids, lora_config=lora_config - ) + output_lens = get_unif_dist_lengths(kwargs["output_min"], + kwargs["output_max"], num_reqs, + root_args.random_seed) + input_ids = gen_random_tokens(input_lens, root_args.tokenizer, + root_args.random_seed) + _, print_task_ids, lora_config = _generate_task_ids_and_lora_config( + root_args, num_reqs) + dataset_generator = generate_text_dataset(input_ids, + output_lens, + task_ids=print_task_ids, + lora_config=lora_config) write_dataset_to_file(dataset_generator, root_args.output) diff --git a/tensorrt_llm/bench/dataset/utils.py b/tensorrt_llm/bench/dataset/utils.py index 15c91701953..cb727065f50 100644 --- a/tensorrt_llm/bench/dataset/utils.py +++ b/tensorrt_llm/bench/dataset/utils.py @@ -7,25 +7,38 @@ import numpy as np -def generate_text_dataset(input_ids, output_lens, task_ids=None, lora_config=None): +def generate_text_dataset(input_ids, + output_lens, + task_ids=None, + lora_config=None): for i, input_tokens in enumerate(input_ids): - d = {"task_id": i, "input_ids": input_tokens, "output_tokens": output_lens[i]} + d = { + "task_id": i, + "input_ids": input_tokens, + "output_tokens": output_lens[i] + } # Add LoRA request if task_ids indicate LoRA usage if task_ids is not None and lora_config is not None: task_id = task_ids[i] if task_id != -1: # -1 means no LoRA d["lora_request"] = { - "lora_name": f"lora_{task_id}", - "lora_int_id": task_id, - "lora_path": os.path.join(lora_config.get("lora_dir", "loras"), str(task_id)), + "lora_name": + f"lora_{task_id}", + "lora_int_id": + task_id, + "lora_path": + os.path.join(lora_config.get("lora_dir", "loras"), + str(task_id)), } yield json.dumps(d, separators=(",", ":"), ensure_ascii=False) -def generate_multimodal_dataset(multimodal_texts, multimodal_image_paths, output_lens): - for i, (text, image_paths) in enumerate(zip(multimodal_texts, multimodal_image_paths)): +def generate_multimodal_dataset(multimodal_texts, multimodal_image_paths, + output_lens): + for i, (text, image_paths) in enumerate( + zip(multimodal_texts, multimodal_image_paths)): d = { "task_id": i, "prompt": text, @@ -39,7 +52,8 @@ def get_list_of_delays(delay_dist, mean_time_bet_reqs, num_reqs, random_seed): if delay_dist == "constant": delays = [mean_time_bet_reqs] * num_reqs elif delay_dist == "exponential_dist": - delays = get_exponential_dist_delays(mean_time_bet_reqs, num_reqs, random_seed) + delays = get_exponential_dist_delays(mean_time_bet_reqs, num_reqs, + random_seed) return delays @@ -53,7 +67,8 @@ def get_exponential_dist_delays(mean_time_bet_reqs, num_reqs, random_seed): def get_norm_dist_lengths(mean, stdev, num_reqs, random_seed): # set seed for determinism np.random.seed(random_seed) - numbers_list = np.random.normal(loc=mean, scale=stdev, size=num_reqs).tolist() + numbers_list = np.random.normal(loc=mean, scale=stdev, + size=num_reqs).tolist() return [max(1, math.ceil(x)) for x in numbers_list] @@ -65,6 +80,7 @@ def get_unif_dist_lengths(min_len, max_len, num_reqs, random_seed): def gen_random_tokens(ip_lens, tokenizer, random_seed): + def get_sample_from_population(population_range, sample_size): # random.sample can not sample a value more than once. hence the check if sample_size < len(population_range): @@ -77,12 +93,16 @@ def get_sample_from_population(population_range, sample_size): input_ids = [] random.seed(random_seed) for ip_len in ip_lens: - start_ids = get_sample_from_population(range(0, tokenizer.vocab_size), ip_len) + start_ids = get_sample_from_population(range(0, tokenizer.vocab_size), + ip_len) # Make sure it does not contain EOS token eos_id = tokenizer.encode(tokenizer.eos_token, add_special_tokens=False) while set(eos_id).issubset(start_ids): tmp_id = (eos_id[0] + 1) % tokenizer.vocab_size - start_ids = [tmp_id if element == eos_id[0] else element for element in start_ids] + start_ids = [ + tmp_id if element == eos_id[0] else element + for element in start_ids + ] input_ids.append(start_ids) return input_ids