Skip to content

Commit 0ab204b

Browse files
committed
Fix llama-cli
1 parent 8a9a4bc commit 0ab204b

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

ggml/src/ggml-openvino/ggml-decoder.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,22 +244,36 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
244244
}
245245

246246
void GgmlOvDecoder::add_extra_inputs() {
247-
// attention_size not used for NPU
247+
// Extra inputs:
248+
// 1. `attention_size`, used in matmul's in the attention block. The shape of those matmul's are 32 aligned,
249+
// see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.
250+
// Not used for NPU
248251
int64_t attention_size = -1;
249252

250253
int64_t past_token_len = -1;
254+
int64_t past_token_len_from_inp_pos = -1;
251255
for (const auto& node : m_nodes) {
256+
if (node->op == GGML_OP_ROPE && std::string(node->src[1]->name) == "inp_pos") {
257+
if (node->src[1]->type != GGML_TYPE_I32) {
258+
throw std::runtime_error("Expected cgraph input `inp_pos` to be of type GGML_TYPE_I32");
259+
}
260+
past_token_len_from_inp_pos = ((int32_t*) (node->src[1]->data))[0];
261+
}
252262
if (node->op == GGML_OP_CPY && ggml_is_contiguous(node)) {
253263
assert(std::string(node->view_src->name).find("cache_k") == 0);
254-
int64_t head_size = node->src[0]->ne[0];
255-
int64_t num_heads = node->src[0]->ne[1];
256-
past_token_len = (int64_t) (node->src[1]->op_params[0] / node->src[1]->nb[0] / head_size / num_heads);
264+
past_token_len =
265+
(int64_t) (node->src[1]->op_params[0] / node->src[1]->nb[0] / m_head_size / m_num_heads_kv);
257266
break;
258267
}
259268
}
260269
if (past_token_len == -1) {
261270
throw std::runtime_error("Failed to find input \"cache_k\" in the graph");
262271
}
272+
if (past_token_len != past_token_len_from_inp_pos) {
273+
throw std::runtime_error("Mismatch between past_token_len from cache_k and inp_pos: " +
274+
std::to_string(past_token_len) + " vs " + std::to_string(past_token_len_from_inp_pos));
275+
}
276+
263277
for (const auto& node : m_nodes) {
264278
if (node->src[1] && std::string(node->src[1]->name).find("inp_tokens") == 0) {
265279
int64_t total_token_len = node->src[1]->ne[0] + past_token_len;

0 commit comments

Comments
 (0)