@@ -73,6 +73,11 @@ GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph,
7373}
7474
7575GgmlOvDecoder::GgmlOvDecoder (struct ggml_cgraph * cgraph) {
76+ if (getenv (" GGML_OPENVINO_DUMP_CGRAPH" )) {
77+ std::string filename = " cgraph.txt" ;
78+ dump_cgraph (cgraph, filename);
79+ }
80+
7681 m_cgraph = cgraph;
7782 for (int node_n = 0 ; node_n < cgraph->n_nodes ; node_n++) {
7883 auto * cur_node = cgraph->nodes [node_n];
@@ -173,49 +178,46 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
173178 break ;
174179 }
175180 case GGML_OP_CONT: {
176- if (ggml_nelements (node->src [0 ]) == ggml_nelements (node->src [0 ]->view_src )) {
177- // The input comes from a PERMUTE
178- m_op_case = 1 ;
179- } else {
180- // The input comes from a VIEW which is subtensor
181- m_op_case = 2 ;
182- }
183- break ;
184- }
185- case GGML_OP_SET_ROWS: {
186- if (std::string (node->name ).find (" cache_k" ) == 0 ) {
181+ if (node->src [0 ]->op == GGML_OP_PERMUTE) {
187182 m_op_case = 1 ;
188- } else {
183+ } else if (node-> src [ 0 ]-> op == GGML_OP_TRANSPOSE) {
189184 m_op_case = 2 ;
185+ } else if (node->src [0 ]->op == GGML_OP_VIEW) {
186+ // The input comes from a VIEW which is subtensor
187+ m_op_case = 3 ;
190188 }
191189 break ;
192190 }
193191 case GGML_OP_PERMUTE: {
194- if (node->src [0 ]->view_src == nullptr ) {
195- // Permute Qcur
192+ if (node->src [0 ]->op != GGML_OP_VIEW) {
196193 m_op_case = 1 ;
197194 } else if (ggml_is_contiguous (node->src [0 ])) {
198195 // Permute cache_k (view)
199196 m_op_case = 2 ;
200197 } else {
201- // Permute cache_v (view)
198+ // Permute cache_v (view), deprecated, cache_v will also fall to case 2
199+ m_op_case = 3 ;
200+ }
201+ break ;
202+ }
203+ case GGML_OP_MUL_MAT: {
204+ if (node->src [0 ]->op == GGML_OP_CONT && node->src [0 ]->src [0 ]->op == GGML_OP_TRANSPOSE) {
205+ m_op_case = 2 ;
206+ } else if (node->src [0 ]->op == GGML_OP_VIEW && node->src [1 ]->op == GGML_OP_VIEW) {
207+ // test-backend-ops case
202208 m_op_case = 3 ;
203209 }
204210 break ;
205211 }
206212 case GGML_OP_GET_ROWS: {
207213 if (node->src [1 ]->op == GGML_OP_VIEW) {
208214 m_op_case = 2 ;
209- } else {
210- m_op_case = 1 ;
211215 }
212216 break ;
213217 }
214218 case GGML_OP_ROPE: {
215219 if (node->src [0 ]->op == GGML_OP_VIEW) {
216220 m_op_case = 2 ;
217- } else {
218- m_op_case = 1 ;
219221 }
220222 break ;
221223 }
@@ -270,19 +272,9 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
270272 } else if (name.find (" cache_k" ) == 0 ) {
271273 input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size};
272274 } else if (name.find (" cache_v" ) == 0 ) {
273- input_shape = ov::PartialShape{m_num_heads_kv, m_head_size, m_context_size };
275+ input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size };
274276 } else if (const auto * op = get_tensor_used_op (src); op && op->op == GGML_OP_SET_ROWS) {
275- input_shape = ov::PartialShape{1 , 1 , -1 };
276- if (m_is_static) {
277- if (m_is_first_token) {
278- // Dummy static shape, since the indices are not used in this case
279- input_shape = ov::PartialShape{1 };
280- } else if (std::string (op->name ).find (" cache_k" ) == 0 ) {
281- input_shape = ov::PartialShape{1 , 1 , 1 };
282- } else {
283- input_shape = ov::PartialShape{1 , 1 , m_num_heads_kv * m_head_size};
284- }
285- }
277+ input_shape = ov::PartialShape{1 , 1 , m_is_static ? 1 : -1 };
286278 } else if (src->op == GGML_OP_VIEW) {
287279 // This case is added to make test-backend-ops work
288280 input_shape = ov::PartialShape{get_shape (src->view_src )};
@@ -610,26 +602,28 @@ void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecode
610602
611603const std::string& GgmlOvDecoder::get_op_type () const {
612604 static const std::map<ggml_op, std::string> ops = {
613- {GGML_OP_NONE, " GGML_OP_NONE" },
614- {GGML_OP_ACC, " GGML_OP_ACC" },
615- {GGML_OP_ADD, " GGML_OP_ADD" },
616- {GGML_OP_ADD1, " GGML_OP_ADD1" },
617- {GGML_OP_CONT, " GGML_OP_CONT" },
618- {GGML_OP_DIV, " GGML_OP_DIV" },
619- {GGML_OP_DUP, " GGML_OP_DUP" },
620- {GGML_OP_GET_ROWS, " GGML_OP_GET_ROWS" },
621- {GGML_OP_MUL, " GGML_OP_MUL" },
622- {GGML_OP_MUL_MAT, " GGML_OP_MUL_MAT" },
623- {GGML_OP_PERMUTE, " GGML_OP_PERMUTE" },
624- {GGML_OP_RESHAPE, " GGML_OP_RESHAPE" },
625- {GGML_OP_RMS_NORM, " GGML_OP_RMS_NORM" },
626- {GGML_OP_ROPE, " GGML_OP_ROPE" },
627- {GGML_OP_SCALE, " GGML_OP_SCALE" },
628- {GGML_OP_SOFT_MAX, " GGML_OP_SOFT_MAX" },
629- {GGML_OP_SUB, " GGML_OP_SUB" },
630- {GGML_OP_TRANSPOSE, " GGML_OP_TRANSPOSE" },
631- {GGML_OP_VIEW, " GGML_OP_VIEW" },
632- {GGML_OP_SET_ROWS, " GGML_OP_SET_ROWS" },
605+ {GGML_OP_NONE, " GGML_OP_NONE" },
606+ {GGML_OP_ACC, " GGML_OP_ACC" },
607+ {GGML_OP_ADD, " GGML_OP_ADD" },
608+ {GGML_OP_ADD1, " GGML_OP_ADD1" },
609+ {GGML_OP_CONT, " GGML_OP_CONT" },
610+ {GGML_OP_DIV, " GGML_OP_DIV" },
611+ {GGML_OP_DUP, " GGML_OP_DUP" },
612+ {GGML_OP_GET_ROWS, " GGML_OP_GET_ROWS" },
613+ {GGML_OP_MUL, " GGML_OP_MUL" },
614+ {GGML_OP_MUL_MAT, " GGML_OP_MUL_MAT" },
615+ {GGML_OP_PERMUTE, " GGML_OP_PERMUTE" },
616+ {GGML_OP_RESHAPE, " GGML_OP_RESHAPE" },
617+ {GGML_OP_RMS_NORM, " GGML_OP_RMS_NORM" },
618+ {GGML_OP_ROPE, " GGML_OP_ROPE" },
619+ {GGML_OP_SCALE, " GGML_OP_SCALE" },
620+ {GGML_OP_SOFT_MAX, " GGML_OP_SOFT_MAX" },
621+ {GGML_OP_SUB, " GGML_OP_SUB" },
622+ {GGML_OP_TRANSPOSE, " GGML_OP_TRANSPOSE" },
623+ {GGML_OP_VIEW, " GGML_OP_VIEW" },
624+ {GGML_OP_SET_ROWS, " GGML_OP_SET_ROWS" },
625+ {GGML_OP_CPY, " GGML_OP_CPY" },
626+ {GGML_OP_FLASH_ATTN_EXT, " GGML_OP_FLASH_ATTN_EXT" },
633627 };
634628 static const std::map<ggml_unary_op, std::string> unary_ops = {
635629 {GGML_UNARY_OP_ABS, " GGML_UNARY_OP_ABS" },
0 commit comments