@@ -191,6 +191,135 @@ def cohere_model_forward(
191191 )
192192
193193
194+ def cohere_model_forward_4_41 (
195+ self ,
196+ input_ids : torch .LongTensor = None ,
197+ attention_mask : Optional [torch .Tensor ] = None ,
198+ position_ids : Optional [torch .LongTensor ] = None ,
199+ past_key_values : Optional [List [torch .FloatTensor ]] = None ,
200+ inputs_embeds : Optional [torch .FloatTensor ] = None ,
201+ use_cache : Optional [bool ] = None ,
202+ output_attentions : Optional [bool ] = None ,
203+ output_hidden_states : Optional [bool ] = None ,
204+ return_dict : Optional [bool ] = None ,
205+ cache_position : Optional [torch .LongTensor ] = None ,
206+ ):
207+ use_cache = use_cache if use_cache is not None \
208+ else self .config .use_cache
209+ if use_cache and use_quantize_kv_cache (self .layers [0 ].mlp .up_proj , input_ids ):
210+ if not isinstance (past_key_values , DynamicFp8Cache ):
211+ past_key_values = DynamicFp8Cache .from_legacy_cache (past_key_values )
212+ output_attentions = output_attentions if output_attentions is not None \
213+ else self .config .output_attentions
214+ output_hidden_states = (
215+ output_hidden_states if output_hidden_states is not None
216+ else self .config .output_hidden_states
217+ )
218+ use_cache = use_cache if use_cache is not None else self .config .use_cache
219+ return_dict = return_dict if return_dict is not None else self .config .use_return_dict
220+
221+ if input_ids is not None and inputs_embeds is not None :
222+ invalidInputError (False ,
223+ "You cannot specify both input_ids and inputs_embeds at the same time" )
224+
225+ if self .gradient_checkpointing and self .training and use_cache :
226+ invalidInputError (False ,
227+ "`use_cache=True` is incompatible "
228+ "with gradient checkpointing. Setting `use_cache=False`." )
229+ use_cache = False
230+
231+ if inputs_embeds is None :
232+ inputs_embeds = self .embed_tokens (input_ids )
233+
234+ past_seen_tokens = 0
235+ return_legacy_cache = False
236+ # kept for BC (non `Cache` `past_key_values` inputs)
237+ if use_cache and not isinstance (past_key_values , Cache ):
238+ return_legacy_cache = True
239+ past_key_values = DynamicCache .from_legacy_cache (past_key_values )
240+
241+ if cache_position is None :
242+ past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
243+ cache_position = torch .arange (
244+ past_seen_tokens , past_seen_tokens + inputs_embeds .shape [1 ], device = inputs_embeds .device
245+ )
246+
247+ if position_ids is None :
248+ position_ids = cache_position .unsqueeze (0 )
249+
250+ causal_mask = self ._update_causal_mask (
251+ attention_mask , inputs_embeds , cache_position , past_key_values , output_attentions
252+ )
253+
254+ # embed positions
255+ hidden_states = inputs_embeds
256+
257+ # decoder layers
258+ all_hidden_states = () if output_hidden_states else None
259+ all_self_attns = () if output_attentions else None
260+ next_decoder_cache = None
261+
262+ for decoder_layer in self .layers :
263+ if output_hidden_states :
264+ all_hidden_states += (hidden_states ,)
265+
266+ if self .gradient_checkpointing and self .training :
267+ layer_outputs = self ._gradient_checkpointing_func (
268+ decoder_layer .__call__ ,
269+ hidden_states ,
270+ causal_mask ,
271+ position_ids ,
272+ past_key_values ,
273+ output_attentions ,
274+ use_cache ,
275+ cache_position ,
276+ )
277+ else :
278+ # ipex-llm changes
279+ curr_device = decoder_layer .input_layernorm .weight .device
280+ if causal_mask is not None :
281+ causal_mask = causal_mask .to (curr_device )
282+ if position_ids is not None :
283+ position_ids = position_ids .to (curr_device )
284+ # ipex-llm changes end
285+ layer_outputs = decoder_layer (
286+ hidden_states ,
287+ attention_mask = causal_mask ,
288+ position_ids = position_ids ,
289+ past_key_value = past_key_values ,
290+ output_attentions = output_attentions ,
291+ use_cache = use_cache ,
292+ cache_position = cache_position ,
293+ )
294+
295+ hidden_states = layer_outputs [0 ]
296+
297+ if use_cache :
298+ next_decoder_cache = layer_outputs [2 if output_attentions else 1 ]
299+
300+ if output_attentions :
301+ all_self_attns += (layer_outputs [1 ],)
302+
303+ hidden_states = self .norm (hidden_states )
304+
305+ # add hidden states from the last decoder layer
306+ if output_hidden_states :
307+ all_hidden_states += (hidden_states ,)
308+
309+ next_cache = next_decoder_cache if use_cache else None
310+ if return_legacy_cache :
311+ next_cache = next_cache .to_legacy_cache ()
312+ if not return_dict :
313+ return tuple (v for v in [hidden_states , next_cache ,
314+ all_hidden_states , all_self_attns ] if v is not None )
315+ return BaseModelOutputWithPast (
316+ last_hidden_state = hidden_states ,
317+ past_key_values = next_cache ,
318+ hidden_states = all_hidden_states ,
319+ attentions = all_self_attns ,
320+ )
321+
322+
194323def cohere_attention_forward (
195324 self ,
196325 hidden_states : torch .Tensor ,
0 commit comments