Skip to content

Commit 1b8323f

Browse files
committed
Simpilfy translation of get_rows
1 parent ef4de4d commit 1b8323f

File tree

1 file changed

+8
-18
lines changed

1 file changed

+8
-18
lines changed

ggml/src/ggml-openvino/openvino/op/get_rows.cpp

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33
#include <openvino/op/constant.hpp>
44
#include <openvino/op/convert.hpp>
55
#include <openvino/op/gather.hpp>
6-
#include <openvino/op/reshape.hpp>
7-
#include <openvino/op/slice.hpp>
86
#include <openvino/op/squeeze.hpp>
9-
#include <openvino/op/unsqueeze.hpp>
107

118
#include "../node_context.hpp"
129
#include "../op_table.hpp"
@@ -31,22 +28,15 @@ OutputVector translate_get_rows(const NodeContext& context) {
3128
indices = process_view_input(context, 1);
3229
}
3330

34-
Output<Node> axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1});
35-
if (indices.get_partial_shape()[1].get_length() == 1) {
36-
indices =
37-
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
38-
if (data.get_partial_shape().rank() == 2) {
39-
axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
40-
}
41-
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis);
42-
if (data.get_partial_shape().rank() == 2) {
43-
res =
44-
std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
45-
}
46-
} else {
47-
indices =
48-
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
31+
// data[b,x,y] ind[1,b,x'] test-backend-ops case
32+
// data[x,y] ind[1,1,x'] normal case
33+
indices = std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
34+
if (data.get_partial_shape().rank() == 3) {
35+
auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1});
4936
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis, 1);
37+
} else {
38+
auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
39+
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis);
5040
}
5141

5242
if (res.get_element_type() != context.get_output_type(0)) {

0 commit comments

Comments
 (0)