diff --git a/tensorrt_llm/_torch/attention_backend/sparse/kernel.py b/tensorrt_llm/_torch/attention_backend/sparse/kernel.py index 6ed4ffaa717..f9eed5e9563 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/kernel.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/kernel.py @@ -3,6 +3,97 @@ import torch import triton import triton.language as tl +import triton.language.core as core +from triton.language.standard import _log2, sum, zeros_like + +######################################################## +# Argsort utilities for topk operations +# Adapted from https://github.com/triton-lang/triton/issues/3698#issuecomment-2067681396 +######################################################## + + +@triton.jit +def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr): + n_outer: core.constexpr = x.numel >> n_dims + shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)] + y = core.reshape(x, shape) + # slice left/right with 'stride' 2**(n_dims - i - 1) + mask = core.arange(0, 2)[None, :, None] + left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape) + right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape) + left = core.reshape(left, x.shape) + right = core.reshape(right, x.shape) + + # idx + y_idx = core.reshape(ids, shape) + left_idx = core.broadcast_to(sum(y_idx * (1 - mask), 1)[:, None, :], shape) + right_idx = core.broadcast_to(sum(y_idx * mask, 1)[:, None, :], shape) + left_idx = core.reshape(left_idx, x.shape) + right_idx = core.reshape(right_idx, x.shape) + + # actual compare-and-swap + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, + signed=True) + ileft = left.to(idtype, bitcast=True) + iright = right.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + + cond = (left > right) ^ flip + + ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix)) + + new_ids = ids ^ core.where(cond, left_idx ^ right_idx, zeros_like(ids)) + + return ret.to(x.dtype, bitcast=True), new_ids + + +@triton.jit +def _bitonic_merge(x, ids, stage: core.constexpr, order: core.constexpr, + n_dims: core.constexpr): + ''' + order_type 0 == ascending + order_type 1 == descending + order_type 2 == alternating + ''' + n_outer: core.constexpr = x.numel >> n_dims + core.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if order == 2: + shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage] + # Create boolean flip pattern instead of integer + flip = core.reshape( + core.broadcast_to(core.arange(0, 2)[None, :, None], shape), + x.shape) != 0 + else: + # Ensure flip is boolean for XOR operations + flip = order != 0 + # perform `stage` rounds of `compare-and-swap` + for i in core.static_range(stage): + x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) + return x, ids + + +@triton.jit +def argsort(x, + ids, + dim: core.constexpr = None, + descending: core.constexpr = core.CONSTEXPR_0): + # handle default dimension or check that it is the most minor dim + _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim + core.static_assert(_dim == len(x.shape) - 1, + "only minor dimension is currently supported") + # iteratively run bitonic merge-sort steps + n_dims: core.constexpr = _log2(x.shape[_dim]) + + for i in core.static_range(1, n_dims + 1): + x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, + n_dims) + return x, ids + ######################################################## # Index gather kernel @@ -270,16 +361,15 @@ def fused_qk_split(input_tensor: torch.Tensor, num_heads: int, ######################################################## -# BMM softmax kernel +# BMM kernel ######################################################## @triton.jit -def bmm_softmax_kernel( +def bmm_kernel( q_ptr, k_ptr, scores_ptr, - m_i_stored_ptr, q_cu_seqlens_ptr, k_cu_seqlens_ptr, total_q_tokens, @@ -296,15 +386,13 @@ def bmm_softmax_kernel( BLOCK_K: tl.constexpr, ): """ - Optimized bmm softmax kernel for computing softmax(QK^T) with online softmax.. + Optimized BMM kernel for computing QK^T with tiled matrix multiplication. + Inspired by mm_demo.py optimization techniques. Args: q_ptr: Query tensor [num_q_heads, total_q_tokens, head_dim] k_ptr: Key tensor [num_kv_heads, total_k_tokens, head_dim] scores_ptr: Output tensor [num_q_heads, q_len_per_seq, total_k_tokens] - where q_len_per_seq = total_q_tokens // batch_size (uniform seq assumption) - m_i_stored_ptr: Tensor to store m_i_new values [num_q_heads, q_len_per_seq, total_k_tokens] - for correct final normalization while maintaining numerical stability BLOCK_M: Query block size (compile-time constant) BLOCK_N: Key block size (compile-time constant) BLOCK_K: Head dimension block size for tiled matmul (compile-time constant) @@ -336,10 +424,6 @@ def bmm_softmax_kernel( q_mask = q_offsets < q_seqlen q_global_offsets = q_seq_start + q_offsets - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float( - "inf") # Running max - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # Running sum - for k_block_start in tl.range(0, k_seqlen, BLOCK_N): k_offsets = k_block_start + tl.arange(0, BLOCK_N) k_mask = k_offsets < k_seqlen @@ -348,6 +432,7 @@ def bmm_softmax_kernel( # Initialize QK^T accumulator for this (M, N) block qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # Tiled matrix multiplication following mm_demo.py pattern for k_dim_start in tl.range(0, head_dim, BLOCK_K): k_dim_offsets = k_dim_start + tl.arange(0, BLOCK_K) k_dim_mask = k_dim_offsets < head_dim @@ -366,11 +451,13 @@ def bmm_softmax_kernel( mask=k_mask[:, None] & k_dim_mask[None, :], other=0.0) + # Accumulate QK^T using tl.dot for better performance qk += tl.dot(q_chunk, tl.trans(k_chunk)) # Scale the accumulated QK^T qk = qk * sm_scale + # Apply masking valid_mask = q_mask[:, None] & k_mask[None, :] if causal: # Create causal mask based on positions within this batch's sequence @@ -381,82 +468,24 @@ def bmm_softmax_kernel( else: qk = tl.where(valid_mask, qk, float("-inf")) - # Online softmax update - m_ij = tl.max(qk, 1) # Max across keys [BLOCK_M] - m_i_new = tl.maximum(m_i, m_ij) - - # Rescale previous sum - alpha = tl.exp(m_i - m_i_new) - l_i = l_i * alpha - - # Add contribution from current block - p = tl.exp( - qk - - m_i_new[:, None]) # [BLOCK_M, BLOCK_N] - numerically stable - - l_ij = tl.sum(p, 1) # Sum across keys [BLOCK_M] - l_i = l_i + l_ij - - # Update running max - m_i = m_i_new - - # Vectorized output index calculation + # Store results - note that we store in the global k position space + # This matches the original bmm_softmax kernel behavior output_indices = (head_idx * q_len_per_seq * total_k_tokens + q_offsets[:, None] * total_k_tokens + k_global_offsets[None, :]) # [BLOCK_M, BLOCK_N] - # Store exp(qk - m_i_new) for numerical stability - tl.store(scores_ptr + output_indices, p, mask=valid_mask) - - # Store corresponding m_i_new values for each position, this is needed for correct final normalization - tl.store(m_i_stored_ptr + output_indices, - m_i_new[:, None], - mask=valid_mask) - - # Perform normalization for this q_block only after all k_blocks are processed - for k_block_start in tl.range(0, k_seqlen, BLOCK_N): - k_offsets = k_block_start + tl.arange(0, BLOCK_N) - k_mask = k_offsets < k_seqlen - k_global_offsets = k_seq_start + k_offsets + tl.store(scores_ptr + output_indices, qk, mask=valid_mask) - valid_mask = q_mask[:, None] & k_mask[None, :] - output_indices = (head_idx * q_len_per_seq * total_k_tokens + - q_offsets[:, None] * total_k_tokens + - k_global_offsets[None, :]) - - # Load current scores exp(qk - m_i_new_block) - stored_scores = tl.load(scores_ptr + output_indices, - mask=valid_mask, - other=0.0) - - # Load the stored m_i_new values for each position - stored_m_i_new = tl.load(m_i_stored_ptr + output_indices, - mask=valid_mask, - other=float("-inf")) - - # Apply correct normalization: - correction_factor = tl.exp(stored_m_i_new - m_i[:, None]) - - normalized_scores = tl.where( - valid_mask, stored_scores * correction_factor / l_i[:, None], - tl.zeros_like(stored_scores)) - - # Store normalized scores - tl.store(scores_ptr + output_indices, - normalized_scores, - mask=valid_mask) - - -def bmm_softmax(q: torch.Tensor, - k: torch.Tensor, - q_cu_seqlens: torch.Tensor, - k_cu_seqlens: torch.Tensor, - batch_size: int, - sm_scale: float = None, - causal: bool = False) -> torch.Tensor: +def bmm(q: torch.Tensor, + k: torch.Tensor, + q_cu_seqlens: torch.Tensor, + k_cu_seqlens: torch.Tensor, + batch_size: int, + sm_scale: float = None, + causal: bool = False) -> torch.Tensor: """ - Compute softmax(QK^T) using optimized bmm softmax algorithm with tiled matrix multiplication. + Compute softmax(QK^T) using separated BMM and Softmax kernels. Args: q: Query tensor [num_q_heads, total_q_tokens, head_dim] @@ -469,8 +498,6 @@ def bmm_softmax(q: torch.Tensor, Returns: scores: Attention scores [num_q_heads, q_len_per_seq, total_k_tokens] - where q_len_per_seq = total_q_tokens // batch_size - Each batch's results are concatenated along the last dimension """ num_q_heads, total_q_tokens, head_dim = q.shape num_k_heads, total_k_tokens, _ = k.shape @@ -486,27 +513,21 @@ def bmm_softmax(q: torch.Tensor, if sm_scale is None: sm_scale = 1.0 / math.sqrt(head_dim) - # Create output tensor with correct shape: [num_heads, q_len_per_seq, total_k_tokens] - scores = torch.empty((num_q_heads, q_len_per_seq, total_k_tokens), - dtype=torch.float32, - device=q.device) - - # Create tensor to store m_i_new values for each position - m_i_stored = torch.empty((num_q_heads, q_len_per_seq, total_k_tokens), - dtype=torch.float32, - device=q.device) + bmm_results = torch.empty((num_q_heads, q_len_per_seq, total_k_tokens), + dtype=torch.float32, + device=q.device) + # BMM kernel configuration BLOCK_M = 32 - BLOCK_N = 512 + BLOCK_N = 256 BLOCK_K = 64 - grid = lambda meta: (batch_size, num_q_heads) + grid_bmm = lambda meta: (batch_size, num_q_heads) - bmm_softmax_kernel[grid]( + bmm_kernel[grid_bmm]( q, k, - scores, - m_i_stored, + bmm_results, q_cu_seqlens, k_cu_seqlens, total_q_tokens, @@ -523,132 +544,19 @@ def bmm_softmax(q: torch.Tensor, BLOCK_K=BLOCK_K, ) - return scores + return bmm_results ######################################################## -# Separated BMM and Softmax kernels +# Softmax kernel ######################################################## -@triton.jit -def bmm_kernel( - q_ptr, - k_ptr, - scores_ptr, - q_cu_seqlens_ptr, - k_cu_seqlens_ptr, - total_q_tokens, - total_k_tokens, - head_dim, - batch_size, - num_q_heads, - num_k_heads, - q_len_per_seq, - sm_scale, - causal, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - """ - Optimized BMM kernel for computing QK^T with tiled matrix multiplication. - Inspired by mm_demo.py optimization techniques. - - Args: - q_ptr: Query tensor [num_q_heads, total_q_tokens, head_dim] - k_ptr: Key tensor [num_kv_heads, total_k_tokens, head_dim] - scores_ptr: Output tensor [num_q_heads, q_len_per_seq, total_k_tokens] - BLOCK_M: Query block size (compile-time constant) - BLOCK_N: Key block size (compile-time constant) - BLOCK_K: Head dimension block size for tiled matmul (compile-time constant) - """ - - batch_idx = tl.program_id(0) - head_idx = tl.program_id(1) - - if batch_idx >= batch_size or head_idx >= num_q_heads: - return - - # Continuous mapping of query heads to key heads - k_head_idx = head_idx // (num_q_heads // num_k_heads) - - q_seq_start = tl.load(q_cu_seqlens_ptr + batch_idx) - q_seq_end = tl.load(q_cu_seqlens_ptr + batch_idx + 1) - k_seq_start = tl.load(k_cu_seqlens_ptr + batch_idx) - k_seq_end = tl.load(k_cu_seqlens_ptr + batch_idx + 1) - - q_seqlen = q_seq_end - q_seq_start - k_seqlen = k_seq_end - k_seq_start - - if q_seqlen <= 0 or k_seqlen <= 0: - return - - # Process queries in this batch with BLOCK_M parallelization - for q_block_start in tl.range(0, q_seqlen, BLOCK_M): - q_offsets = q_block_start + tl.arange(0, BLOCK_M) - q_mask = q_offsets < q_seqlen - q_global_offsets = q_seq_start + q_offsets - - for k_block_start in tl.range(0, k_seqlen, BLOCK_N): - k_offsets = k_block_start + tl.arange(0, BLOCK_N) - k_mask = k_offsets < k_seqlen - k_global_offsets = k_seq_start + k_offsets - - # Initialize QK^T accumulator for this (M, N) block - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - - # Tiled matrix multiplication following mm_demo.py pattern - for k_dim_start in tl.range(0, head_dim, BLOCK_K): - k_dim_offsets = k_dim_start + tl.arange(0, BLOCK_K) - k_dim_mask = k_dim_offsets < head_dim - - # Load query chunk [BLOCK_M, BLOCK_K] - q_indices = head_idx * total_q_tokens * head_dim + q_global_offsets[:, None] * head_dim + k_dim_offsets[ - None, :] - q_chunk = tl.load(q_ptr + q_indices, - mask=q_mask[:, None] & k_dim_mask[None, :], - other=0.0) - - # Load key chunk [BLOCK_N, BLOCK_K] - k_indices = k_head_idx * total_k_tokens * head_dim + k_global_offsets[:, None] * head_dim + k_dim_offsets[ - None, :] - k_chunk = tl.load(k_ptr + k_indices, - mask=k_mask[:, None] & k_dim_mask[None, :], - other=0.0) - - # Accumulate QK^T using tl.dot for better performance - qk += tl.dot(q_chunk, tl.trans(k_chunk)) - - # Scale the accumulated QK^T - qk = qk * sm_scale - - # Apply masking - valid_mask = q_mask[:, None] & k_mask[None, :] - if causal: - # Create causal mask based on positions within this batch's sequence - q_pos_in_seq = q_offsets[:, None] # [BLOCK_M, 1] - k_pos_in_seq = k_offsets[None, :] # [1, BLOCK_N] - causal_mask = q_pos_in_seq >= k_pos_in_seq - qk = tl.where(causal_mask & valid_mask, qk, float("-inf")) - else: - qk = tl.where(valid_mask, qk, float("-inf")) - - # Store results - note that we store in the global k position space - # This matches the original bmm_softmax kernel behavior - output_indices = (head_idx * q_len_per_seq * total_k_tokens + - q_offsets[:, None] * total_k_tokens + - k_global_offsets[None, :]) # [BLOCK_M, BLOCK_N] - - tl.store(scores_ptr + output_indices, qk, mask=valid_mask) - - @triton.jit def softmax_kernel_batched( input_ptr, output_ptr, - q_cu_seqlens_ptr, - k_cu_seqlens_ptr, + cu_seq_lens_ptr, batch_size, num_heads, q_len_per_seq, @@ -673,8 +581,8 @@ def softmax_kernel_batched( return # Get k sequence boundaries for this batch - k_seq_start = tl.load(k_cu_seqlens_ptr + batch_idx) - k_seq_end = tl.load(k_cu_seqlens_ptr + batch_idx + 1) + k_seq_start = tl.load(cu_seq_lens_ptr + batch_idx) + k_seq_end = tl.load(cu_seq_lens_ptr + batch_idx + 1) k_seqlen = k_seq_end - k_seq_start if k_seqlen <= 0: @@ -730,96 +638,47 @@ def softmax_kernel_batched( tl.store(output_ptr + output_indices, softmax_values, mask=k_mask) -def separated_bmm_softmax(q: torch.Tensor, - k: torch.Tensor, - q_cu_seqlens: torch.Tensor, - k_cu_seqlens: torch.Tensor, - batch_size: int, - sm_scale: float = None, - causal: bool = False) -> torch.Tensor: +def triton_softmax( + input_tensor: torch.Tensor, + cum_lens: torch.Tensor, + batch_size: int, +) -> torch.Tensor: """ - Compute softmax(QK^T) using separated BMM and Softmax kernels. + Apply softmax to KT token scores with batch-aware sequence boundaries. Args: - q: Query tensor [num_q_heads, total_q_tokens, head_dim] - k: Key tensor [num_kv_heads, total_k_tokens, head_dim] - q_cu_seqlens: Query cumulative sequence lengths [batch_size + 1] - k_cu_seqlens: Key cumulative sequence lengths [batch_size + 1] - batch_size: Number of batches - sm_scale: Scaling factor (default: 1/sqrt(head_dim)) - causal: Whether to apply causal masking + input_tensor: Input tensor [total_num_heads, 1, total_kt_tokens] + cum_lens: Cumulative lengths [batch_size + 1] + batch_size: Number of generation batches Returns: - scores: Attention scores [num_q_heads, q_len_per_seq, total_k_tokens] + output: Softmax results [total_num_heads, 1, total_kt_tokens] """ - num_q_heads, total_q_tokens, head_dim = q.shape - num_k_heads, total_k_tokens, _ = k.shape - - assert total_q_tokens % batch_size == 0, "total_q_tokens must be divisible by batch_size" - q_len_per_seq = total_q_tokens // batch_size - - if total_k_tokens == 0: - return torch.zeros((num_q_heads, q_len_per_seq, total_k_tokens), - dtype=torch.float32, - device=q.device) - - if sm_scale is None: - sm_scale = 1.0 / math.sqrt(head_dim) - - # Step 1: BMM to compute QK^T - raw_scores = torch.empty((num_q_heads, q_len_per_seq, total_k_tokens), - dtype=torch.float32, - device=q.device) - - # BMM kernel configuration - BLOCK_M = 32 - BLOCK_N = 256 - BLOCK_K = 64 - - grid_bmm = lambda meta: (batch_size, num_q_heads) - - bmm_kernel[grid_bmm]( - q, - k, - raw_scores, - q_cu_seqlens, - k_cu_seqlens, - total_q_tokens, - total_k_tokens, - head_dim, - batch_size, - num_q_heads, - num_k_heads, - q_len_per_seq, - sm_scale, - causal, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_K=BLOCK_K, - ) + total_num_heads, q_len_per_seq, total_kt_tokens = input_tensor.shape - # Step 2: Apply softmax with sequence length awareness - final_scores = torch.empty_like(raw_scores) + # Create output tensor + output = torch.empty_like(input_tensor, + dtype=torch.float32, + device=input_tensor.device) # Softmax kernel configuration - SOFTMAX_BLOCK_SIZE = 1024 + BLOCK_SIZE = 1024 - # Grid: (num_heads, batch_size * q_len_per_seq) - grid_softmax = lambda meta: (num_q_heads, batch_size * q_len_per_seq) + # Grid: (total_num_heads, batch_size) + grid = (total_num_heads, batch_size * q_len_per_seq) - softmax_kernel_batched[grid_softmax]( - raw_scores, - final_scores, - q_cu_seqlens, - k_cu_seqlens, + softmax_kernel_batched[grid]( + input_tensor, + output, + cum_lens, batch_size, - num_q_heads, + total_num_heads, q_len_per_seq, - total_k_tokens, - BLOCK_SIZE=SOFTMAX_BLOCK_SIZE, + total_kt_tokens, + BLOCK_SIZE=BLOCK_SIZE, ) - return final_scores + return output ######################################################## @@ -1085,148 +944,10 @@ def flatten_sparse_indices( ######################################################## -# Sparse KT cache update kernel +# KT cache update kernel ######################################################## -@triton.jit -def sparse_update_kt_cache_kernel(qkv_ptr, kt_cache_tensor_ptr, - sparse_kv_indices_ptr, sparse_kv_offsets_ptr, - kt_cache_slots_ptr, num_heads, num_kv_heads, - head_dim, page_size, num_pages_per_block, - num_total_tokens, num_total_sparse_tokens, - batch_size, BLOCK_SIZE: tl.constexpr): - """ - Batched sparse KT cache update kernel for RocketKV algorithm. - - Args: - qkv_ptr: Input QKV tensor [num_total_tokens, num_heads*head_dim + num_kv_heads*head_dim + num_kv_heads*head_dim] - kt_cache_tensor_ptr: KT cache tensor [max_batch_size, num_kv_heads, 2*head_dim, num_pages_per_block] - sparse_kv_indices_ptr: Sparse indices [num_kv_heads, num_total_sparse_tokens] - sparse_kv_offsets_ptr: Sparse offsets [batch_size + 1] - kt_cache_slots_ptr: Cache slot indices for each batch [batch_size] - num_heads: Number of Q heads - num_kv_heads: Number of KV heads - head_dim: Head dimension - page_size: Page size for grouping tokens - num_pages_per_block: Maximum pages per block in cache - batch_size: Number of batches to process - """ - batch_idx = tl.program_id(0) - head_idx = tl.program_id(1) - page_idx = tl.program_id(2) - - if batch_idx >= batch_size or head_idx >= num_kv_heads: - return - - # Get cache slot for current batch - cache_slot = tl.load(kt_cache_slots_ptr + batch_idx) - - # Get sparse token range for current batch - sparse_start = tl.load(sparse_kv_offsets_ptr + batch_idx) - sparse_end = tl.load(sparse_kv_offsets_ptr + batch_idx + 1) - num_sparse_tokens = sparse_end - sparse_start - - if num_sparse_tokens <= 0: - return - - # Calculate page boundaries - page_token_start = page_idx * page_size - page_token_end = tl.minimum(page_token_start + page_size, num_sparse_tokens) - - if page_token_start >= num_sparse_tokens or page_idx >= num_pages_per_block: - return - - # Load existing cache values using cache_slot - cache_base = cache_slot * num_kv_heads * 2 * head_dim * num_pages_per_block + head_idx * 2 * head_dim * num_pages_per_block - - # Process head dimensions in blocks using BLOCK_SIZE - for dim_block_start in tl.range(0, head_dim, BLOCK_SIZE): - # Calculate dimension offsets for current block - dim_offsets = dim_block_start + tl.arange(0, BLOCK_SIZE) - dim_mask = dim_offsets < head_dim - - # Initialize min/max values for this dimension block - k_min = tl.full([BLOCK_SIZE], float('inf'), dtype=tl.float32) - k_max = tl.full([BLOCK_SIZE], float('-inf'), dtype=tl.float32) - - # Process tokens in current page for this dimension block - for token_offset in range(page_token_start, page_token_end): - # Get global sparse index for current head and token - sparse_global_idx = sparse_start + token_offset - sparse_indices_base = head_idx * num_total_sparse_tokens + sparse_global_idx - global_token_idx = tl.load(sparse_kv_indices_ptr + - sparse_indices_base) - - qkv_k_start_offset = num_heads * head_dim - qkv_stride = num_heads * head_dim + 2 * num_kv_heads * head_dim - - k_indices = (global_token_idx * qkv_stride + qkv_k_start_offset + - head_idx * head_dim + dim_offsets) - - # Load values for current dimension block - k_values = tl.load(qkv_ptr + k_indices, mask=dim_mask, other=0.0) - - # Update min/max for this page and dimension block - k_min = tl.where(dim_mask, tl.minimum(k_min, k_values), k_min) - k_max = tl.where(dim_mask, tl.maximum(k_max, k_values), k_max) - - # Calculate indices for min and max cache locations for current dimension block - cache_min_indices = cache_base + dim_offsets * num_pages_per_block + page_idx - cache_max_indices = cache_base + ( - head_dim + dim_offsets) * num_pages_per_block + page_idx - - # Store updated values back to cache for current dimension block - tl.store(kt_cache_tensor_ptr + cache_min_indices, k_min, mask=dim_mask) - tl.store(kt_cache_tensor_ptr + cache_max_indices, k_max, mask=dim_mask) - - -def batched_update_kt_cache(qkv: torch.Tensor, kt_cache_tensor: torch.Tensor, - kt_cache_slots: torch.Tensor, - sparse_kv_indices: torch.Tensor, - sparse_kv_offsets: torch.Tensor, batch_size: int, - num_heads: int, num_kv_heads: int, head_dim: int, - max_num_pages: int, page_size: int) -> None: - """ - Batched update KT cache with QKV tensor using Triton kernel. - Updated to work directly with QKV input tensor. - - Args: - qkv: Input QKV tensor [num_total_tokens, num_heads*head_dim + num_kv_heads*head_dim + num_kv_heads*head_dim] - kt_cache_tensor: KT cache tensor [max_batch_size, num_kv_heads, 2*head_dim, num_pages_per_block] - kt_cache_slots: Cache slot indices for each batch [batch_size] (tensor or list) - sparse_kv_indices: Sparse indices [num_kv_heads, num_total_sparse_tokens] - sparse_kv_offsets: Sparse offsets [batch_size + 1] - batch_size: Number of batches - num_heads: Number of Q heads - num_kv_heads: Number of KV heads - head_dim: Head dimension - max_num_pages: Maximum number of pages across all batches - page_size: Page size for grouping tokens - """ - num_total_tokens = qkv.shape[0] - max_batch_size, _, _, num_pages_per_block = kt_cache_tensor.shape - num_total_sparse_tokens = sparse_kv_indices.shape[1] - - grid = (batch_size, num_kv_heads, max_num_pages) - BLOCK_SIZE = 128 - - sparse_update_kt_cache_kernel[grid](qkv, - kt_cache_tensor, - sparse_kv_indices, - sparse_kv_offsets, - kt_cache_slots, - num_heads, - num_kv_heads, - head_dim, - page_size, - num_pages_per_block, - num_total_tokens, - num_total_sparse_tokens, - batch_size, - BLOCK_SIZE=BLOCK_SIZE) - - @triton.jit def _update_kt_cache_ctx_kernel(k_ptr, cache_ptr, block_offsets_ptr, cum_seq_lens_ptr, cum_kt_seq_lens_ptr, @@ -1427,6 +1148,93 @@ def _load_kt_cache_kernel(kt_states_ptr, cache_ptr, block_offsets_ptr, tl.store(kt_states_base + hidden_indices, kt, mask=mask) +@triton.jit +def kt_cache_update_kernel( + k_ptr, + kt_cache_tensor_ptr, + kt_cache_block_offsets_ptr, + kv_lens_ptr, + num_gen_tokens, + num_kv_heads, + head_dim, + kt_page_size, + tokens_per_block, + max_kt_blocks_per_seq, + DIM_BLOCK_SIZE: tl.constexpr, +): + """ + Specialized kernel for updating KT cache during generation phase. + + Grid: (num_gen_tokens, num_kv_heads, ceil(head_dim / DIM_BLOCK_SIZE)) + Each program handles one (batch, kv_head, dimension_block) combination. + + Args: + k_ptr: Key tensor [num_gen_tokens, num_kv_heads * head_dim] + kt_cache_tensor_ptr: KT cache [num_blocks, tokens_per_block, num_kv_heads, 2*head_dim] + kt_cache_block_offsets_ptr: Block offsets [batch_size, max_kt_blocks_per_seq] + kv_lens_ptr: Sequence lengths [num_gen_tokens] + num_gen_tokens: Number of generation tokens (batch size) + num_kv_heads: Number of KV heads + head_dim: Head dimension + kt_page_size: Page size for KT tokens + tokens_per_block: Tokens per cache block + max_kt_blocks_per_seq: Maximum KT blocks per sequence + DIM_BLOCK_SIZE: Size of dimension blocks for processing + """ + batch_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + dim_block_idx = tl.program_id(2) + + if batch_idx >= num_gen_tokens or kv_head_idx >= num_kv_heads: + return + + dim_block_start = dim_block_idx * DIM_BLOCK_SIZE + dim_offsets = tl.arange(0, DIM_BLOCK_SIZE) + dim_indices = dim_block_start + dim_offsets + dim_mask = dim_indices < head_dim + + k_base = batch_idx * num_kv_heads * head_dim + kv_head_idx * head_dim + k_indices = k_base + dim_indices + k_values = tl.load(k_ptr + k_indices, mask=dim_mask, other=0.0) + + kv_len = tl.load(kv_lens_ptr + batch_idx) + if kv_len <= 0: + return + + # Determine which kt_token to update (the last one) + last_token_idx = kv_len - 1 + last_kt_token_idx = last_token_idx // kt_page_size + + block_offset_in_seq = last_kt_token_idx // tokens_per_block + if block_offset_in_seq >= max_kt_blocks_per_seq: + return + + block_idx = tl.load(kt_cache_block_offsets_ptr + + batch_idx * max_kt_blocks_per_seq + block_offset_in_seq) + token_idx_in_block = last_kt_token_idx % tokens_per_block + + cache_base = ((block_idx * tokens_per_block + token_idx_in_block) * + num_kv_heads * 2 * head_dim + kv_head_idx * 2 * head_dim) + + cache_min_indices = cache_base + dim_indices + cache_max_indices = cache_base + head_dim + dim_indices + + k_min_existing = tl.load(kt_cache_tensor_ptr + cache_min_indices, + mask=dim_mask, + other=float('inf')) + k_max_existing = tl.load(kt_cache_tensor_ptr + cache_max_indices, + mask=dim_mask, + other=float('-inf')) + + k_min_new = tl.minimum(k_min_existing, k_values) + k_max_new = tl.maximum(k_max_existing, k_values) + k_min_new = k_min_new.to(kt_cache_tensor_ptr.dtype.element_ty) + k_max_new = k_max_new.to(kt_cache_tensor_ptr.dtype.element_ty) + + tl.store(kt_cache_tensor_ptr + cache_min_indices, k_min_new, mask=dim_mask) + tl.store(kt_cache_tensor_ptr + cache_max_indices, k_max_new, mask=dim_mask) + + def triton_update_kt_cache(k, kt_cache_tensor, kt_cache_block_offsets, @@ -1564,3 +1372,596 @@ def triton_update_kt_cache(k, BLOCK_SIZE=1024) return kt_states + + +######################################################## +# Paged KT cache BMM kernel +######################################################## + + +@triton.jit +def paged_kt_cache_bmm_kernel( + q_ptr, + q_mask_ptr, + kt_cache_tensor_ptr, + kt_cache_block_offsets_ptr, + dim_pos_ptr, + kv_lens_ptr, + output_ptr, + output_offsets_ptr, + num_gen_tokens, + num_kv_heads, + num_heads_per_kv, + head_dim, + kt_page_size, + tokens_per_block, + max_kt_blocks_per_seq, + max_num_kt_tokens, + total_kt_tokens, + sm_scale, + KT_BLOCK_SIZE: tl.constexpr, + DIM_BLOCK_SIZE: tl.constexpr, +): + """ + Specialized kernel for paged KT cache matrix multiplication in generation phase. + + Grid: (num_gen_tokens, ceil(max_num_kt_tokens / KT_BLOCK_SIZE), total_num_heads) + Each program handles KT_BLOCK_SIZE kt_tokens for one (batch, global_head) combination. + + Args: + q_ptr: Query tensor [num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim] + q_mask_ptr: Query mask [num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim] + kt_cache_tensor_ptr: KT cache [num_blocks, tokens_per_block, num_kv_heads, 2*head_dim] + kt_cache_block_offsets_ptr: Block offsets [batch_size, max_kt_blocks_per_seq] + dim_pos_ptr: Dimension positions [num_gen_tokens * num_kv_heads * head_dim] (0 or head_dim for each dim) + kv_lens_ptr: Sequence lengths [num_gen_tokens] + output_ptr: Output tensor [num_heads, 1, sum(ceil(kv_len / kt_page_size))] + output_offsets_ptr: Output offsets [num_gen_tokens + 1] + KT_BLOCK_SIZE: Number of kt_tokens to process per thread block + DIM_BLOCK_SIZE: Size of dimension blocks for processing + """ + batch_idx = tl.program_id(0) + kt_block_idx = tl.program_id(1) + global_head_idx = tl.program_id(2) + + total_num_heads = num_kv_heads * num_heads_per_kv + + if batch_idx >= num_gen_tokens or global_head_idx >= total_num_heads: + return + + kv_head_idx = global_head_idx // num_heads_per_kv + q_head_idx = global_head_idx % num_heads_per_kv + + kv_len = tl.load(kv_lens_ptr + batch_idx) + num_kt_tokens = (kv_len + kt_page_size - 1) // kt_page_size + + if num_kt_tokens <= 0: + return + + kt_token_start = kt_block_idx * KT_BLOCK_SIZE + + kt_token_offsets = tl.arange(0, KT_BLOCK_SIZE) + kt_token_indices = kt_token_start + kt_token_offsets + kt_token_mask = kt_token_indices < num_kt_tokens + + q_base = (batch_idx * num_kv_heads * num_heads_per_kv * head_dim + + kv_head_idx * num_heads_per_kv * head_dim + q_head_idx * head_dim) + dim_pos_base = (batch_idx * num_kv_heads * head_dim + + kv_head_idx * head_dim) + + block_offsets = batch_idx * max_kt_blocks_per_seq + kt_token_indices // tokens_per_block + block_indices = tl.load(kt_cache_block_offsets_ptr + block_offsets, + mask=kt_token_mask, + other=0) + token_indices_in_block = kt_token_indices % tokens_per_block + + cache_bases = ((block_indices * tokens_per_block + token_indices_in_block) * + num_kv_heads * 2 * head_dim + kv_head_idx * 2 * head_dim) + + results = tl.zeros([KT_BLOCK_SIZE], dtype=tl.float32) + + for dim_block_start in tl.range(0, head_dim, DIM_BLOCK_SIZE): + dim_offsets = tl.arange(0, DIM_BLOCK_SIZE) + dim_indices = dim_block_start + dim_offsets + dim_mask = dim_indices < head_dim + + q_indices = q_base + dim_indices + q_mask_values = tl.load(q_mask_ptr + q_indices, + mask=dim_mask, + other=0.0) + q_raw_values = tl.load(q_ptr + q_indices, mask=dim_mask, other=0.0) + q_values = tl.where(q_mask_values != 0.0, q_raw_values, 0.0) + + dim_pos_indices = dim_pos_base + dim_indices + kt_cache_offsets = tl.load(dim_pos_ptr + dim_pos_indices, + mask=dim_mask, + other=0) + + dim_indices_expanded = dim_indices[:, + None] # Shape: [DIM_BLOCK_SIZE, 1] + cache_bases_expanded = cache_bases[None, :] # Shape: [1, KT_BLOCK_SIZE] + dim_mask_expanded = dim_mask[:, None] # Shape: [DIM_BLOCK_SIZE, 1] + + kt_cache_offsets_expanded = kt_cache_offsets[:, + None] # Shape: [DIM_BLOCK_SIZE, 1] + kt_cache_indices = cache_bases_expanded + kt_cache_offsets_expanded + dim_indices_expanded # Shape: [DIM_BLOCK_SIZE, KT_BLOCK_SIZE] + + kt_token_mask_expanded = kt_token_mask[ + None, :] # Shape: [1, KT_BLOCK_SIZE] + combined_mask = dim_mask_expanded & kt_token_mask_expanded # Shape: [DIM_BLOCK_SIZE, KT_BLOCK_SIZE] + + kt_cache_flat = tl.reshape(kt_cache_indices, + [DIM_BLOCK_SIZE * KT_BLOCK_SIZE]) + mask_flat = tl.reshape(combined_mask, [DIM_BLOCK_SIZE * KT_BLOCK_SIZE]) + + kt_values_flat = tl.load(kt_cache_tensor_ptr + kt_cache_flat, + mask=mask_flat, + other=0.0) + kt_values = tl.reshape(kt_values_flat, [DIM_BLOCK_SIZE, KT_BLOCK_SIZE]) + + q_values_expanded = q_values[:, None] # Shape: [DIM_BLOCK_SIZE, 1] + products = q_values_expanded * kt_values # Shape: [DIM_BLOCK_SIZE, KT_BLOCK_SIZE] + masked_products = tl.where(combined_mask, products, 0.0) + + results += tl.sum(masked_products, axis=0) # Shape: [KT_BLOCK_SIZE] + + output_offset = tl.load(output_offsets_ptr + batch_idx) + output_indices = global_head_idx * total_kt_tokens + output_offset + kt_token_indices + tl.store(output_ptr + output_indices, + results * sm_scale, + mask=kt_token_mask) + + +def kt_cache_update_and_bmm( + q: torch.Tensor, + k: torch.Tensor, + q_mask: torch.Tensor, + dim_pos: torch.Tensor, + layer_idx: int, + metadata: "RocketTrtllmAttentionMetadata", + sm_scale: float = None, +) -> torch.Tensor: + """ + Separated KT cache update and BMM computation for generation phase. + + This function first updates the KT cache with new key values, then performs + the matrix multiplication with the cached values. + + Args: + q: Query tensor [num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim] + k: Key tensor [num_gen_tokens, num_kv_heads * head_dim] + q_mask: Query mask [num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim] (1 for top-r dims, 0 elsewhere) + dim_pos: Dimension offsets [num_gen_tokens, num_kv_heads, 1, head_dim] (0 or head_dim for each dim) + layer_idx: Layer index + metadata: Metadata + + Returns: + output: BMM results [num_heads, 1, sum(ceil(kv_lens/kt_page_size))] + """ + num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim = q.shape + total_num_heads = num_kv_heads * num_heads_per_kv + + kv_lens = metadata.kv_lens_cuda_runtime[metadata.num_contexts:] + kt_page_size = metadata.page_size + tokens_per_block = metadata.kt_tokens_per_block + max_kt_blocks_per_seq = metadata.kv_cache_manager.max_kt_blocks_per_seq + + # Calculate number of kt tokens for each batch and total + num_kt_tokens = metadata.num_kt_tokens_cuda[:metadata.num_generations] + total_kt_tokens = metadata.total_kt_tokens + max_num_kt_tokens = metadata.max_kt_tokens + + kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers(layer_idx) + kt_cache_block_offsets = metadata.kt_cache_block_offsets[metadata. + num_contexts:] + + # Calculate output offsets for each batch + output_offsets = metadata.cum_kt_lens_cuda[:metadata.num_generations + 1] + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + # Step 1: Update KT cache with new key values + grid = (num_gen_tokens, num_kv_heads, 1) + + kt_cache_update_kernel[grid](k, + kt_cache_tensor, + kt_cache_block_offsets, + kv_lens, + num_gen_tokens, + num_kv_heads, + head_dim, + kt_page_size, + tokens_per_block, + max_kt_blocks_per_seq, + DIM_BLOCK_SIZE=128) + + # Step 2: Perform BMM with updated cache + # Create output tensor with shape [num_heads, 1, sum(ceil(kv_lens/kt_page_size))] + output = torch.empty((total_num_heads, 1, total_kt_tokens), + dtype=torch.float32, + device=q.device) + + KT_BLOCK_SIZE = 64 + DIM_BLOCK_SIZE = 32 + grid = (num_gen_tokens, + (max_num_kt_tokens + KT_BLOCK_SIZE - 1) // KT_BLOCK_SIZE, + total_num_heads) + + paged_kt_cache_bmm_kernel[grid]( + q, + q_mask, + kt_cache_tensor, + kt_cache_block_offsets, + dim_pos, + kv_lens, + output, + output_offsets, + num_gen_tokens, + num_kv_heads, + num_heads_per_kv, + head_dim, + kt_page_size, + tokens_per_block, + max_kt_blocks_per_seq, + max_num_kt_tokens, + total_kt_tokens, + sm_scale, + KT_BLOCK_SIZE=KT_BLOCK_SIZE, + DIM_BLOCK_SIZE=DIM_BLOCK_SIZE, + num_warps=2, + num_stages=2, + ) + + return output + + +######################################################## +# Triton TopK kernel with optional interleave +######################################################## + + +@triton.jit +def triton_interleave_kernel( + input_ptr, + output_ptr, + kt_offsets_ptr, + kv_lens_ptr, + kv_offsets_ptr, + batch_size, + num_kv_heads, + kt_page_size, + total_kt_tokens, + total_kv_tokens, + BLOCK_SIZE: tl.constexpr, +): + """ + Interleave kt tokens to kv tokens by repeating each kt token kt_page_size times. + + Args: + input_ptr: Input tensor [num_kv_heads, total_kt_tokens] + output_ptr: Output tensor [num_kv_heads, total_kv_tokens] + kt_offsets_ptr: KT offsets [batch_size + 1] + kv_lens_ptr: KV lengths [batch_size] + kv_offsets_ptr: KV offsets [batch_size + 1] + batch_size: Number of batches + num_kv_heads: Number of KV heads + kt_page_size: Page size for interleaving + total_kt_tokens: Total KT tokens + total_kv_tokens: Total KV tokens + """ + batch_idx = tl.program_id(0) + head_idx = tl.program_id(1) + + if batch_idx >= batch_size or head_idx >= num_kv_heads: + return + + # Get batch kt and kv ranges + kt_start = tl.load(kt_offsets_ptr + batch_idx) + kt_end = tl.load(kt_offsets_ptr + batch_idx + 1) + kt_len = kt_end - kt_start + + kv_len = tl.load(kv_lens_ptr + batch_idx) + kv_start = tl.load(kv_offsets_ptr + batch_idx) + + if kt_len <= 0 or kv_len <= 0: + return + + # Process in blocks + for block_start in tl.range(0, kv_len, BLOCK_SIZE): + block_offsets = block_start + tl.arange(0, BLOCK_SIZE) + block_mask = block_offsets < kv_len + + # Calculate which kt_token each kv position corresponds to + kt_indices = block_offsets // kt_page_size + kt_valid_mask = kt_indices < kt_len + combined_mask = block_mask & kt_valid_mask + + # Load from input kt tokens + input_indices = head_idx * total_kt_tokens + kt_start + kt_indices + + values = tl.load(input_ptr + input_indices, + mask=combined_mask, + other=0.0) + + # Store to output kv positions + output_indices = head_idx * total_kv_tokens + kv_start + block_offsets + tl.store(output_ptr + output_indices, values, mask=block_mask) + + +@triton.jit +def triton_topk_kernel( + input_ptr, + output_indices_ptr, + temp_values_ptr, + temp_indices_ptr, + input_offsets_ptr, + sparse_offsets_ptr, + batch_size, + num_kv_heads, + topk, + total_input_tokens, + total_sparse_indices, + max_seq_len, + BLOCK_SIZE: tl.constexpr, +): + """ + Perform topk operation on each batch independently using efficient argsort implementation. + + Args: + input_ptr: Input tensor [num_kv_heads, total_input_tokens] + output_indices_ptr: Output indices [num_kv_heads, total_sparse_indices] + temp_values_ptr: Temporary values storage [batch_size, num_kv_heads, max_seq_len] + temp_indices_ptr: Temporary indices storage [batch_size, num_kv_heads, max_seq_len] + input_offsets_ptr: Input offsets [batch_size + 1] + sparse_offsets_ptr: Sparse offsets [batch_size + 1] + batch_size: Number of batches + num_kv_heads: Number of KV heads + topk: TopK parameter + total_input_tokens: Total input tokens + total_sparse_indices: Total sparse indices + max_seq_len: Maximum sequence length + """ + batch_idx = tl.program_id(0) + head_idx = tl.program_id(1) + + if batch_idx >= batch_size or head_idx >= num_kv_heads: + return + + input_start = tl.load(input_offsets_ptr + batch_idx) + input_end = tl.load(input_offsets_ptr + batch_idx + 1) + input_len = input_end - input_start + + sparse_start = tl.load(sparse_offsets_ptr + batch_idx) + sparse_end = tl.load(sparse_offsets_ptr + batch_idx + 1) + sparse_len = sparse_end - sparse_start + + if input_len <= 0 or sparse_len <= 0: + return + + actual_topk = tl.minimum(topk, input_len) + actual_topk = tl.minimum(actual_topk, sparse_len) + + # Base addresses + input_base = head_idx * total_input_tokens + input_start + temp_base = batch_idx * num_kv_heads * max_seq_len * 2 + head_idx * max_seq_len * 2 + output_base = head_idx * total_sparse_indices + sparse_start + + # Process sequence in chunks to handle variable lengths efficiently + max_process_len = tl.cdiv(input_len, BLOCK_SIZE) * BLOCK_SIZE + + for block_start in tl.range(0, max_process_len, BLOCK_SIZE): + block_offsets = block_start + tl.arange(0, BLOCK_SIZE) + block_mask = block_offsets < input_len + + values = tl.load(input_ptr + input_base + block_offsets, + mask=block_mask, + other=0.0) + + # Store values to temporary storage + tl.store(temp_values_ptr + temp_base + block_offsets, + values, + mask=block_mask) + # Store original indices + tl.store(temp_indices_ptr + temp_base + block_offsets, + block_offsets, + mask=block_mask) + + # Multi-round iterative argsort approach + # This works for both short and long sequences uniformly + current_len = input_len + current_base = temp_base + round_num = 0 + + while current_len > BLOCK_SIZE: + round_num += 1 + + num_chunks = tl.cdiv(current_len, BLOCK_SIZE) + + # Alternate between two halves of temp storage to avoid conflicts + if round_num % 2 == 1: + next_base = temp_base + max_seq_len + else: + next_base = temp_base + + next_len = 0 + + # Process each chunk in this round + for chunk_id in tl.range(0, num_chunks): + chunk_start = chunk_id * BLOCK_SIZE + chunk_end = tl.minimum(chunk_start + BLOCK_SIZE, current_len) + chunk_len = chunk_end - chunk_start + + if chunk_len > 0: + # Load chunk data from current round's storage + chunk_offsets = tl.arange(0, BLOCK_SIZE) + chunk_mask = chunk_offsets < chunk_len + + chunk_values = tl.load(temp_values_ptr + current_base + + chunk_start + chunk_offsets, + mask=chunk_mask, + other=0.0) + chunk_indices = tl.load(temp_indices_ptr + current_base + + chunk_start + chunk_offsets, + mask=chunk_mask, + other=0.0).to(tl.int32) + + # Sort this chunk using argsort + chunk_sorted_values, chunk_sorted_indices = argsort( + chunk_values, chunk_indices, dim=0, descending=True) + + # Extract top-k candidates from this chunk + chunk_topk = tl.minimum(actual_topk, chunk_len) + chunk_topk_mask = chunk_offsets < chunk_topk + + # Store top-k candidates to next round's storage + next_offsets = next_len + chunk_offsets + next_store_mask = chunk_topk_mask & (next_offsets < max_seq_len) + + tl.store(temp_values_ptr + next_base + next_offsets, + chunk_sorted_values, + mask=next_store_mask) + tl.store(temp_indices_ptr + next_base + next_offsets, + chunk_sorted_indices, + mask=next_store_mask) + + next_len += chunk_topk + + # Update parameters for next round + current_len = next_len + current_base = next_base + + final_offsets = tl.arange(0, BLOCK_SIZE) + final_mask = final_offsets < current_len + + final_values = tl.load(temp_values_ptr + current_base + final_offsets, + mask=final_mask, + other=0.0) + final_indices = tl.load(temp_indices_ptr + current_base + final_offsets, + mask=final_mask, + other=0.0).to(tl.int32) + + final_sorted_values, final_sorted_indices = argsort(final_values, + final_indices, + dim=0, + descending=True) + + result_offsets = tl.arange(0, BLOCK_SIZE) + result_mask = result_offsets < actual_topk + + selected_indices = tl.where(result_mask, final_sorted_indices, + tl.zeros_like(final_sorted_indices)) + tl.store(output_indices_ptr + output_base + result_offsets, + selected_indices, + mask=result_mask) + + +def triton_topk( + input_tensor: torch.Tensor, + kt_offsets: torch.Tensor, + kv_lens: torch.Tensor, + kt_lens: torch.Tensor, + kv_cu_lens: torch.Tensor, + total_kv_tokens: int, + topk: int, + kt_page_size: int, + use_interleave: bool = False) -> tuple[torch.Tensor, torch.Tensor]: + """ + Perform topk operation with optional interleaving. + + Args: + input_tensor: Input scores [num_kv_heads, sum(kt_lens)] + kt_offsets: KT offsets [batch_size + 1] + kv_lens: KV lengths [batch_size] + kt_lens: KT lengths [batch_size] + topk: TopK parameter + kt_page_size: Page size for interleaving + use_interleave: Whether to perform interleaving + + Returns: + output_indices: Selected indices [num_kv_heads, num_total_sparse_indices] + sparse_offsets: Sparse offsets [batch_size + 1] + """ + + num_kv_heads = input_tensor.shape[0] + batch_size = len(kv_lens) + device = input_tensor.device + + if use_interleave: + total_kt_tokens = input_tensor.shape[1] + + # Create interleaved tensor + interleaved_tensor = torch.empty((num_kv_heads, total_kv_tokens), + dtype=input_tensor.dtype, + device=device) + + # Launch interleave kernel + grid = (batch_size, num_kv_heads) + triton_interleave_kernel[grid](input_tensor, + interleaved_tensor, + kt_offsets, + kv_lens, + kv_cu_lens, + batch_size, + num_kv_heads, + kt_page_size, + total_kt_tokens, + total_kv_tokens, + BLOCK_SIZE=1024) + + # Use interleaved tensor and kv_cu_lens for topk + working_tensor = interleaved_tensor + working_offsets = kv_cu_lens + working_lens = kv_lens + else: + # Use original tensor and kt_offsets for topk + working_tensor = input_tensor + working_offsets = kt_offsets + working_lens = kt_lens + topk = (topk + kt_page_size - 1) // kt_page_size + + # Calculate sparse counts and offsets + sparse_counts = torch.minimum(torch.tensor(topk, device=device), + working_lens) + sparse_offsets = torch.cumsum(torch.cat( + [torch.zeros(1, device=device), sparse_counts]), + dim=0).to(torch.int32) + + total_sparse_indices = sparse_offsets[-1].item() + total_working_tokens = working_tensor.shape[1] + max_seq_len = working_lens.max().item() + + # Create output tensor + output_indices = torch.empty((num_kv_heads, total_sparse_indices), + dtype=torch.int32, + device=device) + + # Create temporary storage for topk algorithm (double size for dual-buffer design) + temp_values = torch.empty((batch_size, num_kv_heads, max_seq_len * 2), + dtype=working_tensor.dtype, + device=device) + temp_indices = torch.empty((batch_size, num_kv_heads, max_seq_len * 2), + dtype=torch.int32, + device=device) + + grid = (batch_size, num_kv_heads) + + triton_topk_kernel[grid]( + working_tensor, + output_indices, + temp_values, + temp_indices, + working_offsets, + sparse_offsets, + batch_size, + num_kv_heads, + topk, + total_working_tokens, + total_sparse_indices, + max_seq_len, + BLOCK_SIZE=512, + num_warps=16, + num_stages=3, + ) + + return output_indices, sparse_offsets diff --git a/tensorrt_llm/_torch/attention_backend/sparse/rocket.py b/tensorrt_llm/_torch/attention_backend/sparse/rocket.py index 4d9715d6689..579af97af46 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/rocket.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/rocket.py @@ -21,9 +21,10 @@ from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig -from .kernel import (flatten_sparse_indices, fused_qk_split, - reshape_flatten_to_batched, separated_bmm_softmax, - triton_index_gather, triton_update_kt_cache) +from .kernel import (bmm, flatten_sparse_indices, fused_qk_split, + kt_cache_update_and_bmm, reshape_flatten_to_batched, + triton_index_gather, triton_softmax, triton_topk, + triton_update_kt_cache) ModelConfig = tensorrt_llm.bindings.ModelConfig @@ -36,6 +37,7 @@ def __post_init__(self): raise ValueError("Sparse attention config is not set") self.prompt_budget = self.sparse_attention_config.prompt_budget self.window_size = self.sparse_attention_config.window_size + self.page_size = self.sparse_attention_config.page_size self.context_lens_cuda = torch.empty( self.max_num_sequences, @@ -153,6 +155,33 @@ def __post_init__(self): device='cuda', ) + self.num_kt_tokens = torch.empty( + self.max_num_sequences, + device='cpu', + dtype=torch.int32, + ) + self.num_kt_tokens_cuda = torch.empty_like(self.num_kt_tokens, + device='cuda', + dtype=torch.int32) + + self.cum_kt_lens = torch.zeros( + self.max_num_sequences + 1, + device='cpu', + dtype=torch.int32, + ) + self.cum_kt_lens_cuda = torch.empty_like(self.cum_kt_lens, + device='cuda', + dtype=torch.int32) + + self.cum_kv_gen_lens = torch.zeros( + self.max_num_sequences + 1, + device='cpu', + dtype=torch.int32, + ) + self.cum_kv_gen_lens_cuda = torch.empty_like(self.cum_kv_gen_lens, + device='cuda', + dtype=torch.int32) + @property def kt_tokens_per_block(self) -> Optional[int]: """ @@ -275,6 +304,31 @@ def prepare(self): self.valid_batch_size = valid_batch_size self.total_sparse_tokens = self.sparse_offsets[self.num_contexts].item() + self.num_kt_tokens[:self.num_generations] = ( + self.kv_lens[self.num_contexts:self.num_seqs] + self.page_size - + 1) // self.page_size + self.num_kt_tokens_cuda[:self.num_generations].copy_( + self.num_kt_tokens[:self.num_generations], non_blocking=True) + + self.cum_kt_lens[1:self.num_generations + 1] = torch.cumsum( + self.num_kt_tokens[:self.num_generations], dim=0) + self.cum_kt_lens_cuda[:self.num_generations + 1].copy_( + self.cum_kt_lens[:self.num_generations + 1], non_blocking=True) + + if self.num_generations > 0: + self.max_kt_tokens = self.num_kt_tokens[:self.num_generations].max( + ).item() + + self.total_kt_tokens = self.cum_kt_lens[self.num_generations].item() + + self.cum_kv_gen_lens[1:self.num_generations + 1] = torch.cumsum( + self.kv_lens_cuda_runtime[self.num_contexts:self.num_seqs], dim=0) + self.cum_kv_gen_lens_cuda[:self.num_generations + 1].copy_( + self.cum_kv_gen_lens[:self.num_generations + 1], non_blocking=True) + + self.total_kv_gen_tokens = self.cum_kv_gen_lens[ + self.num_generations].item() + class RocketTrtllmAttention(TrtllmAttention): Metadata = RocketTrtllmAttentionMetadata @@ -312,7 +366,7 @@ def __init__( self.kernel_size = sparse_attention_config.kernel_size self.page_size = sparse_attention_config.page_size - def batched_ctx_sparse_predict( + def _batched_sparse_kv_predict( self, qkv: torch.Tensor, metadata: RocketTrtllmAttentionMetadata ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: """ @@ -332,12 +386,15 @@ def batched_ctx_sparse_predict( self.window_size, self.prompt_budget, metadata) - scores = separated_bmm_softmax(q_window, - k_context, - metadata.q_cu_seqlens_cuda, - metadata.k_cu_seqlens_cuda, - metadata.valid_batch_size, - causal=False) + scores = bmm(q_window, + k_context, + metadata.q_cu_seqlens_cuda, + metadata.k_cu_seqlens_cuda, + metadata.valid_batch_size, + causal=False) + + scores = triton_softmax(scores, metadata.k_cu_seqlens_cuda, + metadata.valid_batch_size) scores = scores.view(self.num_kv_heads, self.num_heads // self.num_kv_heads, @@ -385,262 +442,85 @@ def batched_ctx_sparse_predict( return sparse_kv_indices, metadata.sparse_offsets_cuda[:metadata. num_contexts + 1] - def batched_gen_sparse_predict( - self, q: torch.Tensor, k: torch.Tensor, - metadata: RocketTrtllmAttentionMetadata - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - pass + @torch.compile(dynamic=True) + def preprocess_for_gen( + self, qkv: torch.Tensor, metadata: RocketTrtllmAttentionMetadata + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qkv_input = qkv[metadata.num_ctx_tokens:] + q, k = qkv_input[:, :self.num_heads * self. + head_dim], qkv_input[:, self.num_heads * + self.head_dim:self.num_heads * + self.head_dim + + self.num_kv_heads * self.head_dim] + q = q.view(-1, self.num_kv_heads, self.num_heads // self.num_kv_heads, + self.head_dim) + + q_abs = torch.abs(q) + q_mask = torch.zeros_like(q) + + i1 = torch.topk(q_abs.mean(dim=2, keepdim=True), self.topr, + dim=-1).indices + q_mask.scatter_(-1, i1.expand_as(q[..., :self.topr]), 1) - def batched_sparse_attention_predict( - self, q: torch.Tensor, k: torch.Tensor, - metadata: RocketTrtllmAttentionMetadata - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - """ - Predict sparse KV indices and sparse attention indices for the input sequence. - """ - assert k is None, "RocketKV can only support fused qkv input." - sparse_kv_indices, sparse_kv_offsets = self.batched_ctx_sparse_predict( - q, metadata) + dim_pos = torch.where(q.sum(dim=2, keepdim=True) > 0, self.head_dim, + 0).to(torch.int32) - return sparse_kv_indices, sparse_kv_offsets + return q, k, q_mask, dim_pos - def sparse_attention_predict( - self, q: torch.Tensor, k: torch.Tensor, - metadata: RocketTrtllmAttentionMetadata + def _batched_sparse_attn_predict( + self, qkv: torch.Tensor, metadata: RocketTrtllmAttentionMetadata ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - """ - Predict sparse KV indices and sparse attention indices for the input sequence. - - For RocketKV: - - Context phase: predict SnapKV sparse kv indices - - Generation phase: predict RocketKV sparse attention indices - - Returns: - Tuple of (flattened_indices, batch_offsets) - - flattened_indices: [total_selected_indices, num_kv_heads] - - batch_offsets: [batch_size + 1] with cumulative indices count - """ - q, k, _ = q.split([ - self.num_heads * self.head_dim, self.num_kv_heads * self.head_dim, - self.num_kv_heads * self.head_dim - ], - dim=-1) - - if k is None or metadata is None: + metadata.num_ctx_tokens + if metadata.num_generations == 0: return None, None - num_contexts = metadata.num_contexts - num_generations = metadata.num_generations - seq_lens = metadata.seq_lens - seq_lens_kv = metadata.seq_lens_kv if metadata.seq_lens_kv is not None else seq_lens - past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq - - sparse_kv_indices = [] - sparse_attn_indices = [] - sparse_kv_offsets = [0] - sparse_attn_offsets = [0] - - q_offset = 0 - k_offset = 0 - - for i in range(num_contexts + num_generations): - seq_len = seq_lens[i].item() - seq_len_kv = seq_lens_kv[i].item() - - if seq_len <= 0 or seq_len_kv <= 0: - assert False, "Invalid sequence length" - - single_q = q[q_offset:q_offset + seq_len] - single_k = k[k_offset:k_offset + seq_len_kv] - - single_q = single_q.view(1, seq_len, self.num_heads, - self.head_dim).transpose(1, 2) - single_k = single_k.view(1, seq_len_kv, self.num_kv_heads, - self.head_dim) - - past_seen_token = past_seen_tokens[i] - if i < num_contexts: - pass - # # Context phase: SnapKV sparse kv indices - # _sparse_kv_indices = self._get_snapkv_indices( - # single_q, single_k, past_seen_token, kt_cache_slot, - # metadata) - # if _sparse_kv_indices is not None: - # sparse_kv_indices.append( - # _sparse_kv_indices.squeeze(0)) # [budget, num_kv_heads] - # sparse_kv_offsets.append(sparse_kv_offsets[-1] + - # _sparse_kv_indices.size(1)) - # else: - # sparse_kv_offsets.append(sparse_kv_offsets[-1]) - else: - # Generation phase: RocketKV sparse attention indices - _sparse_attn_indices = self._rocketkv_selection( - single_q, single_k, past_seen_token, metadata, i) - if _sparse_attn_indices is not None: - sparse_attn_indices.append( - _sparse_attn_indices.squeeze(0)) # [topk, num_kv_heads] - sparse_attn_offsets.append(sparse_attn_offsets[-1] + - _sparse_attn_indices.size(1)) - else: - sparse_attn_offsets.append(sparse_attn_offsets[-1]) - - q_offset += seq_len - k_offset += seq_len_kv - - if len(sparse_kv_indices) == 0: - sparse_kv_indices, sparse_kv_offsets = None, None - else: - sparse_kv_indices = torch.cat(sparse_kv_indices, - dim=0).to(torch.int32) - sparse_kv_offsets = torch.tensor(sparse_kv_offsets, - dtype=torch.int32).to(q.device) - if len(sparse_attn_indices) == 0: - sparse_attn_indices, sparse_attn_offsets = None, None - else: - sparse_attn_indices = torch.cat(sparse_attn_indices, - dim=0).to(torch.int32) - sparse_attn_offsets = torch.tensor(sparse_attn_offsets, - dtype=torch.int32).to(q.device) - - return sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets - - def _get_snapkv_indices(self, q: Tensor, k: Tensor, past_seen_token: int, - metadata: RocketTrtllmAttentionMetadata, - sample_idx: int) -> Optional[Tensor]: - """ - Get SnapKV sparse kv indices from the input sequence for context phase. - The shape of output is (1, prompt_budget, num_kv_heads) - """ - bsz = 1 - seq_len = k.size(1) # k shape: (1, seq_len, num_kv_heads, head_dim) - - # If the sequence length is less than the prompt budget, do not enable sparse kv cache - if seq_len <= self.prompt_budget: - return None - - # Use last window_size tokens as observation - # (1, num_heads, window_size, head_dim) - q_obs = q[:, :, -self.window_size:] - # (1, num_kv_heads, seq_len, head_dim) - k_pre = repeat_kv(k.transpose(1, 2), self.num_heads // - self.num_kv_heads)[:, :, :-self.window_size] - - score = torch.matmul(q_obs, k_pre.transpose(-1, -2)) / math.sqrt( - self.head_dim) - - score = torch.nn.functional.softmax(score, dim=-1) - - score = score.sum(dim=-2) - - score = score.view(bsz, self.num_kv_heads, - self.num_heads // self.num_kv_heads, -1).sum(dim=2) - score = torch.nn.functional.max_pool1d(score, - kernel_size=self.kernel_size, - padding=self.kernel_size // 2, - stride=1) + q, k, q_mask, dim_pos = self.preprocess_for_gen(qkv, metadata) - # Select top important tokens from prefix - prefix_len = seq_len - self.window_size - selected_prefix_indices = score.topk(self.prompt_budget - - self.window_size, - dim=-1).indices.sort().values + scores = kt_cache_update_and_bmm( + q, + k, + q_mask, + dim_pos, + self.layer_idx, + metadata, + ) - # Combine selected prefix indices with window indices - window_indices = torch.arange( - prefix_len, seq_len, - device=k.device).unsqueeze(0).unsqueeze(0).expand( - bsz, self.num_kv_heads, -1) - selected_indices = torch.cat([selected_prefix_indices, window_indices], - dim=-1).transpose(1, 2) + scores = triton_softmax(scores, metadata.cum_kt_lens_cuda, + metadata.num_generations) - k = k.reshape(1, -1, self.num_kv_heads, self.head_dim) - k_snap = triton_index_gather(k, selected_indices) - # Update KT cache - kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers( - self.layer_idx) - k_snap_len = torch.clamp( - metadata.kv_lens_cuda_runtime[sample_idx:sample_idx + 1], - max=self.prompt_budget).int() - triton_update_kt_cache( - k_snap.squeeze(0).contiguous(), - kt_cache_tensor, - metadata.kt_cache_block_offsets[sample_idx:sample_idx + 1], - k_snap_len, + scores = scores.view(self.num_kv_heads, + self.num_heads // self.num_kv_heads, + -1).mean(dim=-2) + + # Use triton_topk to select sparse attention indices + selected_indices, sparse_attn_offsets = triton_topk( + scores, + metadata.cum_kt_lens_cuda[:metadata.num_generations + 1], + metadata.kv_lens_cuda_runtime[metadata.num_contexts:metadata. + num_seqs], + metadata.num_kt_tokens_cuda[:metadata.num_generations], + metadata.cum_kv_gen_lens_cuda[:metadata.num_generations + 1], + metadata.total_kv_gen_tokens, + self.topk, self.page_size, - metadata.kt_tokens_per_block, - metadata.kv_cache_manager.max_kt_blocks_per_seq, - update=False) + use_interleave=False) - return selected_indices + return selected_indices, sparse_attn_offsets - def _rocketkv_selection(self, q: Tensor, k: Tensor, past_seen_token: int, - metadata: RocketTrtllmAttentionMetadata, - sample_idx: int) -> Tensor: + def batched_sparse_attention_predict( + self, q: torch.Tensor, k: torch.Tensor, + metadata: RocketTrtllmAttentionMetadata + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: """ - Implement RocketKV's two-stage selection process for generation phase. - The shape of output is (1, topk, num_kv_heads) + Predict sparse KV indices and sparse attention indices for the input sequence. """ - bsz = 1 - q_len = q.size(2) - - # Helper functions - def _gather(t: Tensor, dim: int, i: Tensor) -> Tensor: - dim += (dim < 0) * t.ndim - return t.gather( - dim, i.expand(*t.shape[:dim], i.shape[dim], *t.shape[dim + 1:])) - - @torch.compile(disable=not torch.cuda.is_available()) - def _scaled_softmax(x: Tensor, divscale: Tensor | float, - dim: int) -> Tensor: - return torch.softmax(x / divscale, dim=dim) - - # Get KT cache for key-token matching - kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers( - self.layer_idx) - target_seq_len = past_seen_token + 1 # +1 for current token - - # Update KT cache - kt_states = triton_update_kt_cache( - k.squeeze(0).contiguous(), kt_cache_tensor, - metadata.kt_cache_block_offsets[sample_idx:sample_idx + 1], - metadata.kv_lens_cuda_runtime[sample_idx:sample_idx + 1], - self.page_size, metadata.kt_tokens_per_block, - metadata.kv_cache_manager.max_kt_blocks_per_seq) - kt_states = kt_states.unsqueeze(0).permute(0, 2, 3, 1) - - # Reshape query for multi-head processing - qi = q.view(bsz, self.num_kv_heads, self.num_heads // self.num_kv_heads, - q_len, self.head_dim) - qi_abs = torch.abs(qi) - - # Top-r selection on query features - i1 = torch.topk(qi_abs.mean(dim=2, keepdim=True), self.topr, - dim=-1).indices - qi_hat = _gather(qi, -1, i1) - - # Generate signed indices for key-token matching - i1_sign = torch.where( - qi_hat.sum(dim=2, keepdim=True) > 0, i1 + self.head_dim, - i1).transpose(-1, -2) - - # Gather key tokens and compute attention scores - kt_hat = _gather(kt_states.unsqueeze(2), -2, i1_sign) - qk_hat = qi_hat @ kt_hat - qk_hat = qk_hat.repeat_interleave(self.page_size, - dim=-1)[:, :, :, :, :target_seq_len] - scale = torch.sqrt(self.head_dim * - torch.abs(qi_hat).sum(dim=-1, keepdim=True) / - qi_abs.sum(dim=-1, keepdim=True)) - - # (1, num_kv_heads, num_heads, target_seq_len) - s_hat = _scaled_softmax(qk_hat, scale, dim=-1) - - topk = min(self.topk, target_seq_len) - i2 = torch.topk(s_hat.mean(dim=2, keepdim=True), topk, dim=-1).indices - - iKV = i2[:, :, 0, 0, :].transpose(1, 2) # (1, topk, num_kv_heads) + assert k is None, "RocketKV can only support fused qkv input." + sparse_kv_indices, sparse_kv_offsets = self._batched_sparse_kv_predict( + q, metadata) + sparse_attn_indices, sparse_attn_offsets = self._batched_sparse_attn_predict( + q, metadata) - return iKV + return sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets class RocketVanillaAttentionMetadata(VanillaAttentionMetadata): diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index fe5e0efd0c9..ec0558ec76a 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -1340,11 +1340,7 @@ def forward( sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets = None, None, None, None if self.sparse_attention_config is not None: - sparse_kv_indices, sparse_kv_offsets = self.batched_sparse_attention_predict( - q, k, metadata) - - sparse_attn_indices, sparse_attn_offsets = None, None - _, _, sparse_attn_indices, sparse_attn_offsets = self.sparse_attention_predict( + sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets = self.batched_sparse_attention_predict( q, k, metadata) if sparse_attn_indices is not None: