Skip to content

Commit 2442668

Browse files
committed
change graph to 4d, use scatterupdate
1 parent 75c720a commit 2442668

20 files changed

+288
-252
lines changed

ggml/src/ggml-openvino/ggml-decoder.cpp

Lines changed: 83 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,26 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_tensor * node,
4444
int num_heads_kv,
4545
int head_size,
4646
const std::vector<int> & swa_layers) :
47+
m_is_static(is_static),
4748
m_cgraph(cgraph),
4849
m_node(node),
4950
m_op_name(std::string(node->name)),
50-
m_context_size(context_size),
51-
m_context_size_swa(context_size_swa),
52-
m_swa_layers(swa_layers),
53-
m_num_heads(num_heads),
54-
m_num_heads_kv(num_heads_kv),
51+
m_ctx(context_size),
52+
m_ctx_swa(context_size_swa),
53+
m_n_heads(num_heads),
54+
m_n_heads_kv(num_heads_kv),
5555
m_head_size(head_size),
56-
m_is_static(is_static) {
56+
m_swa_layers(swa_layers) {
5757
set_input_output(node);
5858
}
5959

6060
GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph,
6161
std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
6262
bool is_static) :
63+
m_is_static(is_static),
6364
m_cgraph(cgraph),
6465
m_op_name(m_node ? std::string(m_node->name) : ""),
65-
m_model_weights(model_weights),
66-
m_is_static(is_static) {
66+
m_model_weights(model_weights) {
6767
if (auto * env = getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS"); env && std::string(env) != "0") {
6868
unsetenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS");
6969
print_tensor_address_map(cgraph);
@@ -78,7 +78,7 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph,
7878
set_input_output(cur_node);
7979
}
8080

81-
// add_extra_inputs();
81+
add_extra_inputs();
8282
}
8383

8484
GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, std::map<std::string, std::shared_ptr<ov::Node>> & model_weights) {
@@ -125,7 +125,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
125125
// Add model inputs and weights constants, if called for the whole graph
126126
if (naive) {
127127
if (m_model_weights.find(src_name) == m_model_weights.end()) {
128-
auto param_node = std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(src));
128+
auto param_node =
129+
std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(node, src));
129130
param_node->set_friendly_name(src_name);
130131
param_node->output(0).get_tensor().set_names({src_name});
131132
m_model_inputs[src_name] = param_node;
@@ -142,7 +143,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
142143
if (m_model_inputs.find(src_name) != m_model_inputs.end()) {
143144
continue;
144145
}
145-
auto param_node = std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(src));
146+
auto param_node =
147+
std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(node, src));
146148
param_node->set_friendly_name(src_name);
147149
param_node->output(0).get_tensor().set_names({src_name});
148150
m_model_inputs[src_name] = param_node;
@@ -175,15 +177,20 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
175177
if (m_node) {
176178
switch (node->op) {
177179
case GGML_OP_RESHAPE: {
178-
if (node->src[0]->op == GGML_OP_RESHAPE && node->src[0]->src[0]->ne[0] == node->ne[0] &&
179-
node->src[0]->src[0]->ne[1] == node->ne[1]) {
180+
auto * src = node->src[0];
181+
if (src->op == GGML_OP_RESHAPE && src->src[0]->ne[0] == node->ne[0] && src->src[0]->ne[1] == node->ne[1]) {
180182
m_op_case = 4;
181-
} else if (node->ne[0] * node->ne[1] == node->src[0]->ne[0]) {
183+
} else if (node->ne[0] * node->ne[1] == src->ne[0]) {
182184
m_op_case = 1;
183-
} else if (node->src[0]->ne[0] * node->src[0]->ne[1] == node->ne[0]) {
185+
} else if (src->ne[0] * src->ne[1] == node->ne[0]) {
184186
m_op_case = 2;
185-
} else if (node->src[0]->ne[0] * node->src[0]->ne[1] == node->ne[1]) {
187+
if (src->ne[2] * src->ne[3] == node->ne[1]) {
188+
m_op_case = 5;
189+
}
190+
} else if (src->ne[0] * src->ne[1] == node->ne[1]) {
186191
m_op_case = 3;
192+
} else if (src->ne[1] * src->ne[2] == node->ne[1]) {
193+
m_op_case = 6;
187194
}
188195
break;
189196
}
@@ -204,7 +211,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
204211
} else if (ggml_is_contiguous(node->src[0])) {
205212
std::string src_name(node->view_src->name);
206213
if (src_name.find("cache") == std::string::npos) {
207-
m_op_case = 1;
214+
// permute Qcur
215+
m_op_case = 4;
208216
} else {
209217
// Permute kv cache (view)
210218
int layer = extract_layer_from_name(src_name);
@@ -272,64 +280,80 @@ void GgmlOvDecoder::set_llm_params() {
272280
auto * node = m_cgraph->nodes[i];
273281
std::string name = std::string(node->name);
274282
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
275-
auto * cache_k = node->src[1];
276-
cache_k = cache_k->view_src ? cache_k->view_src : cache_k;
283+
auto * cache_k_perm = node->src[1];
284+
assert(cache_k_perm->op == GGML_OP_PERMUTE);
285+
auto * cache_k_view = cache_k_perm->src[0];
286+
assert(cache_k_view->op == GGML_OP_VIEW);
287+
288+
auto * cache_k = cache_k_view->src[0];
277289
int layer = extract_layer_from_name(cache_k->name);
290+
auto * mask = node->src[3];
291+
std::string mask_name(mask->name);
292+
assert(mask_name.find("KQ_mask") == 0);
278293

279294
if (std::string(node->src[3]->name).find("swa") != std::string::npos) {
280295
m_swa_layers.push_back(layer);
281-
m_context_size_swa = cache_k->ne[1];
296+
m_ctx_per_seq_swa = cache_k->ne[1];
282297
} else {
283-
m_context_size = cache_k->ne[1];
298+
m_ctx_per_seq = cache_k->ne[1];
299+
m_n_seq = cache_k->ne[2];
284300
}
301+
302+
m_n_seq_active = mask->ne[3];
303+
auto seq_size = cache_k->ne[0] * cache_k->ne[1] * ggml_type_size(cache_k->type);
304+
m_seq_active_start = ((size_t *) cache_k_view->op_params)[0] / seq_size;
305+
m_token_len_per_seq = node->ne[2];
306+
307+
if (mask_name.find("swa") != std::string::npos) {
308+
m_attention_size_swa = mask->ne[0];
309+
} else {
310+
m_attention_size = mask->ne[0];
311+
}
312+
285313
} else if (node->op == GGML_OP_ROPE) {
286314
if (name.find("Qcur-0") == 0 || std::string(node->src[0]->name).find("Qcur-0") == 0) {
287315
m_head_size = node->ne[0];
288-
m_num_heads = node->ne[1];
316+
m_n_heads = node->ne[1];
289317
m_rope_params = node->op_params;
290318
auto * inp_pos = node->src[1];
291319
m_input_len = inp_pos->ne[0];
292-
m_past_kv_len = *(int32_t *) inp_pos->data;
293320
} else if (name.find("Kcur-0") == 0 || std::string(node->src[0]->name).find("Kcur-0") == 0) {
294-
m_num_heads_kv = node->ne[1];
321+
m_n_heads_kv = node->ne[1];
295322
}
296323
}
297324
}
325+
m_ctx = m_ctx_per_seq * m_n_seq;
326+
m_ctx_swa = m_ctx_per_seq_swa * m_n_seq;
298327
}
299328

300-
void GgmlOvDecoder::validate_cgraph() const {}
329+
void GgmlOvDecoder::validate_cgraph() const {
330+
if (m_n_seq > 1 && m_is_static == true) {
331+
throw std::runtime_error("n_seq > 1 is not supported on NPU");
332+
}
333+
}
301334

302-
ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * src) const {
303-
auto name = std::string(src->name);
335+
ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const {
336+
auto name = std::string(input->name);
304337
ov::PartialShape input_shape;
305338

306339
if (name == "inp_tokens" || name == "inp_pos" || name == "inp_out_ids") {
307-
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
340+
input_shape = ov::PartialShape{1, 1, 1, m_is_static ? 1 : -1};
308341

309342
} else if (name.find("KQ_mask") == 0) {
310343
if (m_is_static) {
311-
input_shape = ov::PartialShape{1, 1, m_context_size};
344+
input_shape = ov::PartialShape{1, 1, 1, m_ctx};
312345
} else {
313-
input_shape = ov::PartialShape{1, -1, -1};
314-
}
315-
316-
} else if (name.find("cache_") == 0) {
317-
auto past_token_len = -1;
318-
if (m_is_static) {
319-
int layer = extract_layer_from_name(name);
320-
bool is_swa = is_swa_layer(layer);
321-
past_token_len = is_swa ? m_context_size_swa : m_context_size;
346+
input_shape = ov::PartialShape{-1, 1, -1, -1};
322347
}
323-
input_shape = ov::PartialShape{past_token_len, m_num_heads_kv, m_head_size};
324348

325-
} else if (const auto * op = get_tensor_used_op(src); op && op->op == GGML_OP_SET_ROWS) {
326-
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
349+
} else if (op && op->op == GGML_OP_SET_ROWS && op->src[1] == input) {
350+
input_shape = ov::PartialShape{1, 1, 1, m_is_static ? 1 : -1};
327351

328-
} else if (src->op == GGML_OP_VIEW) {
352+
} else if (input->op == GGML_OP_VIEW) {
329353
// This case is added to make test-backend-ops work
330-
input_shape = ov::PartialShape{get_shape(src->view_src)};
354+
input_shape = ov::PartialShape{get_shape(input->view_src)};
331355
} else {
332-
input_shape = ov::PartialShape{get_shape(src)};
356+
input_shape = ov::PartialShape{get_shape(input)};
333357
}
334358
return input_shape;
335359
}
@@ -339,25 +363,9 @@ void GgmlOvDecoder::add_extra_inputs() {
339363
// 1. `attention_size`, used in FLASH_ATTN where the shape of the matmul's are 256 aligned,
340364
// see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.
341365
// Not used for NPU.
342-
// Update: not used anymore after the optimization of making kvcache dynamic (but breaks iSWA models)
343-
int64_t attention_size = -1;
344-
int64_t attention_size_swa = -1;
345-
for (const auto & node : m_nodes) {
346-
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
347-
auto * mask = node->src[3];
348-
std::string mask_name(mask->name);
349-
if (mask_name.find("KQ_mask") != 0) {
350-
throw std::runtime_error("Unexpected flash attention node: " + std::string(mask->name));
351-
}
352-
if (mask_name.find("swa") != std::string::npos) {
353-
attention_size_swa = mask->ne[0];
354-
} else {
355-
attention_size = mask->ne[0];
356-
}
357-
}
358-
}
366+
// 2. `n_seq_active` and `seq_active_start`, used in FLASH_ATTN_EXT to indicate the active sequences in the batch
359367

360-
auto create_attention_size_input = [this](const std::string & name, int64_t size) {
368+
auto create_1d_input = [this](const std::string & name, int64_t size) {
361369
auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{1});
362370
param_node->set_friendly_name(name);
363371
param_node->output(0).get_tensor().set_names({name});
@@ -368,10 +376,14 @@ void GgmlOvDecoder::add_extra_inputs() {
368376
m_model_extra_input_values[name] = tensor;
369377
};
370378

371-
create_attention_size_input("attention_size", attention_size);
372-
if (attention_size_swa != -1) {
373-
create_attention_size_input("attention_size_swa", attention_size_swa);
379+
create_1d_input("attention_size", m_attention_size);
380+
if (m_attention_size_swa != -1) {
381+
create_1d_input("attention_size_swa", m_attention_size_swa);
374382
}
383+
create_1d_input("seq_active_start", m_seq_active_start);
384+
create_1d_input("n_seq_active", m_n_seq_active);
385+
create_1d_input("token_len_per_seq", m_token_len_per_seq);
386+
create_1d_input("seq_active_end", m_seq_active_start + m_n_seq_active);
375387
}
376388

377389
const ggml_tensor * GgmlOvDecoder::get_tensor_used_op(const ggml_tensor * tensor) const {
@@ -472,6 +484,8 @@ std::shared_ptr<ov::Node> GgmlOvDecoder::create_weight_node(ggml_tensor * tensor
472484
auto node_shape = get_shape(tensor);
473485
auto ne_total = ggml_nelements(tensor);
474486

487+
OPENVINO_ASSERT(node_shape[0] == 1, "Got 4D weights, expect all weights to be 2D: ", tensor->name);
488+
node_shape.erase(node_shape.begin());
475489
OPENVINO_ASSERT(node_shape[0] == 1, "Got 3D weights, expect all weights to be 2D: ", tensor->name);
476490
node_shape.erase(node_shape.begin());
477491

@@ -641,15 +655,15 @@ void print_tensor_address_map(const ggml_cgraph * cgraph) {
641655

642656
std::vector<size_t> GgmlOvDecoder::get_shape(const ggml_tensor * tensor) {
643657
std::vector<size_t> shape;
644-
for (int i = GGML_MAX_DIMS - 2; i >= 0; --i) {
658+
for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) {
645659
shape.push_back(static_cast<size_t>(tensor->ne[i]));
646660
}
647661
return shape;
648662
}
649663

650664
std::vector<size_t> GgmlOvDecoder::get_stride(const ggml_tensor * tensor) {
651665
std::vector<size_t> stride;
652-
for (int i = GGML_MAX_DIMS - 2; i >= 0; --i) {
666+
for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) {
653667
stride.push_back(static_cast<size_t>(tensor->nb[i]));
654668
}
655669
return stride;
@@ -738,8 +752,8 @@ int32_t * GgmlOvDecoder::get_output_op_params(const std::string & name) const {
738752

739753
void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const {
740754
for (const auto & node : m_nodes) {
741-
auto decoder = std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_context_size, m_context_size_swa,
742-
m_num_heads, m_num_heads_kv, m_head_size, m_swa_layers);
755+
auto decoder = std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_ctx, m_ctx_swa, m_n_heads,
756+
m_n_heads_kv, m_head_size, m_swa_layers);
743757
node_visitor(decoder);
744758
}
745759
}

ggml/src/ggml-openvino/ggml-decoder.h

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -103,19 +103,19 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
103103

104104
virtual const std::vector<std::string> & get_model_output_names() const override { return m_model_output_names; }
105105

106-
virtual int get_context_size() const override { return m_context_size; }
106+
virtual int get_ctx_size() const { return m_ctx; }
107107

108-
virtual int get_context_size_swa() const override { return m_context_size_swa; }
108+
virtual int get_ctx_swa_size() const { return m_ctx_swa; }
109109

110-
virtual int is_swa_layer(int layer) const override {
111-
return std::find(m_swa_layers.begin(), m_swa_layers.end(), layer) != m_swa_layers.end();
112-
}
110+
virtual int get_ctx_per_seq() const { return m_ctx_per_seq; }
113111

114-
virtual int get_num_heads() const override { return m_num_heads; }
112+
virtual int get_ctx_per_seq_swa() const { return m_ctx_per_seq_swa; }
115113

116-
virtual int get_num_heads_kv() const override { return m_num_heads_kv; }
114+
virtual int get_n_seq() const { return m_n_seq; }
117115

118-
virtual int get_head_size() const override { return m_head_size; }
116+
virtual int is_swa_layer(int layer) const override {
117+
return std::find(m_swa_layers.begin(), m_swa_layers.end(), layer) != m_swa_layers.end();
118+
}
119119

120120
int get_past_kv_len() const { return m_past_kv_len; }
121121

@@ -127,7 +127,7 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
127127

128128
virtual bool is_static() const override { return m_is_static; }
129129

130-
ov::PartialShape get_graph_input_shape(const ggml_tensor * src) const;
130+
ov::PartialShape get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const;
131131

132132
static void dump_cgraph(const ggml_cgraph * cgraph, std::string & filename);
133133

@@ -151,10 +151,11 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
151151
static std::vector<size_t> get_stride(const ggml_tensor * tensor);
152152
static ov::element::Type get_ov_type(const ggml_tensor * tensor);
153153

154-
// set context_size, num_heads, etc
155154
void set_llm_params();
156155
void validate_cgraph() const;
157156

157+
bool m_is_static = false;
158+
158159
ggml_cgraph * m_cgraph = nullptr;
159160
ggml_tensor * m_node = nullptr;
160161
std::vector<ggml_tensor *> m_nodes;
@@ -171,17 +172,28 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
171172
std::map<std::string, std::shared_ptr<ov::Tensor>> m_model_extra_input_values;
172173
std::map<std::string, std::shared_ptr<ov::Node>> m_model_weights;
173174
std::vector<std::string> m_model_output_names;
174-
int m_context_size;
175-
int m_context_size_swa;
175+
176+
// Fixed for a model
177+
int m_ctx = -1;
178+
int m_ctx_swa = -1;
179+
int m_ctx_per_seq = -1;
180+
int m_ctx_per_seq_swa = -1;
181+
int m_n_seq = -1;
182+
int m_n_heads = -1;
183+
int m_n_heads_kv = -1;
184+
int m_head_size = -1;
176185
std::vector<int> m_swa_layers;
177-
int m_num_heads;
178-
int m_num_heads_kv;
179-
int m_head_size;
180-
int m_past_kv_len;
181-
int m_input_len;
182-
int32_t * m_rope_params;
183186
std::vector<std::string> m_kv_names;
184-
bool m_is_static = false;
187+
188+
// Changed per inference
189+
int m_n_seq_active = -1;
190+
int m_seq_active_start = -1;
191+
int m_attention_size = -1;
192+
int m_attention_size_swa = -1;
193+
int m_input_len = -1;
194+
int m_token_len_per_seq = -1;
195+
int m_past_kv_len = -1;
196+
int32_t * m_rope_params = nullptr;
185197
};
186198

187199
void print_tensor_address_map(const ggml_cgraph * cgraph);

0 commit comments

Comments
 (0)