Skip to content

Commit 7c4344b

Browse files
authored
[https://nvbugs/5590408][fix] Exclude num of draft tokens from mMaxSeqLenKv (NVIDIA#9210)
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 3ac11a6 commit 7c4344b

File tree

2 files changed

+73
-8
lines changed

2 files changed

+73
-8
lines changed

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,10 @@ def prepare(self) -> None:
881881
) <= self.kv_cache_manager.max_seq_len, error_message
882882

883883
self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs]
884-
self.kv_lens_runtime = self.kv_lens[:self.num_seqs]
884+
# Don't use self.kv_lens here because it includes extra tokens.
885+
# Use actual KV length (without extra tokens) for kv_lens_runtime,
886+
# which becomes host_past_key_value_lengths and eventually mMaxSeqLenKv.
887+
self.kv_lens_runtime = kv_lens[:self.num_seqs]
885888
self.prompt_lens_cuda_runtime = self.prompt_lens_cuda[:self.num_seqs]
886889
self.prompt_lens_cpu_runtime = self.prompt_lens_cpu[:self.num_seqs]
887890
self.host_request_types_runtime = self.host_request_types[:self.
@@ -898,13 +901,6 @@ def prepare_flash_mla(self) -> None:
898901
self.block_ids_per_seq[:self.num_generations, :num_blocks].copy_(
899902
block_ids_per_seq[self.num_contexts:], non_blocking=True)
900903

901-
self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs]
902-
self.kv_lens_runtime = self.kv_lens[:self.num_seqs]
903-
self.prompt_lens_cuda_runtime = self.prompt_lens_cuda[:self.num_seqs]
904-
self.prompt_lens_cpu_runtime = self.prompt_lens_cpu[:self.num_seqs]
905-
self.host_request_types_runtime = self.host_request_types[:self.
906-
num_seqs]
907-
908904
def pre_process_for_chunked_prefill(
909905
self,
910906
chunked_seq_len: torch.Tensor,

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
import tempfile
55
import unittest
66
from pathlib import Path
7+
from unittest.mock import MagicMock
78

89
import pytest
910
import torch
1011
from utils.llm_data import llm_models_root
1112

1213
from tensorrt_llm import LLM, SamplingParams
14+
from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttentionMetadata
15+
from tensorrt_llm._torch.metadata import KVCacheParams
1316
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
1417
KvCacheConfig)
1518

@@ -22,6 +25,72 @@ def enforce_single_worker(monkeypatch):
2225
yield
2326

2427

28+
def test_kv_lens_runtime_with_eagle3_one_model():
29+
"""
30+
Validates that kv_lens_runtime correctly excludes num_extra_kv_tokens when
31+
preparing attention metadata during EAGLE3 one-model speculative decoding.
32+
33+
Background:
34+
- EAGLE3 reserves num_extra_kv_tokens = max_draft_len - 1 in KV cache for draft token management
35+
- kv_lens_runtime becomes host_past_key_value_lengths, which eventually becomes mMaxSeqLenKv in FMHA kernel
36+
- Bug: mMaxSeqLenKv was incorrectly set to actual_kv_length + num_extra_kv_tokens
37+
- Fix: mMaxSeqLenKv should be set to actual_kv_length only (without extra tokens)
38+
39+
This test validates the fix by directly testing the prepare() logic.
40+
"""
41+
42+
# Test parameters
43+
num_seqs = 3
44+
num_extra_kv_tokens = 7 # e.g., max_draft_len = 8, so extra = 7
45+
prompt_lens = [50, 100, 75] # These represent actual KV lengths
46+
seq_lens_q = [1, 1, 1] # 1 token each in generation
47+
num_cached_tokens_per_seq = [
48+
prompt_lens[i] - seq_lens_q[i] for i in range(num_seqs)
49+
]
50+
51+
# Create a mock KV cache manager
52+
mock_kv_cache_manager = MagicMock()
53+
mock_kv_cache_manager.tokens_per_block = 32
54+
mock_kv_cache_manager.num_pools = 1
55+
mock_kv_cache_manager.max_blocks_per_seq = 16
56+
mock_kv_cache_manager.max_batch_size = num_seqs
57+
mock_kv_cache_manager.max_seq_len = 512 # Large enough to hold our test sequences
58+
mock_kv_cache_manager.impl.copy_batch_block_offsets = MagicMock()
59+
60+
attn_metadata = TrtllmAttentionMetadata(
61+
max_num_requests=num_seqs,
62+
max_num_tokens=sum(seq_lens_q),
63+
kv_cache_manager=mock_kv_cache_manager,
64+
)
65+
66+
# Set required attributes
67+
attn_metadata.request_ids = list(range(1, num_seqs + 1))
68+
attn_metadata.prompt_lens = prompt_lens
69+
attn_metadata._seq_lens = torch.tensor(seq_lens_q, dtype=torch.int32)
70+
# seq_lens_kv is the number of new KV tokens being added in this step (for generation, same as seq_lens_q)
71+
attn_metadata._seq_lens_kv = torch.tensor(seq_lens_q, dtype=torch.int32)
72+
73+
# Set KV cache params with num_extra_kv_tokens (EAGLE3 one-model case)
74+
attn_metadata.kv_cache_params = KVCacheParams(
75+
use_cache=True,
76+
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
77+
num_extra_kv_tokens=num_extra_kv_tokens)
78+
79+
attn_metadata.prepare()
80+
actual_kv_lengths = torch.tensor(prompt_lens, dtype=torch.int32)
81+
82+
# kv_lens_runtime should equal actual KV lengths (without extra tokens)
83+
kv_lens_runtime = attn_metadata.kv_lens_runtime[:num_seqs]
84+
assert torch.equal(kv_lens_runtime, actual_kv_lengths), \
85+
f"kv_lens_runtime should be {actual_kv_lengths.tolist()}, but got {kv_lens_runtime.tolist()}"
86+
87+
# Internal kv_lens should include extra tokens
88+
kv_lens_internal = attn_metadata.kv_lens[:num_seqs]
89+
expected_kv_lens_with_extra = actual_kv_lengths + num_extra_kv_tokens
90+
assert torch.equal(kv_lens_internal, expected_kv_lens_with_extra), \
91+
f"kv_lens should be {expected_kv_lens_with_extra.tolist()}, but got {kv_lens_internal.tolist()}"
92+
93+
2594
@pytest.mark.parametrize(
2695
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp",
2796
[

0 commit comments

Comments
 (0)