@@ -244,22 +244,36 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
244244}
245245
246246void GgmlOvDecoder::add_extra_inputs () {
247- // attention_size not used for NPU
247+ // Extra inputs:
248+ // 1. `attention_size`, used in matmul's in the attention block. The shape of those matmul's are 32 aligned,
249+ // see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.
250+ // Not used for NPU
248251 int64_t attention_size = -1 ;
249252
250253 int64_t past_token_len = -1 ;
254+ int64_t past_token_len_from_inp_pos = -1 ;
251255 for (const auto & node : m_nodes) {
256+ if (node->op == GGML_OP_ROPE && std::string (node->src [1 ]->name ) == " inp_pos" ) {
257+ if (node->src [1 ]->type != GGML_TYPE_I32) {
258+ throw std::runtime_error (" Expected cgraph input `inp_pos` to be of type GGML_TYPE_I32" );
259+ }
260+ past_token_len_from_inp_pos = ((int32_t *) (node->src [1 ]->data ))[0 ];
261+ }
252262 if (node->op == GGML_OP_CPY && ggml_is_contiguous (node)) {
253263 assert (std::string (node->view_src ->name ).find (" cache_k" ) == 0 );
254- int64_t head_size = node->src [0 ]->ne [0 ];
255- int64_t num_heads = node->src [0 ]->ne [1 ];
256- past_token_len = (int64_t ) (node->src [1 ]->op_params [0 ] / node->src [1 ]->nb [0 ] / head_size / num_heads);
264+ past_token_len =
265+ (int64_t ) (node->src [1 ]->op_params [0 ] / node->src [1 ]->nb [0 ] / m_head_size / m_num_heads_kv);
257266 break ;
258267 }
259268 }
260269 if (past_token_len == -1 ) {
261270 throw std::runtime_error (" Failed to find input \" cache_k\" in the graph" );
262271 }
272+ if (past_token_len != past_token_len_from_inp_pos) {
273+ throw std::runtime_error (" Mismatch between past_token_len from cache_k and inp_pos: " +
274+ std::to_string (past_token_len) + " vs " + std::to_string (past_token_len_from_inp_pos));
275+ }
276+
263277 for (const auto & node : m_nodes) {
264278 if (node->src [1 ] && std::string (node->src [1 ]->name ).find (" inp_tokens" ) == 0 ) {
265279 int64_t total_token_len = node->src [1 ]->ne [0 ] + past_token_len;
0 commit comments