Skip to content

Commit b6130a7

Browse files
committed
Change openvino device_type to GPU; Enable flash_attn
1 parent 80f0969 commit b6130a7

File tree

6 files changed

+104
-30
lines changed

6 files changed

+104
-30
lines changed

ggml/src/ggml-openvino/ggml-decoder.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,13 @@ void GgmlOvDecoder::add_extra_inputs() {
299299
attention_size = mask->ne[0];
300300
break;
301301
}
302+
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
303+
auto* mask = node->src[3];
304+
if (std::string(mask->name).find("KQ_mask") != 0) {
305+
throw std::runtime_error("Unexpected flash attention node: " + std::string(mask->name));
306+
}
307+
attention_size = mask->ne[0];
308+
}
302309
}
303310

304311
{

ggml/src/ggml-openvino/ggml-openvino.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,15 @@ static void ggml_backend_openvino_device_get_memory(ggml_backend_dev_t dev, size
173173
GGML_ASSERT(free != nullptr);
174174
GGML_ASSERT(total != nullptr);
175175
ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *)dev->context;
176-
// Placeholder
177176
GGML_ASSERT(ctx->device >= 0);
178177
// ggml_openvino_set_device(ctx->device);
178+
*total = 1;
179+
*free = 1;
179180
}
180181

181182
static enum ggml_backend_dev_type ggml_backend_openvino_device_get_type(ggml_backend_dev_t dev) {
182183
GGML_UNUSED(dev);
183-
return GGML_BACKEND_DEVICE_TYPE_ACCEL;
184+
return GGML_BACKEND_DEVICE_TYPE_GPU;
184185
}
185186

186187
static void ggml_backend_openvino_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
@@ -293,7 +294,7 @@ static bool is_op_unsupported_case(const ggml_tensor* op) {
293294
GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode);
294295
return true;
295296
}
296-
if (n_dims != op->src[0]->ne[0]) {
297+
if (n_dims != 0.0f && n_dims != op->src[0]->ne[0]) {
297298
GGML_LOG_WARN("OpenVINO backend does not support ROPE with n_dims %d != src[0]->ne[0] %ld\n",
298299
n_dims,
299300
op->src[0]->ne[0]);
@@ -305,7 +306,7 @@ static bool is_op_unsupported_case(const ggml_tensor* op) {
305306
}
306307
float freq_scale;
307308
memcpy(&freq_scale, op_params + 6, sizeof(float));
308-
if (freq_scale != 1.0f) {
309+
if (freq_scale != 0.0f && freq_scale != 1.0f) {
309310
GGML_LOG_WARN("OpenVINO backend does not support ROPE with freq_scale %f != 1.0f\n", freq_scale);
310311
return true;
311312
}

ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
#include <memory>
2+
#include <openvino/op/broadcast.hpp>
3+
#include <openvino/op/concat.hpp>
24
#include <openvino/op/convert.hpp>
5+
#include <openvino/op/reshape.hpp>
36
#include <openvino/op/scaled_dot_product_attention.hpp>
7+
#include <openvino/op/transpose.hpp>
8+
#include <openvino/op/unsqueeze.hpp>
9+
410
#include "../node_context.hpp"
511
#include "../op_table.hpp"
612
#include "../utils.hpp"
@@ -24,9 +30,53 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
2430

2531
auto q = std::make_shared<ov::op::v0::Convert>(q_f32, ov::element::f16);
2632
auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{}, std::vector<float>{scale});
27-
auto res = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v , mask, scale_node, false);
28-
auto res_f32 = std::make_shared<ov::op::v0::Convert>(res, ov::element::f32);
29-
return rename_outputs_with_suffix({res_f32}, context.get_name());
33+
34+
ov::Output<ov::Node> mask_sliced;
35+
if (context.has_input("KQ_mask_sliced")) {
36+
mask_sliced = context.get_input("KQ_mask_sliced");
37+
} else {
38+
auto token_len = get_dimensions(q, {1});
39+
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
40+
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
41+
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
42+
}
43+
44+
if (mask_sliced.get_element_type() != ov::element::f16) {
45+
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
46+
}
47+
48+
auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv) {
49+
int64_t factor = q_batch / kv_batch;
50+
if (factor > 1) {
51+
auto q_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{q_batch});
52+
auto kv_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{kv_batch});
53+
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});
54+
55+
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
56+
auto kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
57+
58+
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2});
59+
auto kv_broadcast_shape =
60+
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0);
61+
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape);
62+
63+
auto new_kv_shape =
64+
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0);
65+
kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, false);
66+
}
67+
return kv;
68+
};
69+
70+
auto q_shape = context.get_input_shape(0).to_shape();
71+
auto k_shape = context.get_input_shape(1).to_shape();
72+
k = tile_kv(q_shape[0], k_shape[0], k);
73+
v = tile_kv(q_shape[0], k_shape[0], v);
74+
75+
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false);
76+
auto sdpa_f32 = std::make_shared<ov::op::v0::Convert>(sdpa, ov::element::f32);
77+
auto res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32,
78+
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
79+
return rename_outputs_with_suffix({res}, context.get_name());
3080
}
3181

3282
} // namespace op

ggml/src/ggml-openvino/openvino/op/mulmat.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -62,34 +62,34 @@ OutputVector translate_mulmat(const NodeContext& context) {
6262
auto B_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{B_batch});
6363
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});
6464

65-
auto Z_last_two_dim = get_dimensions(Z.get_node_shared_ptr(), {1, 2});
65+
auto Z_last_two_dims = get_dimensions(Z.get_node_shared_ptr(), {1, 2});
6666

6767
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
6868
auto Z_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(Z, unsqueeze_axes);
6969

7070
Output<Node> batch_small = A_batch_larger ? B_batch_node : A_batch_node;
7171
Output<Node> batch_large = A_batch_larger ? A_batch_node : B_batch_node;
7272
auto broadcast_shape =
73-
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_small, factor_node, Z_last_two_dim}, 0);
73+
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_small, factor_node, Z_last_two_dims}, 0);
7474
auto Z_broadcasted = std::make_shared<ov::op::v3::Broadcast>(Z_unsqueezed, broadcast_shape);
7575

76-
auto new_Z_shape = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_large, Z_last_two_dim}, 0);
76+
auto new_Z_shape = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_large, Z_last_two_dims}, 0);
7777
Z = std::make_shared<ov::op::v1::Reshape>(Z_broadcasted, new_Z_shape, false);
78-
}
79-
if (A_batch_larger) {
80-
B = Z;
81-
} else {
82-
A = Z;
83-
}
78+
}
79+
if (A_batch_larger) {
80+
B = Z;
81+
} else {
82+
A = Z;
83+
}
8484

85-
if (convert_out_type) {
86-
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);
87-
res = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
88-
} else {
89-
res = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);
90-
}
85+
if (convert_out_type) {
86+
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);
87+
res = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
88+
} else {
89+
res = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);
90+
}
9191

92-
return rename_outputs_with_suffix({res}, context.get_name());
92+
return rename_outputs_with_suffix({res}, context.get_name());
9393
}
9494

9595
} // namespace op

ggml/src/ggml-openvino/openvino/op/softmax.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,18 @@ OutputVector translate_soft_max(const NodeContext& context) {
5151
return rename_outputs_with_suffix({res}, context.get_name());
5252
}
5353

54-
auto mask_node = context.get_input(1);
54+
ov::Output<ov::Node> mask_node_sliced;
55+
if (context.has_input("KQ_mask_sliced")) {
56+
mask_node_sliced = context.get_input("KQ_mask_sliced");
57+
} else {
58+
auto token_len = get_dimensions(input_node, {1});
59+
auto mask_node = context.get_input(1);
60+
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
61+
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
62+
mask_node_sliced = std::make_shared<ov::op::v8::Slice>(mask_node, zero, token_len, one, one);
63+
}
5564

56-
auto token_len = context.has_input("token_len") ? context.get_input("token_len") : get_dimensions(input_node, {1});
57-
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
58-
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
59-
std::shared_ptr<ov::Node> mask_node_sliced =
60-
std::make_shared<ov::op::v8::Slice>(mask_node, zero, token_len, one, one);
61-
if (mask_node_sliced->get_element_type() != context.get_output_type(0)) {
65+
if (mask_node_sliced.get_element_type() != context.get_output_type(0)) {
6266
mask_node_sliced = std::make_shared<ov::op::v0::Convert>(mask_node_sliced, context.get_output_type(0));
6367
}
6468

ggml/src/ggml-openvino/openvino/translate_session.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ namespace ggml {
3636
using namespace ov::op;
3737

3838
namespace {
39+
3940
ov::pass::MakeStateful::ParamResPairs get_kv_param_res_pairs(
4041
const std::shared_ptr<ov::Model>& model, const std::map<std::string, std::string>& kv_param_res_names) {
4142
ov::pass::MakeStateful::ParamResPairs pairs;
@@ -76,6 +77,16 @@ void add_token_len(TensorMap& tensor_map) {
7677
tensor_map.insert({"token_len", token_len->output(0)});
7778
}
7879

80+
void add_sliced_mask(TensorMap& tensor_map) {
81+
auto mask = tensor_map.at("KQ_mask").get_node_shared_ptr();
82+
auto token_len = tensor_map.at("token_len").get_node_shared_ptr();
83+
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
84+
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
85+
std::shared_ptr<ov::Node> mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
86+
mask_sliced->set_friendly_name("KQ_mask_sliced");
87+
tensor_map.insert({"KQ_mask_sliced", mask_sliced->output(0)});
88+
}
89+
7990
void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
8091
int32_t* rope_params = ggml_model_decoder.get_rope_params();
8192
auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr();
@@ -97,6 +108,7 @@ void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
97108
// Create common patterns
98109
void preprocess(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
99110
add_token_len(tensor_map);
111+
add_sliced_mask(tensor_map);
100112
add_rope_sin_cos(tensor_map, ggml_model_decoder);
101113
}
102114

0 commit comments

Comments
 (0)