Skip to content

Commit 41720fb

Browse files
committed
save changes to get seqlen 64 working
1 parent 59ea383 commit 41720fb

File tree

6 files changed

+54
-6
lines changed

6 files changed

+54
-6
lines changed

cpp/tensorrt_llm/kernels/mlaKernels.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,14 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
354354
float const* dequant_scale_kv, float host_bmm1_scale, int32_t const* helix_position_offsets,
355355
bool const* helix_is_inactive_rank)
356356
{
357+
// if (helix_is_inactive_rank != nullptr)
358+
// {
359+
// printf("[applyMLARopeAndAssignQKVKernelGeneration] helix_is_inactive_rank: %p\n", helix_is_inactive_rank);
360+
// }
361+
// else
362+
// {
363+
// printf("[applyMLARopeAndAssignQKVKernelGeneration] helix_is_inactive_rank: nullptr\n");
364+
// }
357365

358366
// Constants.
359367
using VecT = typename VecType<T>::Type;

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,20 @@ class Runner : public RunnerBase
231231
if (mla_helix_position_offsets.has_value())
232232
{
233233
mla_params.helix_position_offsets = mla_helix_position_offsets->data_ptr<int32_t>();
234+
printf("[AttentionOp] helix_position_offsets: %p\n", mla_params.helix_position_offsets);
235+
}
236+
else
237+
{
238+
printf("[AttentionOp] helix_position_offsets: nullptr\n");
234239
}
235240
if (mla_helix_is_inactive_rank.has_value())
236241
{
237-
mla_params.helix_is_inactive_rank = mla_helix_is_inactive_rank->const_data_ptr<bool>();
242+
printf("[AttentionOp] helix_is_inactive_rank: %p\n", mla_helix_is_inactive_rank->data_ptr<bool>());
243+
mla_params.helix_is_inactive_rank = mla_helix_is_inactive_rank->data_ptr<bool>();
244+
}
245+
else
246+
{
247+
printf("[AttentionOp] helix_is_inactive_rank: nullptr\n");
238248
}
239249
}
240250
else

cpp/tensorrt_llm/thop/dsv3RopeOp.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,22 @@ void invokeMLARopeGenerationHelper(T const* latent_cache_ptr, T* q_pe_ptr, T* fu
108108
mla_params.helix_position_offsets = args.helix_position_offsets_ptr;
109109
mla_params.helix_is_inactive_rank = args.helix_is_inactive_rank_ptr;
110110

111+
if (mla_params.helix_position_offsets != nullptr)
112+
{
113+
printf("[invokeMLARopeGenerationHelper] helix_position_offsets: %p\n", mla_params.helix_position_offsets);
114+
}
115+
else
116+
{
117+
printf("[invokeMLARopeGenerationHelper] helix_position_offsets: nullptr\n");
118+
}
119+
if (mla_params.helix_is_inactive_rank != nullptr)
120+
{
121+
printf("[invokeMLARopeGenerationHelper] helix_is_inactive_rank: %p\n", mla_params.helix_is_inactive_rank);
122+
}
123+
else
124+
{
125+
printf("[invokeMLARopeGenerationHelper] helix_is_inactive_rank: nullptr\n");
126+
}
111127
tk::invokeMLARopeGeneration<T>(mla_params, kv_cache_buffer, stream);
112128
}
113129

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,10 @@ def plan(
296296
self.sparse_mla_topk = sparse_mla_topk
297297
self.helix_position_offsets = helix_position_offsets
298298
self.helix_is_inactive_rank = helix_is_inactive_rank
299-
if self.helix_is_inactive_rank is not None:
299+
if self.helix_is_inactive_rank is not None and not isinstance(self.helix_is_inactive_rank, torch.Tensor):
300300
self.helix_is_inactive_rank = torch.tensor(
301301
self.helix_is_inactive_rank, dtype=torch.bool, pin_memory=True)
302+
print(f"[TrtllmAttention] helix_is_inactive_rank: {self.helix_is_inactive_rank}")
302303

303304
if max_sequence_length > self.rope_params.max_positions:
304305
self.rope_params.max_positions = max_sequence_length

tensorrt_llm/_torch/modules/attention.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,6 +1702,14 @@ def forward_absorption_generation(
17021702

17031703
# Compute helix_position_offsets for helix parallelism.
17041704
helix_position_offsets = position_ids if self.mapping.cp_size > 1 else None
1705+
# Get helix_is_inactive_rank from attn_metadata for helix parallelism.
1706+
helix_is_inactive_rank = getattr(attn_metadata, 'helix_is_inactive_rank', None)
1707+
1708+
if self.mapping.cp_size > 1:
1709+
assert helix_position_offsets is not None
1710+
assert helix_is_inactive_rank is not None
1711+
print(f"[Attention] helix_position_offsets: {helix_position_offsets}")
1712+
print(f"[Attention] helix_is_inactive_rank: {helix_is_inactive_rank}")
17051713

17061714
rope_stream = self.aux_stream if not has_fp8_kv_cache else None
17071715
if self.k_b_proj_trans.dtype == torch.bfloat16:
@@ -1727,7 +1735,9 @@ def forward_absorption_generation(
17271735
mla_bmm2_scale,
17281736
quant_q_buffer,
17291737
helix_position_offsets=
1730-
helix_position_offsets),
1738+
helix_position_offsets,
1739+
helix_is_inactive_rank=
1740+
helix_is_inactive_rank),
17311741
self.ln_events[0],
17321742
self.ln_events[1],
17331743
rope_stream,
@@ -1756,7 +1766,9 @@ def forward_absorption_generation(
17561766
mla_bmm2_scale,
17571767
quant_q_buffer,
17581768
helix_position_offsets=
1759-
helix_position_offsets),
1769+
helix_position_offsets,
1770+
helix_is_inactive_rank=
1771+
helix_is_inactive_rank),
17601772
self.ln_events[0],
17611773
self.ln_events[1],
17621774
rope_stream,

tests/unittest/_torch/modules/test_mla_helix.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,14 +528,15 @@ def _run_mla_distributed(
528528
num_cached_tokens_per_seq=cached_tokens_per_seq,
529529
),
530530
enable_context_mla_with_cached_kv=True,
531-
helix_is_inactive_rank=helix_is_inactive_rank,
531+
helix_is_inactive_rank=torch.tensor(helix_is_inactive_rank, dtype=torch.bool, device="cuda"),
532532
)
533533
else:
534534
attn_metadata.kv_cache_params = KVCacheParams(
535535
use_cache=True,
536536
num_cached_tokens_per_seq=cached_tokens_per_seq,
537537
)
538-
attn_metadata.helix_is_inactive_rank = helix_is_inactive_rank
538+
attn_metadata.helix_is_inactive_rank = torch.tensor(
539+
helix_is_inactive_rank, dtype=torch.bool, device="cuda")
539540
attn_metadata.prepare()
540541
extra_attrs["attention_metadata"] = weakref.ref(attn_metadata)
541542
if not use_cuda_graph:

0 commit comments

Comments
 (0)