@@ -214,12 +214,35 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
214214 auto states = infer_request->query_state ();
215215 int32_t kv_len = *(int32_t *) inp_pos->data ;
216216 int32_t kv_len_in_state = states[0 ].get_state ().get_shape ()[1 ];
217- if (kv_len != kv_len_in_state) {
217+
218+ // outdated if:
219+ // 1. kv_len != kv_len_in_state
220+ // 2. last row has different values
221+ bool state_outdated = kv_len != kv_len_in_state;
222+ if (!state_outdated && kv_len > 0 ) {
223+ auto state_tensor = states[0 ].get_state ();
224+ auto state_name = states[0 ].get_name ();
225+ state_name = state_name.substr (0 , state_name.size () / 2 );
226+ auto state_shape = state_tensor.get_shape ();
227+ auto * ggml_tensor = kv_tensors[state_name];
228+ auto offset = (kv_len - 1 ) * state_shape[2 ] * state_shape[3 ] * ggml_type_size (ggml_tensor->type );
229+ auto size = state_shape[2 ] * state_shape[3 ] * ggml_type_size (ggml_tensor->type );
230+ state_outdated =
231+ std::memcmp ((char *) ggml_tensor->data + offset, (char *) state_tensor.data () + offset, size) != 0 ;
232+ }
233+
234+ if (state_outdated) {
235+ GGML_LOG_DEBUG (
236+ " GGML OpenVINO Backend: updating kv cache states from ggml tensors (kv_len: %d, kv_len_in_state: %d)\n " ,
237+ kv_len,
238+ kv_len_in_state);
218239 for (auto & state : states) {
219- ov::Tensor state_tensor = state.get_state ();
220- ov::Shape state_shape = state_tensor.get_shape ();
240+ auto state_name = state.get_name ();
241+ state_name = state_name.substr (0 , state_name.size () / 2 );
242+ auto * ggml_tensor = kv_tensors[state_name];
243+ auto state_shape = state.get_state ().get_shape ();
221244 state_shape[1 ] = kv_len;
222- state_tensor. set_shape ( state_shape);
245+ ov::Tensor state_tensor (state. get_state (). get_element_type (), state_shape, ggml_tensor-> data );
223246 state.set_state (state_tensor);
224247 }
225248 }
@@ -255,6 +278,18 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
255278 print_output_tensor_info (result_name, output_tensor, gguf_tensor_addrs);
256279 }
257280 }
281+
282+ for (auto & state : infer_request->query_state ()) {
283+ auto state_name = state.get_name ();
284+ state_name = state_name.substr (0 , state_name.size () / 2 );
285+ auto state_tensor = state.get_state ();
286+ auto state_shape = state_tensor.get_shape ();
287+ auto * ggml_tensor = kv_tensors[state_name];
288+ auto size = state_shape[2 ] * state_shape[3 ] * inp_pos->ne [0 ] * ggml_type_size (ggml_tensor->type );
289+ auto offset = state_shape[2 ] * state_shape[3 ] * (*(int32_t *) inp_pos->data ) * ggml_type_size (ggml_tensor->type );
290+ std::memcpy ((char *) ggml_tensor->data + offset, (char *) state_tensor.data () + offset, size);
291+ }
292+
258293 auto end_time = ggml_time_us ();
259294
260295 if (getenv (" GGML_OPENVINO_PROFILING" )) {
0 commit comments