Skip to content

Commit fdd8a38

Browse files
committed
Revert changes in fuse_to_sdpa
1 parent ff401b8 commit fdd8a38

File tree

3 files changed

+8
-15
lines changed

3 files changed

+8
-15
lines changed

ggml/src/ggml-openvino/openvino/op/soft_max.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,7 @@ OutputVector translate_soft_max(const NodeContext& context) {
5353

5454
auto mask_node = context.get_input(1);
5555

56-
std::shared_ptr<ov::Node> token_len = get_dimensions(input_node, {1});
57-
// Try using Q-cur to retrieve the token length, so that the translation of SOFT_MAX
58-
// does not depend on the result of the QK MatMul, so that QK matmul + softmax + qkv matmul
59-
// can be fused into SDPA.
60-
if (input_node->get_type_info() == ov::op::v0::MatMul::get_type_info_static()) {
61-
token_len = get_dimensions(input_node->get_input_node_shared_ptr(0), {1});
62-
}
56+
auto token_len = context.get_input("token_len");
6357
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
6458
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
6559
std::shared_ptr<ov::Node> mask_node_sliced =

ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <openvino/op/softmax.hpp>
1111
#include <openvino/op/transpose.hpp>
1212
#include <openvino/pass/pattern/op/label.hpp>
13-
#include <openvino/pass/pattern/op/optional.hpp>
1413
#include <openvino/pass/pattern/op/pattern.hpp>
1514
#include <openvino/pass/pattern/op/wrap_type.hpp>
1615

@@ -23,13 +22,15 @@ FuseToSDPA::FuseToSDPA() {
2322
const auto m_k = ov::pass::pattern::any_input();
2423
const auto m_q = ov::pass::pattern::any_input();
2524
const auto m_qk = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_q, m_k});
25+
const auto m_qk_f32 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_qk});
2626
const auto m_scale = ov::pass::pattern::any_input();
27-
const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk, m_scale});
27+
const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk_f32, m_scale});
2828
const auto m_mask = ov::pass::pattern::any_input();
2929
const auto m_masked_qk = ov::pass::pattern::wrap_type<ov::op::v1::Add>({m_scaled_qk, m_mask});
3030
const auto m_softmax_qk = ov::pass::pattern::wrap_type<ov::op::v8::Softmax>({m_masked_qk});
31+
const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_softmax_qk});
3132
const auto m_v = ov::pass::pattern::any_input();
32-
const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk, m_v});
33+
const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk_f16, m_v});
3334

3435
const auto callback = [=](ov::pass::pattern::Matcher& m) {
3536
auto& pattern_to_output = m.get_pattern_value_map();
@@ -41,7 +42,9 @@ FuseToSDPA::FuseToSDPA() {
4142

4243
auto v_trans =
4344
register_new_node<ov::op::v1::Transpose>(v, ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 2, 1}));
44-
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask, scale, false);
45+
auto mask_f16 = register_new_node<ov::op::v0::Convert>(mask, ov::element::f16);
46+
auto scale_f16 = register_new_node<ov::op::v0::Convert>(scale, ov::element::f16);
47+
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask_f16, scale_f16, false);
4548

4649
ov::replace_node(m.get_match_root(), sdpa);
4750
ov::copy_runtime_info(m.get_matched_nodes(), sdpa);

ggml/src/ggml-openvino/openvino/translate_session.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include <openvino/op/unsqueeze.hpp>
2323
#include <openvino/pass/constant_folding.hpp>
2424
#include <openvino/pass/make_stateful.hpp>
25-
#include <openvino/core/preprocess/pre_post_process.hpp>
2625

2726
#include "ggml-openvino/openvino/node_context.hpp"
2827
#include "ggml-openvino/openvino/utils.hpp"
@@ -269,12 +268,9 @@ std::shared_ptr<Model> TranslateSession::apply_transformations(std::shared_ptr<M
269268
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
270269
}
271270

272-
// SDPA is even worse on performance
273271
manager.register_pass<pass::FuseToSDPA>();
274272
manager.run_passes(model);
275273
}
276-
auto preprocessor = ov::preprocess::PrePostProcessor(model);
277-
model = preprocessor.build();
278274
return model;
279275
}
280276

0 commit comments

Comments
 (0)