Skip to content

Commit 8b9b8ef

Browse files
committed
Simplify broadcast op in attention
1 parent 77d2195 commit 8b9b8ef

File tree

1 file changed

+16
-19
lines changed

1 file changed

+16
-19
lines changed

ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
#include "../op_table.hpp"
33
#include "../utils.hpp"
44

5+
#include <cstdint>
56
#include <memory>
67
#include <openvino/op/broadcast.hpp>
78
#include <openvino/op/concat.hpp>
9+
#include <openvino/op/constant.hpp>
810
#include <openvino/op/convert.hpp>
911
#include <openvino/op/reshape.hpp>
1012
#include <openvino/op/scaled_dot_product_attention.hpp>
@@ -58,45 +60,40 @@ OutputVector translate_flash_attn_ext(const NodeContext & context) {
5860
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
5961
}
6062

61-
auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv, bool is_static) {
62-
int64_t factor = q_batch / kv_batch;
63+
auto tile_kv = [&](int64_t num_heads, int64_t num_heads_kv, int64_t head_size, ov::Output<Node> kv,
64+
bool is_static) {
65+
int64_t factor = num_heads / num_heads_kv;
6366
if (factor > 1) {
64-
auto q_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{q_batch});
65-
auto kv_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{kv_batch});
66-
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});
67-
6867
ov::Output<ov::Node> kv_broadcast_shape, kv_unsqueezed, new_kv_shape;
6968
if (is_static) {
7069
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
7170
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
7271

73-
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2});
74-
kv_broadcast_shape = std::make_shared<ov::op::v0::Concat>(
75-
ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0);
72+
kv_broadcast_shape =
73+
ov::op::v0::Constant::create(ov::element::i64, {4}, {num_heads_kv, factor, (int64_t) 1, head_size});
7674
new_kv_shape =
77-
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0);
75+
ov::op::v0::Constant::create(ov::element::i64, {3}, {num_heads, (int64_t) -1, head_size});
7876
} else {
79-
auto one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
8077
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});
8178
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
8279

83-
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {2, 3});
84-
kv_broadcast_shape = std::make_shared<ov::op::v0::Concat>(
85-
ov::OutputVector{one_1d, kv_batch_node, factor_node, kv_last_two_dims}, 0);
86-
new_kv_shape =
87-
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one_1d, q_batch_node, kv_last_two_dims}, 0);
80+
kv_broadcast_shape = ov::op::v0::Constant::create(
81+
ov::element::i64, {5}, {(int64_t) 1, num_heads_kv, factor, (int64_t) 1, head_size});
82+
new_kv_shape = ov::op::v0::Constant::create(ov::element::i64, {4},
83+
{(int64_t) 1, num_heads, (int64_t) -1, head_size});
8884
}
8985

90-
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape);
86+
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape,
87+
ov::op::BroadcastType::BIDIRECTIONAL);
9188
kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, false);
9289
}
9390
return kv;
9491
};
9592

9693
auto q_shape = context.get_input_shape(0).to_shape();
9794
auto k_shape = context.get_input_shape(1).to_shape();
98-
k = tile_kv(q_shape[0], k_shape[0], k, context.is_static());
99-
v = tile_kv(q_shape[0], k_shape[0], v, context.is_static());
95+
k = tile_kv(q_shape[0], k_shape[0], q_shape[2], k, context.is_static());
96+
v = tile_kv(q_shape[0], k_shape[0], q_shape[2], v, context.is_static());
10097

10198
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false);
10299
auto sdpa_f32 = std::make_shared<ov::op::v0::Convert>(sdpa, ov::element::f32);

0 commit comments

Comments
 (0)