@@ -254,9 +254,9 @@ def model_forward(
254254 ):
255255 # IPEX-LLM OPT: kv cache and quantize kv cache and sdp
256256 use_cache = use_cache if use_cache is not None else self .config .use_cache
257- input = input_ids if input_ids is not None else inputs_embeds
258- use_quantize_kv = use_quantize_kv_cache (self .layers [0 ].mlp .down_proj , input )
259- use_compress_kv = should_use_compresskv (input , input .shape [1 ])
257+ inputs = input_ids if input_ids is not None else inputs_embeds
258+ use_quantize_kv = use_quantize_kv_cache (self .layers [0 ].mlp .down_proj , inputs )
259+ use_compress_kv = should_use_compresskv (inputs , inputs .shape [1 ])
260260 if use_cache :
261261 if use_compress_kv and not isinstance (past_key_values ,
262262 DynamicCompressCache ):
@@ -272,6 +272,14 @@ def model_forward(
272272 DynamicCompressCache
273273 )):
274274 past_key_values = DynamicNormalCache .from_legacy_cache (past_key_values )
275+ if past_key_values .get_seq_length () == 0 :
276+ n_layer = self .config .num_hidden_layers
277+ n_head = self .config .num_attention_heads
278+ head_dim = self .config .hidden_size // self .config .num_attention_heads
279+ past_key_values = DynamicNormalCache .from_reserved (
280+ n_layer , inputs .size (0 ), n_head , inputs .size (1 ), head_dim ,
281+ inputs .dtype , inputs .device
282+ )
275283 return origin_model_forward (
276284 self = self ,
277285 input_ids = input_ids ,
0 commit comments