|
2 | 2 | #include <openvino/op/broadcast.hpp> |
3 | 3 | #include <openvino/op/concat.hpp> |
4 | 4 | #include <openvino/op/convert.hpp> |
| 5 | +#include <openvino/op/gather.hpp> |
5 | 6 | #include <openvino/op/reshape.hpp> |
6 | 7 | #include <openvino/op/scaled_dot_product_attention.hpp> |
7 | 8 | #include <openvino/op/transpose.hpp> |
@@ -32,55 +33,82 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) { |
32 | 33 | auto q = std::make_shared<ov::op::v0::Convert>(q_f32, ov::element::f16); |
33 | 34 | auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{}, std::vector<float>{scale}); |
34 | 35 |
|
35 | | - ov::Output<ov::Node> mask_sliced; |
| 36 | + ov::Output<ov::Node> mask_sliced, res; |
36 | 37 | std::string mask_name = "KQ_mask_sliced"; |
37 | 38 | if (context.get_input_names()[3].find("swa") != std::string::npos) { |
38 | 39 | mask_name = "KQ_mask_swa_sliced"; |
39 | 40 | } |
40 | 41 | if (context.has_input(mask_name)) { |
41 | 42 | mask_sliced = context.get_input(mask_name); |
42 | 43 | } else { |
43 | | - auto token_len = get_dimensions(q, {1}); |
44 | | - auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); |
45 | | - auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); |
46 | | - mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one); |
| 44 | + auto token_len = get_dimensions(q, {2}); |
| 45 | + auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0}); |
| 46 | + auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1}); |
| 47 | + auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); |
| 48 | + auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); |
| 49 | + auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,2}); |
| 50 | + auto leaf_8 = context.get_input("leaf_8"); |
| 51 | + auto shape_of_leaf_8 = std::make_shared<ov::op::v3::ShapeOf>(leaf_8); |
| 52 | + auto gather_leaf_8 = std::make_shared<ov::op::v8::Gather>(shape_of_leaf_8, two_1d, zero_1d); |
| 53 | + auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, gather_leaf_8}, 0); |
| 54 | + mask_sliced = |
| 55 | + std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes); |
| 56 | + mask_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_sliced, zero_1d); |
47 | 57 | } |
48 | 58 |
|
49 | 59 | if (mask_sliced.get_element_type() != ov::element::f16) { |
50 | 60 | mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16); |
51 | 61 | } |
52 | 62 |
|
53 | | - auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv) { |
| 63 | + auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv, bool is_static) { |
54 | 64 | int64_t factor = q_batch / kv_batch; |
55 | 65 | if (factor > 1) { |
56 | 66 | auto q_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{q_batch}); |
57 | 67 | auto kv_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{kv_batch}); |
58 | 68 | auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor}); |
59 | 69 |
|
60 | | - auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1}); |
61 | | - auto kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes); |
| 70 | + ov::Output<ov::Node> kv_broadcast_shape, kv_unsqueezed, new_kv_shape; |
| 71 | + if (is_static) { |
| 72 | + auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1}); |
| 73 | + kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes); |
62 | 74 |
|
63 | | - auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2}); |
64 | | - auto kv_broadcast_shape = |
65 | | - std::make_shared<ov::op::v0::Concat>(ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0); |
66 | | - kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape); |
| 75 | + auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2}); |
| 76 | + kv_broadcast_shape = |
| 77 | + std::make_shared<ov::op::v0::Concat>(ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0); |
| 78 | + new_kv_shape = |
| 79 | + std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0); |
| 80 | + } else { |
| 81 | + auto one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); |
| 82 | + auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2}); |
| 83 | + kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes); |
| 84 | + |
| 85 | + auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {2, 3}); |
| 86 | + kv_broadcast_shape = |
| 87 | + std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one_1d, kv_batch_node, factor_node, kv_last_two_dims}, 0); |
| 88 | + new_kv_shape = |
| 89 | + std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one_1d, q_batch_node, kv_last_two_dims}, 0); |
| 90 | + } |
67 | 91 |
|
68 | | - auto new_kv_shape = |
69 | | - std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0); |
| 92 | + kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape); |
70 | 93 | kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, false); |
71 | 94 | } |
72 | 95 | return kv; |
73 | 96 | }; |
74 | 97 |
|
75 | 98 | auto q_shape = context.get_input_shape(0).to_shape(); |
76 | 99 | auto k_shape = context.get_input_shape(1).to_shape(); |
77 | | - k = tile_kv(q_shape[0], k_shape[0], k); |
78 | | - v = tile_kv(q_shape[0], k_shape[0], v); |
| 100 | + k = tile_kv(q_shape[0], k_shape[0], k, context.is_static()); |
| 101 | + v = tile_kv(q_shape[0], k_shape[0], v, context.is_static()); |
79 | 102 |
|
80 | 103 | auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false); |
81 | 104 | auto sdpa_f32 = std::make_shared<ov::op::v0::Convert>(sdpa, ov::element::f32); |
82 | | - auto res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32, |
83 | | - ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2})); |
| 105 | + if (context.is_static()) { |
| 106 | + res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32, |
| 107 | + ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2})); |
| 108 | + } else { |
| 109 | + res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32, |
| 110 | + ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3})); |
| 111 | + } |
84 | 112 | return rename_outputs_with_suffix({res}, context.get_name()); |
85 | 113 | } |
86 | 114 |
|
|
0 commit comments