diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 751fa192a4261..7c6bfe7ee74d7 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -73,7 +73,7 @@ GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph, set_input_output(cur_node); } - add_extra_inputs(); + // add_extra_inputs(); } GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph, @@ -316,9 +316,13 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co input_shape = ov::PartialShape{1, -1, -1}; } } else if (name.find("cache_") == 0) { - int layer = extract_layer_from_name(name); - bool is_swa = is_swa_layer(layer); - input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size}; + if (m_is_static) { + int layer = extract_layer_from_name(name); + bool is_swa = is_swa_layer(layer); + input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size}; + } else { + input_shape = ov::PartialShape{1, -1, m_num_heads_kv, m_head_size}; + } } else if (const auto* op = get_tensor_used_op(src); op && op->op == GGML_OP_SET_ROWS) { input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1}; } else if (src->op == GGML_OP_VIEW) { @@ -332,9 +336,10 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co void GgmlOvDecoder::add_extra_inputs() { // Extra inputs: - // 1. `attention_size`, used in matmul's in the attention block. The shape of those matmul's are 32 aligned, + // 1. `attention_size`, used in FLASH_ATTN where the shape of the matmul's are 256 aligned, // see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding. - // Not used for NPU + // Not used for NPU. + // Update: not used anymore after the optimization of making kvcache dynamic (but breaks iSWA models) int64_t attention_size = -1; int64_t attention_size_swa = -1; for (const auto& node : m_nodes) { diff --git a/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp b/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp index 8b67778fb9373..9845fe0a02aa5 100644 --- a/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +++ b/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp @@ -32,7 +32,7 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) { auto q = std::make_shared(q_f32, ov::element::f16); auto scale_node = std::make_shared(ov::element::f16, ov::Shape{}, std::vector{scale}); - ov::Output mask_sliced; + ov::Output mask_sliced, res; std::string mask_name = "KQ_mask_sliced"; if (context.get_input_names()[3].find("swa") != std::string::npos) { mask_name = "KQ_mask_swa_sliced"; @@ -40,33 +40,55 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) { if (context.has_input(mask_name)) { mask_sliced = context.get_input(mask_name); } else { - auto token_len = get_dimensions(q, {1}); - auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); - auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); - mask_sliced = std::make_shared(mask, zero, token_len, one, one); + auto token_len = get_dimensions(q, {2}); + auto kv_len = get_dimensions(k.get_node_shared_ptr(), {2}); + + auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0}); + auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1}); + auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); + auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,2}); + + auto stop = std::make_shared(ov::OutputVector{token_len, kv_len}, 0); + mask_sliced = + std::make_shared(mask, zero_2d, stop, one_2d, axes); + mask_sliced = std::make_shared(mask_sliced, zero_1d); } if (mask_sliced.get_element_type() != ov::element::f16) { mask_sliced = std::make_shared(mask_sliced, ov::element::f16); } - auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output kv) { + auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output kv, bool is_static) { int64_t factor = q_batch / kv_batch; if (factor > 1) { auto q_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{q_batch}); auto kv_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{kv_batch}); auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{factor}); - auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1}); - auto kv_unsqueezed = std::make_shared(kv, unsqueeze_axes); + ov::Output kv_broadcast_shape, kv_unsqueezed, new_kv_shape; + if (is_static) { + auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1}); + kv_unsqueezed = std::make_shared(kv, unsqueeze_axes); - auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2}); - auto kv_broadcast_shape = - std::make_shared(ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0); - kv = std::make_shared(kv_unsqueezed, kv_broadcast_shape); + auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2}); + kv_broadcast_shape = + std::make_shared(ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0); + new_kv_shape = + std::make_shared(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0); + } else { + auto one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2}); + kv_unsqueezed = std::make_shared(kv, unsqueeze_axes); + + auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {2, 3}); + kv_broadcast_shape = + std::make_shared(ov::OutputVector{one_1d, kv_batch_node, factor_node, kv_last_two_dims}, 0); + new_kv_shape = + std::make_shared(ov::OutputVector{one_1d, q_batch_node, kv_last_two_dims}, 0); + } - auto new_kv_shape = - std::make_shared(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0); + kv = std::make_shared(kv_unsqueezed, kv_broadcast_shape); kv = std::make_shared(kv, new_kv_shape, false); } return kv; @@ -74,13 +96,18 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) { auto q_shape = context.get_input_shape(0).to_shape(); auto k_shape = context.get_input_shape(1).to_shape(); - k = tile_kv(q_shape[0], k_shape[0], k); - v = tile_kv(q_shape[0], k_shape[0], v); + k = tile_kv(q_shape[0], k_shape[0], k, context.is_static()); + v = tile_kv(q_shape[0], k_shape[0], v, context.is_static()); auto sdpa = std::make_shared(q, k, v, mask_sliced, scale_node, false); auto sdpa_f32 = std::make_shared(sdpa, ov::element::f32); - auto res = std::make_shared(sdpa_f32, - ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2})); + if (context.is_static()) { + res = std::make_shared(sdpa_f32, + ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2})); + } else { + res = std::make_shared(sdpa_f32, + ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3})); + } return rename_outputs_with_suffix({res}, context.get_name()); } diff --git a/ggml/src/ggml-openvino/openvino/op/permute.cpp b/ggml/src/ggml-openvino/openvino/op/permute.cpp index 086b1e4cdb172..5f86f47c1cca3 100644 --- a/ggml/src/ggml-openvino/openvino/op/permute.cpp +++ b/ggml/src/ggml-openvino/openvino/op/permute.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include "../node_context.hpp" #include "../op_table.hpp" @@ -23,35 +24,40 @@ OutputVector translate_permute(const NodeContext& context) { int op_case = context.get_op_case(); FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported PERMUTE case"); ov::Output res; + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); if (op_case == 1) { - res = std::make_shared(context.get_input(0), - ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2})); - } else { - auto src = context.get_input(0); - Output attention_size; if (context.is_static()) { - attention_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {INT_MAX}); - } else if (op_case == 2) { - attention_size = context.get_input("attention_size"); + res = std::make_shared(context.get_input(0), + ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2})); } else { - attention_size = context.get_input("attention_size_swa"); + auto src = context.get_input(0); + if (src.get_partial_shape().rank() == 3) { + src = std::make_shared(src, zero); + } + res = std::make_shared(src, + ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3})); } - - auto src_shape_ = context.get_input_shape(0).to_shape(); - std::vector src_shape(src_shape_.begin(), src_shape_.end()); - - auto src_reshaped = std::make_shared( - src, - ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector{-1, src_shape[1], src_shape[2]}), - false); - - auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + } else { + auto src = context.get_input(0); auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); - auto src_slice = std::make_shared(src_reshaped, zero, attention_size, one, zero); - res = std::make_shared(src_slice, - ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2})); + if (context.is_static()) { + auto src_shape_ = context.get_input_shape(0).to_shape(); + std::vector src_shape(src_shape_.begin(), src_shape_.end()); + auto src_reshaped = std::make_shared( + src, + ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector{-1, src_shape[1], src_shape[2]}), + false); + res = std::make_shared( + src_reshaped, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2})); + } else { + if (src.get_partial_shape().rank() == 3) { + src = std::make_shared(src, zero); + } + res = std::make_shared(src, + ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3})); + } } return rename_outputs_with_suffix({res}, context.get_name()); } diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp index 4b1e3b500cf3e..484730d2897f1 100644 --- a/ggml/src/ggml-openvino/openvino/op/rope.cpp +++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp @@ -84,6 +84,9 @@ OutputVector translate_rope(const NodeContext& context) { ov::op::v0::Constant::create(ov::element::i64, {1}, {3})); auto stack = std::make_shared(OutputVector{first_half, second_half}, 3); res = std::make_shared(stack, std::make_shared(data_node), false); + if (!(context.is_static())) { + res = std::make_shared(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); + } } else if (mode == ROPE_TYPE_NEOX) { auto data_split = std::make_shared( data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2}), 2); diff --git a/ggml/src/ggml-openvino/openvino/op/set_rows.cpp b/ggml/src/ggml-openvino/openvino/op/set_rows.cpp index 50817c8323bef..0b2f29441aec7 100644 --- a/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +++ b/ggml/src/ggml-openvino/openvino/op/set_rows.cpp @@ -1,8 +1,10 @@ +#include #include #include #include #include #include +#include #include #include #include @@ -39,17 +41,27 @@ OutputVector translate_set_rows(const NodeContext& context) { auto dst = context.get_input(context.get_output_name()); auto zero = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); - auto dst_reshaped = std::make_shared( - dst, - ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) dst_shape[1], (int64_t) dst_shape[2]}), - false); - auto indices_reshaped = - std::make_shared(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1})); - auto data_reshaped = std::make_shared( - data, ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) -1, (int64_t) dst_shape[2]}), false); - - auto updated = std::make_shared(dst_reshaped, indices_reshaped, data_reshaped, zero); - auto res = std::make_shared(updated, std::make_shared(dst), false); + Output res; + if (context.is_static()) { + auto dst_reshaped = std::make_shared( + dst, + ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) dst_shape[1], (int64_t) dst_shape[2]}), + false); + auto indices_reshaped = + std::make_shared(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1})); + auto data_reshaped = std::make_shared( + data, ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) -1, (int64_t) dst_shape[2]}), false); + + auto updated = std::make_shared(dst_reshaped, indices_reshaped, data_reshaped, zero); + res = std::make_shared(updated, std::make_shared(dst), false); + } else { + assert(dst.get_partial_shape().rank() == 4 && dst.get_partial_shape()[2].is_static() && dst.get_partial_shape()[3].is_static()); + int64_t dim2 = dst.get_partial_shape()[2].get_length(); + int64_t dim3 = dst.get_partial_shape()[3].get_length(); + data = std::make_shared( + data, ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 1, (int64_t) -1, dim2, dim3}), false); + res = std::make_shared(OutputVector{dst, data}, 1); + } return rename_outputs_with_suffix({res}, context.get_name()); } diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index 944381968226d..e35599084e973 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -17,7 +17,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -87,9 +89,30 @@ void add_sliced_mask(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) { if (is_static) { mask_sliced = mask; } else { - auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); - auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); - mask_sliced = std::make_shared(mask, zero, token_len, one, one); + auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 0}); + auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 1}); + auto one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); + auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 2}); + + std::shared_ptr kv_len; + { + auto start = ov::op::v0::Constant::create(element::i64, Shape{3}, {0, 0, -1}); + auto stride = ov::op::v0::Constant::create(element::i64, Shape{3}, {1, 1, 1}); + auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr(); + kv_len = std::make_shared( + inp_pos, start, start, stride, std::vector{0, 0, 0}, std::vector{1, 1, 1}); + } + kv_len = std::make_shared( + kv_len, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1})); + kv_len = std::make_shared(kv_len, ov::element::i64); + kv_len = std::make_shared(kv_len, one_1d); + auto stop = std::make_shared(ov::OutputVector{token_len, kv_len}, 0); + + mask_sliced = + std::make_shared(mask, zero_2d, stop, one_2d, axes); + mask_sliced = std::make_shared(mask_sliced, zero_1d); mask_sliced = std::make_shared(mask_sliced, ov::element::f16); mask_sliced->set_friendly_name(sliced_name); } @@ -98,7 +121,8 @@ void add_sliced_mask(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) { }; create_sliced_mask("KQ_mask", "KQ_mask_sliced", ggml_model_decoder.is_static()); - create_sliced_mask("KQ_mask_swa", "KQ_mask_swa_sliced", ggml_model_decoder.is_static()); + // swa is not working for the `kv_len` is not correct + // create_sliced_mask("KQ_mask_swa", "KQ_mask_swa_sliced", ggml_model_decoder.is_static()); } void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) { diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 0ec815f07f4f9..9b000f26d55ba 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -80,11 +80,6 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c bool is_static = device == "NPU" ? true : false; ov::AnyMap config; - if (device == "GPU") { - config = { - {"GPU_ENABLE_SDPA_OPTIMIZATION", "0"} - }; - } if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) { std::string filename = "cgraph.txt"; @@ -186,6 +181,13 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c ov::serialize(model, timestamped_filename); } + auto* disable_sdpa_optimization = getenv("GGML_OPENVINO_DISABLE_SDPA_OPTIMIZATION"); + if (disable_sdpa_optimization && std::string(disable_sdpa_optimization) != "0") { + config = { + {"GPU_ENABLE_SDPA_OPTIMIZATION", "0"} + }; + } + auto compiled_model = core.compile_model(model, device, config); compile_end_time = ggml_time_us(); infer_request_cache[cgraph] = std::make_shared(compiled_model.create_infer_request());