Skip to content

Commit 2221995

Browse files
committed
WIP: NPU ok, need to fix CPU GPU
1 parent e9abf1c commit 2221995

File tree

7 files changed

+129
-210
lines changed

7 files changed

+129
-210
lines changed

ggml/src/ggml-openvino/ggml-decoder.cpp

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
GgmlOvDecoder::GgmlOvDecoder(ggml_tensor * node,
4040
ggml_cgraph * cgraph,
4141
bool is_static,
42-
bool is_first_token,
4342
int context_size,
4443
int context_size_swa,
4544
int num_heads,
@@ -55,25 +54,24 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_tensor * node,
5554
m_num_heads(num_heads),
5655
m_num_heads_kv(num_heads_kv),
5756
m_head_size(head_size),
58-
m_is_static(is_static),
59-
m_is_first_token(is_first_token) {
57+
m_is_static(is_static) {
6058
set_input_output(node);
6159
}
6260

6361
GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph,
6462
std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
65-
bool is_static,
66-
bool is_first_token) :
63+
bool is_static) :
6764
m_cgraph(cgraph),
6865
m_op_name(m_node ? std::string(m_node->name) : ""),
6966
m_model_weights(model_weights),
70-
m_is_static(is_static),
71-
m_is_first_token(is_first_token) {
72-
if (is_first_token && getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS")) {
67+
m_is_static(is_static) {
68+
if (auto * env = getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS"); env && std::string(env) != "0") {
69+
unsetenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS");
7370
print_tensor_address_map(cgraph);
7471
}
7572

7673
set_llm_params();
74+
validate_cgraph();
7775

7876
for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
7977
auto * cur_node = cgraph->nodes[node_n];
@@ -300,41 +298,39 @@ void GgmlOvDecoder::set_llm_params() {
300298
}
301299
}
302300

301+
void GgmlOvDecoder::validate_cgraph() const {
302+
if (m_is_static && m_input_len != 1) {
303+
throw std::runtime_error("Static graph (NPU) must have input_len == 1, but got " + std::to_string(m_input_len) +
304+
", try set -ub 1");
305+
}
306+
}
307+
303308
ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * src) const {
304309
auto name = std::string(src->name);
305310
ov::PartialShape input_shape;
306-
if (name == "inp_tokens" || name == "inp_pos") {
307-
if (m_is_static) {
308-
if (m_is_first_token) {
309-
input_shape = ov::PartialShape{1, 1, m_context_size};
310-
} else {
311-
input_shape = ov::PartialShape{1, 1, 1};
312-
}
313-
} else {
314-
input_shape = ov::PartialShape{1, 1, -1};
315-
}
316-
} else if (name == "inp_out_ids" && !m_is_static) {
317-
input_shape = ov::PartialShape{1, 1, -1};
311+
312+
if (name == "inp_tokens" || name == "inp_pos" || name == "inp_out_ids") {
313+
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
314+
318315
} else if (name.find("KQ_mask") == 0) {
319316
if (m_is_static) {
320-
if (m_is_first_token) {
321-
input_shape = ov::PartialShape{1, m_context_size, m_context_size};
322-
} else {
323-
input_shape = ov::PartialShape{1, 1, m_context_size};
324-
}
317+
input_shape = ov::PartialShape{1, 1, m_context_size};
325318
} else {
326319
input_shape = ov::PartialShape{1, -1, -1};
327320
}
321+
328322
} else if (name.find("cache_") == 0) {
323+
auto past_token_len = -1;
329324
if (m_is_static) {
330325
int layer = extract_layer_from_name(name);
331326
bool is_swa = is_swa_layer(layer);
332-
input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size};
333-
} else {
334-
input_shape = ov::PartialShape{1, -1, m_num_heads_kv, m_head_size};
327+
past_token_len = is_swa ? m_context_size_swa : m_context_size;
335328
}
329+
input_shape = ov::PartialShape{past_token_len, m_num_heads_kv, m_head_size};
330+
336331
} else if (const auto * op = get_tensor_used_op(src); op && op->op == GGML_OP_SET_ROWS) {
337332
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
333+
338334
} else if (src->op == GGML_OP_VIEW) {
339335
// This case is added to make test-backend-ops work
340336
input_shape = ov::PartialShape{get_shape(src->view_src)};
@@ -748,9 +744,8 @@ int32_t * GgmlOvDecoder::get_output_op_params(const std::string & name) const {
748744

749745
void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const {
750746
for (const auto & node : m_nodes) {
751-
auto decoder =
752-
std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_is_first_token, m_context_size,
753-
m_context_size_swa, m_num_heads, m_num_heads_kv, m_head_size, m_swa_layers);
747+
auto decoder = std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_context_size, m_context_size_swa,
748+
m_num_heads, m_num_heads_kv, m_head_size, m_swa_layers);
754749
node_visitor(decoder);
755750
}
756751
}

ggml/src/ggml-openvino/ggml-decoder.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,12 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
1616
// Graph decoder
1717
GgmlOvDecoder(ggml_cgraph * cgraph,
1818
std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
19-
bool is_static,
20-
bool is_first_token);
19+
bool is_staticn);
2120

2221
// Node decoder, called in GgmlOvDecoder::visit_subgraph
2322
GgmlOvDecoder(ggml_tensor * node,
2423
ggml_cgraph * cgraph,
2524
bool is_static,
26-
bool is_first_token,
2725
int context_size,
2826
int context_size_swa,
2927
int num_heads,
@@ -129,8 +127,6 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
129127

130128
virtual bool is_static() const override { return m_is_static; }
131129

132-
virtual bool is_first_token() const override { return m_is_first_token; }
133-
134130
ov::PartialShape get_graph_input_shape(const ggml_tensor * src) const;
135131

136132
static void dump_cgraph(const ggml_cgraph * cgraph, std::string & filename);
@@ -157,6 +153,7 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
157153

158154
// set context_size, num_heads, etc
159155
void set_llm_params();
156+
void validate_cgraph() const;
160157

161158
ggml_cgraph * m_cgraph = nullptr;
162159
ggml_tensor * m_node = nullptr;
@@ -185,7 +182,6 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
185182
int32_t * m_rope_params;
186183
std::vector<std::string> m_kv_names;
187184
bool m_is_static = false;
188-
bool m_is_first_token;
189185
};
190186

191187
void print_tensor_address_map(const ggml_cgraph * cgraph);

ggml/src/ggml-openvino/openvino/decoder.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ class GgmlDecoder : public DecoderBase {
6565
virtual std::map<std::string, std::string> get_kv_param_res_names() const = 0;
6666

6767
virtual bool is_static() const = 0;
68-
virtual bool is_first_token() const = 0;
6968
virtual int get_context_size() const = 0;
7069
virtual int get_context_size_swa() const = 0;
7170
virtual int is_swa_layer(int layer) const = 0;

ggml/src/ggml-openvino/openvino/node_context.hpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,7 @@ class NodeContext : public frontend::NodeContext {
9797
int get_op_case() const {
9898
return m_decoder->get_op_case();
9999
}
100-
bool is_static() const {
101-
return m_decoder->is_static();
102-
}
103-
bool is_first_token() const {
104-
return m_decoder->is_first_token();
105-
}
100+
bool is_static() const { return m_decoder->is_static(); }
106101

107102
private:
108103
std::shared_ptr<GgmlDecoder> m_decoder;

ggml/src/ggml-openvino/openvino/op/set_rows.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ OutputVector translate_set_rows(const NodeContext & context) {
3333
auto dst_shape = context.get_output_shape(0).to_shape();
3434
FRONT_END_OP_CONVERSION_CHECK(dst_shape[0] == 1, "Unsupported shape in SET_ROWS");
3535

36-
if (context.is_static() && context.is_first_token()) {
37-
return rename_outputs_with_suffix({data}, context.get_name());
38-
}
39-
4036
auto indices = context.get_input(1);
4137
auto dst = context.get_input(context.get_output_name());
4238

0 commit comments

Comments
 (0)