44import tempfile
55import unittest
66from pathlib import Path
7+ from unittest .mock import MagicMock
78
89import pytest
910import torch
1011from utils .llm_data import llm_models_root
1112
1213from tensorrt_llm import LLM , SamplingParams
14+ from tensorrt_llm ._torch .attention_backend .trtllm import TrtllmAttentionMetadata
15+ from tensorrt_llm ._torch .metadata import KVCacheParams
1316from tensorrt_llm .llmapi import (CudaGraphConfig , EagleDecodingConfig ,
1417 KvCacheConfig )
1518
@@ -22,6 +25,72 @@ def enforce_single_worker(monkeypatch):
2225 yield
2326
2427
28+ def test_kv_lens_runtime_with_eagle3_one_model ():
29+ """
30+ Validates that kv_lens_runtime correctly excludes num_extra_kv_tokens when
31+ preparing attention metadata during EAGLE3 one-model speculative decoding.
32+
33+ Background:
34+ - EAGLE3 reserves num_extra_kv_tokens = max_draft_len - 1 in KV cache for draft token management
35+ - kv_lens_runtime becomes host_past_key_value_lengths, which eventually becomes mMaxSeqLenKv in FMHA kernel
36+ - Bug: mMaxSeqLenKv was incorrectly set to actual_kv_length + num_extra_kv_tokens
37+ - Fix: mMaxSeqLenKv should be set to actual_kv_length only (without extra tokens)
38+
39+ This test validates the fix by directly testing the prepare() logic.
40+ """
41+
42+ # Test parameters
43+ num_seqs = 3
44+ num_extra_kv_tokens = 7 # e.g., max_draft_len = 8, so extra = 7
45+ prompt_lens = [50 , 100 , 75 ] # These represent actual KV lengths
46+ seq_lens_q = [1 , 1 , 1 ] # 1 token each in generation
47+ num_cached_tokens_per_seq = [
48+ prompt_lens [i ] - seq_lens_q [i ] for i in range (num_seqs )
49+ ]
50+
51+ # Create a mock KV cache manager
52+ mock_kv_cache_manager = MagicMock ()
53+ mock_kv_cache_manager .tokens_per_block = 32
54+ mock_kv_cache_manager .num_pools = 1
55+ mock_kv_cache_manager .max_blocks_per_seq = 16
56+ mock_kv_cache_manager .max_batch_size = num_seqs
57+ mock_kv_cache_manager .max_seq_len = 512 # Large enough to hold our test sequences
58+ mock_kv_cache_manager .impl .copy_batch_block_offsets = MagicMock ()
59+
60+ attn_metadata = TrtllmAttentionMetadata (
61+ max_num_requests = num_seqs ,
62+ max_num_tokens = sum (seq_lens_q ),
63+ kv_cache_manager = mock_kv_cache_manager ,
64+ )
65+
66+ # Set required attributes
67+ attn_metadata .request_ids = list (range (1 , num_seqs + 1 ))
68+ attn_metadata .prompt_lens = prompt_lens
69+ attn_metadata ._seq_lens = torch .tensor (seq_lens_q , dtype = torch .int32 )
70+ # seq_lens_kv is the number of new KV tokens being added in this step (for generation, same as seq_lens_q)
71+ attn_metadata ._seq_lens_kv = torch .tensor (seq_lens_q , dtype = torch .int32 )
72+
73+ # Set KV cache params with num_extra_kv_tokens (EAGLE3 one-model case)
74+ attn_metadata .kv_cache_params = KVCacheParams (
75+ use_cache = True ,
76+ num_cached_tokens_per_seq = num_cached_tokens_per_seq ,
77+ num_extra_kv_tokens = num_extra_kv_tokens )
78+
79+ attn_metadata .prepare ()
80+ actual_kv_lengths = torch .tensor (prompt_lens , dtype = torch .int32 )
81+
82+ # kv_lens_runtime should equal actual KV lengths (without extra tokens)
83+ kv_lens_runtime = attn_metadata .kv_lens_runtime [:num_seqs ]
84+ assert torch .equal (kv_lens_runtime , actual_kv_lengths ), \
85+ f"kv_lens_runtime should be { actual_kv_lengths .tolist ()} , but got { kv_lens_runtime .tolist ()} "
86+
87+ # Internal kv_lens should include extra tokens
88+ kv_lens_internal = attn_metadata .kv_lens [:num_seqs ]
89+ expected_kv_lens_with_extra = actual_kv_lengths + num_extra_kv_tokens
90+ assert torch .equal (kv_lens_internal , expected_kv_lens_with_extra ), \
91+ f"kv_lens should be { expected_kv_lens_with_extra .tolist ()} , but got { kv_lens_internal .tolist ()} "
92+
93+
2594@pytest .mark .parametrize (
2695 "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp" ,
2796 [
0 commit comments