Skip to content

Commit b10137f

Browse files
authored
[None][feat] Support MLA chunked prefill for DeepSeek V3.2 model (NVIDIA#9376)
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent 1bf2d75 commit b10137f

File tree

8 files changed

+751
-127
lines changed

8 files changed

+751
-127
lines changed

cpp/tensorrt_llm/kernels/mlaKernels.cu

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -230,15 +230,22 @@ __global__ void applyMLARopeAndAssignQKVKernelOptContext(T* q_ptr, T* q_pe, T* k
230230

231231
int const global_token_offset = cu_q_seqlens[batch_idx];
232232
int const cache_seq_len = kv_cache_lengths[batch_idx];
233-
int token_idx_in_kv_cache = local_token_idx;
234-
bool const valid_token = token_idx_in_kv_cache < cache_seq_len;
233+
234+
// Derive cached offset and current input length
235+
int const current_seq_len = cu_q_seqlens[batch_idx + 1] - global_token_offset;
236+
int const cached_offset = cache_seq_len - current_seq_len;
237+
238+
int token_idx_in_kv_cache = local_token_idx + cached_offset;
239+
// Check against BOTH total cache length (valid slot) AND input length (valid read)
240+
bool const valid_token = (token_idx_in_kv_cache < cache_seq_len) && (local_token_idx < current_seq_len);
241+
235242
// Limit the token_idx to cache seq length (we need all threads in this block to be involved).
236243
token_idx_in_kv_cache = std::min(token_idx_in_kv_cache, cache_seq_len - 1);
237-
local_token_idx = std::min(local_token_idx, cache_seq_len - 1);
238-
int const global_token_idx = local_token_idx + global_token_offset;
244+
int const safe_local_token_idx = std::min(local_token_idx, current_seq_len - 1);
245+
int const global_token_idx = safe_local_token_idx + global_token_offset;
239246

240247
auto const position_id
241-
= helix_position_offsets ? helix_position_offsets[global_token_idx] : local_token_idx;
248+
= helix_position_offsets ? helix_position_offsets[global_token_idx] : token_idx_in_kv_cache;
242249
float2 const* rotary_coef_cache_buffer
243250
= cos_sin_cache + static_cast<size_t>(ROPE_DIM) * position_id + (head_dim_idx / 2);
244251

@@ -317,12 +324,19 @@ __global__ void applyMLARopeAndAssignQKVKernelOptContext(T* q_ptr, T* q_pe, T* k
317324

318325
int const global_token_offset = cu_q_seqlens[batch_idx];
319326
int const cache_seq_len = kv_cache_lengths[batch_idx];
320-
int token_idx_in_kv_cache = local_token_idx;
321-
bool const valid_token = token_idx_in_kv_cache < cache_seq_len;
327+
328+
// Derive cached offset and current input length (same as first loop)
329+
int const current_seq_len = cu_q_seqlens[batch_idx + 1] - global_token_offset;
330+
int const cached_offset = cache_seq_len - current_seq_len;
331+
332+
int token_idx_in_kv_cache = local_token_idx + cached_offset;
333+
// Check against BOTH total cache length (valid slot) AND input length (valid read)
334+
bool const valid_token = (token_idx_in_kv_cache < cache_seq_len) && (local_token_idx < current_seq_len);
335+
322336
// Limit the token_idx to cache seq length (we need all threads in this block to be involved).
323337
token_idx_in_kv_cache = std::min(token_idx_in_kv_cache, cache_seq_len - 1);
324-
local_token_idx = std::min(local_token_idx, cache_seq_len - 1);
325-
int const global_token_idx = local_token_idx + global_token_offset;
338+
int const safe_local_token_idx = std::min(local_token_idx, current_seq_len - 1);
339+
int const global_token_idx = safe_local_token_idx + global_token_offset;
326340

327341
if (valid_token)
328342
{

examples/llm-api/llm_sparse_attention.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ def parse_arguments():
121121
nargs='+',
122122
type=int,
123123
default=None)
124+
parser.add_argument('--enable_chunked_prefill',
125+
default=False,
126+
action='store_true',
127+
help='Enable chunked prefill')
124128
args = parser.parse_args()
125129
return args
126130

@@ -136,6 +140,7 @@ def run_llm(args, sparse_attention_config):
136140
False, # sparse attention does not support kv cache reuse now
137141
free_gpu_memory_fraction=args.kv_cache_fraction,
138142
dtype=args.kv_cache_dtype,
143+
tokens_per_block=64,
139144
)
140145

141146
cuda_graph_config = CudaGraphConfig(
@@ -159,6 +164,7 @@ def run_llm(args, sparse_attention_config):
159164
print_iter_log=args.print_iter_log,
160165
enable_iter_perf_stats=args.print_iter_log,
161166
moe_config=MoeConfig(backend=args.moe_backend),
167+
enable_chunked_prefill=args.enable_chunked_prefill,
162168
)
163169

164170
prompts = []

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 232 additions & 114 deletions
Large diffs are not rendered by default.

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2530,6 +2530,68 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
25302530
task = GSM8K(self.MODEL_NAME)
25312531
task.evaluate(llm)
25322532

2533+
@pytest.mark.skip_less_mpi_world_size(8)
2534+
@pytest.mark.skip_less_device(8)
2535+
@skip_pre_blackwell
2536+
@pytest.mark.parametrize(
2537+
"tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend",
2538+
[
2539+
(8, 1, 8, 0, True, True, True, True, 32, "CUTLASS"),
2540+
(8, 1, 8, 3, False, False, True, True, 1, "TRTLLM"),
2541+
],
2542+
ids=["baseline_fp8kv", "latency"])
2543+
def test_nvfp4_multi_gpus_chunked_prefill(self, tp_size, pp_size, ep_size,
2544+
mtp_nextn, fp8kv, attention_dp,
2545+
cuda_graph, overlap_scheduler,
2546+
max_batch_size, moe_backend):
2547+
if moe_backend == "TRTLLM" and (get_sm_version() == 120
2548+
or get_sm_version() == 121):
2549+
pytest.skip(
2550+
"MOE TRTLLM backend does not support SM version 120 or 121")
2551+
2552+
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
2553+
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
2554+
free_gpu_memory_fraction=0.7,
2555+
tokens_per_block=64)
2556+
cuda_graph_config = CudaGraphConfig(
2557+
enable_padding=True,
2558+
max_batch_size=max_batch_size) if cuda_graph else None
2559+
pytorch_config = dict(
2560+
disable_overlap_scheduler=not overlap_scheduler,
2561+
cuda_graph_config=cuda_graph_config,
2562+
moe_config=moe_config,
2563+
)
2564+
2565+
if fp8kv:
2566+
kv_cache_config.dtype = "fp8"
2567+
mtp_config = None
2568+
if mtp_nextn > 0:
2569+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
2570+
with LLM(f"{llm_models_root()}/DeepSeek-V3.2-Exp-FP4-v2",
2571+
max_batch_size=max_batch_size,
2572+
tensor_parallel_size=tp_size,
2573+
pipeline_parallel_size=pp_size,
2574+
moe_expert_parallel_size=ep_size,
2575+
kv_cache_config=kv_cache_config,
2576+
**pytorch_config,
2577+
enable_attention_dp=attention_dp,
2578+
speculative_config=mtp_config,
2579+
enable_chunked_prefill=True,
2580+
max_num_tokens=512) as llm:
2581+
2582+
# GPQA Diamond takes too long to run, we enable it only for fp8kv.
2583+
if fp8kv:
2584+
task = GPQADiamond(self.MODEL_NAME)
2585+
task.evaluate(llm,
2586+
extra_evaluator_kwargs=dict(
2587+
apply_chat_template=True,
2588+
chat_template_kwargs=dict(thinking=True)))
2589+
else:
2590+
task = MMLU(self.MODEL_NAME)
2591+
task.evaluate(llm)
2592+
task = GSM8K(self.MODEL_NAME)
2593+
task.evaluate(llm)
2594+
25332595

25342596
@skip_pre_blackwell
25352597
class TestGLM4_6(LlmapiAccuracyTestHarness):

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,8 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baselin
501501
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1]
502502
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_fp8kv]
503503
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency]
504+
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv]
505+
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[latency]
504506
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput]
505507
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
506508
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False]

tests/integration/test_lists/qa/llm_function_core_sanity.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baselin
5656
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1]
5757
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_fp8kv]
5858
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency]
59+
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv]
60+
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[latency]
5961
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0-moe_backend=WIDEEP]
6062
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2-moe_backend=WIDEEP]
6163
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False]

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ l0_dgx_b200:
123123
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (180)
124124
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_fp8kv] TIMEOUT (180)
125125
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency] TIMEOUT (180)
126+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv] TIMEOUT (180)
126127
- condition:
127128
ranges:
128129
system_gpu_count:

0 commit comments

Comments
 (0)