Skip to content

Commit ec8dadc

Browse files
committed
fix: fix cuda graph padding for spec decoding (#4853)
Signed-off-by: Fanrong Li <[email protected]>
1 parent 9f45e80 commit ec8dadc

File tree

5 files changed

+81
-41
lines changed

5 files changed

+81
-41
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,10 @@ def create_response(
268268
return LlmResponse(response,
269269
self.py_result) if response is not None else None
270270

271+
@property
272+
def is_dummy(self):
273+
return self.is_attention_dp_dummy or self.is_cuda_graph_dummy
274+
271275

272276
def convert_wordlist(word_list) -> List[List[int]]:
273277
"""Converts a wordlist from format:

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,13 +1114,18 @@ def _prepare_tp_inputs(
11141114
new_tokens_lens_device = new_tensors_device.new_tokens_lens # [batch]
11151115
next_draft_tokens_device = new_tensors_device.next_draft_tokens # [batch, draft_len]
11161116

1117-
# Requests with draft tokens are treated like extend requests.
1117+
# Requests with draft tokens are treated like extend requests. Dummy extend requests should be
1118+
# at the end of extend_requests.
11181119
extend_requests = []
1120+
extend_dummy_requests = []
11191121
generation_requests = []
11201122
for request in scheduled_requests.generation_requests:
11211123
if len(request.py_draft_tokens
11221124
) > 0 or next_draft_tokens_device is not None:
1123-
extend_requests.append(request)
1125+
if request.is_dummy:
1126+
extend_dummy_requests.append(request)
1127+
else:
1128+
extend_requests.append(request)
11241129
else:
11251130
generation_requests.append(request)
11261131

@@ -1130,6 +1135,7 @@ def _prepare_tp_inputs(
11301135
torch.tensor([mrope_position_deltas],
11311136
dtype=torch.int32).to('cuda',
11321137
non_blocking=True))
1138+
extend_requests += extend_dummy_requests
11331139

11341140
if not self._disable_overlap_scheduler and self.is_spec_decode:
11351141
spec_dec_mode = self.spec_config.spec_dec_mode
@@ -1139,21 +1145,18 @@ def _prepare_tp_inputs(
11391145
# will contain previous batch incices of generation requests
11401146
previous_batch_indices = []
11411147
previous_pos_indices = []
1142-
request_ids_with_previous_batch = []
1143-
num_extend_reqs_wo_previous_batch = 0
11441148
for request in extend_requests:
1145-
if next_draft_tokens_device is None or request.py_batch_idx is None:
1146-
# the request has no previous device tensors:
1147-
# (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
1148-
# (2) request.py_batch_idx is None, which means the request has no previous batch.
1149-
# the second condition includes dummy generation requests created for CUDA graph padding or
1150-
# attention DP. These dummy generation requests should be at the head of generation_requests.
1151-
# TODO: move the dummy generation requests to the end of generation_requests to align with
1152-
# the logic for those requests in generation_requests.
1153-
# get token ids, including input token ids and draft token ids
1154-
input_ids.append(request.get_last_tokens(0))
1155-
input_ids.extend(request.py_draft_tokens)
1156-
draft_tokens.extend(request.py_draft_tokens)
1149+
# the request has no previous tensor:
1150+
# (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
1151+
# (2) a dummy request; or
1152+
# (3) the first step in the generation server of disaggregated serving
1153+
if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None:
1154+
# get token ids, including input token ids and draft token ids. For these dummy requests,
1155+
# no need to copy the token ids.
1156+
if not request.is_dummy:
1157+
input_ids.append(request.get_last_tokens(0))
1158+
input_ids.extend(request.py_draft_tokens)
1159+
draft_tokens.extend(request.py_draft_tokens)
11571160
# get other ids and lengths
11581161
num_draft_tokens = len(request.py_draft_tokens)
11591162
past_seen_token_num = request.max_beam_num_tokens - 1
@@ -1173,7 +1176,6 @@ def _prepare_tp_inputs(
11731176
# update batch index
11741177
request.py_batch_idx = batch_idx
11751178
batch_idx += 1
1176-
num_extend_reqs_wo_previous_batch += 1
11771179
else:
11781180
# update batch index
11791181
previous_batch_idx = request.py_batch_idx
@@ -1200,10 +1202,7 @@ def _prepare_tp_inputs(
12001202
num_cached_tokens_per_seq.append(past_seen_token_num +
12011203
self.max_draft_len + 1)
12021204
prompt_lengths.append(request.py_prompt_len)
1203-
request_ids_with_previous_batch.append(request.py_request_id)
1204-
1205-
# move requests with previous batch to the end of the list
1206-
request_ids.extend(request_ids_with_previous_batch)
1205+
request_ids.append(request.py_request_id)
12071206

12081207
sequence_lengths.extend([1] * len(generation_requests))
12091208
gather_ids.extend(
@@ -1238,6 +1237,7 @@ def _prepare_tp_inputs(
12381237
num_tokens = len(input_ids)
12391238
num_draft_tokens = len(draft_tokens)
12401239
previous_batchs = len(previous_batch_indices)
1240+
num_requests = len(request_ids)
12411241
# if exist requests that do not have previous batch, copy input_ids and draft_tokens
12421242
if num_tokens > 0:
12431243
input_ids = torch.tensor(input_ids,
@@ -1276,31 +1276,27 @@ def _prepare_tp_inputs(
12761276
non_blocking=True)
12771277
# prepare data for the preprocess inputs
12781278
kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1
1279-
pre_tokens_start_idx = num_extend_reqs_wo_previous_batch * (
1280-
1 + self.max_draft_len)
1281-
pre_tokens_end_idx = pre_tokens_start_idx + previous_batch_tokens
1282-
pre_batch_start_idx = num_extend_reqs_wo_previous_batch
1283-
pre_batch_end_idx = pre_batch_start_idx + previous_batchs
12841279
previous_pos_indices = torch.tensor(previous_pos_indices,
12851280
dtype=torch.int,
12861281
pin_memory=True)
1287-
self.previous_pos_indices_cuda[
1288-
pre_tokens_start_idx:pre_tokens_end_idx].copy_(
1289-
previous_pos_indices, non_blocking=True)
1282+
self.previous_pos_indices_cuda[0:previous_batch_tokens].copy_(
1283+
previous_pos_indices, non_blocking=True)
12901284
self.previous_pos_id_offsets_cuda[
1291-
pre_tokens_start_idx:pre_tokens_end_idx].copy_(
1285+
0:previous_batch_tokens].copy_(
12921286
new_tokens_lens_device[self.previous_pos_indices_cuda[
1293-
pre_tokens_start_idx:pre_tokens_end_idx]],
1294-
non_blocking=True)
1295-
self.previous_kv_lens_offsets_cuda[
1296-
pre_batch_start_idx:pre_batch_end_idx].copy_(
1297-
kv_len_offsets_device[
1298-
self.previous_batch_indices_cuda[:previous_batchs]],
1287+
0:previous_batch_tokens]],
12991288
non_blocking=True)
1289+
self.previous_kv_lens_offsets_cuda[0:previous_batchs].copy_(
1290+
kv_len_offsets_device[
1291+
self.previous_batch_indices_cuda[:previous_batchs]],
1292+
non_blocking=True)
13001293
# for the requests that do not have previous batch, set the previous_pos_id_offsets and
13011294
# previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs
1302-
self.previous_pos_id_offsets_cuda[:pre_tokens_start_idx] *= 0
1303-
self.previous_kv_lens_offsets_cuda[:pre_batch_start_idx] *= 0
1295+
self.previous_pos_id_offsets_cuda[
1296+
previous_batch_tokens:num_requests *
1297+
(1 + self.max_draft_len)] *= 0
1298+
self.previous_kv_lens_offsets_cuda[
1299+
previous_batchs:num_requests] *= 0
13041300
else:
13051301
# change the data to zeros to skip the value changes in _preprocess_inputs
13061302
self.previous_pos_id_offsets_cuda *= 0

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -568,16 +568,53 @@ def test_fp8_block_scales(self, mtp_nextn, fp8kv, attention_dp, cuda_graph,
568568
task.evaluate(llm)
569569

570570
@pytest.mark.skip_device_not_contain(["H100"])
571-
def test_fp8_block_scales_cuda_graph_padding(self):
571+
@parametrize_with_ids("mtp_nextn", [0, 2])
572+
def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
572573
# OOM on H100 with default free_gpu_memory_fraction=0.9
573574
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
575+
mtp_config = None
576+
if mtp_nextn > 0:
577+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
574578
pytorch_config = PyTorchConfig(disable_overlap_scheduler=False,
575579
use_cuda_graph=True,
576580
cuda_graph_max_batch_size=512,
577581
cuda_graph_padding_enabled=True)
578582
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
579583
kv_cache_config=kv_cache_config,
580-
pytorch_backend_config=pytorch_config)
584+
pytorch_backend_config=pytorch_config,
585+
speculative_config=mtp_config)
586+
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
587+
with llm:
588+
task = MMLU(self.MODEL_NAME)
589+
task.evaluate(llm)
590+
task = GSM8K(self.MODEL_NAME)
591+
task.evaluate(llm)
592+
593+
@pytest.mark.skip_less_device(4)
594+
@pytest.mark.skip_device_not_contain(["H100", "H200"])
595+
@parametrize_with_ids("mtp_nextn", [0, 2])
596+
@parametrize_with_ids("attention_dp", [False, True])
597+
def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
598+
attention_dp):
599+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
600+
mtp_config = None
601+
if mtp_nextn > 0:
602+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
603+
pytorch_config = PyTorchConfig(
604+
disable_overlap_scheduler=False,
605+
use_cuda_graph=True,
606+
cuda_graph_padding_enabled=True,
607+
)
608+
quant_config = QuantConfig()
609+
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
610+
611+
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
612+
tensor_parallel_size=4,
613+
kv_cache_config=kv_cache_config,
614+
pytorch_backend_config=pytorch_config,
615+
quant_config=quant_config,
616+
enable_attention_dp=attention_dp,
617+
speculative_config=mtp_config)
581618
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
582619
with llm:
583620
task = CnnDailymail(self.MODEL_NAME)

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ l0_dgx_h100:
101101
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
102102
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=True]
103103
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
104+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding_4gpus[attention_dp=True-mtp_nextn=0]
105+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding_4gpus[attention_dp=True-mtp_nextn=2]
104106
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8[DeepSeek-V3-Lite-fp8]
105107
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
106108
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp[DeepSeek-V3-Lite-fp8]

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ l0_h100:
5151
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=fp8-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
5252
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
5353
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency]
54-
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding
54+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=0]
55+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=2]
5556
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
5657
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-instruct-hf-fp8-True-True]
5758
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu[DeepSeek-V3-Lite-fp8]

0 commit comments

Comments
 (0)