Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions cpp/tensorrt_llm/kernels/mlaKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
int* seqQOffset, uint32_t* fmha_tile_counter, int32_t const* kv_cache_lengths, int* seqKVOffsets, int q_pe_ld,
int q_pe_stride, KvCacheDataType cache_type, float* bmm1_scale, float* bmm2_scale, float const* quant_scale_o,
float const* quant_scale_q, float const* quant_scale_kv, float const* dequant_scale_q,
float const* dequant_scale_kv, float host_bmm1_scale, int32_t const* helix_position_offsets)
float const* dequant_scale_kv, float host_bmm1_scale, int32_t const* helix_position_offsets,
bool const* helix_is_inactive_rank)
{

// Constants.
Expand Down Expand Up @@ -460,7 +461,7 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,

if (valid_token)
{
if (head_idx == head_num)
if (head_idx == head_num && (helix_is_inactive_rank == nullptr || !helix_is_inactive_rank[batch_idx]))
{
auto const token_kv_idx = kv_cache_lengths[batch_idx] - seq_len + local_token_idx;

Expand Down Expand Up @@ -514,7 +515,7 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
auto local_token_idx = global_token_idx % seq_len;
bool valid_token = global_token_idx < total_s_len;

if (valid_token)
if (valid_token && (helix_is_inactive_rank == nullptr || !helix_is_inactive_rank[batch_idx]))
{
if (head_dim_vec_idx == 0)
{
Expand Down Expand Up @@ -1047,7 +1048,7 @@ void invokeMLARopeGeneration(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer
params.seqQOffset, params.fmha_tile_counter, params.cache_seq_lens, params.cu_kv_seqlens, params.q_pe_ld,
params.q_pe_stride, params.cache_type, params.bmm1_scale, params.bmm2_scale, params.quant_scale_o,
params.quant_scale_q, params.quant_scale_kv, params.dequant_scale_q, params.dequant_scale_kv,
params.host_bmm1_scale, params.helix_position_offsets);
params.host_bmm1_scale, params.helix_position_offsets, params.helix_is_inactive_rank);
}

template <typename T, typename TCache>
Expand Down
4 changes: 4 additions & 0 deletions cpp/tensorrt_llm/kernels/mlaKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ struct MlaParams

// for Helix parallelism: the rotary position offsets [b]
int32_t const* helix_position_offsets{nullptr};

// for Helix parallelism: whether the current rank is inactive, shape [b]
// (the current query tokens are not appended to this rank's KV cache)
bool const* helix_is_inactive_rank{nullptr};
};

template <typename T, typename KVCacheBuffer>
Expand Down
9 changes: 7 additions & 2 deletions cpp/tensorrt_llm/thop/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ class Runner : public RunnerBase
[[maybe_unused]] MlaParams<T> mla_params;
if (op.isMLAEnabled())
{
TORCH_CHECK(mla_tensor_params.size() == 1,
"Expecting 1 tensor for custom MLA tensor params: helix_position_offsets.");
TORCH_CHECK(mla_tensor_params.size() == 2,
"Expecting 2 tensors for custom MLA tensor params: helix_position_offsets and helix_is_inactive_rank.");
if (is_context && op.mUseSparseAttention)
{
if (latent_cache.has_value())
Expand Down Expand Up @@ -227,10 +227,15 @@ class Runner : public RunnerBase

// For generation, helix position is in ropeOp
auto& mla_helix_position_offsets = mla_tensor_params[0];
auto& mla_helix_is_inactive_rank = mla_tensor_params[1];
if (mla_helix_position_offsets.has_value())
{
mla_params.helix_position_offsets = mla_helix_position_offsets->data_ptr<int32_t>();
}
if (mla_helix_is_inactive_rank.has_value())
{
mla_params.helix_is_inactive_rank = mla_helix_is_inactive_rank->data_ptr<bool>();
}
}
else
{
Expand Down
11 changes: 8 additions & 3 deletions cpp/tensorrt_llm/thop/dsv3RopeOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ struct MlaRopeGenArgs
float const* kv_scale_quant_orig_ptr;
float host_bmm1_scale;
int32_t const* helix_position_offsets_ptr;
bool const* helix_is_inactive_rank_ptr;
};

template <typename T, typename KVCacheBuffer>
Expand Down Expand Up @@ -105,6 +106,7 @@ void invokeMLARopeGenerationHelper(T const* latent_cache_ptr, T* q_pe_ptr, T* fu
mla_params.dequant_scale_kv = args.kv_scale_quant_orig_ptr;
mla_params.host_bmm1_scale = args.host_bmm1_scale;
mla_params.helix_position_offsets = args.helix_position_offsets_ptr;
mla_params.helix_is_inactive_rank = args.helix_is_inactive_rank_ptr;

tk::invokeMLARopeGeneration<T>(mla_params, kv_cache_buffer, stream);
}
Expand Down Expand Up @@ -133,8 +135,8 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
TLLM_CHECK_WITH_INFO(
head_size == kv_lora_rank + qk_rope_head_dim, "head_size must = kv_lora_rank + qk_rope_head_dim");
TLLM_CHECK_WITH_INFO(num_kv_heads == 1, "num_kv_heads must = 1");
TORCH_CHECK(
mla_tensor_params.size() == 1, "Expecting 1 tensor for custom MLA tensor params: helix_position_offsets.");
TORCH_CHECK(mla_tensor_params.size() == 2,
"Expecting 2 tensors for custom MLA tensor params: helix_position_offsets and helix_is_inactive_rank.");

auto stream = at::cuda::getCurrentCUDAStream(fused_q.get_device());
auto const kv_cache_quant_mode = tc::QuantMode(uint32_t(quant_mode));
Expand All @@ -153,6 +155,7 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
int32_t const num_gen_tokens = num_tokens;
int32_t const seq_offset = num_contexts;
auto& mla_helix_position_offsets = mla_tensor_params[0];
auto& mla_helix_is_inactive_rank = mla_tensor_params[1];
int32_t const layer_num = host_kv_cache_pool_mapping.value().size(0);

tk::MlaMetaParams mla_meta_params = {static_cast<int>(q_lora_rank), static_cast<int>(kv_lora_rank),
Expand All @@ -161,6 +164,8 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +

int32_t const* helix_position_offsets_ptr
= mla_helix_position_offsets.has_value() ? mla_helix_position_offsets->data_ptr<int32_t>() : nullptr;
bool const* helix_is_inactive_rank_ptr
= mla_helix_is_inactive_rank.has_value() ? mla_helix_is_inactive_rank->data_ptr<bool>() : nullptr;

int* cu_q_seqlens_ptr = reinterpret_cast<int*>(cu_q_seqlens.data_ptr());
int* cu_kv_seqlens_ptr = reinterpret_cast<int*>(cu_kv_seqlens.data_ptr());
Expand Down Expand Up @@ -274,7 +279,7 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
static_cast<int32_t>(num_heads), mla_meta_params, sequence_lengths_ptr, max_context_q_len,
block_ids_per_seq_ptr, cache_type, cu_q_seqlens_ptr, cu_kv_seqlens_ptr, fmha_tile_counter_ptr,
mla_bmm1_scale_ptr, mla_bmm2_scale_ptr, quant_q_buffer_ptr, quant_scale_o_ptr, kv_scale_orig_quant_ptr,
kv_scale_quant_orig_ptr, host_bmm1_scale, helix_position_offsets_ptr};
kv_scale_quant_orig_ptr, host_bmm1_scale, helix_position_offsets_ptr, helix_is_inactive_rank_ptr};

auto const input_dtype = fused_q.scalar_type();
if (input_dtype == torch::kFloat16)
Expand Down
59 changes: 52 additions & 7 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ def plan(
q_pe: Optional[torch.Tensor] = None,
mrope_config: Optional[dict] = None,
softmax_stats_tensor: Optional[torch.Tensor] = None,
helix_position_offsets: Optional[torch.Tensor] = None,
is_spec_decoding_enabled: bool = False,
use_spec_decoding: bool = False,
is_spec_dec_tree: bool = False,
Expand All @@ -207,6 +206,8 @@ def plan(
sparse_attn_offsets: Optional[torch.Tensor] = None,
sparse_attn_indices_block_size: int = 1,
sparse_mla_topk: int = 0,
helix_position_offsets: Optional[torch.Tensor] = None,
helix_is_inactive_rank: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -243,7 +244,6 @@ def plan(
use_paged_context_fmha (bool): Sets the mPagedContextFMHA attribute in the op runner.
mrope_config (dict): The dictionary containing the mRope configuration.
softmax_stats_tensor (torch.Tensor): The tensor to store the softmax statistics (max/sum)
helix_position_offsets (torch.Tensor): The tensor to store the helix position offsets, with shape (num_tokens) on GPU.
attention_sinks (torch.Tensor): The attention sinks (additional value in the denominator of the softmax) with shape of (num_heads_q) on GPU.
chunked_prefill_buffer_batch_size (int): used for malloc buffer for k and v in fp8 context mla. the max input kv length is not max_num_tokens in this case. It is chunked_prefill_buffer_batch_size * max_num_tokens.
sparse_kv_indices (torch.Tensor): The sparse indices for the KV cache, with shape of (num_heads_kv, num_sparse_tokens) on GPU.
Expand All @@ -252,6 +252,8 @@ def plan(
sparse_attn_offsets (torch.Tensor): The batch offsets for the sparse attention indices, with shape of (num_generations + 1) on GPU.
sparse_attn_indices_block_size (int): The granularity of the sparse attention indices, used by block sparse attention.
sparse_mla_topk (int): The topk for the sparse MLA, used by DSA attention.
helix_position_offsets (torch.Tensor): The tensor to store the helix position offsets, with shape (num_tokens) on GPU.
helix_is_inactive_rank (torch.Tensor): For Helix: whether the current rank is inactive, with shape (batch_size) on GPU.
"""
self.layer_idx = layer_idx
self.tokens_per_block = tokens_per_block
Expand Down Expand Up @@ -287,14 +289,20 @@ def plan(
'mrope_position_deltas') if mrope_config is not None else None
self.block_ids_per_seq = block_ids_per_seq
self.softmax_stats_tensor = softmax_stats_tensor
self.helix_position_offsets = helix_position_offsets
self.attention_sinks = attention_sinks
self.sparse_kv_indices = sparse_kv_indices
self.sparse_kv_offsets = sparse_kv_offsets
self.sparse_attn_indices = sparse_attn_indices
self.sparse_attn_offsets = sparse_attn_offsets
self.sparse_attn_indices_block_size = sparse_attn_indices_block_size
self.sparse_mla_topk = sparse_mla_topk
self.helix_position_offsets = helix_position_offsets
self.helix_is_inactive_rank = helix_is_inactive_rank
if self.helix_is_inactive_rank is not None and not isinstance(
self.helix_is_inactive_rank, torch.Tensor):
self.helix_is_inactive_rank = torch.tensor(
self.helix_is_inactive_rank, dtype=torch.bool, pin_memory=True)

if max_sequence_length > self.rope_params.max_positions:
self.rope_params.max_positions = max_sequence_length
self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params(
Expand Down Expand Up @@ -473,7 +481,9 @@ def run(
spec_decoding_tensor_params.append(self.spec_decoding_bl_tree_mask)
spec_decoding_tensor_params.append(
self.spec_bl_tree_first_sparse_mask_offset_kv)
mla_tensor_params = [self.helix_position_offsets]
mla_tensor_params = [
self.helix_position_offsets, self.helix_is_inactive_rank
]

thop.attention(
q,
Expand Down Expand Up @@ -632,6 +642,13 @@ class TrtllmAttentionMetadata(AttentionMetadata):
spec_decoding_bl_tree_mask: Optional[torch.Tensor] = None
spec_bl_tree_first_sparse_mask_offset_kv: Optional[torch.Tensor] = None

# Whether the current rank is inactive for helix parallelism.
# In helix parallelism, only the active rank appends KV cache for the query token
# and attends to the previously cached tokens as well as the query token. Inactive
# ranks do not append KV cache for the query token and attend to the previously
# cached tokens only.
helix_is_inactive_rank: Optional[torch.Tensor] = None

@property
def max_seq_len(self) -> int:
"""
Expand Down Expand Up @@ -840,7 +857,21 @@ def prepare(self) -> None:
if self.enable_flash_mla:
self.prepare_flash_mla()
# number of tokens needed in the kv cache for each sequence after the next pass
kv_lens = cached_token_lens + self.seq_lens_kv if cached_token_lens is not None else self.seq_lens_kv
if self.helix_is_inactive_rank is not None and len(
self.helix_is_inactive_rank):
# If helix is inactive, attend to the previously cached tokens only.
assert cached_token_lens is not None, "cached_token_lens should be set for helix"
kv_lens = cached_token_lens.clone()
helix_is_inactive_rank_cpu = torch.tensor(
self.helix_is_inactive_rank,
dtype=torch.bool,
device='cpu',
)
active_rank = ~helix_is_inactive_rank_cpu
kv_lens[active_rank] += self.seq_lens_kv[active_rank]
else:
kv_lens = cached_token_lens + self.seq_lens_kv if cached_token_lens is not None else self.seq_lens_kv

# self.kv_lens is the valid kv cache length, while the self.kv_lens_cuda is
# the sequence length including the cached tokens and the input tokens.
self.kv_lens[:self.num_seqs].copy_(
Expand Down Expand Up @@ -1527,7 +1558,6 @@ def forward(
q_pe=q_pe,
mrope_config=mrope_config,
softmax_stats_tensor=softmax_stats_tensor,
helix_position_offsets=helix_position_offsets,
is_spec_decoding_enabled=metadata.is_spec_decoding_enabled,
use_spec_decoding=metadata.use_spec_decoding,
is_spec_dec_tree=metadata.is_spec_dec_tree,
Expand All @@ -1550,6 +1580,8 @@ def forward(
sparse_attn_indices_block_size=sparse_attn_indices_block_size,
sparse_mla_topk=metadata.sparse_mla_topk if hasattr(
metadata, 'sparse_mla_topk') else 0,
helix_position_offsets=helix_position_offsets,
helix_is_inactive_rank=metadata.helix_is_inactive_rank,
)
out_dtype = None
if out_scale is not None:
Expand Down Expand Up @@ -1809,6 +1841,7 @@ def mla_rope_generation(
mla_bmm2_scale: torch.Tensor,
quant_q_buffer: torch.Tensor,
helix_position_offsets: Optional[torch.Tensor] = None,
helix_is_inactive_rank: Optional[torch.Tensor] = None,
out_scale: Optional[torch.Tensor] = None,
) -> None:
"""
Expand All @@ -1828,7 +1861,19 @@ def mla_rope_generation(
assert self.is_mla_enable and self.mla_params is not None
assert metadata.kv_cache_manager is not None
sink_token_length = 0
mla_tensor_params = [helix_position_offsets]

# Ensure helix_is_inactive_rank is on the same device as other tensors
if helix_is_inactive_rank is not None:
if isinstance(helix_is_inactive_rank, list):
helix_is_inactive_rank = torch.tensor(
helix_is_inactive_rank,
dtype=torch.bool,
device=helix_position_offsets.device)
elif helix_is_inactive_rank.device.type != 'cuda':
helix_is_inactive_rank = helix_is_inactive_rank.to(
helix_position_offsets.device)

mla_tensor_params = [helix_position_offsets, helix_is_inactive_rank]

torch.ops.trtllm.mla_rope_generation(
fused_q,
Expand Down
18 changes: 17 additions & 1 deletion tensorrt_llm/_torch/distributed/communicator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import math
import pickle # nosec B403
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -341,9 +342,24 @@ class MPIDist(Distributed):

def __init__(self, mapping: Mapping):
super().__init__(mapping)
self.create_cp_comm()
# Repurpose CP ranks to TP for Helix so that the right comms are created.
mapping_with_cp = None
if self.mapping.cp_size > 1:
logger.info(
f"[MPIDist::__init__] Repurposing CP ranks to TP for Helix.")
mapping_with_cp = copy.deepcopy(self.mapping)
self.mapping = self.mapping.repurpose_helix_cp_to_tp()

self.create_tp_comm()
self.create_pp_comm()
self.create_cp_comm()

# Restore the original mapping.
if mapping_with_cp is not None:
logger.info(
f"[MPIDist::__init__] Restoring original mapping undoing Helix manipulation."
)
self.mapping = mapping_with_cp

def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
comm = mpi_comm()
Expand Down
Loading
Loading