@@ -212,12 +212,35 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph *
212212 auto states = infer_request->query_state ();
213213 int32_t kv_len = *(int32_t *) inp_pos->data ;
214214 int32_t kv_len_in_state = states[0 ].get_state ().get_shape ()[1 ];
215- if (kv_len != kv_len_in_state) {
215+
216+ // outdated if:
217+ // 1. kv_len != kv_len_in_state
218+ // 2. last row has different values
219+ bool state_outdated = kv_len != kv_len_in_state;
220+ if (!state_outdated && kv_len > 0 ) {
221+ auto state_tensor = states[0 ].get_state ();
222+ auto state_name = states[0 ].get_name ();
223+ state_name = state_name.substr (0 , state_name.size () / 2 );
224+ auto state_shape = state_tensor.get_shape ();
225+ auto * ggml_tensor = kv_tensors[state_name];
226+ auto offset = (kv_len - 1 ) * state_shape[2 ] * state_shape[3 ] * ggml_type_size (ggml_tensor->type );
227+ auto size = state_shape[2 ] * state_shape[3 ] * ggml_type_size (ggml_tensor->type );
228+ state_outdated =
229+ std::memcmp ((char *) ggml_tensor->data + offset, (char *) state_tensor.data () + offset, size) != 0 ;
230+ }
231+
232+ if (state_outdated) {
233+ GGML_LOG_DEBUG (
234+ " GGML OpenVINO Backend: updating kv cache states from ggml tensors (kv_len: %d, kv_len_in_state: %d)\n " ,
235+ kv_len,
236+ kv_len_in_state);
216237 for (auto & state : states) {
217- ov::Tensor state_tensor = state.get_state ();
218- ov::Shape state_shape = state_tensor.get_shape ();
238+ auto state_name = state.get_name ();
239+ state_name = state_name.substr (0 , state_name.size () / 2 );
240+ auto * ggml_tensor = kv_tensors[state_name];
241+ auto state_shape = state.get_state ().get_shape ();
219242 state_shape[1 ] = kv_len;
220- state_tensor. set_shape ( state_shape);
243+ ov::Tensor state_tensor (state. get_state (). get_element_type (), state_shape, ggml_tensor-> data );
221244 state.set_state (state_tensor);
222245 }
223246 }
@@ -253,6 +276,18 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph *
253276 print_output_tensor_info (result_name, output_tensor, gguf_tensor_addrs);
254277 }
255278 }
279+
280+ for (auto & state : infer_request->query_state ()) {
281+ auto state_name = state.get_name ();
282+ state_name = state_name.substr (0 , state_name.size () / 2 );
283+ auto state_tensor = state.get_state ();
284+ auto state_shape = state_tensor.get_shape ();
285+ auto * ggml_tensor = kv_tensors[state_name];
286+ auto size = state_shape[2 ] * state_shape[3 ] * inp_pos->ne [0 ] * ggml_type_size (ggml_tensor->type );
287+ auto offset = state_shape[2 ] * state_shape[3 ] * (*(int32_t *) inp_pos->data ) * ggml_type_size (ggml_tensor->type );
288+ std::memcpy ((char *) ggml_tensor->data + offset, (char *) state_tensor.data () + offset, size);
289+ }
290+
256291 auto end_time = ggml_time_us ();
257292
258293 if (getenv (" GGML_OPENVINO_PROFILING" )) {
0 commit comments