Skip to content

Commit 8934f73

Browse files
committed
draft NPU support version 2: prefill + kvcache
1 parent 9f99529 commit 8934f73

File tree

7 files changed

+211
-113
lines changed

7 files changed

+211
-113
lines changed

ggml/src/ggml-openvino/ggml-decoder.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,22 +108,25 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node) {
108108
ov::PartialShape input_shape;
109109
if (std::string(src->name) == "inp_tokens" || std::string(src->name) == "inp_pos") {
110110
if (m_is_static) {
111-
input_shape = ov::PartialShape(get_shape(src));
112-
// if (m_is_first_token) {
113-
// input_shape = ov::PartialShape{1, 1, m_max_token_len};
114-
// } else {
115-
// input_shape = ov::PartialShape{1, 1, 1};
116-
// }
111+
if (m_is_first_token) {
112+
input_shape = ov::PartialShape{1, 1, m_max_token_len};
113+
} else {
114+
input_shape = ov::PartialShape{1, 1, 1};
115+
}
117116
} else {
118117
input_shape = ov::PartialShape{1, 1, ov::Dimension(1, m_max_token_len)};
119118
}
120-
} else if (std::string(src->name).find("KQ_mask") == 0) {
119+
} else if (std::string(src->name) == "KQ_mask") {
121120
if (m_is_static) {
122-
input_shape = ov::PartialShape(get_shape(src));
121+
if (m_is_first_token) {
122+
input_shape = ov::PartialShape{1, m_max_token_len, m_max_token_len};
123+
} else {
124+
input_shape = ov::PartialShape{1, 1, m_max_token_len};
125+
}
123126
} else {
124-
auto max_token_len = GGML_PAD(m_max_token_len, GGML_KQ_MASK_PAD);
127+
auto max_mask_size = GGML_PAD(m_max_token_len, GGML_KQ_MASK_PAD);
125128
input_shape =
126-
ov::PartialShape{1, ov::Dimension(1, max_token_len), ov::Dimension(1, max_token_len)};
129+
ov::PartialShape{1, ov::Dimension(1, max_mask_size), ov::Dimension(1, max_mask_size)};
127130
}
128131
} else {
129132
input_shape = ov::Shape{get_shape(src)};
@@ -208,6 +211,7 @@ void GgmlOvDecoder::set_max_token_len() {
208211

209212
void GgmlOvDecoder::add_extra_inputs() {
210213
int64_t past_token_len;
214+
// attention_size not used for NPU
211215
int64_t attention_size;
212216

213217
for (const auto& node : m_nodes) {
@@ -231,8 +235,7 @@ void GgmlOvDecoder::add_extra_inputs() {
231235
for (const auto& node : m_nodes) {
232236
if (node->src[1] && std::string(node->src[1]->name).find("inp_tokens") == 0) {
233237
int64_t total_token_len = node->src[1]->ne[0] + past_token_len;
234-
attention_size = (total_token_len + 31) / 32 * 32;
235-
238+
attention_size = GGML_PAD(total_token_len, 32);
236239
std::string name = "attention_size";
237240
auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{1});
238241
param_node->set_friendly_name(name);

ggml/src/ggml-openvino/ggml-decoder.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,12 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
9292
virtual bool is_static() const override {
9393
return m_is_static;
9494
}
95-
virtual bool is_first_token() const {
95+
virtual bool is_first_token() const override {
9696
return m_is_first_token;
9797
}
98+
virtual int get_max_token_len() const override {
99+
return m_max_token_len;
100+
}
98101

99102
private:
100103
void set_input_output(ggml_tensor* node);
@@ -106,7 +109,7 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
106109
static std::shared_ptr<ov::Node> create_weight_node(ggml_tensor* tensor);
107110

108111
void set_max_token_len();
109-
int64_t m_max_token_len;
112+
int m_max_token_len;
110113

111114
void add_weight_const_parallel(std::map<std::string, std::shared_ptr<ov::Node>>& model_weights);
112115

ggml/src/ggml-openvino/openvino/decoder.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <cstdint>
34
#include <map>
45
#include <openvino/core/node.hpp>
56
#include <openvino/frontend/decoder.hpp>
@@ -57,6 +58,8 @@ class GgmlDecoder : public DecoderBase {
5758
virtual const std::vector<std::string>& get_model_output_names() const = 0;
5859

5960
virtual bool is_static() const = 0;
61+
virtual bool is_first_token() const = 0;
62+
virtual int get_max_token_len() const = 0;
6063
};
6164

6265
} // namespace ggml

ggml/src/ggml-openvino/openvino/node_context.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <cstdint>
34
#include <openvino/frontend/node_context.hpp>
45

56
#include "decoder.hpp"
@@ -87,6 +88,12 @@ class NodeContext : public frontend::NodeContext {
8788
bool is_static() const {
8889
return m_decoder->is_static();
8990
}
91+
bool is_first_token() const {
92+
return m_decoder->is_first_token();
93+
}
94+
int get_max_token_len() const {
95+
return m_decoder->get_max_token_len();
96+
}
9097

9198
private:
9299
std::shared_ptr<GgmlDecoder> m_decoder;

ggml/src/ggml-openvino/openvino/op/cpy.cpp

Lines changed: 38 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include <openvino/op/broadcast.hpp>
99
#include <openvino/op/concat.hpp>
1010
#include <openvino/op/constant.hpp>
11-
#include <openvino/op/convert_like.hpp>
11+
#include <openvino/op/convert.hpp>
1212
#include <openvino/op/range.hpp>
1313
#include <openvino/op/reshape.hpp>
1414
#include <openvino/op/scatter_nd_update.hpp>
@@ -34,18 +34,26 @@ OutputVector translate_cpy(const NodeContext& context) {
3434

3535
auto src0 = context.get_input(0);
3636
auto src1 = context.get_input(1);
37-
auto past_token_len = context.get_input("past_token_len");
37+
auto past_token_len_scalar = context.get_input("past_token_len");
38+
39+
src0 = std::make_shared<ov::op::v0::Convert>(src0, context.get_input_type(1));
3840
ov::Output<Node> res;
3941

42+
if (context.is_static() && context.is_first_token()) {
43+
res = src0;
44+
return rename_outputs_with_suffix({res}, context.get_name());
45+
}
46+
4047
auto src0_shape = context.get_input_shape(0).to_shape();
4148
auto output_shape = context.get_output_shape(0).to_shape();
4249

4350
std::vector<size_t> input0_strides = context.get_input_stride(0);
4451
std::vector<size_t> output_strides = context.get_output_stride(0);
4552

46-
auto one = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1});
53+
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
54+
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
55+
auto one_scalar = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1});
4756

48-
src0 = std::make_shared<ov::op::v1::ConvertLike>(src0, src1);
4957
if (op_case == 1) {
5058
// Write K to cache_k
5159
int64_t head_size = src0_shape[2];
@@ -56,69 +64,36 @@ OutputVector translate_cpy(const NodeContext& context) {
5664
auto reshaped_src1 = std::make_shared<ov::op::v1::Reshape>(src1, reshaped_src1_shape, false);
5765

5866
auto token_len = get_dimensions(src0.get_node_shared_ptr(), {0});
59-
token_len = std::make_shared<ov::op::v1::Reshape>(token_len,
60-
ov::op::v0::Constant::create(ov::element::i64, {0}, {}),
61-
false);
67+
auto token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(token_len, zero);
6268

69+
std::shared_ptr<ov::Node> indices;
6370
if (context.is_static()) {
64-
int32_t* op_params = context.get_input_op_params(1);
65-
int64_t past_token_len_val = op_params[0] / context.get_input_stride(1)[2] / num_heads / head_size;
66-
past_token_len = ov::op::v0::Constant::create(ov::element::i64, {}, {past_token_len_val});
71+
indices = past_token_len_scalar.get_node_shared_ptr();
72+
indices = std::make_shared<ov::op::v0::Unsqueeze>(
73+
indices,
74+
ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector<int64_t>{0, 1}));
75+
} else {
76+
auto total_token_len_scalar = std::make_shared<ov::op::v1::Add>(past_token_len_scalar, token_len_scalar);
77+
indices = std::make_shared<ov::op::v4::Range>(past_token_len_scalar,
78+
total_token_len_scalar,
79+
one_scalar,
80+
ov::element::i64);
81+
indices = std::make_shared<ov::op::v0::Unsqueeze>(indices, one);
6782
}
6883

69-
auto total_token_len = std::make_shared<ov::op::v1::Add>(past_token_len, token_len);
70-
std::shared_ptr<ov::Node> indices =
71-
std::make_shared<ov::op::v4::Range>(past_token_len, total_token_len, one, ov::element::i64);
72-
indices = std::make_shared<ov::op::v0::Unsqueeze>(
73-
indices,
74-
ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{1}));
75-
7684
res = std::make_shared<ov::op::v3::ScatterNDUpdate>(reshaped_src1, indices, src0);
7785
} else {
7886
// Write V to cache_v
79-
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
8087
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
8188
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
82-
8389
auto zero_scalar = ov::op::v0::Constant::create(ov::element::i64, {}, {0});
84-
auto one_scalar = ov::op::v0::Constant::create(ov::element::i64, {}, {1});
8590

8691
int64_t total_head_size = src0_shape[1];
8792
auto total_head_size_node = ov::op::v0::Constant::create(ov::element::i64, {1}, {total_head_size});
8893
auto total_head_size_scalar = std::make_shared<ov::op::v0::Squeeze>(total_head_size_node, zero);
8994

9095
auto token_len = get_dimensions(src0.get_node_shared_ptr(), {2});
9196
auto token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(token_len, zero);
92-
if (context.is_static()) {
93-
int32_t* op_params = context.get_input_op_params(1);
94-
int64_t past_token_len_val = op_params[0] / context.get_input_stride(1)[2];
95-
past_token_len = ov::op::v0::Constant::create(ov::element::i64, {}, {past_token_len_val});
96-
}
97-
auto total_token_len_scalar = std::make_shared<ov::op::v1::Add>(past_token_len, token_len_scalar);
98-
99-
// auto reshaped_src1 = std::make_shared<ov::op::v1::Reshape>(
100-
// src1,
101-
// ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{1, total_head_size, -1}),
102-
// false);
103-
104-
// auto src1_left = std::make_shared<ov::op::v8::Slice>(
105-
// reshaped_src1,
106-
// ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 0, 0}),
107-
// std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one, total_head_size_node, past_token_len}, 0),
108-
// ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 1, 1}));
109-
110-
// auto src1_right = std::make_shared<ov::op::v8::Slice>(
111-
// reshaped_src1,
112-
// std::make_shared<ov::op::v0::Concat>(ov::OutputVector{zero, zero, total_token_len}, 0),
113-
// ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{1, total_head_size, INT_MAX}),
114-
// ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 1, 1}));
115-
116-
// auto reshaped_src0 = std::make_shared<ov::op::v1::Reshape>(
117-
// src0,
118-
// ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{1, total_head_size, -1}),
119-
// false);
120-
121-
// auto res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{src1_left, reshaped_src0, src1_right}, 2);
12297

12398
// 1D tensor of shape [total_head_size], values starting from 0
12499
auto range_row =
@@ -131,8 +106,19 @@ OutputVector translate_cpy(const NodeContext& context) {
131106
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{total_head_size_node, token_len, one}, 0));
132107

133108
// 1D tensor of shape [token_len], values starting from past_token_len
134-
auto range_col =
135-
std::make_shared<ov::op::v4::Range>(past_token_len, total_token_len_scalar, one_scalar, element::i64);
109+
std::shared_ptr<ov::Node> range_col;
110+
if (context.is_static()) {
111+
range_col = past_token_len_scalar.get_node_shared_ptr();
112+
range_col = std::make_shared<ov::op::v0::Unsqueeze>(
113+
range_col,
114+
ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{0}));
115+
} else {
116+
auto total_token_len_scalar = std::make_shared<ov::op::v1::Add>(past_token_len_scalar, token_len_scalar);
117+
range_col = std::make_shared<ov::op::v4::Range>(past_token_len_scalar,
118+
total_token_len_scalar,
119+
one_scalar,
120+
ov::element::i64);
121+
}
136122
auto range_col_reshaped =
137123
std::make_shared<ov::op::v0::Unsqueeze>(range_col,
138124
ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 2}));

0 commit comments

Comments
 (0)