Skip to content

Commit 94934fa

Browse files
committed
Fix llama-server
1 parent 5ea2158 commit 94934fa

File tree

1 file changed

+39
-4
lines changed

1 file changed

+39
-4
lines changed

ggml/src/ggml-openvino/utils.cpp

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,35 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
211211
auto states = infer_request->query_state();
212212
int32_t kv_len = *(int32_t*) inp_pos->data;
213213
int32_t kv_len_in_state = states[0].get_state().get_shape()[1];
214-
if (kv_len != kv_len_in_state) {
214+
215+
// outdated if:
216+
// 1. kv_len != kv_len_in_state
217+
// 2. last row has different values
218+
bool state_outdated = kv_len != kv_len_in_state;
219+
if (!state_outdated && kv_len > 0) {
220+
auto state_tensor = states[0].get_state();
221+
auto state_name = states[0].get_name();
222+
state_name = state_name.substr(0, state_name.size() / 2);
223+
auto state_shape = state_tensor.get_shape();
224+
auto* ggml_tensor = kv_tensors[state_name];
225+
auto offset = (kv_len - 1) * state_shape[2] * state_shape[3] * ggml_type_size(ggml_tensor->type);
226+
auto size = state_shape[2] * state_shape[3] * ggml_type_size(ggml_tensor->type);
227+
state_outdated =
228+
std::memcmp((char*) ggml_tensor->data + offset, (char*) state_tensor.data() + offset, size) != 0;
229+
}
230+
231+
if (state_outdated) {
232+
GGML_LOG_DEBUG(
233+
"GGML OpenVINO Backend: updating kv cache states from ggml tensors (kv_len: %d, kv_len_in_state: %d)\n",
234+
kv_len,
235+
kv_len_in_state);
215236
for (auto& state : states) {
216-
ov::Tensor state_tensor = state.get_state();
217-
ov::Shape state_shape = state_tensor.get_shape();
237+
auto state_name = state.get_name();
238+
state_name = state_name.substr(0, state_name.size() / 2);
239+
auto* ggml_tensor = kv_tensors[state_name];
240+
auto state_shape = state.get_state().get_shape();
218241
state_shape[1] = kv_len;
219-
state_tensor.set_shape(state_shape);
242+
ov::Tensor state_tensor(state.get_state().get_element_type(), state_shape, ggml_tensor->data);
220243
state.set_state(state_tensor);
221244
}
222245
}
@@ -252,6 +275,18 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
252275
print_output_tensor_info(result_name, output_tensor, gguf_tensor_addrs);
253276
}
254277
}
278+
279+
for (auto& state : infer_request->query_state()) {
280+
auto state_name = state.get_name();
281+
state_name = state_name.substr(0, state_name.size() / 2);
282+
auto state_tensor = state.get_state();
283+
auto state_shape = state_tensor.get_shape();
284+
auto* ggml_tensor = kv_tensors[state_name];
285+
auto size = state_shape[2] * state_shape[3] * inp_pos->ne[0] * ggml_type_size(ggml_tensor->type);
286+
auto offset = state_shape[2] * state_shape[3] * (*(int32_t*) inp_pos->data) * ggml_type_size(ggml_tensor->type);
287+
std::memcpy((char*) ggml_tensor->data + offset, (char*) state_tensor.data() + offset, size);
288+
}
289+
255290
auto end_time = ggml_time_us();
256291

257292
if (getenv("GGML_OPENVINO_PROFILING")) {

0 commit comments

Comments
 (0)