@@ -73,6 +73,11 @@ GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph,
7373}
7474
7575GgmlOvDecoder::GgmlOvDecoder (struct ggml_cgraph * cgraph) {
76+ if (getenv (" GGML_OPENVINO_DUMP_CGRAPH" )) {
77+ std::string filename = " cgraph.txt" ;
78+ dump_cgraph (cgraph, filename);
79+ }
80+
7681 m_cgraph = cgraph;
7782 for (int node_n = 0 ; node_n < cgraph->n_nodes ; node_n++) {
7883 auto * cur_node = cgraph->nodes [node_n];
@@ -173,49 +178,46 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
173178 break ;
174179 }
175180 case GGML_OP_CONT: {
176- if (ggml_nelements (node->src [0 ]) == ggml_nelements (node->src [0 ]->view_src )) {
177- // The input comes from a PERMUTE
178- m_op_case = 1 ;
179- } else {
180- // The input comes from a VIEW which is subtensor
181- m_op_case = 2 ;
182- }
183- break ;
184- }
185- case GGML_OP_SET_ROWS: {
186- if (std::string (node->name ).find (" cache_k" ) == 0 ) {
181+ if (node->src [0 ]->op == GGML_OP_PERMUTE) {
187182 m_op_case = 1 ;
188- } else {
183+ } else if (node-> src [ 0 ]-> op == GGML_OP_TRANSPOSE) {
189184 m_op_case = 2 ;
185+ } else if (node->src [0 ]->op == GGML_OP_VIEW) {
186+ // The input comes from a VIEW which is subtensor
187+ m_op_case = 3 ;
190188 }
191189 break ;
192190 }
193191 case GGML_OP_PERMUTE: {
194- if (node->src [0 ]->view_src == nullptr ) {
195- // Permute Qcur
192+ if (node->src [0 ]->op != GGML_OP_VIEW) {
196193 m_op_case = 1 ;
197194 } else if (ggml_is_contiguous (node->src [0 ])) {
198195 // Permute cache_k (view)
199196 m_op_case = 2 ;
200197 } else {
201- // Permute cache_v (view)
198+ // Permute cache_v (view), deprecated, cache_v will also fall to case 2
199+ m_op_case = 3 ;
200+ }
201+ break ;
202+ }
203+ case GGML_OP_MUL_MAT: {
204+ if (node->src [0 ]->op == GGML_OP_CONT && node->src [0 ]->src [0 ]->op == GGML_OP_TRANSPOSE) {
205+ m_op_case = 2 ;
206+ } else if (node->src [0 ]->op == GGML_OP_VIEW && node->src [1 ]->op == GGML_OP_VIEW) {
207+ // test-backend-ops case
202208 m_op_case = 3 ;
203209 }
204210 break ;
205211 }
206212 case GGML_OP_GET_ROWS: {
207213 if (node->src [1 ]->op == GGML_OP_VIEW) {
208214 m_op_case = 2 ;
209- } else {
210- m_op_case = 1 ;
211215 }
212216 break ;
213217 }
214218 case GGML_OP_ROPE: {
215219 if (node->src [0 ]->op == GGML_OP_VIEW) {
216220 m_op_case = 2 ;
217- } else {
218- m_op_case = 1 ;
219221 }
220222 break ;
221223 }
@@ -270,19 +272,9 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
270272 } else if (name.find (" cache_k" ) == 0 ) {
271273 input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size};
272274 } else if (name.find (" cache_v" ) == 0 ) {
273- input_shape = ov::PartialShape{m_num_heads_kv, m_head_size, m_context_size };
275+ input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size };
274276 } else if (const auto * op = get_tensor_used_op (src); op && op->op == GGML_OP_SET_ROWS) {
275- input_shape = ov::PartialShape{1 , 1 , -1 };
276- if (m_is_static) {
277- if (m_is_first_token) {
278- // Dummy static shape, since the indices are not used in this case
279- input_shape = ov::PartialShape{1 };
280- } else if (std::string (op->name ).find (" cache_k" ) == 0 ) {
281- input_shape = ov::PartialShape{1 , 1 , 1 };
282- } else {
283- input_shape = ov::PartialShape{1 , 1 , m_num_heads_kv * m_head_size};
284- }
285- }
277+ input_shape = ov::PartialShape{1 , 1 , m_is_static ? 1 : -1 };
286278 } else if (src->op == GGML_OP_VIEW) {
287279 // This case is added to make test-backend-ops work
288280 input_shape = ov::PartialShape{get_shape (src->view_src )};
0 commit comments