3939GgmlOvDecoder::GgmlOvDecoder (ggml_tensor * node,
4040 ggml_cgraph * cgraph,
4141 bool is_static,
42- bool is_first_token,
4342 int context_size,
4443 int context_size_swa,
4544 int num_heads,
@@ -55,25 +54,24 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_tensor * node,
5554 m_num_heads(num_heads),
5655 m_num_heads_kv(num_heads_kv),
5756 m_head_size(head_size),
58- m_is_static(is_static),
59- m_is_first_token(is_first_token) {
57+ m_is_static(is_static) {
6058 set_input_output (node);
6159}
6260
6361GgmlOvDecoder::GgmlOvDecoder (ggml_cgraph * cgraph,
6462 std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
65- bool is_static,
66- bool is_first_token) :
63+ bool is_static) :
6764 m_cgraph(cgraph),
6865 m_op_name(m_node ? std::string(m_node->name) : ""),
6966 m_model_weights(model_weights),
70- m_is_static(is_static),
71- m_is_first_token(is_first_token ) {
72- if (is_first_token && getenv (" GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS" )) {
67+ m_is_static(is_static) {
68+ if ( auto * env = getenv ( " GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS " ); env && std::string (env) != " 0 " ) {
69+ unsetenv (" GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS" );
7370 print_tensor_address_map (cgraph);
7471 }
7572
7673 set_llm_params ();
74+ validate_cgraph ();
7775
7876 for (int node_n = 0 ; node_n < cgraph->n_nodes ; node_n++) {
7977 auto * cur_node = cgraph->nodes [node_n];
@@ -300,41 +298,39 @@ void GgmlOvDecoder::set_llm_params() {
300298 }
301299}
302300
301+ void GgmlOvDecoder::validate_cgraph () const {
302+ if (m_is_static && m_input_len != 1 ) {
303+ throw std::runtime_error (" Static graph (NPU) must have input_len == 1, but got " + std::to_string (m_input_len) +
304+ " , try set -ub 1" );
305+ }
306+ }
307+
303308ov::PartialShape GgmlOvDecoder::get_graph_input_shape (const ggml_tensor * src) const {
304309 auto name = std::string (src->name );
305310 ov::PartialShape input_shape;
306- if (name == " inp_tokens" || name == " inp_pos" ) {
307- if (m_is_static) {
308- if (m_is_first_token) {
309- input_shape = ov::PartialShape{1 , 1 , m_context_size};
310- } else {
311- input_shape = ov::PartialShape{1 , 1 , 1 };
312- }
313- } else {
314- input_shape = ov::PartialShape{1 , 1 , -1 };
315- }
316- } else if (name == " inp_out_ids" && !m_is_static) {
317- input_shape = ov::PartialShape{1 , 1 , -1 };
311+
312+ if (name == " inp_tokens" || name == " inp_pos" || name == " inp_out_ids" ) {
313+ input_shape = ov::PartialShape{1 , 1 , m_is_static ? 1 : -1 };
314+
318315 } else if (name.find (" KQ_mask" ) == 0 ) {
319316 if (m_is_static) {
320- if (m_is_first_token) {
321- input_shape = ov::PartialShape{1 , m_context_size, m_context_size};
322- } else {
323- input_shape = ov::PartialShape{1 , 1 , m_context_size};
324- }
317+ input_shape = ov::PartialShape{1 , 1 , m_context_size};
325318 } else {
326319 input_shape = ov::PartialShape{1 , -1 , -1 };
327320 }
321+
328322 } else if (name.find (" cache_" ) == 0 ) {
323+ auto past_token_len = -1 ;
329324 if (m_is_static) {
330325 int layer = extract_layer_from_name (name);
331326 bool is_swa = is_swa_layer (layer);
332- input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size};
333- } else {
334- input_shape = ov::PartialShape{1 , -1 , m_num_heads_kv, m_head_size};
327+ past_token_len = is_swa ? m_context_size_swa : m_context_size;
335328 }
329+ input_shape = ov::PartialShape{past_token_len, m_num_heads_kv, m_head_size};
330+
336331 } else if (const auto * op = get_tensor_used_op (src); op && op->op == GGML_OP_SET_ROWS) {
337332 input_shape = ov::PartialShape{1 , 1 , m_is_static ? 1 : -1 };
333+
338334 } else if (src->op == GGML_OP_VIEW) {
339335 // This case is added to make test-backend-ops work
340336 input_shape = ov::PartialShape{get_shape (src->view_src )};
@@ -748,9 +744,8 @@ int32_t * GgmlOvDecoder::get_output_op_params(const std::string & name) const {
748744
749745void GgmlOvDecoder::visit_subgraph (std::function<void (std::shared_ptr<GgmlDecoder>)> node_visitor) const {
750746 for (const auto & node : m_nodes) {
751- auto decoder =
752- std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_is_first_token, m_context_size,
753- m_context_size_swa, m_num_heads, m_num_heads_kv, m_head_size, m_swa_layers);
747+ auto decoder = std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_context_size, m_context_size_swa,
748+ m_num_heads, m_num_heads_kv, m_head_size, m_swa_layers);
754749 node_visitor (decoder);
755750 }
756751}
0 commit comments