@@ -154,11 +154,6 @@ def forward(
154154 attn_metadata : AttentionMetadata ,
155155 ) -> torch .Tensor :
156156
157- # warm up does not prepare resources, there are two warmup requests
158- is_warmup = attn_metadata .kv_cache_manager is None or attn_metadata .request_ids == [
159- 0
160- ]
161-
162157 # calculate split size
163158 num_contexts = attn_metadata .num_contexts
164159 num_generations = attn_metadata .seq_lens .shape [0 ] - num_contexts
@@ -167,11 +162,18 @@ def forward(
167162 split_gen = sum_seq [- 1 ] - split_ctx
168163 split_size = [split_ctx , split_gen ]
169164
170- # handle warm up request
171- if not is_warmup :
172- state_indices = attn_metadata .kv_cache_manager .get_state_indices ()
173- split_indices = torch .split (state_indices ,
174- [num_contexts , num_generations ])
165+ state_indices = attn_metadata .kv_cache_manager .get_state_indices ()
166+
167+ # warm up does not prepare resources, so no relevant state indices
168+ is_warmup = state_indices .numel () == 0
169+ if is_warmup :
170+ # in this case, assume batch takes first indices in mamba cache
171+ state_indices = torch .arange (num_contexts + num_generations ,
172+ device = state_indices .device ,
173+ dtype = state_indices .dtype )
174+
175+ split_indices = torch .split (state_indices ,
176+ [num_contexts , num_generations ])
175177
176178 split_seq_lens = torch .split (attn_metadata .seq_lens ,
177179 [num_contexts , num_generations ])
@@ -200,16 +202,11 @@ def forward(
200202 out = []
201203 for req_type in batch :
202204
203- if not is_warmup :
204- indices = split_indices [req_type ].to (torch .device ("cuda" ))
205- conv_states = attn_metadata .kv_cache_manager .get_conv_states (
206- self .layer_idx )
207- ssm_states = attn_metadata .kv_cache_manager .get_ssm_states (
208- self .layer_idx )
209- else :
210- indices = None
211- conv_states = None
212- ssm_states = None
205+ indices = split_indices [req_type ].to (torch .device ("cuda" ))
206+ conv_states = attn_metadata .kv_cache_manager .get_conv_states (
207+ self .layer_idx )
208+ ssm_states = attn_metadata .kv_cache_manager .get_ssm_states (
209+ self .layer_idx )
213210
214211 z , xbc , dt = torch .split (
215212 split_zxbcdt [req_type ],
@@ -278,8 +275,7 @@ def forward(
278275 y = rearrange (y , "b l h p -> (b l) (h p)" )
279276
280277 # copy new ssm state
281- if not is_warmup :
282- ssm_states [indices ] = current_ssm_states
278+ ssm_states [indices ] = current_ssm_states
283279
284280 # decode
285281 else :
0 commit comments