@@ -189,8 +189,10 @@ def _triton_cached_ssm(
189189
190190 num_prefill , num_prefill_tokens , num_decode = batch_info_tensor .tolist ()
191191
192- # Prefill: concatenate tokens at the front and run combined scan
193192 y_prefill = None
193+ y_dec = None
194+
195+ # Prefill: concatenate tokens at the front and run combined scan
194196 if num_prefill > 0 :
195197 hs_prefill = hs_flat [:num_prefill_tokens ].unsqueeze (0 ) # [1, S_p, H, D]
196198 B_prefill = B_flat [:num_prefill_tokens ].unsqueeze (0 ) # [1, S_p, G, N]
@@ -234,11 +236,7 @@ def _triton_cached_ssm(
234236 0 , slot_idx [:num_prefill ], varlen_states .to (ssm_state_cache .dtype )
235237 )
236238
237- # y_prefill is [1, S_p, H, D] -> remove batch dim
238- y_prefill = y_prefill [0 ]
239-
240239 # Decode: batch single-token updates via selective_state_update
241- y_dec = None
242240 if num_decode > 0 :
243241 slot_idx_decode = slot_idx [num_prefill :]
244242
@@ -266,27 +264,19 @@ def _triton_cached_ssm(
266264 state_batch_indices = slot_idx_decode ,
267265 ) # [nd, H, D]
268266
269- # Combine results
267+ # Dispatch return logic
270268 if num_prefill > 0 and num_decode > 0 :
271- # Concatenate prefill and decode outputs to form the final flattened output
272- # Both need to be the same dtype
273- y_flat = torch . cat (
274- [ y_prefill . to ( hidden_states . dtype ), y_dec . to ( hidden_states . dtype )], dim = 0
275- )
269+ y = torch . empty_like ( hidden_states , memory_format = torch . contiguous_format )
270+ y_flat = y . view ( bs , * y . shape [ 2 :])
271+ y_flat [: num_prefill_tokens ]. copy_ ( y_prefill [ 0 ])
272+ y_flat [ num_prefill_tokens : num_prefill_tokens + num_decode ]. copy_ ( y_dec )
273+ return y
276274 elif num_prefill > 0 :
277- y_flat = y_prefill .to (hidden_states .dtype )
275+ return y_prefill [ 0 ]. view ( b , s , num_heads , head_dim ) .to (hidden_states .dtype )
278276 elif num_decode > 0 :
279- y_flat = y_dec .to (hidden_states .dtype )
277+ return y_dec . view ( b , s , num_heads , head_dim ) .to (hidden_states .dtype )
280278 else :
281- # Should not happen given input shapes, but handle empty case
282- y_flat = torch .empty (
283- 0 , num_heads , head_dim , device = hidden_states .device , dtype = hidden_states .dtype
284- )
285-
286- # Reshape back to [B, S, H, D] if needed, or return flat if layout allows
287- # The original code reshaped y_flat into y [b, s, h, d] via view at the start.
288- # We constructed y_flat directly, so we just view it back to original shape.
289- return y_flat .view (b , s , num_heads , head_dim )
279+ return torch .empty_like (hidden_states )
290280
291281
292282@_triton_cached_ssm .register_fake
0 commit comments