@@ -125,18 +125,20 @@ def max_position_embeddings(self) -> int:
125125 Scenario (batch = 16 , ctx_len = 16384 ),
126126 Scenario (batch = 16 , ctx_len = 32768 ),
127127 Scenario (batch = 16 , ctx_len = 65536 ),
128+ Scenario (batch = 1 , ctx_len = 64 ),
128129]
129130
130131# limit the number of test scenarios to avoid taking too long
131132test_scenarios = [
132133 # note: tests with ctx_len=1024 (or less) are currently failing, most likely due to
133134 # bad numerics especially with bf16. We ignore those tests for now.
134- all_scenarios [2 ],
135- all_scenarios [5 ],
136- all_scenarios [12 ],
137- all_scenarios [15 ],
138- all_scenarios [21 ],
139- all_scenarios [22 ],
135+ # all_scenarios[2],
136+ # all_scenarios[5],
137+ # all_scenarios[12],
138+ # all_scenarios[15],
139+ # all_scenarios[21],
140+ # all_scenarios[22],
141+ all_scenarios [- 1 ],
140142]
141143
142144
@@ -501,9 +503,16 @@ def _run_mla_distributed(
501503 start = time .time ()
502504
503505 for step in range (gen_steps ):
506+ helix_is_inactive_rank = []
504507 for req_id in range (scenario .batch ):
505508 kv_cache_manager .impl .add_token (req_id )
506- cache_add = step if rank == world_size - 1 else 0
509+ # Assume last rank is active for all gen steps.
510+ if rank == world_size - 1 :
511+ helix_is_inactive_rank .append (False )
512+ cache_add = step
513+ else :
514+ helix_is_inactive_rank .append (True )
515+ cache_add = 0
507516 cached_tokens_per_seq = [ctx_len_per_gpu + cache_add for _ in range (scenario .batch )]
508517 if step == 0 :
509518 attn_metadata = get_attention_backend ("TRTLLM" ).Metadata (
@@ -519,12 +528,14 @@ def _run_mla_distributed(
519528 num_cached_tokens_per_seq = cached_tokens_per_seq ,
520529 ),
521530 enable_context_mla_with_cached_kv = True ,
531+ helix_is_inactive_rank = helix_is_inactive_rank ,
522532 )
523533 else :
524534 attn_metadata .kv_cache_params = KVCacheParams (
525535 use_cache = True ,
526536 num_cached_tokens_per_seq = cached_tokens_per_seq ,
527537 )
538+ attn_metadata .helix_is_inactive_rank = helix_is_inactive_rank
528539 attn_metadata .prepare ()
529540 extra_attrs ["attention_metadata" ] = weakref .ref (attn_metadata )
530541 if not use_cuda_graph :
0 commit comments