Skip to content

Commit 9a85e53

Browse files
committed
Add GeGLU
1 parent fa237a1 commit 9a85e53

File tree

5 files changed

+85
-11
lines changed

5 files changed

+85
-11
lines changed

ggml/src/ggml-openvino/ggml-openvino.cpp

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -249,17 +249,30 @@ static bool is_op_unsupported_case(const ggml_tensor* op) {
249249
const auto* op_params = op->op_params;
250250
memcpy(&scale, (const float*) op_params + 0, sizeof(float));
251251
memcpy(&max_bias, (const float*) op_params + 1, sizeof(float));
252-
const uint32_t h = op->src[0]->ne[2];
253-
const uint32_t n_head = op->src[0]->ne[0];
254-
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
255-
256-
const float m0 = powf(2.0f, -(max_bias) / n_head_log2);
257-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
258-
const float slope =
259-
(max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f;
252+
if (max_bias > 0) {
253+
GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with max_bias > 0\n");
254+
return true;
255+
}
256+
}
260257

261-
if (slope != 1.0f) {
262-
GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with slope != 1.0f\n");
258+
if (op->op == GGML_OP_FLASH_ATTN_EXT) {
259+
if (op->src[4] != nullptr) {
260+
GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with sinks\n");
261+
return true;
262+
}
263+
float scale = 1.0f;
264+
float max_bias = 0.0f;
265+
float logit_softcap = 0.0f;
266+
const auto* op_params = op->op_params;
267+
memcpy(&scale, (const float*) op_params + 0, sizeof(float));
268+
memcpy(&max_bias, (const float*) op_params + 1, sizeof(float));
269+
memcpy(&logit_softcap, (const float*) op_params + 2, sizeof(float));
270+
if (max_bias > 0) {
271+
GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with max_bias > 0\n");
272+
return true;
273+
}
274+
if (logit_softcap != 0) {
275+
GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with logit_softcap != 0\n");
263276
return true;
264277
}
265278
}
@@ -357,7 +370,8 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con
357370
GGML_OP_ROPE,
358371
GGML_OP_RMS_NORM,
359372
GGML_OP_SCALE,
360-
GGML_OP_SOFT_MAX,
373+
// softmax is not updated due to replaced by flash_attn_ext
374+
// GGML_OP_SOFT_MAX,
361375
GGML_OP_SET_ROWS,
362376
GGML_OP_FLASH_ATTN_EXT,
363377
GGML_OP_CPY};
@@ -366,6 +380,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con
366380
};
367381
static const std::set<ggml_glu_op> supported_glu_ops{
368382
GGML_GLU_OP_SWIGLU,
383+
GGML_GLU_OP_GEGLU,
369384
};
370385

371386
switch (op->op) {
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#include <memory>
2+
#include <openvino/core/node_output.hpp>
3+
#include <openvino/op/constant.hpp>
4+
#include <openvino/op/gelu.hpp>
5+
#include <openvino/op/multiply.hpp>
6+
#include <openvino/op/sigmoid.hpp>
7+
#include <openvino/op/slice.hpp>
8+
#include <openvino/op/split.hpp>
9+
10+
#include "../node_context.hpp"
11+
#include "../op_table.hpp"
12+
#include "../utils.hpp"
13+
14+
namespace ov {
15+
namespace frontend {
16+
namespace ggml {
17+
namespace op {
18+
19+
OutputVector translate_glu_geglu(const NodeContext& context) {
20+
num_inputs_check(context, 1, 2);
21+
22+
ov::Output<ov::Node> src0;
23+
ov::Output<ov::Node> src1;
24+
if (context.get_input_size() == 2) {
25+
src0 = context.get_input(0);
26+
src1 = context.get_input(1);
27+
} else {
28+
auto combined = context.get_input(0);
29+
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {2});
30+
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
31+
src0 = split->output(0);
32+
src1 = split->output(1);
33+
}
34+
35+
int32_t* params = context.get_output_op_params(0);
36+
const int32_t swapped = params[1];
37+
if (swapped) {
38+
std::swap(src0, src1);
39+
}
40+
41+
auto gelu = std::make_shared<ov::op::v7::Gelu>(src0);
42+
auto res = std::make_shared<ov::op::v1::Multiply>(gelu, src1);
43+
44+
return rename_outputs_with_suffix({res}, context.get_name());
45+
}
46+
47+
} // namespace op
48+
} // namespace ggml
49+
} // namespace frontend
50+
} // namespace ov

ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ OutputVector translate_glu_swiglu(const NodeContext& context) {
3131
src0 = split->output(0);
3232
src1 = split->output(1);
3333
}
34+
35+
int32_t* params = context.get_output_op_params(0);
36+
const int32_t swapped = params[1];
37+
if (swapped) {
38+
std::swap(src0, src1);
39+
}
40+
3441
auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(src0);
3542
auto silu = std::make_shared<ov::op::v1::Multiply>(src0, sigmoid);
3643
auto res = std::make_shared<ov::op::v1::Multiply>(silu, src1);

ggml/src/ggml-openvino/openvino/op_table.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ std::unordered_map<std::string, CreatorFunction> get_supported_ops() {
3434
{"GGML_UNARY_OP_SILU", op::translate_unary_silu },
3535
{"GGML_OP_VIEW", op::translate_view },
3636
{"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu },
37+
{"GGML_GLU_OP_GEGLU", op::translate_glu_geglu },
3738
{"GGML_OP_SET_ROWS", op::translate_set_rows },
3839
{"GGML_OP_CPY", op::translate_cpy },
3940
{"GGML_OP_FLASH_ATTN_EXT", op::translate_flash_attn_ext },

ggml/src/ggml-openvino/openvino/op_table.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ GGML_OP_CONVERTER(translate_soft_max);
2525
GGML_OP_CONVERTER(translate_transpose);
2626
GGML_OP_CONVERTER(translate_view);
2727
GGML_OP_CONVERTER(translate_glu_swiglu);
28+
GGML_OP_CONVERTER(translate_glu_geglu);
2829
GGML_OP_CONVERTER(translate_set_rows);
2930
GGML_OP_CONVERTER(translate_cpy);
3031
GGML_OP_CONVERTER(translate_flash_attn_ext);

0 commit comments

Comments
 (0)