|
2 | 2 | #include "../op_table.hpp" |
3 | 3 | #include "../utils.hpp" |
4 | 4 |
|
| 5 | +#include <cstdint> |
5 | 6 | #include <memory> |
6 | 7 | #include <openvino/op/broadcast.hpp> |
7 | 8 | #include <openvino/op/concat.hpp> |
| 9 | +#include <openvino/op/constant.hpp> |
8 | 10 | #include <openvino/op/convert.hpp> |
9 | 11 | #include <openvino/op/reshape.hpp> |
10 | 12 | #include <openvino/op/scaled_dot_product_attention.hpp> |
@@ -58,45 +60,40 @@ OutputVector translate_flash_attn_ext(const NodeContext & context) { |
58 | 60 | mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16); |
59 | 61 | } |
60 | 62 |
|
61 | | - auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv, bool is_static) { |
62 | | - int64_t factor = q_batch / kv_batch; |
| 63 | + auto tile_kv = [&](int64_t num_heads, int64_t num_heads_kv, int64_t head_size, ov::Output<Node> kv, |
| 64 | + bool is_static) { |
| 65 | + int64_t factor = num_heads / num_heads_kv; |
63 | 66 | if (factor > 1) { |
64 | | - auto q_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{q_batch}); |
65 | | - auto kv_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{kv_batch}); |
66 | | - auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor}); |
67 | | - |
68 | 67 | ov::Output<ov::Node> kv_broadcast_shape, kv_unsqueezed, new_kv_shape; |
69 | 68 | if (is_static) { |
70 | 69 | auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1}); |
71 | 70 | kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes); |
72 | 71 |
|
73 | | - auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2}); |
74 | | - kv_broadcast_shape = std::make_shared<ov::op::v0::Concat>( |
75 | | - ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0); |
| 72 | + kv_broadcast_shape = |
| 73 | + ov::op::v0::Constant::create(ov::element::i64, {4}, {num_heads_kv, factor, (int64_t) 1, head_size}); |
76 | 74 | new_kv_shape = |
77 | | - std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0); |
| 75 | + ov::op::v0::Constant::create(ov::element::i64, {3}, {num_heads, (int64_t) -1, head_size}); |
78 | 76 | } else { |
79 | | - auto one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); |
80 | 77 | auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2}); |
81 | 78 | kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes); |
82 | 79 |
|
83 | | - auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {2, 3}); |
84 | | - kv_broadcast_shape = std::make_shared<ov::op::v0::Concat>( |
85 | | - ov::OutputVector{one_1d, kv_batch_node, factor_node, kv_last_two_dims}, 0); |
86 | | - new_kv_shape = |
87 | | - std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one_1d, q_batch_node, kv_last_two_dims}, 0); |
| 80 | + kv_broadcast_shape = ov::op::v0::Constant::create( |
| 81 | + ov::element::i64, {5}, {(int64_t) 1, num_heads_kv, factor, (int64_t) 1, head_size}); |
| 82 | + new_kv_shape = ov::op::v0::Constant::create(ov::element::i64, {4}, |
| 83 | + {(int64_t) 1, num_heads, (int64_t) -1, head_size}); |
88 | 84 | } |
89 | 85 |
|
90 | | - kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape); |
| 86 | + kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape, |
| 87 | + ov::op::BroadcastType::BIDIRECTIONAL); |
91 | 88 | kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, false); |
92 | 89 | } |
93 | 90 | return kv; |
94 | 91 | }; |
95 | 92 |
|
96 | 93 | auto q_shape = context.get_input_shape(0).to_shape(); |
97 | 94 | auto k_shape = context.get_input_shape(1).to_shape(); |
98 | | - k = tile_kv(q_shape[0], k_shape[0], k, context.is_static()); |
99 | | - v = tile_kv(q_shape[0], k_shape[0], v, context.is_static()); |
| 95 | + k = tile_kv(q_shape[0], k_shape[0], q_shape[2], k, context.is_static()); |
| 96 | + v = tile_kv(q_shape[0], k_shape[0], q_shape[2], v, context.is_static()); |
100 | 97 |
|
101 | 98 | auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false); |
102 | 99 | auto sdpa_f32 = std::make_shared<ov::op::v0::Convert>(sdpa, ov::element::f32); |
|
0 commit comments