Skip to content

Commit b91f9e0

Browse files
committed
SDPA in f32
1 parent ca5e725 commit b91f9e0

File tree

3 files changed

+9
-18
lines changed

3 files changed

+9
-18
lines changed

ggml/src/ggml-openvino/openvino/op/mulmat.cpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,8 @@ OutputVector translate_mulmat(const NodeContext& context) {
3131
ov::Output<ov::Node> B = context.get_input(0);
3232
ov::Output<ov::Node> A = context.get_input(1);
3333

34-
bool convert_out_type = false;
35-
if (ov::op::util::is_constant(B.get_node()) && context.get_input_type(0) != context.get_input_type(1)) {
34+
if (context.get_input_type(0) != context.get_input_type(1)) {
3635
B = std::make_shared<ov::op::v0::Convert>(context.get_input(0), context.get_input_type(1));
37-
} else if (context.get_input_type(0) != context.get_input_type(1)) {
38-
A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
39-
convert_out_type = true;
4036
}
4137

4238
auto B_shape = context.get_input_shape(0).to_shape();
@@ -71,12 +67,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
7167
A = Z;
7268
}
7369

74-
if (convert_out_type) {
75-
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
76-
res = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
77-
} else {
78-
res = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
79-
}
70+
res = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
8071

8172
return rename_outputs_with_suffix({res}, context.get_name());
8273
}

ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,13 @@ FuseToSDPA::FuseToSDPA() {
2222
const auto m_k = ov::pass::pattern::any_input();
2323
const auto m_q = ov::pass::pattern::any_input();
2424
const auto m_qk = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_q, m_k});
25-
const auto m_qk_f32 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_qk});
2625
const auto m_scale = ov::pass::pattern::any_input();
27-
const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk_f32, m_scale});
26+
const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk, m_scale});
2827
const auto m_mask = ov::pass::pattern::any_input();
2928
const auto m_masked_qk = ov::pass::pattern::wrap_type<ov::op::v1::Add>({m_scaled_qk, m_mask});
3029
const auto m_softmax_qk = ov::pass::pattern::wrap_type<ov::op::v8::Softmax>({m_masked_qk});
31-
const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_softmax_qk});
3230
const auto m_v = ov::pass::pattern::any_input();
33-
const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk_f16, m_v});
31+
const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk, m_v});
3432

3533
const auto callback = [=](ov::pass::pattern::Matcher& m) {
3634
auto& pattern_to_output = m.get_pattern_value_map();
@@ -42,9 +40,7 @@ FuseToSDPA::FuseToSDPA() {
4240

4341
auto v_trans =
4442
register_new_node<ov::op::v1::Transpose>(v, ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 2, 1}));
45-
auto mask_f16 = register_new_node<ov::op::v0::Convert>(mask, ov::element::f16);
46-
auto scale_f16 = register_new_node<ov::op::v0::Convert>(scale, ov::element::f16);
47-
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask_f16, scale_f16, false);
43+
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask, scale, false);
4844

4945
ov::replace_node(m.get_match_root(), sdpa);
5046
ov::copy_runtime_info(m.get_matched_nodes(), sdpa);

ggml/src/ggml-openvino/utils.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
7979
ov::AnyMap config;
8080
if (device == "NPU") {
8181
config = get_npu_config();
82+
} else if (device == "GPU") {
83+
config = {
84+
{"GPU_ENABLE_SDPA_OPTIMIZATION", "0"}
85+
};
8286
}
8387

8488
if (is_naive(cgraph)) {

0 commit comments

Comments
 (0)