Skip to content

Commit ca41a71

Browse files
authored
[TRTLLM-8948][test] Add long bench case (#9165)
Signed-off-by: Ivy Zhang <[email protected]>
1 parent 8e001dd commit ca41a71

File tree

4 files changed

+222
-1
lines changed

4 files changed

+222
-1
lines changed

tests/integration/defs/accuracy/accuracy_core.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,108 @@ class PassKeyRetrieval128k(AccuracyTask):
433433
MAX_OUTPUT_LEN = 50
434434

435435

436+
class LongBenchV2(AccuracyTask):
437+
DATASET = "longbench_v2"
438+
DATASET_DIR = f"{llm_models_root()}/zai-org/LongBench-v2"
439+
440+
ALPHA = 0.05
441+
BETA = 0.2
442+
SIGMA = 50.0
443+
NUM_SAMPLES = 215
444+
445+
MAX_BATCH_SIZE = 32
446+
MAX_INPUT_LEN = 1280000
447+
MAX_OUTPUT_LEN = 32000
448+
449+
EVALUATOR_CLS = tensorrt_llm.evaluate.LongBenchV2
450+
EVALUATOR_KWARGS = dict(
451+
dataset_path=DATASET_DIR,
452+
length="medium",
453+
max_len=1280000,
454+
apply_chat_template=True,
455+
random_seed=0,
456+
)
457+
458+
@staticmethod
459+
def create_modified_model_dir(original_model_dir: str,
460+
max_position_embeddings: int = 1280000,
461+
model_max_length: int = 1280000) -> str:
462+
"""
463+
Create temporary directory with modified config files for long context evaluation.
464+
465+
This method creates a temporary directory with symlinks to all model files except
466+
config files, which are copied and modified to support longer context lengths.
467+
This is useful for evaluating models on long context tasks that exceed the
468+
original model's max_position_embeddings.
469+
470+
Args:
471+
original_model_dir: Path to the original model directory
472+
max_position_embeddings: New value for max_position_embeddings in config.json
473+
model_max_length: New value for model_max_length in tokenizer_config.json
474+
475+
Returns:
476+
Path to the temporary modified model directory
477+
478+
Note:
479+
The caller is responsible for cleaning up the temporary directory after use.
480+
"""
481+
import tempfile
482+
483+
# Create temporary model directory with symlinks
484+
temp_dir = tempfile.mkdtemp(prefix="longbench_v2_modified_model_")
485+
logger.info(f"Created temporary model directory: {temp_dir}")
486+
487+
# Create symlinks for all files except config files
488+
for item in os.listdir(original_model_dir):
489+
src = os.path.join(original_model_dir, item)
490+
dst = os.path.join(temp_dir, item)
491+
492+
# Skip config files - will handle them separately
493+
if item in ["config.json", "tokenizer_config.json"]:
494+
continue
495+
496+
# Create symlink for other files/directories
497+
os.symlink(src, dst)
498+
logger.info(f" Symlinked: {item}")
499+
500+
# Modify and copy config.json
501+
config_src = os.path.join(original_model_dir, "config.json")
502+
config_dst = os.path.join(temp_dir, "config.json")
503+
if os.path.exists(config_src):
504+
with open(config_src, 'r', encoding='utf-8') as f:
505+
config = json.load(f)
506+
507+
# Modify max_position_embeddings
508+
original_max_pos = config.get('max_position_embeddings')
509+
config['max_position_embeddings'] = max_position_embeddings
510+
logger.info(
511+
f" Modified config.json: max_position_embeddings {original_max_pos} -> {max_position_embeddings}"
512+
)
513+
514+
with open(config_dst, 'w', encoding='utf-8') as f:
515+
json.dump(config, f, indent=2, ensure_ascii=False)
516+
517+
# Modify and copy tokenizer_config.json
518+
tokenizer_config_src = os.path.join(original_model_dir,
519+
"tokenizer_config.json")
520+
tokenizer_config_dst = os.path.join(temp_dir, "tokenizer_config.json")
521+
if os.path.exists(tokenizer_config_src):
522+
with open(tokenizer_config_src, 'r', encoding='utf-8') as f:
523+
tokenizer_config = json.load(f)
524+
525+
# Modify model_max_length
526+
original_max_len = tokenizer_config.get('model_max_length')
527+
tokenizer_config['model_max_length'] = model_max_length
528+
logger.info(
529+
f" Modified tokenizer_config.json: model_max_length {original_max_len} -> {model_max_length}"
530+
)
531+
532+
with open(tokenizer_config_dst, 'w', encoding='utf-8') as f:
533+
json.dump(tokenizer_config, f, indent=2, ensure_ascii=False)
534+
535+
return temp_dir
536+
537+
436538
class CliFlowAccuracyTestHarness:
437539
# Model
438540
MODEL_NAME = None
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
DeepSeek-R1-0528:
2+
- quant_algo: FP8_BLOCK_SCALES
3+
kv_cache_quant_algo: FP8
4+
spec_dec_algo: MTP
5+
accuracy: 52.093
6+
- quant_algo: NVFP4
7+
kv_cache_quant_algo: FP8
8+
spec_dec_algo: MTP
9+
accuracy: 52.093

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
skip_post_blackwell, skip_pre_ada, skip_pre_blackwell,
3434
skip_pre_hopper, skip_ray)
3535
from .accuracy_core import (GSM8K, MMLU, CnnDailymail, GPQADiamond,
36-
JsonModeEval, LlmapiAccuracyTestHarness)
36+
JsonModeEval, LlmapiAccuracyTestHarness,
37+
LongBenchV2)
3738

3839

3940
class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
@@ -4136,3 +4137,110 @@ def test_auto_dtype(self):
41364137
extra_evaluator_kwargs=dict(
41374138
apply_chat_template=True,
41384139
chat_template_kwargs=chat_template_kwargs))
4140+
4141+
4142+
@skip_pre_blackwell
4143+
@pytest.mark.skip_less_device_memory(183000)
4144+
@pytest.mark.timeout(28800)
4145+
class TestDeepSeekR1LongBenchV2(LlmapiAccuracyTestHarness):
4146+
MODEL_NAME = "DeepSeek-R1-0528"
4147+
4148+
@pytest.mark.skip_less_mpi_world_size(8)
4149+
def test_fp8_8gpus(self):
4150+
original_model_dir = f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1-0528"
4151+
if not os.path.exists(original_model_dir):
4152+
pytest.skip(f"Model directory {original_model_dir} does not exist")
4153+
4154+
temp_dir = None
4155+
try:
4156+
# Create modified model directory using LongBenchV2 static method
4157+
# This is a WAR for the fact that the model config is not modified to support long context.
4158+
# TODO: remove this once the model config is modified to support long context.
4159+
temp_dir = LongBenchV2.create_modified_model_dir(original_model_dir)
4160+
4161+
# Configure model settings
4162+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8,
4163+
enable_block_reuse=True,
4164+
enable_partial_reuse=False,
4165+
dtype="fp8")
4166+
4167+
cuda_graph_config = CudaGraphConfig(enable_padding=True,
4168+
max_batch_size=32)
4169+
4170+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=3)
4171+
4172+
moe_config = MoeConfig(backend='DEEPGEMM', max_num_tokens=32000)
4173+
4174+
pytorch_config = dict(cuda_graph_config=cuda_graph_config,
4175+
kv_cache_config=kv_cache_config,
4176+
speculative_config=mtp_config,
4177+
moe_config=moe_config,
4178+
enable_chunked_prefill=True,
4179+
enable_autotuner=True)
4180+
4181+
# Create LLM instance and evaluate
4182+
with LLM(temp_dir,
4183+
tensor_parallel_size=8,
4184+
moe_expert_parallel_size=8,
4185+
max_num_tokens=32000,
4186+
max_batch_size=32,
4187+
**pytorch_config) as llm:
4188+
4189+
task = LongBenchV2(self.MODEL_NAME)
4190+
4191+
sampling_params = SamplingParams(max_tokens=32000)
4192+
4193+
task.evaluate(llm, sampling_params=sampling_params)
4194+
4195+
finally:
4196+
# Cleanup temporary files
4197+
if temp_dir and os.path.exists(temp_dir):
4198+
import shutil
4199+
shutil.rmtree(temp_dir, ignore_errors=True)
4200+
4201+
@pytest.mark.skip_less_mpi_world_size(4)
4202+
def test_nvfp4_4gpus(self):
4203+
original_model_dir = f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1-0528-FP4"
4204+
temp_dir = None
4205+
try:
4206+
# Create modified model directory using LongBenchV2 static method
4207+
temp_dir = LongBenchV2.create_modified_model_dir(original_model_dir)
4208+
4209+
# Configure model settings (no MOE config for FP4 version)
4210+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8,
4211+
enable_block_reuse=True,
4212+
enable_partial_reuse=False,
4213+
dtype="fp8")
4214+
4215+
cuda_graph_config = CudaGraphConfig(enable_padding=True,
4216+
max_batch_size=32)
4217+
4218+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=3)
4219+
4220+
pytorch_config = dict(cuda_graph_config=cuda_graph_config,
4221+
kv_cache_config=kv_cache_config,
4222+
speculative_config=mtp_config,
4223+
enable_chunked_prefill=True,
4224+
enable_autotuner=True)
4225+
4226+
# Create LLM instance and evaluate
4227+
with LLM(temp_dir,
4228+
tensor_parallel_size=4,
4229+
moe_expert_parallel_size=4,
4230+
max_num_tokens=32000,
4231+
max_batch_size=32,
4232+
**pytorch_config) as llm:
4233+
4234+
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
4235+
4236+
task = LongBenchV2(self.MODEL_NAME)
4237+
4238+
sampling_params = SamplingParams(max_tokens=32000)
4239+
4240+
task.evaluate(llm, sampling_params=sampling_params)
4241+
4242+
finally:
4243+
# Cleanup temporary files
4244+
if temp_dir and os.path.exists(temp_dir):
4245+
import shutil
4246+
shutil.rmtree(temp_dir, ignore_errors=True)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
stress_test/stress_test.py::test_run_stress_test[DeepSeek-V3_tp8-stress_time_3600s_timeout_5400s-GUARANTEED_NO_EVICT-pytorch-stress-test-with-accuracy]
22
stress_test/stress_test.py::test_run_stress_test[DeepSeek-V3_tp8-stress_time_3600s_timeout_5400s-MAX_UTILIZATION-pytorch-stress-test-with-accuracy]
33
stress_test/stress_test.py::test_run_stress_test[DeepSeek-R1_tp8-stress_time_3600s_timeout_5400s-MAX_UTILIZATION-pytorch-stress-test-with-accuracy]
4+
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1LongBenchV2::test_fp8_8gpus
5+
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1LongBenchV2::test_nvfp4_4gpus

0 commit comments

Comments
 (0)