Skip to content

Commit 59ea383

Browse files
committed
updates to test - seqlen 64 fails
1 parent 2b9acff commit 59ea383

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

tests/unittest/_torch/modules/test_mla_helix.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,20 @@ def max_position_embeddings(self) -> int:
125125
Scenario(batch=16, ctx_len=16384),
126126
Scenario(batch=16, ctx_len=32768),
127127
Scenario(batch=16, ctx_len=65536),
128+
Scenario(batch=1, ctx_len=64),
128129
]
129130

130131
# limit the number of test scenarios to avoid taking too long
131132
test_scenarios = [
132133
# note: tests with ctx_len=1024 (or less) are currently failing, most likely due to
133134
# bad numerics especially with bf16. We ignore those tests for now.
134-
all_scenarios[2],
135-
all_scenarios[5],
136-
all_scenarios[12],
137-
all_scenarios[15],
138-
all_scenarios[21],
139-
all_scenarios[22],
135+
# all_scenarios[2],
136+
# all_scenarios[5],
137+
# all_scenarios[12],
138+
# all_scenarios[15],
139+
# all_scenarios[21],
140+
# all_scenarios[22],
141+
all_scenarios[-1],
140142
]
141143

142144

@@ -501,9 +503,16 @@ def _run_mla_distributed(
501503
start = time.time()
502504

503505
for step in range(gen_steps):
506+
helix_is_inactive_rank = []
504507
for req_id in range(scenario.batch):
505508
kv_cache_manager.impl.add_token(req_id)
506-
cache_add = step if rank == world_size - 1 else 0
509+
# Assume last rank is active for all gen steps.
510+
if rank == world_size - 1:
511+
helix_is_inactive_rank.append(False)
512+
cache_add = step
513+
else:
514+
helix_is_inactive_rank.append(True)
515+
cache_add = 0
507516
cached_tokens_per_seq = [ctx_len_per_gpu + cache_add for _ in range(scenario.batch)]
508517
if step == 0:
509518
attn_metadata = get_attention_backend("TRTLLM").Metadata(
@@ -519,12 +528,14 @@ def _run_mla_distributed(
519528
num_cached_tokens_per_seq=cached_tokens_per_seq,
520529
),
521530
enable_context_mla_with_cached_kv=True,
531+
helix_is_inactive_rank=helix_is_inactive_rank,
522532
)
523533
else:
524534
attn_metadata.kv_cache_params = KVCacheParams(
525535
use_cache=True,
526536
num_cached_tokens_per_seq=cached_tokens_per_seq,
527537
)
538+
attn_metadata.helix_is_inactive_rank = helix_is_inactive_rank
528539
attn_metadata.prepare()
529540
extra_attrs["attention_metadata"] = weakref.ref(attn_metadata)
530541
if not use_cuda_graph:

0 commit comments

Comments
 (0)