Skip to content

Commit fb50fd2

Browse files
committed
updates to test - seqlen 64 works
1 parent 2b9acff commit fb50fd2

File tree

6 files changed

+70
-11
lines changed

6 files changed

+70
-11
lines changed

cpp/tensorrt_llm/kernels/mlaKernels.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,14 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
354354
float const* dequant_scale_kv, float host_bmm1_scale, int32_t const* helix_position_offsets,
355355
bool const* helix_is_inactive_rank)
356356
{
357+
// if (helix_is_inactive_rank != nullptr)
358+
// {
359+
// printf("[applyMLARopeAndAssignQKVKernelGeneration] helix_is_inactive_rank: %p\n", helix_is_inactive_rank);
360+
// }
361+
// else
362+
// {
363+
// printf("[applyMLARopeAndAssignQKVKernelGeneration] helix_is_inactive_rank: nullptr\n");
364+
// }
357365

358366
// Constants.
359367
using VecT = typename VecType<T>::Type;

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,20 @@ class Runner : public RunnerBase
231231
if (mla_helix_position_offsets.has_value())
232232
{
233233
mla_params.helix_position_offsets = mla_helix_position_offsets->data_ptr<int32_t>();
234+
printf("[AttentionOp] helix_position_offsets: %p\n", mla_params.helix_position_offsets);
235+
}
236+
else
237+
{
238+
printf("[AttentionOp] helix_position_offsets: nullptr\n");
234239
}
235240
if (mla_helix_is_inactive_rank.has_value())
236241
{
237-
mla_params.helix_is_inactive_rank = mla_helix_is_inactive_rank->const_data_ptr<bool>();
242+
printf("[AttentionOp] helix_is_inactive_rank: %p\n", mla_helix_is_inactive_rank->data_ptr<bool>());
243+
mla_params.helix_is_inactive_rank = mla_helix_is_inactive_rank->data_ptr<bool>();
244+
}
245+
else
246+
{
247+
printf("[AttentionOp] helix_is_inactive_rank: nullptr\n");
238248
}
239249
}
240250
else

cpp/tensorrt_llm/thop/dsv3RopeOp.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,22 @@ void invokeMLARopeGenerationHelper(T const* latent_cache_ptr, T* q_pe_ptr, T* fu
108108
mla_params.helix_position_offsets = args.helix_position_offsets_ptr;
109109
mla_params.helix_is_inactive_rank = args.helix_is_inactive_rank_ptr;
110110

111+
if (mla_params.helix_position_offsets != nullptr)
112+
{
113+
printf("[invokeMLARopeGenerationHelper] helix_position_offsets: %p\n", mla_params.helix_position_offsets);
114+
}
115+
else
116+
{
117+
printf("[invokeMLARopeGenerationHelper] helix_position_offsets: nullptr\n");
118+
}
119+
if (mla_params.helix_is_inactive_rank != nullptr)
120+
{
121+
printf("[invokeMLARopeGenerationHelper] helix_is_inactive_rank: %p\n", mla_params.helix_is_inactive_rank);
122+
}
123+
else
124+
{
125+
printf("[invokeMLARopeGenerationHelper] helix_is_inactive_rank: nullptr\n");
126+
}
111127
tk::invokeMLARopeGeneration<T>(mla_params, kv_cache_buffer, stream);
112128
}
113129

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,10 @@ def plan(
296296
self.sparse_mla_topk = sparse_mla_topk
297297
self.helix_position_offsets = helix_position_offsets
298298
self.helix_is_inactive_rank = helix_is_inactive_rank
299-
if self.helix_is_inactive_rank is not None:
299+
if self.helix_is_inactive_rank is not None and not isinstance(self.helix_is_inactive_rank, torch.Tensor):
300300
self.helix_is_inactive_rank = torch.tensor(
301301
self.helix_is_inactive_rank, dtype=torch.bool, pin_memory=True)
302+
print(f"[TrtllmAttention] helix_is_inactive_rank: {self.helix_is_inactive_rank}")
302303

303304
if max_sequence_length > self.rope_params.max_positions:
304305
self.rope_params.max_positions = max_sequence_length

tensorrt_llm/_torch/modules/attention.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,6 +1702,14 @@ def forward_absorption_generation(
17021702

17031703
# Compute helix_position_offsets for helix parallelism.
17041704
helix_position_offsets = position_ids if self.mapping.cp_size > 1 else None
1705+
# Get helix_is_inactive_rank from attn_metadata for helix parallelism.
1706+
helix_is_inactive_rank = getattr(attn_metadata, 'helix_is_inactive_rank', None)
1707+
1708+
if self.mapping.cp_size > 1:
1709+
assert helix_position_offsets is not None
1710+
assert helix_is_inactive_rank is not None
1711+
print(f"[Attention] helix_position_offsets: {helix_position_offsets}")
1712+
print(f"[Attention] helix_is_inactive_rank: {helix_is_inactive_rank}")
17051713

17061714
rope_stream = self.aux_stream if not has_fp8_kv_cache else None
17071715
if self.k_b_proj_trans.dtype == torch.bfloat16:
@@ -1727,7 +1735,9 @@ def forward_absorption_generation(
17271735
mla_bmm2_scale,
17281736
quant_q_buffer,
17291737
helix_position_offsets=
1730-
helix_position_offsets),
1738+
helix_position_offsets,
1739+
helix_is_inactive_rank=
1740+
helix_is_inactive_rank),
17311741
self.ln_events[0],
17321742
self.ln_events[1],
17331743
rope_stream,
@@ -1756,7 +1766,9 @@ def forward_absorption_generation(
17561766
mla_bmm2_scale,
17571767
quant_q_buffer,
17581768
helix_position_offsets=
1759-
helix_position_offsets),
1769+
helix_position_offsets,
1770+
helix_is_inactive_rank=
1771+
helix_is_inactive_rank),
17601772
self.ln_events[0],
17611773
self.ln_events[1],
17621774
rope_stream,

tests/unittest/_torch/modules/test_mla_helix.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,20 @@ def max_position_embeddings(self) -> int:
125125
Scenario(batch=16, ctx_len=16384),
126126
Scenario(batch=16, ctx_len=32768),
127127
Scenario(batch=16, ctx_len=65536),
128+
Scenario(batch=1, ctx_len=64),
128129
]
129130

130131
# limit the number of test scenarios to avoid taking too long
131132
test_scenarios = [
132133
# note: tests with ctx_len=1024 (or less) are currently failing, most likely due to
133134
# bad numerics especially with bf16. We ignore those tests for now.
134-
all_scenarios[2],
135-
all_scenarios[5],
136-
all_scenarios[12],
137-
all_scenarios[15],
138-
all_scenarios[21],
139-
all_scenarios[22],
135+
# all_scenarios[2],
136+
# all_scenarios[5],
137+
# all_scenarios[12],
138+
# all_scenarios[15],
139+
# all_scenarios[21],
140+
# all_scenarios[22],
141+
all_scenarios[-1],
140142
]
141143

142144

@@ -501,9 +503,16 @@ def _run_mla_distributed(
501503
start = time.time()
502504

503505
for step in range(gen_steps):
506+
helix_is_inactive_rank = []
504507
for req_id in range(scenario.batch):
505508
kv_cache_manager.impl.add_token(req_id)
506-
cache_add = step if rank == world_size - 1 else 0
509+
# Assume last rank is active for all gen steps.
510+
if rank == world_size - 1:
511+
helix_is_inactive_rank.append(False)
512+
cache_add = step
513+
else:
514+
helix_is_inactive_rank.append(True)
515+
cache_add = 0
507516
cached_tokens_per_seq = [ctx_len_per_gpu + cache_add for _ in range(scenario.batch)]
508517
if step == 0:
509518
attn_metadata = get_attention_backend("TRTLLM").Metadata(
@@ -519,12 +528,15 @@ def _run_mla_distributed(
519528
num_cached_tokens_per_seq=cached_tokens_per_seq,
520529
),
521530
enable_context_mla_with_cached_kv=True,
531+
helix_is_inactive_rank=torch.tensor(helix_is_inactive_rank, dtype=torch.bool, device="cuda"),
522532
)
523533
else:
524534
attn_metadata.kv_cache_params = KVCacheParams(
525535
use_cache=True,
526536
num_cached_tokens_per_seq=cached_tokens_per_seq,
527537
)
538+
attn_metadata.helix_is_inactive_rank = torch.tensor(
539+
helix_is_inactive_rank, dtype=torch.bool, device="cuda")
528540
attn_metadata.prepare()
529541
extra_attrs["attention_metadata"] = weakref.ref(attn_metadata)
530542
if not use_cuda_graph:

0 commit comments

Comments
 (0)