Skip to content

Commit 2a51cde

Browse files
authored
NPU Unify PD (#14)
* Stateless. Fix llama-cli llama-server * Simplify broadcast op in attention * Replace get_output_tensor+memcpy with set_output_tensor * NPU unify PD. Unify dynamic and static dims
1 parent 9e02b3f commit 2a51cde

File tree

11 files changed

+228
-371
lines changed

11 files changed

+228
-371
lines changed

ggml/src/ggml-openvino/ggml-decoder.cpp

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
#include <openvino/op/constant.hpp>
2828
#include <openvino/op/convert.hpp>
2929
#include <openvino/op/parameter.hpp>
30-
#include <openvino/op/unsqueeze.hpp>
3130
#include <openvino/runtime/tensor.hpp>
3231
#include <optional>
3332
#include <ostream>
@@ -39,7 +38,6 @@
3938
GgmlOvDecoder::GgmlOvDecoder(ggml_tensor * node,
4039
ggml_cgraph * cgraph,
4140
bool is_static,
42-
bool is_first_token,
4341
int context_size,
4442
int context_size_swa,
4543
int num_heads,
@@ -55,25 +53,24 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_tensor * node,
5553
m_num_heads(num_heads),
5654
m_num_heads_kv(num_heads_kv),
5755
m_head_size(head_size),
58-
m_is_static(is_static),
59-
m_is_first_token(is_first_token) {
56+
m_is_static(is_static) {
6057
set_input_output(node);
6158
}
6259

6360
GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph,
6461
std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
65-
bool is_static,
66-
bool is_first_token) :
62+
bool is_static) :
6763
m_cgraph(cgraph),
6864
m_op_name(m_node ? std::string(m_node->name) : ""),
6965
m_model_weights(model_weights),
70-
m_is_static(is_static),
71-
m_is_first_token(is_first_token) {
72-
if (is_first_token && getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS")) {
66+
m_is_static(is_static) {
67+
if (auto * env = getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS"); env && std::string(env) != "0") {
68+
unsetenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS");
7369
print_tensor_address_map(cgraph);
7470
}
7571

7672
set_llm_params();
73+
validate_cgraph();
7774

7875
for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
7976
auto * cur_node = cgraph->nodes[node_n];
@@ -160,8 +157,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
160157
// Model outputs are tensors with GGML_TENSOR_FLAG_OUTPUT flag and kv_caches
161158
static std::set<std::string> debug_output_names = {};
162159
// Workaround: the final tensor "result_output" does not have GGML_TENSOR_FLAG_OUTPUT flag set in cgraph
163-
if (node->op == GGML_OP_SET_ROWS || node->flags & GGML_TENSOR_FLAG_OUTPUT || node_name.find("result") == 0 ||
164-
debug_output_names.count(node_name)) {
160+
if (node->op == GGML_OP_SET_ROWS || node->flags & GGML_TENSOR_FLAG_OUTPUT ||
161+
node_name.find("output") != std::string::npos || debug_output_names.count(node_name)) {
165162
if (node->op == GGML_OP_SET_ROWS) {
166163
assert(node_name.find("cache_k") == 0 || node_name.find("cache_v") == 0);
167164
if (auto it = std::find(m_kv_names.begin(), m_kv_names.end(), node_name); it == m_kv_names.end()) {
@@ -285,53 +282,54 @@ void GgmlOvDecoder::set_llm_params() {
285282
} else {
286283
m_context_size = cache_k->ne[1];
287284
}
288-
} else if (node->op == GGML_OP_ROPE &&
289-
(name.find("Qcur-0") == 0 || std::string(node->src[0]->name).find("Qcur-0") == 0)) {
290-
m_head_size = node->ne[0];
291-
m_num_heads = node->ne[1];
292-
m_rope_params = node->op_params;
293-
} else if (node->op == GGML_OP_ROPE &&
294-
(name.find("Kcur-0") == 0 || std::string(node->src[0]->name).find("Kcur-0") == 0)) {
295-
m_num_heads_kv = node->ne[1];
285+
} else if (node->op == GGML_OP_ROPE) {
286+
if (name.find("Qcur-0") == 0 || std::string(node->src[0]->name).find("Qcur-0") == 0) {
287+
m_head_size = node->ne[0];
288+
m_num_heads = node->ne[1];
289+
m_rope_params = node->op_params;
290+
auto * inp_pos = node->src[1];
291+
m_input_len = inp_pos->ne[0];
292+
m_past_kv_len = *(int32_t *) inp_pos->data;
293+
} else if (name.find("Kcur-0") == 0 || std::string(node->src[0]->name).find("Kcur-0") == 0) {
294+
m_num_heads_kv = node->ne[1];
295+
}
296296
}
297297
}
298298
}
299299

300+
void GgmlOvDecoder::validate_cgraph() const {
301+
if (m_is_static && m_input_len != 1) {
302+
throw std::runtime_error("Static graph (NPU) must have input_len == 1, but got " + std::to_string(m_input_len) +
303+
", try set -ub 1");
304+
}
305+
}
306+
300307
ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * src) const {
301308
auto name = std::string(src->name);
302309
ov::PartialShape input_shape;
303-
if (name == "inp_tokens" || name == "inp_pos") {
304-
if (m_is_static) {
305-
if (m_is_first_token) {
306-
input_shape = ov::PartialShape{1, 1, m_context_size};
307-
} else {
308-
input_shape = ov::PartialShape{1, 1, 1};
309-
}
310-
} else {
311-
input_shape = ov::PartialShape{1, 1, -1};
312-
}
313-
} else if (name == "inp_out_ids" && !m_is_static) {
314-
input_shape = ov::PartialShape{1, 1, -1};
310+
311+
if (name == "inp_tokens" || name == "inp_pos" || name == "inp_out_ids") {
312+
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
313+
315314
} else if (name.find("KQ_mask") == 0) {
316315
if (m_is_static) {
317-
if (m_is_first_token) {
318-
input_shape = ov::PartialShape{1, m_context_size, m_context_size};
319-
} else {
320-
input_shape = ov::PartialShape{1, 1, m_context_size};
321-
}
316+
input_shape = ov::PartialShape{1, 1, m_context_size};
322317
} else {
323318
input_shape = ov::PartialShape{1, -1, -1};
324319
}
320+
325321
} else if (name.find("cache_") == 0) {
322+
auto past_token_len = -1;
326323
if (m_is_static) {
327324
int layer = extract_layer_from_name(name);
328325
bool is_swa = is_swa_layer(layer);
329-
input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size};
330-
} else {
331-
input_shape = ov::PartialShape{1, -1, m_num_heads_kv, m_head_size};
326+
past_token_len = is_swa ? m_context_size_swa : m_context_size;
332327
}
328+
input_shape = ov::PartialShape{past_token_len, m_num_heads_kv, m_head_size};
329+
333330
} else if (const auto * op = get_tensor_used_op(src); op && op->op == GGML_OP_SET_ROWS) {
334331
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
332+
335333
} else if (src->op == GGML_OP_VIEW) {
336334
// This case is added to make test-backend-ops work
337335
input_shape = ov::PartialShape{get_shape(src->view_src)};
@@ -745,9 +743,8 @@ int32_t * GgmlOvDecoder::get_output_op_params(const std::string & name) const {
745743

746744
void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const {
747745
for (const auto & node : m_nodes) {
748-
auto decoder =
749-
std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_is_first_token, m_context_size,
750-
m_context_size_swa, m_num_heads, m_num_heads_kv, m_head_size, m_swa_layers);
746+
auto decoder = std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_context_size, m_context_size_swa,
747+
m_num_heads, m_num_heads_kv, m_head_size, m_swa_layers);
751748
node_visitor(decoder);
752749
}
753750
}

ggml/src/ggml-openvino/ggml-decoder.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,12 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
1616
// Graph decoder
1717
GgmlOvDecoder(ggml_cgraph * cgraph,
1818
std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
19-
bool is_static,
20-
bool is_first_token);
19+
bool is_static);
2120

2221
// Node decoder, called in GgmlOvDecoder::visit_subgraph
2322
GgmlOvDecoder(ggml_tensor * node,
2423
ggml_cgraph * cgraph,
2524
bool is_static,
26-
bool is_first_token,
2725
int context_size,
2826
int context_size_swa,
2927
int num_heads,
@@ -81,9 +79,9 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
8179

8280
virtual void visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const override;
8381

84-
const ggml_tensor * get_input_ggml_tensor(const std::string & name) const { return m_inputs.at(name); }
82+
ggml_tensor * get_input_ggml_tensor(const std::string & name) const { return m_inputs.at(name); }
8583

86-
const ggml_tensor * get_output_ggml_tensor(const std::string & name) const { return m_outputs.at(name); }
84+
ggml_tensor * get_output_ggml_tensor(const std::string & name) const { return m_outputs.at(name); }
8785

8886
virtual int get_op_case() const override { return m_op_case; }
8987

@@ -119,14 +117,16 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
119117

120118
virtual int get_head_size() const override { return m_head_size; }
121119

120+
int get_past_kv_len() const { return m_past_kv_len; }
121+
122+
int get_input_len() const { return m_input_len; }
123+
122124
virtual int32_t * get_rope_params() const override { return m_rope_params; }
123125

124126
virtual std::map<std::string, std::string> get_kv_param_res_names() const override;
125127

126128
virtual bool is_static() const override { return m_is_static; }
127129

128-
virtual bool is_first_token() const override { return m_is_first_token; }
129-
130130
ov::PartialShape get_graph_input_shape(const ggml_tensor * src) const;
131131

132132
static void dump_cgraph(const ggml_cgraph * cgraph, std::string & filename);
@@ -153,6 +153,7 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
153153

154154
// set context_size, num_heads, etc
155155
void set_llm_params();
156+
void validate_cgraph() const;
156157

157158
ggml_cgraph * m_cgraph = nullptr;
158159
ggml_tensor * m_node = nullptr;
@@ -176,10 +177,11 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
176177
int m_num_heads;
177178
int m_num_heads_kv;
178179
int m_head_size;
180+
int m_past_kv_len;
181+
int m_input_len;
179182
int32_t * m_rope_params;
180183
std::vector<std::string> m_kv_names;
181184
bool m_is_static = false;
182-
bool m_is_first_token;
183185
};
184186

185187
void print_tensor_address_map(const ggml_cgraph * cgraph);

ggml/src/ggml-openvino/openvino/decoder.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ class GgmlDecoder : public DecoderBase {
6565
virtual std::map<std::string, std::string> get_kv_param_res_names() const = 0;
6666

6767
virtual bool is_static() const = 0;
68-
virtual bool is_first_token() const = 0;
6968
virtual int get_context_size() const = 0;
7069
virtual int get_context_size_swa() const = 0;
7170
virtual int is_swa_layer(int layer) const = 0;

ggml/src/ggml-openvino/openvino/node_context.hpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,7 @@ class NodeContext : public frontend::NodeContext {
9797
int get_op_case() const {
9898
return m_decoder->get_op_case();
9999
}
100-
bool is_static() const {
101-
return m_decoder->is_static();
102-
}
103-
bool is_first_token() const {
104-
return m_decoder->is_first_token();
105-
}
100+
bool is_static() const { return m_decoder->is_static(); }
106101

107102
private:
108103
std::shared_ptr<GgmlDecoder> m_decoder;

ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
#include "../op_table.hpp"
33
#include "../utils.hpp"
44

5+
#include <cstdint>
56
#include <memory>
67
#include <openvino/op/broadcast.hpp>
78
#include <openvino/op/concat.hpp>
9+
#include <openvino/op/constant.hpp>
810
#include <openvino/op/convert.hpp>
911
#include <openvino/op/reshape.hpp>
1012
#include <openvino/op/scaled_dot_product_attention.hpp>
@@ -51,62 +53,38 @@ OutputVector translate_flash_attn_ext(const NodeContext & context) {
5153

5254
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, kv_len}, 0);
5355
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
54-
mask_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_sliced, zero_1d);
5556
}
5657

5758
if (mask_sliced.get_element_type() != ov::element::f16) {
5859
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
5960
}
6061

61-
auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv, bool is_static) {
62-
int64_t factor = q_batch / kv_batch;
62+
auto tile_kv = [&](int64_t num_heads, int64_t num_heads_kv, int64_t head_size, ov::Output<Node> kv) {
63+
int64_t factor = num_heads / num_heads_kv;
6364
if (factor > 1) {
64-
auto q_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{q_batch});
65-
auto kv_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{kv_batch});
66-
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});
67-
6865
ov::Output<ov::Node> kv_broadcast_shape, kv_unsqueezed, new_kv_shape;
69-
if (is_static) {
70-
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
71-
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
72-
73-
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2});
74-
kv_broadcast_shape = std::make_shared<ov::op::v0::Concat>(
75-
ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0);
76-
new_kv_shape =
77-
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0);
78-
} else {
79-
auto one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
80-
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});
81-
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
66+
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
67+
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
8268

83-
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {2, 3});
84-
kv_broadcast_shape = std::make_shared<ov::op::v0::Concat>(
85-
ov::OutputVector{one_1d, kv_batch_node, factor_node, kv_last_two_dims}, 0);
86-
new_kv_shape =
87-
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one_1d, q_batch_node, kv_last_two_dims}, 0);
88-
}
69+
kv_broadcast_shape =
70+
ov::op::v0::Constant::create(ov::element::i64, {4}, {num_heads_kv, factor, (int64_t) 1, head_size});
71+
new_kv_shape = ov::op::v0::Constant::create(ov::element::i64, {3}, {num_heads, (int64_t) -1, head_size});
8972

90-
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape);
73+
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape,
74+
ov::op::BroadcastType::BIDIRECTIONAL);
9175
kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, false);
9276
}
9377
return kv;
9478
};
9579

9680
auto q_shape = context.get_input_shape(0).to_shape();
9781
auto k_shape = context.get_input_shape(1).to_shape();
98-
k = tile_kv(q_shape[0], k_shape[0], k, context.is_static());
99-
v = tile_kv(q_shape[0], k_shape[0], v, context.is_static());
82+
k = tile_kv(q_shape[0], k_shape[0], q_shape[2], k);
83+
v = tile_kv(q_shape[0], k_shape[0], q_shape[2], v);
10084

10185
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false);
102-
auto sdpa_f32 = std::make_shared<ov::op::v0::Convert>(sdpa, ov::element::f32);
103-
if (context.is_static()) {
104-
res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32,
105-
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
106-
} else {
107-
res = std::make_shared<ov::op::v1::Transpose>(
108-
sdpa_f32, ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
109-
}
86+
res = std::make_shared<ov::op::v1::Transpose>(sdpa, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
87+
res = std::make_shared<ov::op::v0::Convert>(res, ov::element::f32);
11088
return rename_outputs_with_suffix({res}, context.get_name());
11189
}
11290

ggml/src/ggml-openvino/openvino/op/permute.cpp

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,40 +26,8 @@ OutputVector translate_permute(const NodeContext & context) {
2626
ov::Output<Node> res;
2727
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
2828

29-
if (op_case == 1) {
30-
if (context.is_static()) {
31-
res = std::make_shared<ov::op::v1::Transpose>(
32-
context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
33-
} else {
34-
auto src = context.get_input(0);
35-
if (src.get_partial_shape().rank() == 3) {
36-
src = std::make_shared<ov::op::v0::Unsqueeze>(src, zero);
37-
}
38-
res = std::make_shared<ov::op::v1::Transpose>(
39-
src, ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
40-
}
41-
} else {
42-
auto src = context.get_input(0);
43-
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
44-
45-
if (context.is_static()) {
46-
auto src_shape_ = context.get_input_shape(0).to_shape();
47-
std::vector<int64_t> src_shape(src_shape_.begin(), src_shape_.end());
48-
auto src_reshaped = std::make_shared<ov::op::v1::Reshape>(
49-
src,
50-
ov::op::v0::Constant::create(ov::element::i64, {3},
51-
std::vector<int64_t>{-1, src_shape[1], src_shape[2]}),
52-
false);
53-
res = std::make_shared<ov::op::v1::Transpose>(
54-
src_reshaped, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
55-
} else {
56-
if (src.get_partial_shape().rank() == 3) {
57-
src = std::make_shared<ov::op::v0::Unsqueeze>(src, zero);
58-
}
59-
res = std::make_shared<ov::op::v1::Transpose>(
60-
src, ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
61-
}
62-
}
29+
auto src = context.get_input(0);
30+
res = std::make_shared<ov::op::v1::Transpose>(src, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
6331
return rename_outputs_with_suffix({res}, context.get_name());
6432
}
6533

ggml/src/ggml-openvino/openvino/op/rope.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,6 @@ OutputVector translate_rope(const NodeContext & context) {
8484
ov::op::v0::Constant::create(ov::element::i64, {1}, {3}));
8585
auto stack = std::make_shared<ov::op::v0::Concat>(OutputVector{first_half, second_half}, 3);
8686
res = std::make_shared<ov::op::v1::Reshape>(stack, std::make_shared<ov::op::v0::ShapeOf>(data_node), false);
87-
if (!(context.is_static())) {
88-
res =
89-
std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
90-
}
9187
} else if (mode == ROPE_TYPE_NEOX) {
9288
auto data_split = std::make_shared<ov::op::v1::Split>(
9389
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2}), 2);

0 commit comments

Comments
 (0)