Skip to content

Commit 448700e

Browse files
committed
Change the concat to copy_ to reduce the memory usage
Signed-off-by: Chenghao Zhang <[email protected]>
1 parent ad94c87 commit 448700e

File tree

1 file changed

+12
-22
lines changed

1 file changed

+12
-22
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,10 @@ def _triton_cached_ssm(
189189

190190
num_prefill, num_prefill_tokens, num_decode = batch_info_tensor.tolist()
191191

192-
# Prefill: concatenate tokens at the front and run combined scan
193192
y_prefill = None
193+
y_dec = None
194+
195+
# Prefill: concatenate tokens at the front and run combined scan
194196
if num_prefill > 0:
195197
hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
196198
B_prefill = B_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
@@ -234,11 +236,7 @@ def _triton_cached_ssm(
234236
0, slot_idx[:num_prefill], varlen_states.to(ssm_state_cache.dtype)
235237
)
236238

237-
# y_prefill is [1, S_p, H, D] -> remove batch dim
238-
y_prefill = y_prefill[0]
239-
240239
# Decode: batch single-token updates via selective_state_update
241-
y_dec = None
242240
if num_decode > 0:
243241
slot_idx_decode = slot_idx[num_prefill:]
244242

@@ -266,27 +264,19 @@ def _triton_cached_ssm(
266264
state_batch_indices=slot_idx_decode,
267265
) # [nd, H, D]
268266

269-
# Combine results
267+
# Dispatch return logic
270268
if num_prefill > 0 and num_decode > 0:
271-
# Concatenate prefill and decode outputs to form the final flattened output
272-
# Both need to be the same dtype
273-
y_flat = torch.cat(
274-
[y_prefill.to(hidden_states.dtype), y_dec.to(hidden_states.dtype)], dim=0
275-
)
269+
y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format)
270+
y_flat = y.view(bs, *y.shape[2:])
271+
y_flat[:num_prefill_tokens].copy_(y_prefill[0])
272+
y_flat[num_prefill_tokens : num_prefill_tokens + num_decode].copy_(y_dec)
273+
return y
276274
elif num_prefill > 0:
277-
y_flat = y_prefill.to(hidden_states.dtype)
275+
return y_prefill[0].view(b, s, num_heads, head_dim).to(hidden_states.dtype)
278276
elif num_decode > 0:
279-
y_flat = y_dec.to(hidden_states.dtype)
277+
return y_dec.view(b, s, num_heads, head_dim).to(hidden_states.dtype)
280278
else:
281-
# Should not happen given input shapes, but handle empty case
282-
y_flat = torch.empty(
283-
0, num_heads, head_dim, device=hidden_states.device, dtype=hidden_states.dtype
284-
)
285-
286-
# Reshape back to [B, S, H, D] if needed, or return flat if layout allows
287-
# The original code reshaped y_flat into y [b, s, h, d] via view at the start.
288-
# We constructed y_flat directly, so we just view it back to original shape.
289-
return y_flat.view(b, s, num_heads, head_dim)
279+
return torch.empty_like(hidden_states)
290280

291281

292282
@_triton_cached_ssm.register_fake

0 commit comments

Comments
 (0)