Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion megatron/core/ssm/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,8 @@ def dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInferen
zxBCdt_chunked_prefill = zxBCdt[
active_token_count - chunked_prefill_request_token_count : active_token_count
]
batch_index_chunked_prefill = batch_indices[context.chunked_prefill_request_id]
pos = torch.where(context.request_ids == context.chunked_prefill_request_id)[0][0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we abstract away this extraction in context?

batch_index_chunked_prefill = batch_indices[pos]

y_prefill_chunked = self.ssm_prefill(
zxBCdt_chunked_prefill,
Expand Down
Loading