Skip to content

Commit f55c03f

Browse files
committed
updates to test - seqlen 64 works
1 parent 2b9acff commit f55c03f

File tree

4 files changed

+35
-13
lines changed

4 files changed

+35
-13
lines changed

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ class Runner : public RunnerBase
234234
}
235235
if (mla_helix_is_inactive_rank.has_value())
236236
{
237-
mla_params.helix_is_inactive_rank = mla_helix_is_inactive_rank->const_data_ptr<bool>();
237+
mla_params.helix_is_inactive_rank = mla_helix_is_inactive_rank->data_ptr<bool>();
238238
}
239239
}
240240
else

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def plan(
296296
self.sparse_mla_topk = sparse_mla_topk
297297
self.helix_position_offsets = helix_position_offsets
298298
self.helix_is_inactive_rank = helix_is_inactive_rank
299-
if self.helix_is_inactive_rank is not None:
299+
if self.helix_is_inactive_rank is not None and not isinstance(self.helix_is_inactive_rank, torch.Tensor):
300300
self.helix_is_inactive_rank = torch.tensor(
301301
self.helix_is_inactive_rank, dtype=torch.bool, pin_memory=True)
302302

tensorrt_llm/_torch/modules/attention.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,8 +1700,14 @@ def forward_absorption_generation(
17001700
device=q.device,
17011701
)
17021702

1703-
# Compute helix_position_offsets for helix parallelism.
1704-
helix_position_offsets = position_ids if self.mapping.cp_size > 1 else None
1703+
if self.mapping.cp_size > 1:
1704+
helix_position_offsets = position_ids
1705+
helix_is_inactive_rank = attn_metadata.helix_is_inactive_rank
1706+
assert helix_position_offsets is not None, "helix_position_offsets must be provided for helix parallelism."
1707+
assert helix_is_inactive_rank is not None, "helix_is_inactive_rank must be provided for helix parallelism."
1708+
else:
1709+
helix_position_offsets = None
1710+
helix_is_inactive_rank = None
17051711

17061712
rope_stream = self.aux_stream if not has_fp8_kv_cache else None
17071713
if self.k_b_proj_trans.dtype == torch.bfloat16:
@@ -1727,7 +1733,9 @@ def forward_absorption_generation(
17271733
mla_bmm2_scale,
17281734
quant_q_buffer,
17291735
helix_position_offsets=
1730-
helix_position_offsets),
1736+
helix_position_offsets,
1737+
helix_is_inactive_rank=
1738+
helix_is_inactive_rank),
17311739
self.ln_events[0],
17321740
self.ln_events[1],
17331741
rope_stream,
@@ -1756,7 +1764,9 @@ def forward_absorption_generation(
17561764
mla_bmm2_scale,
17571765
quant_q_buffer,
17581766
helix_position_offsets=
1759-
helix_position_offsets),
1767+
helix_position_offsets,
1768+
helix_is_inactive_rank=
1769+
helix_is_inactive_rank),
17601770
self.ln_events[0],
17611771
self.ln_events[1],
17621772
rope_stream,

tests/unittest/_torch/modules/test_mla_helix.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,20 @@ def max_position_embeddings(self) -> int:
125125
Scenario(batch=16, ctx_len=16384),
126126
Scenario(batch=16, ctx_len=32768),
127127
Scenario(batch=16, ctx_len=65536),
128+
Scenario(batch=1, ctx_len=64),
128129
]
129130

130131
# limit the number of test scenarios to avoid taking too long
131132
test_scenarios = [
132133
# note: tests with ctx_len=1024 (or less) are currently failing, most likely due to
133134
# bad numerics especially with bf16. We ignore those tests for now.
134-
all_scenarios[2],
135-
all_scenarios[5],
136-
all_scenarios[12],
137-
all_scenarios[15],
138-
all_scenarios[21],
139-
all_scenarios[22],
135+
# all_scenarios[2],
136+
# all_scenarios[5],
137+
# all_scenarios[12],
138+
# all_scenarios[15],
139+
# all_scenarios[21],
140+
# all_scenarios[22],
141+
all_scenarios[-1],
140142
]
141143

142144

@@ -501,9 +503,16 @@ def _run_mla_distributed(
501503
start = time.time()
502504

503505
for step in range(gen_steps):
506+
helix_is_inactive_rank = []
504507
for req_id in range(scenario.batch):
505508
kv_cache_manager.impl.add_token(req_id)
506-
cache_add = step if rank == world_size - 1 else 0
509+
# Assume last rank is active for all gen steps.
510+
if rank == world_size - 1:
511+
helix_is_inactive_rank.append(False)
512+
cache_add = step
513+
else:
514+
helix_is_inactive_rank.append(True)
515+
cache_add = 0
507516
cached_tokens_per_seq = [ctx_len_per_gpu + cache_add for _ in range(scenario.batch)]
508517
if step == 0:
509518
attn_metadata = get_attention_backend("TRTLLM").Metadata(
@@ -519,12 +528,15 @@ def _run_mla_distributed(
519528
num_cached_tokens_per_seq=cached_tokens_per_seq,
520529
),
521530
enable_context_mla_with_cached_kv=True,
531+
helix_is_inactive_rank=torch.tensor(helix_is_inactive_rank, dtype=torch.bool, device="cuda"),
522532
)
523533
else:
524534
attn_metadata.kv_cache_params = KVCacheParams(
525535
use_cache=True,
526536
num_cached_tokens_per_seq=cached_tokens_per_seq,
527537
)
538+
attn_metadata.helix_is_inactive_rank = torch.tensor(
539+
helix_is_inactive_rank, dtype=torch.bool, device="cuda")
528540
attn_metadata.prepare()
529541
extra_attrs["attention_metadata"] = weakref.ref(attn_metadata)
530542
if not use_cuda_graph:

0 commit comments

Comments
 (0)