Skip to content

Commit 5f25e52

Browse files
committed
Fix llama-server
1 parent 886d418 commit 5f25e52

File tree

1 file changed

+39
-4
lines changed

1 file changed

+39
-4
lines changed

ggml/src/ggml-openvino/utils.cpp

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,35 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph *
212212
auto states = infer_request->query_state();
213213
int32_t kv_len = *(int32_t*) inp_pos->data;
214214
int32_t kv_len_in_state = states[0].get_state().get_shape()[1];
215-
if (kv_len != kv_len_in_state) {
215+
216+
// outdated if:
217+
// 1. kv_len != kv_len_in_state
218+
// 2. last row has different values
219+
bool state_outdated = kv_len != kv_len_in_state;
220+
if (!state_outdated && kv_len > 0) {
221+
auto state_tensor = states[0].get_state();
222+
auto state_name = states[0].get_name();
223+
state_name = state_name.substr(0, state_name.size() / 2);
224+
auto state_shape = state_tensor.get_shape();
225+
auto* ggml_tensor = kv_tensors[state_name];
226+
auto offset = (kv_len - 1) * state_shape[2] * state_shape[3] * ggml_type_size(ggml_tensor->type);
227+
auto size = state_shape[2] * state_shape[3] * ggml_type_size(ggml_tensor->type);
228+
state_outdated =
229+
std::memcmp((char*) ggml_tensor->data + offset, (char*) state_tensor.data() + offset, size) != 0;
230+
}
231+
232+
if (state_outdated) {
233+
GGML_LOG_DEBUG(
234+
"GGML OpenVINO Backend: updating kv cache states from ggml tensors (kv_len: %d, kv_len_in_state: %d)\n",
235+
kv_len,
236+
kv_len_in_state);
216237
for (auto& state : states) {
217-
ov::Tensor state_tensor = state.get_state();
218-
ov::Shape state_shape = state_tensor.get_shape();
238+
auto state_name = state.get_name();
239+
state_name = state_name.substr(0, state_name.size() / 2);
240+
auto* ggml_tensor = kv_tensors[state_name];
241+
auto state_shape = state.get_state().get_shape();
219242
state_shape[1] = kv_len;
220-
state_tensor.set_shape(state_shape);
243+
ov::Tensor state_tensor(state.get_state().get_element_type(), state_shape, ggml_tensor->data);
221244
state.set_state(state_tensor);
222245
}
223246
}
@@ -253,6 +276,18 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph *
253276
print_output_tensor_info(result_name, output_tensor, gguf_tensor_addrs);
254277
}
255278
}
279+
280+
for (auto& state : infer_request->query_state()) {
281+
auto state_name = state.get_name();
282+
state_name = state_name.substr(0, state_name.size() / 2);
283+
auto state_tensor = state.get_state();
284+
auto state_shape = state_tensor.get_shape();
285+
auto* ggml_tensor = kv_tensors[state_name];
286+
auto size = state_shape[2] * state_shape[3] * inp_pos->ne[0] * ggml_type_size(ggml_tensor->type);
287+
auto offset = state_shape[2] * state_shape[3] * (*(int32_t*) inp_pos->data) * ggml_type_size(ggml_tensor->type);
288+
std::memcpy((char*) ggml_tensor->data + offset, (char*) state_tensor.data() + offset, size);
289+
}
290+
256291
auto end_time = ggml_time_us();
257292

258293
if (getenv("GGML_OPENVINO_PROFILING")) {

0 commit comments

Comments
 (0)