Skip to content

Commit 8083bc7

Browse files
committed
Fix NPU
1 parent 2f77e57 commit 8083bc7

File tree

4 files changed

+65
-6
lines changed

4 files changed

+65
-6
lines changed

ggml/src/ggml-openvino/ggml-decoder.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,14 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
193193
}
194194
break;
195195
}
196+
case GGML_OP_SET_ROWS: {
197+
if (std::string(node->name).find("cache_k") == 0) {
198+
m_op_case = 1;
199+
} else {
200+
m_op_case = 2;
201+
}
202+
break;
203+
}
196204
case GGML_OP_PERMUTE: {
197205
if (node->src[0]->view_src == nullptr) {
198206
// Permute Qcur
@@ -274,8 +282,18 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
274282
input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size};
275283
} else if (name.find("cache_v") == 0) {
276284
input_shape = ov::PartialShape{m_num_heads_kv, m_head_size, m_context_size};
277-
} else if (get_tensor_used_op(src)->op == GGML_OP_SET_ROWS) {
285+
} else if (const auto* op = get_tensor_used_op(src); op->op == GGML_OP_SET_ROWS) {
278286
input_shape = ov::PartialShape{1, 1, -1};
287+
if (m_is_static) {
288+
if (m_is_first_token) {
289+
// Dummy static shape, since the indices are not used in this case
290+
input_shape = ov::PartialShape{1};
291+
} else if (std::string(op->name).find("cache_k") == 0) {
292+
input_shape = ov::PartialShape{1, 1, 1};
293+
} else {
294+
input_shape = ov::PartialShape{1, 1, m_num_heads_kv * m_head_size};
295+
}
296+
}
279297
} else if (src->op == GGML_OP_VIEW) {
280298
// This case is added to make test-backend-ops work
281299
input_shape = ov::PartialShape{get_shape(src->view_src)};
@@ -316,6 +334,7 @@ void GgmlOvDecoder::add_extra_inputs() {
316334
if (node->op == GGML_OP_SET_ROWS && std::string(node->name).find("cache_k") == 0) {
317335
assert(node->src[1]->type == GGML_TYPE_I64);
318336
past_token_len = *(int64_t*) (node->src[1]->data);
337+
break;
319338
}
320339
}
321340

@@ -366,6 +385,22 @@ const ggml_tensor* GgmlOvDecoder::get_tensor_used_op(const ggml_tensor* tensor)
366385
throw std::runtime_error("Tensor not found in cgraph");
367386
}
368387

388+
const ggml_tensor* GgmlOvDecoder::get_tensor_from_name(const std::string& name) const {
389+
for (int i = 0; i < m_cgraph->n_nodes; i++) {
390+
const auto* node = m_cgraph->nodes[i];
391+
for (int j = 0; j < GGML_MAX_SRC; j++) {
392+
const auto* src = node->src[j];
393+
if (src == nullptr) {
394+
break;
395+
}
396+
if (std::string(src->name) == name) {
397+
return src;
398+
}
399+
}
400+
}
401+
return nullptr;
402+
}
403+
369404
std::map<std::string, std::string> GgmlOvDecoder::get_kv_param_res_names() const {
370405
std::map<std::string, std::string> kv_param_res_names;
371406
for (const auto& name : m_kv_names) {

ggml/src/ggml-openvino/ggml-decoder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
119119
static std::map<std::string, std::shared_ptr<ov::Node>> create_weight_nodes(struct ggml_cgraph* cgraph);
120120

121121
const ggml_tensor* get_tensor_used_op(const ggml_tensor* tensor) const;
122+
const ggml_tensor* get_tensor_from_name(const std::string& name) const;
122123

123124
void clear_model_weights() { m_model_weights.clear(); }
124125

ggml/src/ggml-openvino/openvino/op/set_rows.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <openvino/op/shape_of.hpp>
1212
#include <openvino/op/slice.hpp>
1313
#include <openvino/op/squeeze.hpp>
14+
#include <openvino/op/transpose.hpp>
1415

1516
#include "../node_context.hpp"
1617
#include "../op_table.hpp"
@@ -25,21 +26,40 @@ OutputVector translate_set_rows(const NodeContext& context) {
2526
num_inputs_check(context, 2, 2);
2627

2728
auto data = context.get_input(0);
28-
auto indices = context.get_input(1);
29-
auto dst = context.get_input(context.get_output_name());
29+
data = std::make_shared<ov::op::v0::Convert>(data, context.get_output_type(0));
30+
3031
auto dst_shape = context.get_output_shape(0).to_shape();
3132
FRONT_END_OP_CONVERSION_CHECK(dst_shape[0] == 1, "Unsupported shape in SET_ROWS");
3233

33-
auto zero = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
34+
if (context.is_static() && context.is_first_token()) {
35+
Output<Node> res;
36+
if (context.get_op_case() == 2) {
37+
res = std::make_shared<ov::op::v1::Reshape>(
38+
data,
39+
ov::op::v0::Constant::create(
40+
ov::element::i64,
41+
{3},
42+
{context.get_context_size(), context.get_num_heads_kv(), context.get_head_size()}),
43+
false);
44+
res = std::make_shared<ov::op::v1::Transpose>(
45+
res, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 2, 0}));
46+
} else {
47+
res = data;
48+
}
49+
return rename_outputs_with_suffix({res}, context.get_name());
50+
}
3451

52+
auto indices = context.get_input(1);
53+
auto dst = context.get_input(context.get_output_name());
54+
55+
auto zero = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
3556
auto dst_reshaped = std::make_shared<ov::op::v1::Reshape>(
3657
dst,
3758
ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) dst_shape[1], (int64_t) dst_shape[2]}),
3859
false);
3960
auto indices_reshaped =
4061
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
41-
auto data_converted = std::make_shared<ov::op::v0::Convert>(data, context.get_output_type(0));
42-
auto data_reshaped = std::make_shared<ov::op::v0::Squeeze>(data_converted, zero);
62+
auto data_reshaped = std::make_shared<ov::op::v0::Squeeze>(data, zero);
4363
auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
4464
auto res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false);
4565
return rename_outputs_with_suffix({res}, context.get_name());

ggml/src/ggml-openvino/utils.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,9 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, cons
328328
std::copy(padded_data.begin(), padded_data.end(), data_ptr);
329329
}
330330

331+
} else if (const auto* op = ggml_decoder->get_tensor_used_op(ggml_decoder->get_tensor_from_name(param_name));
332+
op->op == GGML_OP_SET_ROWS && is_static && is_first_token) {
333+
input_tensor = ov::Tensor(ov::element::i64, ov::Shape{1});
331334
} else {
332335
input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name);
333336
}

0 commit comments

Comments
 (0)