Skip to content

Commit 500aead

Browse files
committed
Support iSWA
1 parent 3d31fa6 commit 500aead

File tree

9 files changed

+124
-79
lines changed

9 files changed

+124
-79
lines changed

ggml/src/ggml-openvino/ggml-decoder.cpp

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,21 @@
3030
#include <set>
3131
#include <stdexcept>
3232
#include <string>
33+
#include <vector>
3334

3435
#include "ggml-backend-impl.h"
3536
#include "ggml-backend.h"
3637
#include "ggml-quants.hpp"
3738

3839
GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token,
39-
int context_size, int num_heads, int num_heads_kv, int head_size) :
40+
int context_size, int context_size_swa, int num_heads, int num_heads_kv, int head_size,
41+
const std::vector<int>& swa_layers) :
4042
m_cgraph(cgraph),
4143
m_node(node),
4244
m_op_name(std::string(node->name)),
4345
m_context_size(context_size),
46+
m_context_size_swa(context_size_swa),
47+
m_swa_layers(swa_layers),
4448
m_num_heads(num_heads),
4549
m_num_heads_kv(num_heads_kv),
4650
m_head_size(head_size),
@@ -204,11 +208,14 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
204208
if (node->src[0]->op != GGML_OP_VIEW) {
205209
m_op_case = 1;
206210
} else if (ggml_is_contiguous(node->src[0])) {
207-
// Permute cache_k (view)
208-
m_op_case = 2;
209-
} else {
210-
// Permute cache_v (view), deprecated, cache_v will also fall to case 2
211-
m_op_case = 3;
211+
// Permute kv cache (view)
212+
std::string src_name(node->view_src->name);
213+
int layer = extract_layer_from_name(src_name);
214+
if (!is_swa_layer(layer)) {
215+
m_op_case = 2;
216+
} else {
217+
m_op_case = 3;
218+
}
212219
}
213220
break;
214221
}
@@ -239,13 +246,34 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
239246
}
240247
}
241248

249+
int extract_layer_from_name(const std::string& name) {
250+
size_t pos1 = name.find("_l");
251+
assert(pos1 != std::string::npos);
252+
pos1 += 2;
253+
size_t pos2 = name.find(' ', pos1);
254+
if (pos2 == std::string::npos) {
255+
pos2 = name.length();
256+
}
257+
std::string layer_str = name.substr(pos1, pos2 - pos1);
258+
int layer = std::stoi(layer_str);
259+
return layer;
260+
}
261+
242262
void GgmlOvDecoder::set_llm_params() {
243263
for (int i = 0; i < m_cgraph->n_nodes; i++) {
244264
auto* node = m_cgraph->nodes[i];
245265
std::string name = std::string(node->name);
246-
if (node->op == GGML_OP_VIEW && std::string(node->name) == "cache_k_l0 (view)") {
247-
auto* cache_k = node->src[0];
248-
m_context_size = cache_k->ne[1];
266+
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
267+
auto* cache_k = node->src[1];
268+
cache_k = cache_k->view_src ? cache_k->view_src : cache_k;
269+
int layer = extract_layer_from_name(cache_k->name);
270+
271+
if (std::string(node->src[3]->name).find("swa") != std::string::npos) {
272+
m_swa_layers.push_back(layer);
273+
m_context_size_swa = cache_k->ne[1];
274+
} else {
275+
m_context_size = cache_k->ne[1];
276+
}
249277
} else if (node->op == GGML_OP_ROPE &&
250278
(name.find("Qcur-0") == 0 || std::string(node->src[0]->name).find("Qcur-0") == 0)) {
251279
m_head_size = node->ne[0];
@@ -269,25 +297,24 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
269297
input_shape = ov::PartialShape{1, 1, 1};
270298
}
271299
} else {
272-
input_shape = ov::PartialShape{1, 1, ov::Dimension(1, m_context_size)};
300+
input_shape = ov::PartialShape{1, 1, -1};
273301
}
274302
} else if (name == "inp_out_ids" && !m_is_static) {
275-
input_shape = ov::PartialShape{1, 1, ov::Dimension(1, m_context_size)};
276-
} else if (name == "KQ_mask") {
303+
input_shape = ov::PartialShape{1, 1, -1};
304+
} else if (name.find("KQ_mask") == 0) {
277305
if (m_is_static) {
278306
if (m_is_first_token) {
279307
input_shape = ov::PartialShape{1, m_context_size, m_context_size};
280308
} else {
281309
input_shape = ov::PartialShape{1, 1, m_context_size};
282310
}
283311
} else {
284-
auto max_mask_size = GGML_PAD(m_context_size, GGML_KQ_MASK_PAD);
285-
input_shape = ov::PartialShape{1, ov::Dimension(1, max_mask_size), ov::Dimension(1, max_mask_size)};
312+
input_shape = ov::PartialShape{1, -1, -1};
286313
}
287-
} else if (name.find("cache_k") == 0) {
288-
input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size};
289-
} else if (name.find("cache_v") == 0) {
290-
input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size};
314+
} else if (name.find("cache_") == 0) {
315+
int layer = extract_layer_from_name(name);
316+
bool is_swa = is_swa_layer(layer);
317+
input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size};
291318
} else if (const auto* op = get_tensor_used_op(src); op && op->op == GGML_OP_SET_ROWS) {
292319
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
293320
} else if (src->op == GGML_OP_VIEW) {
@@ -305,35 +332,35 @@ void GgmlOvDecoder::add_extra_inputs() {
305332
// see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.
306333
// Not used for NPU
307334
int64_t attention_size = -1;
335+
int64_t attention_size_swa = -1;
308336
for (const auto& node : m_nodes) {
309-
if (node->op == GGML_OP_SOFT_MAX) {
310-
auto* mask = node->src[1];
311-
if (std::string(mask->name).find("KQ_mask") != 0) {
312-
throw std::runtime_error("Unexpected softmax node: " + std::string(mask->name));
313-
}
314-
attention_size = mask->ne[0];
315-
break;
316-
}
317337
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
318338
auto* mask = node->src[3];
319-
if (std::string(mask->name).find("KQ_mask") != 0) {
339+
std::string mask_name(mask->name);
340+
if (mask_name.find("KQ_mask") != 0) {
320341
throw std::runtime_error("Unexpected flash attention node: " + std::string(mask->name));
321342
}
322-
attention_size = mask->ne[0];
343+
if (mask_name.find("swa") != std::string::npos) {
344+
attention_size_swa = mask->ne[0];
345+
} else {
346+
attention_size = mask->ne[0];
347+
}
323348
}
324349
}
325350

326-
{
327-
std::string name = "attention_size";
351+
auto create_attention_size_input = [this](const std::string& name, int64_t size) {
328352
auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{1});
329353
param_node->set_friendly_name(name);
330354
param_node->output(0).get_tensor().set_names({name});
331355
m_model_extra_inputs[name] = param_node;
332356

333357
auto tensor = std::make_shared<ov::Tensor>(ov::element::i64, ov::Shape{1});
334-
*tensor->data<int64_t>() = attention_size;
358+
*tensor->data<int64_t>() = size;
335359
m_model_extra_input_values[name] = tensor;
336-
}
360+
};
361+
362+
create_attention_size_input("attention_size", attention_size);
363+
create_attention_size_input("attention_size_swa", attention_size_swa);
337364
}
338365

339366
const ggml_tensor* GgmlOvDecoder::get_tensor_used_op(const ggml_tensor* tensor) const {
@@ -706,8 +733,16 @@ int32_t* GgmlOvDecoder::get_output_op_params(const std::string& name) const {
706733

707734
void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const {
708735
for (const auto& node : m_nodes) {
709-
auto decoder = std::make_shared<GgmlOvDecoder>(
710-
node, m_cgraph, m_is_static, m_is_first_token, m_context_size, m_num_heads, m_num_heads_kv, m_head_size);
736+
auto decoder = std::make_shared<GgmlOvDecoder>(node,
737+
m_cgraph,
738+
m_is_static,
739+
m_is_first_token,
740+
m_context_size,
741+
m_context_size_swa,
742+
m_num_heads,
743+
m_num_heads_kv,
744+
m_head_size,
745+
m_swa_layers);
711746
node_visitor(decoder);
712747
}
713748
}

ggml/src/ggml-openvino/ggml-decoder.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
1919

2020
// Node decoder, called in GgmlOvDecoder::visit_subgraph
2121
GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token,
22-
int context_size, int num_heads, int num_heads_kv, int head_size);
22+
int context_size, int context_size_swa, int num_heads, int num_heads_kv, int head_size,
23+
const std::vector<int>& swa_layers);
2324

2425
// Naive graph decoder
2526
GgmlOvDecoder(struct ggml_cgraph* cgraph, std::map<std::string, std::shared_ptr<ov::Node>>& model_weights);
@@ -101,6 +102,12 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
101102

102103
virtual int get_context_size() const override { return m_context_size; }
103104

105+
virtual int get_context_size_swa() const override { return m_context_size_swa; }
106+
107+
virtual int is_swa_layer(int layer) const override {
108+
return std::find(m_swa_layers.begin(), m_swa_layers.end(), layer) != m_swa_layers.end();
109+
}
110+
104111
virtual int get_num_heads() const override { return m_num_heads; }
105112

106113
virtual int get_num_heads_kv() const override { return m_num_heads_kv; }
@@ -156,6 +163,8 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
156163
std::map<std::string, std::shared_ptr<ov::Node>> m_model_weights;
157164
std::vector<std::string> m_model_output_names;
158165
int m_context_size;
166+
int m_context_size_swa;
167+
std::vector<int> m_swa_layers;
159168
int m_num_heads;
160169
int m_num_heads_kv;
161170
int m_head_size;
@@ -166,3 +175,5 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
166175
};
167176

168177
void print_tensor_address_map(const struct ggml_cgraph* cgraph);
178+
179+
int extract_layer_from_name(const std::string& name);

ggml/src/ggml-openvino/openvino/decoder.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ class GgmlDecoder : public DecoderBase {
6767
virtual bool is_static() const = 0;
6868
virtual bool is_first_token() const = 0;
6969
virtual int get_context_size() const = 0;
70+
virtual int get_context_size_swa() const = 0;
71+
virtual int is_swa_layer(int layer) const = 0;
7072
};
7173

7274
} // namespace ggml

ggml/src/ggml-openvino/openvino/node_context.hpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <cstdint>
44
#include <openvino/frontend/node_context.hpp>
5+
#include <string>
56

67
#include "decoder.hpp"
78

@@ -30,6 +31,8 @@ class NodeContext : public frontend::NodeContext {
3031
return m_translate_session;
3132
}
3233

34+
const std::vector<std::string>& get_input_names() const { return m_input_names; }
35+
3336
size_t get_input_size() const override {
3437
return m_decoder->get_input_size();
3538
}
@@ -101,15 +104,7 @@ class NodeContext : public frontend::NodeContext {
101104
return m_decoder->is_first_token();
102105
}
103106

104-
int get_num_heads() const { return m_decoder->get_num_heads(); }
105-
106-
int get_num_heads_kv() const { return m_decoder->get_num_heads_kv(); }
107-
108-
int get_head_size() const { return m_decoder->get_head_size(); }
109-
110-
int get_context_size() const { return m_decoder->get_context_size(); }
111-
112-
private:
107+
private:
113108
std::shared_ptr<GgmlDecoder> m_decoder;
114109
std::shared_ptr<TensorMap>& m_tensor_map;
115110
TranslateSession* m_translate_session;

ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <openvino/op/scaled_dot_product_attention.hpp>
77
#include <openvino/op/transpose.hpp>
88
#include <openvino/op/unsqueeze.hpp>
9+
#include <string>
910

1011
#include "../node_context.hpp"
1112
#include "../op_table.hpp"
@@ -32,8 +33,12 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
3233
auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{}, std::vector<float>{scale});
3334

3435
ov::Output<ov::Node> mask_sliced;
35-
if (context.has_input("KQ_mask_sliced")) {
36-
mask_sliced = context.get_input("KQ_mask_sliced");
36+
std::string mask_name = "KQ_mask_sliced";
37+
if (context.get_input_names()[3].find("swa") != std::string::npos) {
38+
mask_name = "KQ_mask_swa_sliced";
39+
}
40+
if (context.has_input(mask_name)) {
41+
mask_sliced = context.get_input(mask_name);
3742
} else {
3843
auto token_len = get_dimensions(q, {1});
3944
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});

ggml/src/ggml-openvino/openvino/op/permute.cpp

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,43 +29,29 @@ OutputVector translate_permute(const NodeContext& context) {
2929
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
3030
} else {
3131
auto src = context.get_input(0);
32-
auto attention_size = context.get_input("attention_size");
32+
Output<Node> attention_size;
3333
if (context.is_static()) {
3434
attention_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {INT_MAX});
35+
} else if (op_case == 2) {
36+
attention_size = context.get_input("attention_size");
37+
} else {
38+
attention_size = context.get_input("attention_size_swa");
3539
}
3640

3741
auto src_shape_ = context.get_input_shape(0).to_shape();
3842
std::vector<int64_t> src_shape(src_shape_.begin(), src_shape_.end());
3943

40-
std::shared_ptr<ov::Node> src_reshaped;
41-
if (op_case == 2) {
42-
src_reshaped = std::make_shared<ov::op::v1::Reshape>(
43-
src,
44-
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{-1, src_shape[1], src_shape[2]}),
45-
false);
46-
} else {
47-
src_reshaped = std::make_shared<ov::op::v1::Reshape>(
48-
src,
49-
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{src_shape[1], src_shape[0], -1}),
50-
false);
51-
}
44+
auto src_reshaped = std::make_shared<ov::op::v1::Reshape>(
45+
src,
46+
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{-1, src_shape[1], src_shape[2]}),
47+
false);
5248

5349
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
5450
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
55-
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
56-
std::shared_ptr<ov::Node> slice_axis;
57-
if (op_case == 2) {
58-
slice_axis = zero;
59-
} else {
60-
slice_axis = two;
61-
}
62-
auto src_slice = std::make_shared<ov::op::v8::Slice>(src_reshaped, zero, attention_size, one, slice_axis);
51+
auto src_slice = std::make_shared<ov::op::v8::Slice>(src_reshaped, zero, attention_size, one, zero);
6352

64-
if (op_case == 2) {
65-
res = std::make_shared<ov::op::v1::Transpose>(src_slice, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
66-
} else {
67-
res = src_slice;
68-
}
53+
res = std::make_shared<ov::op::v1::Transpose>(src_slice,
54+
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
6955
}
7056
return rename_outputs_with_suffix({res}, context.get_name());
7157
}

ggml/src/ggml-openvino/openvino/translate_session.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,22 @@ void add_token_len(TensorMap& tensor_map) {
7878
}
7979

8080
void add_sliced_mask(TensorMap& tensor_map) {
81-
auto mask = tensor_map.at("KQ_mask").get_node_shared_ptr();
8281
auto token_len = tensor_map.at("token_len").get_node_shared_ptr();
83-
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
84-
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
85-
std::shared_ptr<ov::Node> mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
86-
mask_sliced->set_friendly_name("KQ_mask_sliced");
87-
tensor_map.insert({"KQ_mask_sliced", mask_sliced->output(0)});
82+
83+
auto create_sliced_mask = [&](const std::string& mask_name, const std::string& sliced_name) {
84+
if (tensor_map.find(mask_name) != tensor_map.end()) {
85+
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
86+
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
87+
auto mask = tensor_map.at(mask_name).get_node_shared_ptr();
88+
std::shared_ptr<ov::Node> mask_sliced =
89+
std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
90+
mask_sliced->set_friendly_name(sliced_name);
91+
tensor_map.insert({sliced_name, mask_sliced->output(0)});
92+
}
93+
};
94+
95+
create_sliced_mask("KQ_mask", "KQ_mask_sliced");
96+
create_sliced_mask("KQ_mask_swa", "KQ_mask_swa_sliced");
8897
}
8998

9099
void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {

ggml/src/ggml-openvino/utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, cons
362362
input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name);
363363
}
364364

365-
} else if (param_name == "KQ_mask") {
365+
} else if (param_name.find("KQ_mask") == 0) {
366366
size_t context_size = ggml_decoder->get_context_size();
367367
const auto* input_tensor_ggml = ggml_decoder->get_input_ggml_tensor(param_name);
368368
if (is_first_token) {

src/llama-graph.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
16421642
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
16431643

16441644
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1645+
ggml_set_name(inp->self_kq_mask, "KQ_mask");
16451646
ggml_set_input(inp->self_kq_mask);
16461647

16471648
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1656,6 +1657,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
16561657
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
16571658

16581659
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1660+
ggml_set_name(inp->self_kq_mask_swa, "KQ_mask_swa");
16591661
ggml_set_input(inp->self_kq_mask_swa);
16601662

16611663
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

0 commit comments

Comments
 (0)