Skip to content

Commit 233e38a

Browse files
committed
Add support for KVCache reuse for DSAv32
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent e77a939 commit 233e38a

File tree

3 files changed

+7
-14
lines changed

3 files changed

+7
-14
lines changed

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -876,14 +876,7 @@ void WindowBlockManager::allocatePools(bool useUvm)
876876
}
877877

878878
nvinfer1::Dims cacheShape;
879-
if (pool.containsIndexerKCache)
880-
{
881-
cacheShape = ITensor::makeShape({mNumPrimaryBlocks, pool.numLayers, blockSize});
882-
}
883-
else
884-
{
885-
cacheShape = ITensor::makeShape({mNumPrimaryBlocks, pool.numLayers, mKVFactor, blockSize});
886-
}
879+
cacheShape = ITensor::makeShape({mNumPrimaryBlocks, pool.numLayers, mKVFactor, blockSize});
887880

888881
TLLM_LOG_DEBUG("[%s] Allocating primary pool with %d blocks for %d layers with %d kv heads", mLogPrefix.c_str(),
889882
mNumPrimaryBlocks, pool.numLayers, pool.numKvHeads);

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,11 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
972972
and host_cached_tokens.sum().item() > 0
973973
and metadata.runtime_features.chunked_prefill)
974974

975+
<<<<<<< HEAD
975976
if has_mla_chunked_prefill:
977+
=======
978+
if has_mla_chunked_prefill or metadata.kv_cache_manager.enable_block_reuse:
979+
>>>>>>> a786050756 (Add support for KVCache reuse for DSAv32)
976980
# MLA chunked prefill mode: prepare single indexer chunk for current MLA chunk
977981
# The MLA has already split the sequence, we just process what's given
978982
chunk_specs = [(i, 0, host_seq_lens[i].item(),
@@ -1009,7 +1013,7 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
10091013

10101014
# Compute causal attention bounds accounting for cached KV tokens
10111015
# For chunked prefill: Q has new tokens, K has cached + new tokens
1012-
if has_mla_chunked_prefill:
1016+
if has_mla_chunked_prefill or metadata.kv_cache_manager.enable_block_reuse:
10131017
# Chunked prefill mode: adjust bounds for cached KV
10141018
host_cu_seqlen_ks, host_cu_seqlen_ke = compute_cu_seqlen_kv_bounds_with_cache(
10151019
host_seq_lens, host_cached_tokens, num_contexts,
@@ -1639,10 +1643,6 @@ def __init__(
16391643
sparse_attn_config: "SparseAttentionConfig",
16401644
**kwargs,
16411645
) -> None:
1642-
1643-
if kv_cache_config.enable_block_reuse:
1644-
raise NotImplementedError(
1645-
"DSA indexer K-cache manager does not support block reuse yet")
16461646
self.quant_block_size = 128
16471647
self.index_head_dim = sparse_attn_config.index_head_dim
16481648
# Use a fixed tokens_per_block for indexer k cache due to DG kernel constraints

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2490,7 +2490,7 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
24902490
"MOE TRTLLM backend does not support SM version 120 or 121")
24912491

24922492
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
2493-
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
2493+
kv_cache_config = KvCacheConfig(enable_block_reuse=True,
24942494
free_gpu_memory_fraction=0.7,
24952495
tokens_per_block=64)
24962496
cuda_graph_config = CudaGraphConfig(

0 commit comments

Comments
 (0)