diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 895792ff05..89ae8c997f 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -486,7 +486,8 @@ def dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInferen zxBCdt_chunked_prefill = zxBCdt[ active_token_count - chunked_prefill_request_token_count : active_token_count ] - batch_index_chunked_prefill = batch_indices[context.chunked_prefill_request_id] + pos = torch.where(context.request_ids == context.chunked_prefill_request_id)[0][0] + batch_index_chunked_prefill = batch_indices[pos] y_prefill_chunked = self.ssm_prefill( zxBCdt_chunked_prefill,