Skip to content

Commit bc66c18

Browse files
bugfix
1 parent e8b9df1 commit bc66c18

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

megatron/core/ssm/mamba_mixer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,8 @@ def dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInferen
486486
zxBCdt_chunked_prefill = zxBCdt[
487487
active_token_count - chunked_prefill_request_token_count : active_token_count
488488
]
489-
batch_index_chunked_prefill = batch_indices[context.chunked_prefill_request_id]
489+
pos = torch.where(context.request_ids == context.chunked_prefill_request_id)[0][0]
490+
batch_index_chunked_prefill = batch_indices[pos]
490491

491492
y_prefill_chunked = self.ssm_prefill(
492493
zxBCdt_chunked_prefill,

0 commit comments

Comments
 (0)