@@ -1197,9 +1197,9 @@ def forward_impl(self,
11971197 q_gen ,
11981198 compressed_kv_gen ,
11991199 k_pe_gen ,
1200- position_ids ,
12011200 attn_metadata ,
12021201 output [num_ctx_tokens :num_tokens , :],
1202+ position_ids = position_ids ,
12031203 latent_cache = latent_cache_gen ,
12041204 )
12051205
@@ -1340,10 +1340,6 @@ def forward_context_default(
13401340
13411341 helix_position_offsets = position_ids if self .mapping .cp_size > 1 else None
13421342
1343- # helix_position_offsets = position_ids if self.mapping.has_cp_helix(
1344- # ) else None
1345- helix_position_offsets = position_ids if self .mapping .cp_size > 1 else None
1346-
13471343 attn_output = self .mha .forward (
13481344 q ,
13491345 k ,
@@ -1389,7 +1385,6 @@ def forward_generation_dsa(
13891385 q : torch .Tensor ,
13901386 compressed_kv : torch .Tensor ,
13911387 k_pe : torch .Tensor ,
1392- position_ids : Optional [torch .Tensor ],
13931388 attn_metadata : AttentionMetadata ,
13941389 output : torch .Tensor ,
13951390 latent_cache : Optional [torch .Tensor ] = None ,
@@ -1399,7 +1394,6 @@ def forward_generation_dsa(
13991394 return self .forward_absorption_generation (q ,
14001395 compressed_kv ,
14011396 k_pe ,
1402- position_ids ,
14031397 attn_metadata ,
14041398 output ,
14051399 latent_cache = latent_cache ,
@@ -1673,9 +1667,9 @@ def forward_absorption_generation(
16731667 q : torch .Tensor ,
16741668 compressed_kv : torch .Tensor ,
16751669 k_pe : torch .Tensor ,
1676- position_ids : torch .Tensor ,
16771670 attn_metadata : AttentionMetadata ,
16781671 output : torch .Tensor ,
1672+ position_ids : Optional [torch .Tensor ] = None ,
16791673 latent_cache : Optional [torch .Tensor ] = None ,
16801674 topk_indices : Optional [torch .Tensor ] = None ,
16811675 ) -> torch .Tensor :
@@ -1725,7 +1719,11 @@ def forward_absorption_generation(
17251719 )
17261720
17271721 # Compute helix_position_offsets for helix parallelism.
1728- helix_position_offsets = position_ids if self .mapping .cp_size > 1 else None
1722+ if self .mapping .cp_size > 1 :
1723+ assert position_ids is not None , "position_ids is required for helix parallelism."
1724+ helix_position_offsets = position_ids
1725+ else :
1726+ helix_position_offsets = None
17291727
17301728 rope_stream = self .aux_stream if not has_fp8_kv_cache else None
17311729 if self .k_b_proj_trans .dtype == torch .bfloat16 :
0 commit comments