Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 18 additions & 22 deletions tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,6 @@ def forward(
attn_metadata: AttentionMetadata,
) -> torch.Tensor:

# warm up does not prepare resources, there are two warmup requests
is_warmup = attn_metadata.kv_cache_manager is None or attn_metadata.request_ids == [
0
]

# calculate split size
num_contexts = attn_metadata.num_contexts
num_generations = attn_metadata.seq_lens.shape[0] - num_contexts
Expand All @@ -167,11 +162,18 @@ def forward(
split_gen = sum_seq[-1] - split_ctx
split_size = [split_ctx, split_gen]

# handle warm up request
if not is_warmup:
state_indices = attn_metadata.kv_cache_manager.get_state_indices()
split_indices = torch.split(state_indices,
[num_contexts, num_generations])
state_indices = attn_metadata.kv_cache_manager.get_state_indices()

# warm up does not prepare resources, so no relevant state indices
is_warmup = state_indices.numel() == 0
if is_warmup:
# in this case, assume batch takes first indices in mamba cache
state_indices = torch.arange(num_contexts + num_generations,
device=state_indices.device,
dtype=state_indices.dtype)

split_indices = torch.split(state_indices,
[num_contexts, num_generations])

split_seq_lens = torch.split(attn_metadata.seq_lens,
[num_contexts, num_generations])
Expand Down Expand Up @@ -200,16 +202,11 @@ def forward(
out = []
for req_type in batch:

if not is_warmup:
indices = split_indices[req_type].to(torch.device("cuda"))
conv_states = attn_metadata.kv_cache_manager.get_conv_states(
self.layer_idx)
ssm_states = attn_metadata.kv_cache_manager.get_ssm_states(
self.layer_idx)
else:
indices = None
conv_states = None
ssm_states = None
indices = split_indices[req_type].to(torch.device("cuda"))
conv_states = attn_metadata.kv_cache_manager.get_conv_states(
self.layer_idx)
ssm_states = attn_metadata.kv_cache_manager.get_ssm_states(
self.layer_idx)

z, xbc, dt = torch.split(
split_zxbcdt[req_type],
Expand Down Expand Up @@ -278,8 +275,7 @@ def forward(
y = rearrange(y, "b l h p -> (b l) (h p)")

# copy new ssm state
if not is_warmup:
ssm_states[indices] = current_ssm_states
ssm_states[indices] = current_ssm_states

# decode
else:
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,9 @@ def __init__(
self.mamba_cache_index: Dict[int, int] = {}

# mamba cache state indices
self.state_indices: torch.Tensor = torch.Tensor()
self.state_indices: torch.Tensor = torch.tensor([],
device=device,
dtype=torch.int32)

def prepare_mamba_cache_blocks(self, request_ids: List[int]):
state_indices = []
Expand Down
Loading