Skip to content

Commit 6f7ffc7

Browse files
committed
minor updates to attention.py
1 parent 20edf93 commit 6f7ffc7

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

tensorrt_llm/_torch/modules/attention.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,9 +1197,9 @@ def forward_impl(self,
11971197
q_gen,
11981198
compressed_kv_gen,
11991199
k_pe_gen,
1200-
position_ids,
12011200
attn_metadata,
12021201
output[num_ctx_tokens:num_tokens, :],
1202+
position_ids=position_ids,
12031203
latent_cache=latent_cache_gen,
12041204
)
12051205

@@ -1340,10 +1340,6 @@ def forward_context_default(
13401340

13411341
helix_position_offsets = position_ids if self.mapping.cp_size > 1 else None
13421342

1343-
# helix_position_offsets = position_ids if self.mapping.has_cp_helix(
1344-
# ) else None
1345-
helix_position_offsets = position_ids if self.mapping.cp_size > 1 else None
1346-
13471343
attn_output = self.mha.forward(
13481344
q,
13491345
k,
@@ -1389,7 +1385,6 @@ def forward_generation_dsa(
13891385
q: torch.Tensor,
13901386
compressed_kv: torch.Tensor,
13911387
k_pe: torch.Tensor,
1392-
position_ids: Optional[torch.Tensor],
13931388
attn_metadata: AttentionMetadata,
13941389
output: torch.Tensor,
13951390
latent_cache: Optional[torch.Tensor] = None,
@@ -1399,7 +1394,6 @@ def forward_generation_dsa(
13991394
return self.forward_absorption_generation(q,
14001395
compressed_kv,
14011396
k_pe,
1402-
position_ids,
14031397
attn_metadata,
14041398
output,
14051399
latent_cache=latent_cache,
@@ -1673,9 +1667,9 @@ def forward_absorption_generation(
16731667
q: torch.Tensor,
16741668
compressed_kv: torch.Tensor,
16751669
k_pe: torch.Tensor,
1676-
position_ids: torch.Tensor,
16771670
attn_metadata: AttentionMetadata,
16781671
output: torch.Tensor,
1672+
position_ids: Optional[torch.Tensor] = None,
16791673
latent_cache: Optional[torch.Tensor] = None,
16801674
topk_indices: Optional[torch.Tensor] = None,
16811675
) -> torch.Tensor:
@@ -1725,7 +1719,11 @@ def forward_absorption_generation(
17251719
)
17261720

17271721
# Compute helix_position_offsets for helix parallelism.
1728-
helix_position_offsets = position_ids if self.mapping.cp_size > 1 else None
1722+
if self.mapping.cp_size > 1:
1723+
assert position_ids is not None, "position_ids is required for helix parallelism."
1724+
helix_position_offsets = position_ids
1725+
else:
1726+
helix_position_offsets = None
17291727

17301728
rope_stream = self.aux_stream if not has_fp8_kv_cache else None
17311729
if self.k_b_proj_trans.dtype == torch.bfloat16:

0 commit comments

Comments
 (0)