Skip to content

Commit f121f13

Browse files
authored
[nvbug 5325284][fix] Increase Nemotron-H warmup request robustness (#4954)
Signed-off-by: Tomer Asida <[email protected]>
1 parent fdfc711 commit f121f13

File tree

3 files changed

+286
-256
lines changed

3 files changed

+286
-256
lines changed

tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,6 @@ def forward(
154154
attn_metadata: AttentionMetadata,
155155
) -> torch.Tensor:
156156

157-
# warm up does not prepare resources, there are two warmup requests
158-
is_warmup = attn_metadata.kv_cache_manager is None or attn_metadata.request_ids == [
159-
0
160-
]
161-
162157
# calculate split size
163158
num_contexts = attn_metadata.num_contexts
164159
num_generations = attn_metadata.seq_lens.shape[0] - num_contexts
@@ -167,11 +162,18 @@ def forward(
167162
split_gen = sum_seq[-1] - split_ctx
168163
split_size = [split_ctx, split_gen]
169164

170-
# handle warm up request
171-
if not is_warmup:
172-
state_indices = attn_metadata.kv_cache_manager.get_state_indices()
173-
split_indices = torch.split(state_indices,
174-
[num_contexts, num_generations])
165+
state_indices = attn_metadata.kv_cache_manager.get_state_indices()
166+
167+
# warm up does not prepare resources, so no relevant state indices
168+
is_warmup = state_indices.numel() == 0
169+
if is_warmup:
170+
# in this case, assume batch takes first indices in mamba cache
171+
state_indices = torch.arange(num_contexts + num_generations,
172+
device=state_indices.device,
173+
dtype=state_indices.dtype)
174+
175+
split_indices = torch.split(state_indices,
176+
[num_contexts, num_generations])
175177

176178
split_seq_lens = torch.split(attn_metadata.seq_lens,
177179
[num_contexts, num_generations])
@@ -200,16 +202,11 @@ def forward(
200202
out = []
201203
for req_type in batch:
202204

203-
if not is_warmup:
204-
indices = split_indices[req_type].to(torch.device("cuda"))
205-
conv_states = attn_metadata.kv_cache_manager.get_conv_states(
206-
self.layer_idx)
207-
ssm_states = attn_metadata.kv_cache_manager.get_ssm_states(
208-
self.layer_idx)
209-
else:
210-
indices = None
211-
conv_states = None
212-
ssm_states = None
205+
indices = split_indices[req_type].to(torch.device("cuda"))
206+
conv_states = attn_metadata.kv_cache_manager.get_conv_states(
207+
self.layer_idx)
208+
ssm_states = attn_metadata.kv_cache_manager.get_ssm_states(
209+
self.layer_idx)
213210

214211
z, xbc, dt = torch.split(
215212
split_zxbcdt[req_type],
@@ -278,8 +275,7 @@ def forward(
278275
y = rearrange(y, "b l h p -> (b l) (h p)")
279276

280277
# copy new ssm state
281-
if not is_warmup:
282-
ssm_states[indices] = current_ssm_states
278+
ssm_states[indices] = current_ssm_states
283279

284280
# decode
285281
else:

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,9 @@ def __init__(
578578
self.mamba_cache_index: Dict[int, int] = {}
579579

580580
# mamba cache state indices
581-
self.state_indices: torch.Tensor = torch.Tensor()
581+
self.state_indices: torch.Tensor = torch.tensor([],
582+
device=device,
583+
dtype=torch.int32)
582584

583585
def prepare_mamba_cache_blocks(self, request_ids: List[int]):
584586
state_indices = []

0 commit comments

Comments
 (0)