Skip to content

Commit c2c8c06

Browse files
committed
Fix for Phi3
1 parent 97d1ddd commit c2c8c06

File tree

5 files changed

+29
-23
lines changed

5 files changed

+29
-23
lines changed

ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
4747
auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
4848
auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
4949
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,2});
50-
auto leaf_8 = context.get_input("leaf_8");
51-
auto shape_of_leaf_8 = std::make_shared<ov::op::v3::ShapeOf>(leaf_8);
52-
auto gather_leaf_8 = std::make_shared<ov::op::v8::Gather>(shape_of_leaf_8, two_1d, zero_1d);
53-
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, gather_leaf_8}, 0);
50+
auto inp_pos = context.get_input("inp_pos");
51+
auto shape_of_inp_pos = std::make_shared<ov::op::v3::ShapeOf>(inp_pos);
52+
auto gather_inp_pos = std::make_shared<ov::op::v8::Gather>(shape_of_inp_pos, two_1d, zero_1d);
53+
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, gather_inp_pos}, 0);
5454
mask_sliced =
5555
std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
5656
mask_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_sliced, zero_1d);

ggml/src/ggml-openvino/openvino/op/permute.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <openvino/op/reshape.hpp>
88
#include <openvino/op/slice.hpp>
99
#include <openvino/op/transpose.hpp>
10+
#include <openvino/op/unsqueeze.hpp>
1011

1112
#include "../node_context.hpp"
1213
#include "../op_table.hpp"
@@ -23,13 +24,18 @@ OutputVector translate_permute(const NodeContext& context) {
2324
int op_case = context.get_op_case();
2425
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported PERMUTE case");
2526
ov::Output<Node> res;
27+
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
2628

2729
if (op_case == 1) {
2830
if (context.is_static()) {
2931
res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
3032
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
3133
} else {
32-
res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
34+
auto src = context.get_input(0);
35+
if (src.get_partial_shape().rank() == 3) {
36+
src = std::make_shared<ov::op::v0::Unsqueeze>(src, zero);
37+
}
38+
res = std::make_shared<ov::op::v1::Transpose>(src,
3339
ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
3440
}
3541
} else {
@@ -43,7 +49,6 @@ OutputVector translate_permute(const NodeContext& context) {
4349
attention_size = context.get_input("attention_size_swa");
4450
}
4551

46-
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
4752
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
4853

4954
if (context.is_static()) {
@@ -57,6 +62,9 @@ OutputVector translate_permute(const NodeContext& context) {
5762
res = std::make_shared<ov::op::v1::Transpose>(src_slice,
5863
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
5964
} else {
65+
if (src.get_partial_shape().rank() == 3) {
66+
src = std::make_shared<ov::op::v0::Unsqueeze>(src, zero);
67+
}
6068
res = std::make_shared<ov::op::v1::Transpose>(src,
6169
ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
6270
}

ggml/src/ggml-openvino/openvino/op/set_rows.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <cassert>
12
#include <cstdint>
23
#include <memory>
34
#include <openvino/core/node.hpp>
@@ -8,7 +9,6 @@
89
#include <openvino/op/convert.hpp>
910
#include <openvino/op/gather.hpp>
1011
#include <openvino/op/reshape.hpp>
11-
#include <openvino/op/unsqueeze.hpp>
1212
#include <openvino/op/scatter_update.hpp>
1313
#include <openvino/op/shape_of.hpp>
1414
#include <openvino/op/slice.hpp>
@@ -55,14 +55,12 @@ OutputVector translate_set_rows(const NodeContext& context) {
5555
auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
5656
res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false);
5757
} else {
58-
// TODO: Better solution would be to reshape the data into 4D at first place (for stateful model)
59-
if (data.get_partial_shape().rank() + 1 == dst.get_partial_shape().rank()) {
60-
data = std::make_shared<ov::op::v0::Unsqueeze>(data, zero);
61-
}
62-
int concat_axis = 1;
63-
if (context.is_static())
64-
concat_axis = 0;
65-
res = std::make_shared<ov::op::v0::Concat>(OutputVector{dst, data}, concat_axis);
58+
assert(dst.get_partial_shape().rank() == 4 && dst.get_partial_shape()[2].is_static() && dst.get_partial_shape()[3].is_static());
59+
int64_t dim2 = dst.get_partial_shape()[2].get_length();
60+
int64_t dim3 = dst.get_partial_shape()[3].get_length();
61+
data = std::make_shared<ov::op::v1::Reshape>(
62+
data, ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 1, (int64_t) -1, dim2, dim3}), false);
63+
res = std::make_shared<ov::op::v0::Concat>(OutputVector{dst, data}, 1);
6664
}
6765
return rename_outputs_with_suffix({res}, context.get_name());
6866
}

ggml/src/ggml-openvino/openvino/op/softmax.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ OutputVector translate_soft_max(const NodeContext& context) {
6464
auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
6565
auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
6666
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,2});
67-
auto leaf_8 = context.get_input("leaf_8");
68-
auto shape_of_leaf_8 = std::make_shared<ov::op::v3::ShapeOf>(leaf_8);
69-
auto gather_leaf_8 = std::make_shared<ov::op::v8::Gather>(shape_of_leaf_8, two_1d, zero_1d);
70-
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, gather_leaf_8}, 0);
67+
auto inp_pos = context.get_input("inp_pos");
68+
auto shape_of_inp_pos = std::make_shared<ov::op::v3::ShapeOf>(inp_pos);
69+
auto gather_inp_pos = std::make_shared<ov::op::v8::Gather>(shape_of_inp_pos, two_1d, zero_1d);
70+
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, gather_inp_pos}, 0);
7171
mask_node_sliced =
7272
std::make_shared<ov::op::v8::Slice>(mask_node, zero_2d, stop, one_2d, axes);
7373
if (!(context.is_static())) {

ggml/src/ggml-openvino/openvino/translate_session.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ void add_sliced_mask(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
9393
auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
9494
auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
9595
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,2});
96-
auto leaf_8 = tensor_map.at("leaf_8").get_node_shared_ptr();
97-
auto shape_of_leaf_8 = std::make_shared<ov::op::v3::ShapeOf>(leaf_8);
98-
auto gather_leaf_8 = std::make_shared<ov::op::v8::Gather>(shape_of_leaf_8, two_1d, zero_1d);
99-
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, gather_leaf_8}, 0);
96+
auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr();
97+
auto shape_of_inp_pos = std::make_shared<ov::op::v3::ShapeOf>(inp_pos);
98+
auto gather_inp_pos = std::make_shared<ov::op::v8::Gather>(shape_of_inp_pos, two_1d, zero_1d);
99+
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, gather_inp_pos}, 0);
100100
mask_sliced =
101101
std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
102102
mask_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_sliced, zero_1d);

0 commit comments

Comments
 (0)