diff --git a/docs/source/features/sparse-attention.md b/docs/source/features/sparse-attention.md index 6ecaf4e8af6..058ee1b924a 100644 --- a/docs/source/features/sparse-attention.md +++ b/docs/source/features/sparse-attention.md @@ -1,6 +1,6 @@ # Sparse Attention -- [Motivation](#motivation) +- [Background and Motivation](#background-and-motivation) - [Quick Start](#quick-start) - [Python API](#python-api) - [Usage with trtllm-bench or trtllm-serve](#usage-with-trtllm-bench-or-trtllm-serve) @@ -21,108 +21,20 @@ As Large Language Models (LLMs) are applied to increasingly complex tasks such as long-document summarization, code generation, and autonomous agents, the demand for processing long contexts and extended generation has surged. In Transformer-based models, the attention mechanism's computational complexity and memory usage grow quadratically and linearly with sequence length, respectively. This creates significant bottlenecks in both the **Context (Prefill)** and **Generation (Decode)** phases: * **Context Phase**: Processing long prompts requires substantial memory bandwidth and computation, affecting time-to-first-token (TTFT). Since the context phase is typically compute-bound, reducing the computational load here is critical. -* **Generation Phase**: The Key-Value (KV) cache grows with every generated token, consuming vast amounts of GPU memory and bandwidth. Since the generation phase is memory-bound, reducing the memory footprint directly alleviates memory pressure, improves token-to-token latency (TPOT), and allows for larger batch sizes. +* **Generation Phase**: The Key-Value (KV) cache grows with every generated token, consuming vast amounts of GPU memory and bandwidth. Since the generation phase is usually memory-bound, reducing the memory footprint directly alleviates memory pressure, improves token-to-token latency (TPOT), and allows for larger batch sizes. -Consequently, using sparse attention to reduce overhead in both context and generation phases has attracted significant research interest. Several state-of-the-art models and techniques are evolving to minimize these overheads. Based on our research, we categorize sparse attention methods as follows: +Fortunately, key observations indicate that attention scores naturally exhibit sparsity, meaning not all K/V tokens are necessary for attention computation. To enhance the efficiency of long-sequence LLMs, numerous methods have been proposed to optimize performance by leveraging approximate sparse attention. Among those methods, sparsity can be applied to different dimensions of the attention: head dimension, hidden dimension, and sequence dimension. When applying sparsity to the sequence dimension, those methods selectively compute only the most important query-key pairs. This approach can be referred to as token sparsity. Token sparsity has been widely explored in lots of recent academic works, and it is also a kind of structured sparse method that is friendly for GPU. TensorRT LLM will focus on the sparse attention methods that leverages token sparsity. -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
ContextGenerationTraining-FreeMethods
Sparse ComputationKV Cache CompressionSparse ComputationKV Cache Compression
NoYesYesYesYesStreamingLLM
YesYesYesYesNoDuoAttention
NoNoYesYesYesH2O
NoYesNoNoYesMinference
NoNoYesNoYesQuest
YesNoNoNoYesXAttention
YesNoYesNoNoNSA,DSA
YesNoNoNoNoMoBA
NoYesYesNoYesRocketKV
-
+Token sparsity can be applied to two distinct aspects of LLM inference: +* **Sparse Computation**: If a query token does not require the entire history, just skip the computation for irrelevant tokens, thereby reducing attention computational costs. +* **Sparse KV cache**: Evicts KV tokens from the cache that are not required for future generation steps. This reduces GPU memory usage and lowers computation overhead for subsequent steps. +Both methods can be enabled simultaneously to achieve better performance. -The table above summarizes several representative sparse attention algorithms. DuoAttention, NSA and MoBA perform sparse computation in the context phase, but they require structural changes to the model and are therefore architecture-specific methods. For the other methods, we observe that most follow a pattern of performing KV cache compression in the context phase and sparse computation in the generation phase. Approaches such as StreamingLLM and H2O also dynamically compress (or evict) the KV cache during generation in addition to sparse computation, typically following a fixed pattern. Based on these observations, TensorRT LLM first focuses on supporting KV cache compression in the context phase and sparse computation in the generation phase, with RocketKV as the primary reference implementation. With the release of the DeepSeek V3.2 model that adopts sparse attention, we have also added support for this model. In the future, we plan to further explore and support sparse computation in the context phase and KV cache compression in the generation phase. +To support these emerging techniques, TensorRT LLM has designed a general, extensible Sparse Attention framework (which is continuously being optimized) to compatibly integrate advanced sparse algorithms. Currently we can support [RocketKV](https://arxiv.org/pdf/2502.14051) and [DSA](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/DeepSeek_V3_2.pdf). ## Quick Start -This section provides a brief guide on enabling sparse attention in TensorRT LLM. For a detailed walkthrough of a specific algorithm, please refer to [RocketKV sparse attention](../../examples/sparse_attention/RocketKV.md). +This section provides a brief guide on enabling sparse attention in TensorRT LLM, using RocketKV as an example. For more details, please refer to [RocketKV sparse attention](../../examples/sparse_attention/RocketKV.md). ### Python API @@ -186,7 +98,7 @@ trtllm-eval --model --extra_llm_api_options extra_config.yaml lo ## Developer Guide -This section describes the sparse attention framework architecture and guides developers on how to implement new sparse attention algorithms in TensorRT LLM. Unless otherwise specified, this framework primarily targets **MHA/MQA/GQA-based** attention mechanisms. +This section describes the sparse attention framework architecture and guides developers on how to implement new sparse attention algorithms in TensorRT LLM. Unless otherwise specified, this framework primarily targets **MQA/GQA/MLA-based** attention mechanisms. ### Architecture Overview @@ -197,30 +109,30 @@ This section describes the sparse attention framework architecture and guides de

Figure 1: The sparse attention framework in TensorRT LLM.

-Our goal is to design a generic, extensible, and flexible sparse attention framework. Figure 1 illustrates the overall design. The architecture is built by inheriting from the existing `AttentionBackend` to define algorithm-specific sparse attention backends. Within these backends, a `prediction` method is implemented to generate the corresponding sparse indices. These indices are then passed as arguments to the `AttentionOp` to perform the sparse attention computation. This approach balances system flexibility with extensibility, allowing new algorithms to be integrated by simply defining their prediction logic **without** modifying the core attention kernels. +Our goal is to design a general, extensible, and flexible sparse attention framework. In this framework, the attention operator provides the unified APIs to support both **sparse computation** and **sparse KV cache** that leverage token sparsity, while the users/developers can only focus on the algorithm of sparse attentions, i.e. how to accurately identify important query-key pairs. -TensorRT LLM abstracts sparse attention into a prediction-based workflow: *a prediction module first identifies the sparse indices (tokens/blocks to keep or attend to), which are then used by the subsequent attention operator*. Currently, for standard attention, TensorRT LLM supports **KV cache compression** in the context phase and **sparse computation** in the generation phase as mentioned above. Different KV heads are allowed to use different sparse indices, while Q heads that map to the same KV head share the same sparse pattern. It does **not** yet support sparse computation in the context phase or KV cache compression in the generation phase. +For the generalization, TensorRT LLM abstracts sparse attention into a prediction-based workflow: *a prediction module first identifies the sparse indices (tokens/blocks to keep or attend to), which are then used by the subsequent attention operator*. Currently, for standard attention (MQA/GQA), TensorRT LLM supports **sparse KV cache** in the context phase and **sparse computation** in the generation phase as mentioned above. Different KV heads are allowed to use different sparse indices, while Q heads that map to the same KV head share the same sparse pattern. It does **not** yet support sparse computation in the context phase or sparse KV cache in the generation phase. -TensorRT LLM currently supports the following operations for standard attention: +For the expansibility, figure 1 illustrates the overall design. The architecture is built by inheriting from the existing `AttentionBackend` to define algorithm-specific sparse attention backends. Within these backends, `prediction` methods are implemented to generate the corresponding sparse indices. These indices are then passed as arguments to the `AttentionOp` to perform the sparse attention computation. This approach balances system flexibility with extensibility, allowing new algorithms to be integrated by simply defining their prediction logic **without** modifying the core attention kernels. -1. **Context Phase (KV cache compression)**: - * **Goal**: Reduce the size of the KV cache populated during the context phase. - * **Mechanism**: Identify important tokens from the prompt and permanently evict non-essential tokens before entering the generation phase. +TensorRT LLM currently supports the following features: -2. **Generation Phase (sparse computation)**: - * **Goal**: Accelerate attention computation during token generation. - * **Mechanism**: For each new token, dynamically select a subset of relevant blocks/tokens from the kv cache to attend to. +1. **Context Phase**: + * **sparse computation**: only for MLA + * **sparse KV cache**: MQA/GQA -However, Multi-head Latent Attention (MLA), used by algorithms like DSA, is a special case. It currently supports sparse computation in both context and generation phases, but does not support KV cache compression. Its sparse computation implementation is handled directly within the TRTLLM-GEN MLA kernel and does not use the general pass described below. +2. **Generation Phase**: + * **sparse computation**: MLA/MQA/GQA + * **sparse KV cache**: no support yet ### Framework Implementation To hide the complexity of sparse algorithms, the main prediction logic is encapsulated within the `tensorrt_llm._torch.attention_backend` module. -We have extended the existing `AttentionBackend` to include a prediction step that retrieves sparse indices before the attention operation. The logic flow in `TrtllmAttention` is conceptually: +We have extended the existing `AttentionBackend` to include a prediction step that retrieves sparse indices before the attention operation. These indices are generated using two prediction methods: ```python -# Predict indices for KV Cache compression (context phase) +# Predict indices for sparse KV Cache (context phase) sparse_kv_indices, sparse_kv_offsets = self.sparse_kv_predict( q, k, metadata, **kwargs) @@ -244,21 +156,23 @@ The key files located in `tensorrt_llm/_torch/attention_backend/sparse/` are:

Figure 2: Sparse attention operator workflow in TensorRT LLM.

-In `AttentionOp`, as illustrated in Figure 2, for the GQA/MHA path we have implemented two kernels, `updateSparseKvCacheAfterFmha` and `gatherKvPageOffsetsKernel`, applied in the context and generation phases respectively: +In `AttentionOp`, currently, the MQA/GQA sparse attention only supports sparse computation at block granularity in the generation phase, where the block size equals to the page size of the KV cache. It means that we can skip the attention computation of those unimportant pages. In addition, we provide a sparse MLA kernel that supports token-level sparse computation in both the context and generation phases. + +To support those features, as illustrated in Figure 2, we have implemented two kernels for the MQA/GQA path, `updateSparseKvCacheAfterFmha` and `gatherKvPageOffsetsKernel`, applied in the context and generation phases respectively: -* **`updateSparseKvCacheAfterFmha`**: Invoked in the post-processing stage after the context attention computation. It performs a rewrite of the KV cache based on the selected indices, effectively implementing KV cache compression. +* **`updateSparseKvCacheAfterFmha`**: Invoked in the post-processing stage after the context attention computation. It selects the important KV tokens and write those K/V vectors to the KV cache to reduce the KV cache size. -* **`gatherKvPageOffsetsKernel`**: Executed before the attention computation in the generation phase. It converts the input sparse indices (which can be of arbitrary granularity) into page-aligned indices. It then gathers `kv_page_offsets` and updates `kv_len` to produce new metadata, which is fed into the subsequent attention kernel for computation. +* **`gatherKvPageOffsetsKernel`**: Executed before the attention computation in the generation phase. It converts the input sparse indices (which can be of arbitrary granularity) into page-aligned indices. This means that if a single token is selected, the entire page it is included in the attention computation. After this conversion, we will get a new `kv_page_offsets` and also an updated `kv_len` that is the number of those selected KV tokens. Then these new metadata are fed into the subsequent attention kernel for computation. -Currently, for GQA/MHA, sparse attention only supports sparse computation at page-size granularity in the generation phase. In addition, we provide a sparse MLA kernel that supports token-level sparse computation in both the context and generation phases. +For sparse MLA, the kernel supports token sparsity directly, eliminating the need for gatherKvPageOffsetsKernel. However, please note that sparse KV cache support is not yet available. -Many sparse attention algorithms also require additional auxiliary memory. In the current system, there are two paths to fulfill this requirement: +Many sparse attention algorithms also require additional auxiliary memory. In the current system, there are two paths to support this feature: * Implement a simple, custom CacheManager at the Python level, inheriting from `KVCacheManager`. * Use `KVCacheManagerCpp` to simultaneously manage both the KV Cache and auxiliary memory. -Each option has its own advantages and disadvantages, which we summarize below. +Each option has its own advantages and disadvantages, please refer to the [Manage Auxiliary Memory Pool](#3-manage-auxiliary-memory-pool) for more details. ### Implementing a New Algorithm @@ -277,32 +191,32 @@ class MySparseAttentionConfig(BaseSparseAttentionConfig): Create a new class inheriting from `TrtllmAttention` (in `tensorrt_llm/_torch/attention_backend/trtllm.py`). You typically need to override two main prediction methods: -**`sparse_kv_predict(self, q, k, metadata, ...)`** -* **Purpose**: Predict indices for KV cache compression during the context phase. +**`sparse_kv_predict(self, q, k, metadata, **kwargs)`** * **Behavior**: This function performs prediction to return the indices of tokens to be preserved in the KV cache. -* **Output**: `sparse_kv_indices` (tokens to keep). -* **KV Cache Update**: The system calls `updateSparseKvCacheAfterFmha` to gather the KV cache based on these indices. This effectively "compresses" the prompt's KV cache. +* **Output**: + - `sparse_kv_indices`: The token indices of the important tokens on sequence dimension, shape `(nHeads, nTokens)`, where `nHeads` is the number of KV heads and `nTokens` is the total number of selected tokens across all samples in the batch. + - `sparse_kv_offsets`: The offset for the `sparse_kv_indices`, shape `(nBatch + 1)`, where `nBatch` is the number of the batch size. The index for head `h` and sample `n` can be obtained via `sparse_kv_indices[h, sparse_kv_offsets[n]]`. * **Constraint**: Returned indices must be **sorted** to ensure safe in-place gathering in memory. Note that this post-processing "gather" step introduces some overhead, but significantly improves flexibility, allowing compatibility with features in context like chunked prefill. -**`sparse_attn_predict(self, q, k, metadata, ...)`** -* **Purpose**: Predict indices for **sparse computation** during the generation phase. -* **Behavior**: For the current query token, predict which pages/blocks in the KV cache are relevant. -* **Output**: `sparse_attn_indices` (relevant blocks/tokens). -* **KV Cache Selection**: These indices are passed to the underlying C++ attention operator. The `gatherKvPageOffsetsKernel` uses these indices to gather `kv_page_offsets` and update `kv_len`, enabling the attention kernel to perform "dense" attention on just the selected pages. +**`sparse_attn_predict(self, q, k, metadata, **kwargs)`** +* **Behavior**: For the current query tokens, predict and return the sparse indices for **sparse computation**. +* **Output**: + - `sparse_attn_indices`: The block indices of the block sparse attention on the KV sequence dimension, shape `(nHeads, nBlocks)`, where `nHeads` is the number of KV heads and `nBlocks` is the total number of selected blocks across all samples in the batch. For block sparse attention, the block size is defined by `sparse_attn_indices_block_size`, which supports arbitrary values. + - `sparse_attn_offsets`: The offset for the `sparse_attn_indices`, shape `(nBatch + 1)`, where `nBatch` is the number of the batch size. The index for head `h` and sample `n` can be obtained via `sparse_attn_indices[h, sparse_kv_offsets[n]]`. * **Constraint**: The generation phase sparse computation is supported for NVIDIA Blackwell GPUs and newer (SM 100+) using TRTLLM-GEN kernels. However, it is flexible enough to extend to different architectures. Currently, only KV cache's **page-level** granularity is supported for sparse computation. -**Note**: The prediction process can be time-consuming, especially in low-latency scenarios where it might account for a significant portion of the attention time. It is highly recommended to optimize this step using custom Triton or CUDA kernels. +**Note**: The prediction process can be time-consuming, especially in low-latency scenarios where it might account for a significant portion of the attention time. It is highly recommended to optimize this step using custom kernels. #### 3. Manage Auxiliary Memory Pool -Many sparse algorithms (like RocketKV or DSA) require auxiliary structures (e.g., a "KT cache" or "Kcache") to select relevant tokens. There are two primary ways to manage this memory in TensorRT LLM: +Many sparse algorithms (like RocketKV or DSA) require auxiliary structures (e.g., a "KT cache" in RocketKV) to select relevant tokens. There are two primary ways to manage this memory in TensorRT LLM: **Option A: Python-level Custom Manager** You can implement a custom manager in Python. * **Use Case**: Algorithms like RocketKV use this approach to store the KT cache (e.g., `RocketKVCacheManager` in `rocket.py`). * **Implementation**: Create a Python level cache manager that handles the allocation and lifecycle of the auxiliary tensors. -* **BlockManager Integration**: It is recommended to use the existing `BlockManager` to manage the auxiliary pools if possible. This allows the auxiliary pool to share block logic with the main KV cache, reducing implementation overhead. +* **BlockManager Integration**: It is recommended to use the existing `BlockManager` to manage the auxiliary pools if possible. This allows the auxiliary pool to share block manager logics, reducing implementation overhead. * **Key Methods to Override**: * `get_cache_size_per_token` / `get_cache_bytes_per_token`: Update `kv_factor` correctly to include the size of the auxiliary structures so TensorRT LLM allocates sufficient GPU memory. * `add_dummy_requests` / `prepare_resources`: Ensure the auxiliary pool allocates correct resources/tokens for new requests. @@ -312,11 +226,11 @@ You can implement a custom manager in Python. **Option B: C++ Integrated Manager** For tighter integration, you can manage the auxiliary memory within the C++ `KVCacheManager`. -* **Use Case**: Algorithms like DSA use this approach to store the Kcache. -* **Pros**: Enables compatibility with advanced features such as KV cache reuse and disagg-serving. For example, DSA's low-rank Kcache can be reused or transmitted between context and generation engines. +* **Use Case**: Algorithms like DSA use this approach to store the indexer Kcache. +* **Pros**: Enables compatibility with advanced features such as KV cache reuse and disagg-serving. For example, DSA's low-rank indexer Kcache can be reused or transmitted between context and generation engines. * **Cons**: Higher implementation complexity. The current C++ `KVCacheManager` is optimized for the standard KV cache pool. Adding custom pools often requires significant modifications or manual implementation of the pool management logic within the C++ level. -**Note**: If your algorithm involves KV cache compression, standard KV cache block reuse is generally incompatible because eviction modifies the block content uniquely for each request. However, algorithms like DSA that use low-rank approximation without eviction can support block reuse. +**Note**: If your algorithm involves sparse KV cache, standard KV cache block reuse is generally incompatible because eviction modifies the block content uniquely for each request. However, algorithms like DSA that use low-rank approximation without eviction can support block reuse. #### 4. Registration and Dispatch @@ -329,14 +243,13 @@ For tighter integration, you can manage the auxiliary memory within the C++ `KVC Currently, the status of the Sparse Attention framework is as follows: -1. **Supported Operations**: The `AttentionOp` currently supports **KV cache compression** in the context phase and **sparse computation** in the generation phase. Other combinations (for example, sparse computation in the context phase) are not yet supported for MHA/GQA. For MLA, sparse computation is supported in both the context and generation phases. +1. **Supported Operations**: The `AttentionOp` currently supports **sparse KV cache** in the context phase and **sparse computation** in the generation phase. Other combinations (for example, sparse computation in the context phase) are not yet supported for MQA/GQA. For MLA, sparse computation is supported in both the context and generation phases. 2. **Algorithm Support**: RocketKV is supported in both the vanilla (PyTorch) backend and the TRTLLM backend, while DSA is supported in the TRTLLM backend. These implementations validate the generality and flexibility of the framework. -3. **Auxiliary Memory**: Both Python-level and C++-level implementations are algorithm-specific. There is no unified abstraction for auxiliary memory management yet. ### Future Work -* **Sparse Computation in Context Phase**: We plan to introduce sparse computation support for the context phase for MHA/GQA, allowing the TensorRT LLM sparse attention framework to cover most scenarios. +* **Sparse Computation in Context Phase**: We plan to introduce sparse computation support for the context phase for MQA/GQA, allowing the TensorRT LLM sparse attention framework to cover more scenarios. * **Dynamic Eviction in Generation Phase**: Dynamically evicting KV cache blocks during the generation phase poses significant challenges to KV cache flexibility. While difficult to implement in the current framework, block-level eviction appears to be a promising compromise and is under further exploration. * **Unified Auxiliary Memory Management**: We are exploring a unified mechanism to manage auxiliary memory pools. This would allow users to define custom auxiliary spaces more flexibly while automatically inheriting advanced features from the KV cache, such as reuse and offloading. * **Code Refactoring**: As more sparse attention algorithms are integrated, the framework will undergo refactoring to unify code and improve maintainability. -* **Optimization and Feature Integration**: We are discussing further optimizations, such as enabling fine-grained token-level sparse computation for MHA/GQA. Additionally, we are exploring integration with other advanced features like Disaggregated Serving. +* **Optimization and Feature Integration**: We are discussing further optimizations, such as improving DSA performance.