@@ -1702,6 +1702,14 @@ def forward_absorption_generation(
17021702
17031703 # Compute helix_position_offsets for helix parallelism.
17041704 helix_position_offsets = position_ids if self .mapping .cp_size > 1 else None
1705+ # Get helix_is_inactive_rank from attn_metadata for helix parallelism.
1706+ helix_is_inactive_rank = getattr (attn_metadata , 'helix_is_inactive_rank' , None )
1707+
1708+ if self .mapping .cp_size > 1 :
1709+ assert helix_position_offsets is not None
1710+ assert helix_is_inactive_rank is not None
1711+ print (f"[Attention] helix_position_offsets: { helix_position_offsets } " )
1712+ print (f"[Attention] helix_is_inactive_rank: { helix_is_inactive_rank } " )
17051713
17061714 rope_stream = self .aux_stream if not has_fp8_kv_cache else None
17071715 if self .k_b_proj_trans .dtype == torch .bfloat16 :
@@ -1727,7 +1735,9 @@ def forward_absorption_generation(
17271735 mla_bmm2_scale ,
17281736 quant_q_buffer ,
17291737 helix_position_offsets =
1730- helix_position_offsets ),
1738+ helix_position_offsets ,
1739+ helix_is_inactive_rank =
1740+ helix_is_inactive_rank ),
17311741 self .ln_events [0 ],
17321742 self .ln_events [1 ],
17331743 rope_stream ,
@@ -1756,7 +1766,9 @@ def forward_absorption_generation(
17561766 mla_bmm2_scale ,
17571767 quant_q_buffer ,
17581768 helix_position_offsets =
1759- helix_position_offsets ),
1769+ helix_position_offsets ,
1770+ helix_is_inactive_rank =
1771+ helix_is_inactive_rank ),
17601772 self .ln_events [0 ],
17611773 self .ln_events [1 ],
17621774 rope_stream ,
0 commit comments