@@ -211,12 +211,35 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
211211 auto states = infer_request->query_state ();
212212 int32_t kv_len = *(int32_t *) inp_pos->data ;
213213 int32_t kv_len_in_state = states[0 ].get_state ().get_shape ()[1 ];
214- if (kv_len != kv_len_in_state) {
214+
215+ // outdated if:
216+ // 1. kv_len != kv_len_in_state
217+ // 2. last row has different values
218+ bool state_outdated = kv_len != kv_len_in_state;
219+ if (!state_outdated && kv_len > 0 ) {
220+ auto state_tensor = states[0 ].get_state ();
221+ auto state_name = states[0 ].get_name ();
222+ state_name = state_name.substr (0 , state_name.size () / 2 );
223+ auto state_shape = state_tensor.get_shape ();
224+ auto * ggml_tensor = kv_tensors[state_name];
225+ auto offset = (kv_len - 1 ) * state_shape[2 ] * state_shape[3 ] * ggml_type_size (ggml_tensor->type );
226+ auto size = state_shape[2 ] * state_shape[3 ] * ggml_type_size (ggml_tensor->type );
227+ state_outdated =
228+ std::memcmp ((char *) ggml_tensor->data + offset, (char *) state_tensor.data () + offset, size) != 0 ;
229+ }
230+
231+ if (state_outdated) {
232+ GGML_LOG_DEBUG (
233+ " GGML OpenVINO Backend: updating kv cache states from ggml tensors (kv_len: %d, kv_len_in_state: %d)\n " ,
234+ kv_len,
235+ kv_len_in_state);
215236 for (auto & state : states) {
216- ov::Tensor state_tensor = state.get_state ();
217- ov::Shape state_shape = state_tensor.get_shape ();
237+ auto state_name = state.get_name ();
238+ state_name = state_name.substr (0 , state_name.size () / 2 );
239+ auto * ggml_tensor = kv_tensors[state_name];
240+ auto state_shape = state.get_state ().get_shape ();
218241 state_shape[1 ] = kv_len;
219- state_tensor. set_shape ( state_shape);
242+ ov::Tensor state_tensor (state. get_state (). get_element_type (), state_shape, ggml_tensor-> data );
220243 state.set_state (state_tensor);
221244 }
222245 }
@@ -252,6 +275,18 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
252275 print_output_tensor_info (result_name, output_tensor, gguf_tensor_addrs);
253276 }
254277 }
278+
279+ for (auto & state : infer_request->query_state ()) {
280+ auto state_name = state.get_name ();
281+ state_name = state_name.substr (0 , state_name.size () / 2 );
282+ auto state_tensor = state.get_state ();
283+ auto state_shape = state_tensor.get_shape ();
284+ auto * ggml_tensor = kv_tensors[state_name];
285+ auto size = state_shape[2 ] * state_shape[3 ] * inp_pos->ne [0 ] * ggml_type_size (ggml_tensor->type );
286+ auto offset = state_shape[2 ] * state_shape[3 ] * (*(int32_t *) inp_pos->data ) * ggml_type_size (ggml_tensor->type );
287+ std::memcpy ((char *) ggml_tensor->data + offset, (char *) state_tensor.data () + offset, size);
288+ }
289+
255290 auto end_time = ggml_time_us ();
256291
257292 if (getenv (" GGML_OPENVINO_PROFILING" )) {
0 commit comments