Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions ggml/src/ggml-openvino/ggml-decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph,
set_input_output(cur_node);
}

add_extra_inputs();
// add_extra_inputs();
}

GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph,
Expand Down Expand Up @@ -316,9 +316,13 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
input_shape = ov::PartialShape{1, -1, -1};
}
} else if (name.find("cache_") == 0) {
int layer = extract_layer_from_name(name);
bool is_swa = is_swa_layer(layer);
input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size};
if (m_is_static) {
int layer = extract_layer_from_name(name);
bool is_swa = is_swa_layer(layer);
input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size};
} else {
input_shape = ov::PartialShape{1, -1, m_num_heads_kv, m_head_size};
}
} else if (const auto* op = get_tensor_used_op(src); op && op->op == GGML_OP_SET_ROWS) {
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
} else if (src->op == GGML_OP_VIEW) {
Expand All @@ -332,9 +336,10 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co

void GgmlOvDecoder::add_extra_inputs() {
// Extra inputs:
// 1. `attention_size`, used in matmul's in the attention block. The shape of those matmul's are 32 aligned,
// 1. `attention_size`, used in FLASH_ATTN where the shape of the matmul's are 256 aligned,
// see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.
// Not used for NPU
// Not used for NPU.
// Update: not used anymore after the optimization of making kvcache dynamic (but breaks iSWA models)
int64_t attention_size = -1;
int64_t attention_size_swa = -1;
for (const auto& node : m_nodes) {
Expand Down
63 changes: 45 additions & 18 deletions ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,55 +32,82 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
auto q = std::make_shared<ov::op::v0::Convert>(q_f32, ov::element::f16);
auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{}, std::vector<float>{scale});

ov::Output<ov::Node> mask_sliced;
ov::Output<ov::Node> mask_sliced, res;
std::string mask_name = "KQ_mask_sliced";
if (context.get_input_names()[3].find("swa") != std::string::npos) {
mask_name = "KQ_mask_swa_sliced";
}
if (context.has_input(mask_name)) {
mask_sliced = context.get_input(mask_name);
} else {
auto token_len = get_dimensions(q, {1});
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
auto token_len = get_dimensions(q, {2});
auto kv_len = get_dimensions(k.get_node_shared_ptr(), {2});

auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0});
auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1});
auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,2});

auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, kv_len}, 0);
mask_sliced =
std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
mask_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_sliced, zero_1d);
}

if (mask_sliced.get_element_type() != ov::element::f16) {
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
}

auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv) {
auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv, bool is_static) {
int64_t factor = q_batch / kv_batch;
if (factor > 1) {
auto q_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{q_batch});
auto kv_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{kv_batch});
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});

auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
auto kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
ov::Output<ov::Node> kv_broadcast_shape, kv_unsqueezed, new_kv_shape;
if (is_static) {
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);

auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2});
auto kv_broadcast_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0);
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape);
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2});
kv_broadcast_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0);
new_kv_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0);
} else {
auto one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);

auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {2, 3});
kv_broadcast_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one_1d, kv_batch_node, factor_node, kv_last_two_dims}, 0);
new_kv_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one_1d, q_batch_node, kv_last_two_dims}, 0);
}

auto new_kv_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0);
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape);
kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, false);
}
return kv;
};

auto q_shape = context.get_input_shape(0).to_shape();
auto k_shape = context.get_input_shape(1).to_shape();
k = tile_kv(q_shape[0], k_shape[0], k);
v = tile_kv(q_shape[0], k_shape[0], v);
k = tile_kv(q_shape[0], k_shape[0], k, context.is_static());
v = tile_kv(q_shape[0], k_shape[0], v, context.is_static());

auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false);
auto sdpa_f32 = std::make_shared<ov::op::v0::Convert>(sdpa, ov::element::f32);
auto res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32,
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
if (context.is_static()) {
res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32,
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
} else {
res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32,
ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
}
return rename_outputs_with_suffix({res}, context.get_name());
}

Expand Down
50 changes: 28 additions & 22 deletions ggml/src/ggml-openvino/openvino/op/permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <openvino/op/reshape.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/op/unsqueeze.hpp>

#include "../node_context.hpp"
#include "../op_table.hpp"
Expand All @@ -23,35 +24,40 @@ OutputVector translate_permute(const NodeContext& context) {
int op_case = context.get_op_case();
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported PERMUTE case");
ov::Output<Node> res;
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});

if (op_case == 1) {
res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
} else {
auto src = context.get_input(0);
Output<Node> attention_size;
if (context.is_static()) {
attention_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {INT_MAX});
} else if (op_case == 2) {
attention_size = context.get_input("attention_size");
res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
} else {
attention_size = context.get_input("attention_size_swa");
auto src = context.get_input(0);
if (src.get_partial_shape().rank() == 3) {
src = std::make_shared<ov::op::v0::Unsqueeze>(src, zero);
}
res = std::make_shared<ov::op::v1::Transpose>(src,
ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
}

auto src_shape_ = context.get_input_shape(0).to_shape();
std::vector<int64_t> src_shape(src_shape_.begin(), src_shape_.end());

auto src_reshaped = std::make_shared<ov::op::v1::Reshape>(
src,
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{-1, src_shape[1], src_shape[2]}),
false);

auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
} else {
auto src = context.get_input(0);
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto src_slice = std::make_shared<ov::op::v8::Slice>(src_reshaped, zero, attention_size, one, zero);

res = std::make_shared<ov::op::v1::Transpose>(src_slice,
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
if (context.is_static()) {
auto src_shape_ = context.get_input_shape(0).to_shape();
std::vector<int64_t> src_shape(src_shape_.begin(), src_shape_.end());
auto src_reshaped = std::make_shared<ov::op::v1::Reshape>(
src,
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{-1, src_shape[1], src_shape[2]}),
false);
res = std::make_shared<ov::op::v1::Transpose>(
src_reshaped, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
} else {
if (src.get_partial_shape().rank() == 3) {
src = std::make_shared<ov::op::v0::Unsqueeze>(src, zero);
}
res = std::make_shared<ov::op::v1::Transpose>(src,
ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
}
}
return rename_outputs_with_suffix({res}, context.get_name());
}
Expand Down
3 changes: 3 additions & 0 deletions ggml/src/ggml-openvino/openvino/op/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ OutputVector translate_rope(const NodeContext& context) {
ov::op::v0::Constant::create(ov::element::i64, {1}, {3}));
auto stack = std::make_shared<ov::op::v0::Concat>(OutputVector{first_half, second_half}, 3);
res = std::make_shared<ov::op::v1::Reshape>(stack, std::make_shared<ov::op::v0::ShapeOf>(data_node), false);
if (!(context.is_static())) {
res = std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
}
} else if (mode == ROPE_TYPE_NEOX) {
auto data_split = std::make_shared<ov::op::v1::Split>(
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2}), 2);
Expand Down
34 changes: 23 additions & 11 deletions ggml/src/ggml-openvino/openvino/op/set_rows.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include <cassert>
#include <cstdint>
#include <memory>
#include <openvino/core/node.hpp>
#include <openvino/core/node_output.hpp>
#include <openvino/frontend/exception.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/gather.hpp>
Expand Down Expand Up @@ -39,17 +41,27 @@ OutputVector translate_set_rows(const NodeContext& context) {
auto dst = context.get_input(context.get_output_name());

auto zero = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
auto dst_reshaped = std::make_shared<ov::op::v1::Reshape>(
dst,
ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) dst_shape[1], (int64_t) dst_shape[2]}),
false);
auto indices_reshaped =
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
auto data_reshaped = std::make_shared<ov::op::v1::Reshape>(
data, ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) -1, (int64_t) dst_shape[2]}), false);

auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
auto res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false);
Output<Node> res;
if (context.is_static()) {
auto dst_reshaped = std::make_shared<ov::op::v1::Reshape>(
dst,
ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) dst_shape[1], (int64_t) dst_shape[2]}),
false);
auto indices_reshaped =
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
auto data_reshaped = std::make_shared<ov::op::v1::Reshape>(
data, ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) -1, (int64_t) dst_shape[2]}), false);

auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false);
} else {
assert(dst.get_partial_shape().rank() == 4 && dst.get_partial_shape()[2].is_static() && dst.get_partial_shape()[3].is_static());
int64_t dim2 = dst.get_partial_shape()[2].get_length();
int64_t dim3 = dst.get_partial_shape()[3].get_length();
data = std::make_shared<ov::op::v1::Reshape>(
data, ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 1, (int64_t) -1, dim2, dim3}), false);
res = std::make_shared<ov::op::v0::Concat>(OutputVector{dst, data}, 1);
}
return rename_outputs_with_suffix({res}, context.get_name());
}

Expand Down
32 changes: 28 additions & 4 deletions ggml/src/ggml-openvino/openvino/translate_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
#include <openvino/op/reshape.hpp>
#include <openvino/op/result.hpp>
#include <openvino/op/sin.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/squeeze.hpp>
#include <openvino/op/strided_slice.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/op/unsqueeze.hpp>
#include <openvino/pass/constant_folding.hpp>
Expand Down Expand Up @@ -87,9 +89,30 @@ void add_sliced_mask(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
if (is_static) {
mask_sliced = mask;
} else {
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 0});
auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 1});
auto one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 2});

std::shared_ptr<ov::Node> kv_len;
{
auto start = ov::op::v0::Constant::create(element::i64, Shape{3}, {0, 0, -1});
auto stride = ov::op::v0::Constant::create(element::i64, Shape{3}, {1, 1, 1});
auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr();
kv_len = std::make_shared<ov::op::v1::StridedSlice>(
inp_pos, start, start, stride, std::vector<int64_t>{0, 0, 0}, std::vector<int64_t>{1, 1, 1});
}
kv_len = std::make_shared<ov::op::v0::Squeeze>(
kv_len, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
kv_len = std::make_shared<ov::op::v0::Convert>(kv_len, ov::element::i64);
kv_len = std::make_shared<ov::op::v1::Add>(kv_len, one_1d);
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, kv_len}, 0);

mask_sliced =
std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
mask_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_sliced, zero_1d);
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
mask_sliced->set_friendly_name(sliced_name);
}
Expand All @@ -98,7 +121,8 @@ void add_sliced_mask(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
};

create_sliced_mask("KQ_mask", "KQ_mask_sliced", ggml_model_decoder.is_static());
create_sliced_mask("KQ_mask_swa", "KQ_mask_swa_sliced", ggml_model_decoder.is_static());
// swa is not working for the `kv_len` is not correct
// create_sliced_mask("KQ_mask_swa", "KQ_mask_swa_sliced", ggml_model_decoder.is_static());
}

void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
Expand Down
12 changes: 7 additions & 5 deletions ggml/src/ggml-openvino/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,6 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c

bool is_static = device == "NPU" ? true : false;
ov::AnyMap config;
if (device == "GPU") {
config = {
{"GPU_ENABLE_SDPA_OPTIMIZATION", "0"}
};
}

if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) {
std::string filename = "cgraph.txt";
Expand Down Expand Up @@ -186,6 +181,13 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
ov::serialize(model, timestamped_filename);
}

auto* disable_sdpa_optimization = getenv("GGML_OPENVINO_DISABLE_SDPA_OPTIMIZATION");
if (disable_sdpa_optimization && std::string(disable_sdpa_optimization) != "0") {
config = {
{"GPU_ENABLE_SDPA_OPTIMIZATION", "0"}
};
}

auto compiled_model = core.compile_model(model, device, config);
compile_end_time = ggml_time_us();
infer_request_cache[cgraph] = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());
Expand Down
Loading