Skip to content

Commit d5f7df3

Browse files
committed
Change due to ggml cgraph changes, not correct yet
1 parent 61ffb18 commit d5f7df3

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

ggml/src/ggml-openvino/ggml-decoder.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,16 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node) {
187187
case GGML_OP_MUL_MAT: {
188188
if (node->src[0]->view_src == nullptr) {
189189
m_op_case = 1;
190+
} else if (std::string(node->src[0]->name).find("cache_k") == 0) {
191+
m_op_case = 2;
192+
} else if (std::string(node->src[0]->name).find("cache_v") == 0) {
193+
m_op_case = 3;
194+
}
195+
break;
196+
}
197+
case GGML_OP_PERMUTE: {
198+
if (ggml_is_contiguous(node->src[0])) {
199+
m_op_case = 1;
190200
} else {
191201
m_op_case = 2;
192202
}

ggml/src/ggml-openvino/openvino/op/mulmat.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
2424
num_inputs_check(context, 2, 2);
2525

2626
int op_case = context.get_op_case();
27-
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2, "Unsupported MULMAT case");
27+
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported MULMAT case");
2828

2929
ov::Output<Node> res;
3030

@@ -59,8 +59,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
5959
auto src0 = context.get_input(0);
6060
auto src0_shape = context.get_input_shape(0).to_shape();
6161
auto src0_stride = context.get_input_stride(0);
62-
auto permuted = is_permuted(src0_stride);
63-
auto token_dim = permuted ? 0 : 2;
62+
auto token_dim = op_case == 2 ? 0 : 2;
6463

6564
auto attention_size = context.get_input("attention_size");
6665

@@ -81,7 +80,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
8180
auto src0_reshape = std::make_shared<ov::op::v1::Reshape>(src0, src0_reshape_shape, false);
8281

8382
std::shared_ptr<ov::Node> slice_end;
84-
if (permuted) {
83+
if (op_case == 2) {
8584
slice_end = std::make_shared<ov::op::v0::Concat>(
8685
ov::OutputVector{attention_size, ov::op::v0::Constant::create(ov::element::i64, {2}, src0_slice_shape)},
8786
0);
@@ -94,7 +93,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
9493
auto slice_step = ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>(3, 1));
9594
auto src0_slice = std::make_shared<ov::op::v8::Slice>(src0_reshape, slice_start, slice_end, slice_step);
9695

97-
if (permuted) {
96+
if (op_case == 2) {
9897
B = std::make_shared<ov::op::v1::Transpose>(
9998
src0_slice,
10099
ov::op::v0::Constant::create(ov::element::i64, {src0_perm.size()}, src0_perm));

ggml/src/ggml-openvino/openvino/op/permute.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,19 @@ namespace op {
1212
OutputVector translate_permute(const NodeContext& context) {
1313
num_inputs_check(context, 1, 1);
1414

15-
auto perm = argsort_descend(context.get_output_stride(0));
16-
auto res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
17-
ov::op::v0::Constant::create(ov::element::i64, {3}, perm));
18-
return rename_outputs_with_suffix({res}, context.get_name());
15+
int op_case = context.get_op_case();
16+
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2, "Unsupported CONT case");
17+
ov::Output<Node> res;
18+
19+
if (op_case == 1) {
20+
auto perm = argsort_descend(context.get_output_stride(0));
21+
auto res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
22+
ov::op::v0::Constant::create(ov::element::i64, {3}, perm));
23+
return rename_outputs_with_suffix({res}, context.get_name());
24+
} else {
25+
auto res = context.get_input(0);
26+
return {res};
27+
}
1928
}
2029

2130
} // namespace op

0 commit comments

Comments
 (0)