1010#include < openvino/op/softmax.hpp>
1111#include < openvino/op/transpose.hpp>
1212#include < openvino/pass/pattern/op/label.hpp>
13+ #include < openvino/pass/pattern/op/optional.hpp>
1314#include < openvino/pass/pattern/op/pattern.hpp>
1415#include < openvino/pass/pattern/op/wrap_type.hpp>
1516
@@ -22,15 +23,13 @@ FuseToSDPA::FuseToSDPA() {
2223 const auto m_k = ov::pass::pattern::any_input ();
2324 const auto m_q = ov::pass::pattern::any_input ();
2425 const auto m_qk = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_q, m_k});
25- const auto m_qk_f32 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_qk});
2626 const auto m_scale = ov::pass::pattern::any_input ();
27- const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk_f32 , m_scale});
27+ const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk , m_scale});
2828 const auto m_mask = ov::pass::pattern::any_input ();
2929 const auto m_masked_qk = ov::pass::pattern::wrap_type<ov::op::v1::Add>({m_scaled_qk, m_mask});
3030 const auto m_softmax_qk = ov::pass::pattern::wrap_type<ov::op::v8::Softmax>({m_masked_qk});
31- const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_softmax_qk});
3231 const auto m_v = ov::pass::pattern::any_input ();
33- const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk_f16 , m_v});
32+ const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk , m_v});
3433
3534 const auto callback = [=](ov::pass::pattern::Matcher& m) {
3635 auto & pattern_to_output = m.get_pattern_value_map ();
@@ -42,9 +41,7 @@ FuseToSDPA::FuseToSDPA() {
4241
4342 auto v_trans =
4443 register_new_node<ov::op::v1::Transpose>(v, ov::op::v0::Constant::create (ov::element::i64 , {3 }, {0 , 2 , 1 }));
45- auto mask_f16 = register_new_node<ov::op::v0::Convert>(mask, ov::element::f16 );
46- auto scale_f16 = register_new_node<ov::op::v0::Convert>(scale, ov::element::f16 );
47- auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask_f16, scale_f16, false );
44+ auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask, scale, false );
4845
4946 ov::replace_node (m.get_match_root (), sdpa);
5047 ov::copy_runtime_info (m.get_matched_nodes (), sdpa);
0 commit comments