1010#include < openvino/op/softmax.hpp>
1111#include < openvino/op/transpose.hpp>
1212#include < openvino/pass/pattern/op/label.hpp>
13- #include < openvino/pass/pattern/op/optional.hpp>
1413#include < openvino/pass/pattern/op/pattern.hpp>
1514#include < openvino/pass/pattern/op/wrap_type.hpp>
1615
@@ -23,13 +22,15 @@ FuseToSDPA::FuseToSDPA() {
2322 const auto m_k = ov::pass::pattern::any_input ();
2423 const auto m_q = ov::pass::pattern::any_input ();
2524 const auto m_qk = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_q, m_k});
25+ const auto m_qk_f32 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_qk});
2626 const auto m_scale = ov::pass::pattern::any_input ();
27- const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk , m_scale});
27+ const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk_f32 , m_scale});
2828 const auto m_mask = ov::pass::pattern::any_input ();
2929 const auto m_masked_qk = ov::pass::pattern::wrap_type<ov::op::v1::Add>({m_scaled_qk, m_mask});
3030 const auto m_softmax_qk = ov::pass::pattern::wrap_type<ov::op::v8::Softmax>({m_masked_qk});
31+ const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_softmax_qk});
3132 const auto m_v = ov::pass::pattern::any_input ();
32- const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk , m_v});
33+ const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk_f16 , m_v});
3334
3435 const auto callback = [=](ov::pass::pattern::Matcher& m) {
3536 auto & pattern_to_output = m.get_pattern_value_map ();
@@ -41,7 +42,9 @@ FuseToSDPA::FuseToSDPA() {
4142
4243 auto v_trans =
4344 register_new_node<ov::op::v1::Transpose>(v, ov::op::v0::Constant::create (ov::element::i64 , {3 }, {0 , 2 , 1 }));
44- auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask, scale, false );
45+ auto mask_f16 = register_new_node<ov::op::v0::Convert>(mask, ov::element::f16 );
46+ auto scale_f16 = register_new_node<ov::op::v0::Convert>(scale, ov::element::f16 );
47+ auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask_f16, scale_f16, false );
4548
4649 ov::replace_node (m.get_match_root (), sdpa);
4750 ov::copy_runtime_info (m.get_matched_nodes (), sdpa);
0 commit comments