11#include < memory>
2+ #include < openvino/op/broadcast.hpp>
3+ #include < openvino/op/concat.hpp>
24#include < openvino/op/convert.hpp>
5+ #include < openvino/op/reshape.hpp>
36#include < openvino/op/scaled_dot_product_attention.hpp>
7+ #include < openvino/op/transpose.hpp>
8+ #include < openvino/op/unsqueeze.hpp>
9+
410#include " ../node_context.hpp"
511#include " ../op_table.hpp"
612#include " ../utils.hpp"
@@ -24,9 +30,53 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
2430
2531 auto q = std::make_shared<ov::op::v0::Convert>(q_f32, ov::element::f16 );
2632 auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16 , ov::Shape{}, std::vector<float >{scale});
27- auto res = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v , mask, scale_node, false );
28- auto res_f32 = std::make_shared<ov::op::v0::Convert>(res, ov::element::f32 );
29- return rename_outputs_with_suffix ({res_f32}, context.get_name ());
33+
34+ ov::Output<ov::Node> mask_sliced;
35+ if (context.has_input (" KQ_mask_sliced" )) {
36+ mask_sliced = context.get_input (" KQ_mask_sliced" );
37+ } else {
38+ auto token_len = get_dimensions (q, {1 });
39+ auto zero = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {0 });
40+ auto one = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {1 });
41+ mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
42+ }
43+
44+ if (mask_sliced.get_element_type () != ov::element::f16 ) {
45+ mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16 );
46+ }
47+
48+ auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv) {
49+ int64_t factor = q_batch / kv_batch;
50+ if (factor > 1 ) {
51+ auto q_batch_node = ov::op::v0::Constant::create (ov::element::i64 , {1 }, std::vector<int64_t >{q_batch});
52+ auto kv_batch_node = ov::op::v0::Constant::create (ov::element::i64 , {1 }, std::vector<int64_t >{kv_batch});
53+ auto factor_node = ov::op::v0::Constant::create (ov::element::i64 , {1 }, std::vector<int64_t >{factor});
54+
55+ auto unsqueeze_axes = ov::op::v0::Constant::create (ov::element::i64 , Shape{}, {1 });
56+ auto kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
57+
58+ auto kv_last_two_dims = get_dimensions (kv.get_node_shared_ptr (), {1 , 2 });
59+ auto kv_broadcast_shape =
60+ std::make_shared<ov::op::v0::Concat>(ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0 );
61+ kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape);
62+
63+ auto new_kv_shape =
64+ std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0 );
65+ kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, false );
66+ }
67+ return kv;
68+ };
69+
70+ auto q_shape = context.get_input_shape (0 ).to_shape ();
71+ auto k_shape = context.get_input_shape (1 ).to_shape ();
72+ k = tile_kv (q_shape[0 ], k_shape[0 ], k);
73+ v = tile_kv (q_shape[0 ], k_shape[0 ], v);
74+
75+ auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false );
76+ auto sdpa_f32 = std::make_shared<ov::op::v0::Convert>(sdpa, ov::element::f32 );
77+ auto res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32,
78+ ov::op::v0::Constant::create (ov::element::i64 , {3 }, {1 , 0 , 2 }));
79+ return rename_outputs_with_suffix ({res}, context.get_name ());
3080}
3181
3282} // namespace op
0 commit comments