From 08de982b9915301c4e8ccfa9b39bf22ad31123b5 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 5 Aug 2025 19:51:01 +0800 Subject: [PATCH 01/30] Update supports_buft and supports_op for quantized models --- ggml/src/ggml-openvino/ggml-openvino.cpp | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index ed612a24660c4..f81b1ee4834de 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -8,6 +8,7 @@ #include #include "ggml-backend-impl.h" +#include "ggml-backend.h" #include "ggml-impl.h" #include "ggml-openvino/utils.h" #include "ggml.h" @@ -332,8 +333,16 @@ static bool is_op_unsupported_case(const ggml_tensor* op) { static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor* op) { GGML_ASSERT(dev->reg != nullptr); - static const std::set supported_types{ - GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_I64, GGML_TYPE_I32}; + static const std::set supported_types{GGML_TYPE_F32, + GGML_TYPE_F16, + GGML_TYPE_BF16, + GGML_TYPE_I64, + GGML_TYPE_I32, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + GGML_TYPE_Q4_K, + GGML_TYPE_Q8_0, + GGML_TYPE_Q6_K}; static const std::set supported_ops{GGML_OP_NONE, GGML_OP_ADD, @@ -411,7 +420,8 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con } static bool ggml_backend_openvino_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { - return ggml_backend_buft_is_host(buft); + // TODO quantized weigts are cpu_repack_buffer_type which does not implement ggml_backend_buft_is_host + return ggml_backend_buft_is_host(buft) || strcmp(buft->device->iface.get_name(buft->device), "CPU") == 0; GGML_UNUSED(dev); } From 7cabe37536f4493824abb69cfd110e736afd3b82 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 5 Aug 2025 20:56:50 +0800 Subject: [PATCH 02/30] Add quant weight conversion functions from genai gguf reader --- ggml/src/ggml-openvino/ggml-decoder.cpp | 76 +++++- ggml/src/ggml-openvino/ggml-quant.cpp | 313 ++++++++++++++++++++++++ ggml/src/ggml-openvino/ggml-quant.hpp | 44 ++++ 3 files changed, 429 insertions(+), 4 deletions(-) create mode 100644 ggml/src/ggml-openvino/ggml-quant.cpp create mode 100644 ggml/src/ggml-openvino/ggml-quant.hpp diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 0fd64c685f71c..c2e164b808baa 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -29,6 +30,7 @@ #include "ggml-backend-impl.h" #include "ggml-backend.h" +#include "ggml-quant.hpp" GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token, int context_size, int num_heads, int num_heads_kv, int head_size) : @@ -402,12 +404,78 @@ std::map> GgmlOvDecoder::create_weight_no } std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor) { + std::set weight_types = { + GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K}; + if (weight_types.find(tensor->type) == weight_types.end()) { + throw std::runtime_error("Unexpected weight tensor type: " + std::string(tensor->name) + " with type " + + ggml_type_name(tensor->type)); + } + auto node_type = get_ov_type(tensor); auto node_shape = get_shape(tensor); auto ne_total = ggml_nelements(tensor); - ov::Tensor weights(node_type, node_shape); - memcpy(weights.data(), tensor->data, ne_total * node_type.size()); - return std::make_shared(weights); + + if (node_type != ov::element::dynamic) { + ov::Tensor weights(node_type, node_shape); + memcpy(weights.data(), tensor->data, ne_total * node_type.size()); + std::shared_ptr weight_node = std::make_shared(weights); + if (node_type == ov::element::f16) { + weight_node = std::make_shared(weight_node, ov::element::f32); + } + weight_node->set_friendly_name(tensor->name); + return weight_node; + } + + uint64_t weights_per_byte; + if (tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_1 || tensor->type == GGML_TYPE_Q4_K) { + weights_per_byte = 2; + } else { // tensor.type == GGUF_TYPE_Q8_0 || tensor.type == GGUF_TYPE_Q6_K + weights_per_byte = 1; + } + + uint64_t weights_per_block; + // here we only consider sub block, q6k:16 q4k:32 + if (tensor->type == GGML_TYPE_Q6_K) { + weights_per_block = 16; + } else { + weights_per_block = 32; + } + + OPENVINO_ASSERT(node_shape.back() % weights_per_block == 0, + "[load_gguf] tensor ", + tensor->name, + " has incompatible last dim shape: ", + node_shape.back()); + + auto weights_shape = node_shape; + weights_shape.back() /= (weights_per_byte * 4); // means u32 type can store 8 q4 or 4 q8 + + ov::Tensor weights(ov::element::u32, weights_shape); + // For scales and bias + node_shape[node_shape.size() - 1] = node_shape[node_shape.size() - 1] / weights_per_block; + + ov::Tensor scales(ov::element::f16, node_shape); + ov::Tensor biases(ov::element::f16, node_shape); + ov::Output weight_node; + if (tensor->type == GGML_TYPE_Q4_0) { + extract_q4_0_data(tensor, weights, scales, biases); + weight_node = make_int8_weights(weights, scales, biases, weights_per_block); + } else if (tensor->type == GGML_TYPE_Q4_1) { + extract_q4_1_data(tensor, weights, scales, biases); + weight_node = make_int4_weights(weights, scales, biases, weights_per_block); + } else if (tensor->type == GGML_TYPE_Q8_0) { + extract_q8_0_data(tensor, weights, scales, biases); + weight_node = make_int8_weights(weights, scales, biases, weights_per_block); + } else if (tensor->type == GGML_TYPE_Q6_K) { + // due to WA #2135, this case will not be used, extract_q6_k_data temporarily disabled. + extract_q6_k_data(tensor, weights, scales, biases); + weight_node = make_int8_weights(weights, scales, biases, weights_per_block); + } else if (tensor->type == GGML_TYPE_Q4_K) { + extract_q4_k_data(tensor, weights, scales, biases); + weight_node = make_int4_weights(weights, scales, biases, weights_per_block); + } + weight_node.get_node_shared_ptr()->set_friendly_name(tensor->name); + return weight_node.get_node_shared_ptr(); } void GgmlOvDecoder::dump_cgraph(const struct ggml_cgraph* cgraph, std::string& filename) { @@ -537,7 +605,7 @@ ov::element::Type GgmlOvDecoder::get_ov_type(const ggml_tensor* tensor) { case GGML_TYPE_I64: return ov::element::i64; default: - throw std::runtime_error("Unsupported tensor type"); + return ov::element::dynamic; } } diff --git a/ggml/src/ggml-openvino/ggml-quant.cpp b/ggml/src/ggml-openvino/ggml-quant.cpp new file mode 100644 index 0000000000000..4311ab138ea0d --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-quant.cpp @@ -0,0 +1,313 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml.h" + +void unpack_32_4(const uint8_t* data, uint8_t* dst) { + std::fill_n(dst, 16, 0); + for (int j = 0; j < 16; ++j) { + uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes. + uint8_t y = (data[j + 2] >> 4); + if (j % 2 != 0) { + x <<= 4; + y <<= 4; + } + dst[j / 2] |= x; + dst[8 + j / 2] |= y; // Last 16 weights are in the higher bits + } +} + +// Extracts (weight, scales, biases) from Q4_0 tensors. +// Data layout is: |16 bit scale|32 x 4bit weights|. +void extract_q4_0_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr) { + const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights + auto data = static_cast(tensor->data); + auto weights = static_cast(weights_arr.data()); + auto scales = scales_arr.data::value_type>(); + auto biases = biases_arr.data::value_type>(); + + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t*)(data + i * bytes_per_block))); + biases[i] = ov::float16(-8.f * static_cast(scales[i])); + unpack_32_4(data + i * bytes_per_block, weights + i * 16); + }); +} + +// Extracts (weight, scales, biases) from Q4_1 tensors. +// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|. +void extract_q4_1_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr) { + const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights + auto data = static_cast(tensor->data); + auto weights = static_cast(weights_arr.data()); + auto scales = scales_arr.data::value_type>(); + auto biases = biases_arr.data::value_type>(); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t*)(data + i * bytes_per_block))); + biases[i] = ov::float16::from_bits(*((uint16_t*)(data + i * bytes_per_block + 1))); + unpack_32_4(data + i * bytes_per_block, weights + i * 16); + }); +} + +// Extracts (weight, scales, biases) from Q8_0 tensors. +// Data layout is: |16 bit scale|32 x 8bit weights|. +void extract_q8_0_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr) { + const uint64_t weights_per_block = 32; + const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights + auto data = static_cast(tensor->data); + auto weights = static_cast(weights_arr.data()); + auto scales = scales_arr.data::value_type>(); + auto biases = biases_arr.data::value_type>(); + for (int64_t i = 0; i < scales_arr.get_size(); i++) { + uint8_t* block_data = data + i * bytes_per_block; + scales[i] = ov::float16::from_bits(*(uint16_t*)block_data); + biases[i] = ov::float16(-128.f * static_cast(scales[i])); + for (int64_t j = 0; j < weights_per_block; ++j) { + uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes. + // Original data is in int8_t, so we add a bias of -128 and invert the + // first bit. + x ^= 1 << 7; + weights[i * weights_per_block + j] = x; + } + } +} + +void unpack_256_4(const uint8_t* data, uint8_t* dst) { + // Initialize the output array with zeros + std::fill_n(dst, 128, 0); + + for (size_t i = 0; i < 4; ++i) { + for (int j = 0; j < 32; ++j) { + uint8_t x = (data[i * 32 + j] & 0x0F); + uint8_t y = (data[i * 32 + j] >> 4); + if (j % 2 != 0) { + x <<= 4; + y <<= 4; + } + dst[i * 32 + j / 2] |= x; + dst[i * 32 + 16 + j / 2] |= y; // Last 16 weights are in the higher bits + } + } +} + +void extract_q4_k_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr) { + const uint64_t bytes_per_block = 2 + 2 + 12 + 128; + // TODO tensor->nb[3] + const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; + auto data = static_cast(tensor->data); + auto weights = static_cast(weights_arr.data()); + auto scales = scales_arr.data::value_type>(); + auto biases = biases_arr.data::value_type>(); + + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t* block_data = data + i * bytes_per_block; + + // Extract scale factors and offsets + float scale_scales = static_cast(ov::float16::from_bits(*((uint16_t*)block_data))); + float scale_biases = static_cast(ov::float16::from_bits(*((uint16_t*)block_data + 1))); + + // Extract qs1 and qs2 + uint8_t* qs1 = block_data + 4; + uint8_t* qs2 = block_data + 16; + + scales[i * 8] = ov::float16(scale_scales * static_cast((*(qs1) & 0b111111))); + scales[i * 8 + 1] = ov::float16(scale_scales * static_cast((*(qs1 + 1) & 0b111111))); + scales[i * 8 + 2] = ov::float16(scale_scales * static_cast((*(qs1 + 2) & 0b111111))); + scales[i * 8 + 3] = ov::float16(scale_scales * static_cast((*(qs1 + 3) & 0b111111))); + scales[i * 8 + 4] = + ov::float16(scale_scales * static_cast((*(qs1 + 8) & 0b00001111) | ((*(qs1) >> 6) << 4))); + scales[i * 8 + 5] = + ov::float16(scale_scales * static_cast((*(qs1 + 9) & 0b00001111) | ((*(qs1 + 1) >> 6) << 4))); + scales[i * 8 + 6] = + ov::float16(scale_scales * static_cast((*(qs1 + 10) & 0b00001111) | ((*(qs1 + 2) >> 6) << 4))); + scales[i * 8 + 7] = + ov::float16(scale_scales * static_cast((*(qs1 + 11) & 0b00001111) | ((*(qs1 + 3) >> 6) << 4))); + + biases[i * 8] = ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 4) & 0b111111))); + biases[i * 8 + 1] = ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 5) & 0b111111))); + biases[i * 8 + 2] = ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 6) & 0b111111))); + biases[i * 8 + 3] = ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 7) & 0b111111))); + biases[i * 8 + 4] = + ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 8) >> 4) | ((*(qs1 + 4) >> 6) << 4))); + biases[i * 8 + 5] = + ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 9) >> 4) | ((*(qs1 + 5) >> 6) << 4))); + biases[i * 8 + 6] = + ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 10) >> 4) | ((*(qs1 + 6) >> 6) << 4))); + biases[i * 8 + 7] = + ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 11) >> 4) | ((*(qs1 + 7) >> 6) << 4))); + unpack_256_4(block_data + 16, weights + i * 128); + }); +} + +void extract_q6_k_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr) { + const uint64_t bytes_per_block = 128 + 64 + 16 + 2; + const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; + auto data = static_cast(tensor->data); + auto weights = static_cast(weights_arr.data()); + auto scales = scales_arr.data::value_type>(); + auto biases = biases_arr.data::value_type>(); + // std::string name(tensor.name, tensor.namelen); + for (int64_t i = 0; i < n_super_block; i++) { + uint8_t* block_data = data + i * bytes_per_block; + + float scale_factor = + static_cast(ov::float16::from_bits(*((uint16_t*)block_data + 104))); // (128+64+16)/2 + + for (size_t j = 0; j < 16; j++) { + scales[j + i * 16] = + ov::float16(scale_factor * static_cast(*((int8_t*)(block_data + 128 + 64 + j)))); + biases[j + i * 16] = ov::float16(-32.f * static_cast(scales[j + i * 16])); + } + + // Extract ql and qh + uint8_t* ql = block_data; + uint8_t* qh = block_data + 128; + + // Extract weights + for (int64_t j = 0; j < 32; ++j) { + weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4); + weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4); + weights[i * 256 + j + 64] = (ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4); + weights[i * 256 + j + 96] = (ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4); + weights[i * 256 + j + 128] = (ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4); + weights[i * 256 + j + 160] = (ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4); + weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4); + weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4); + } + } +} + +// TODO Reorder for make_intX_weights + +ov::Output make_int8_weights(ov::Tensor& weight, ov::Tensor& scales, ov::Tensor& biases, size_t group_size) { + + // Reshape weight to (num_heads, -1, group_size) + ov::Shape orig_shape = weight.get_shape(); + orig_shape[1] *= sizeof(uint32_t) / sizeof(uint8_t); + size_t num_groups = orig_shape[1] / group_size; + + // Expand dimensions for scales and biases + auto scale_shape = scales.get_shape(); + scale_shape.push_back(1); + scales.set_shape(scale_shape); + biases.set_shape(scale_shape); + + // Create graph nodes + auto weights_node = std::make_shared(ov::element::u8, ov::Shape{orig_shape[0], num_groups, group_size}, static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto scales_f16 = std::make_shared(scales); + ov::Tensor biases_u8(ov::element::u8, scale_shape); + + // Calculate zero point + const ov::float16* bias_data = biases.data::value_type>(); + const ov::float16* scale_data = scales.data::value_type>(); + uint8_t* bias_u8_data = biases_u8.data(); + for (size_t i = 0; i < biases_u8.get_size(); ++i) { + bias_u8_data[i] = (uint8_t)std::round(-1.f * static_cast(bias_data[i]) / static_cast(scale_data[i])); + } + + auto zero_point = std::make_shared(biases_u8); + + // Quantization operations + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + auto zero_point_f16 = std::make_shared(zero_point, ov::element::f16); + + auto w_zp = std::make_shared( + weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY + ); + auto w_zp_s = std::make_shared( + w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY + ); + + // Reshape back to original dimensions + auto final_shape = std::make_shared( + ov::element::i64, ov::Shape{orig_shape.size()}, orig_shape + ); + auto w_zp_s_r = std::make_shared( + w_zp_s, final_shape, false + ); + + return std::make_shared(w_zp_s_r, ov::element::f32); +} + +ov::Output make_int4_weights(ov::Tensor& weight, ov::Tensor& scales, ov::Tensor& biases, size_t group_size) { + + // Convert weight to uint8 view and adjust shape + ov::Shape orig_weight_shape = weight.get_shape(); + orig_weight_shape[1] *= sizeof(uint32_t) / sizeof(uint8_t) * 2; // Double number of columns for 4-bit representation + + // Expand dimensions for scales and biases + ov::Shape scale_bias_shape = scales.get_shape(); + scale_bias_shape.push_back(1); // Add new axis at the end + scales.set_shape(scale_bias_shape); + biases.set_shape(scale_bias_shape); + + // Create INT4 weight tensor + ov::Shape packed_shape = { + orig_weight_shape[0], + orig_weight_shape[1] / group_size, + group_size + }; + + auto weights_node = std::make_shared(ov::element::u4, packed_shape, static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holde"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + + // Pack zero points: two subsequent values into one + const ov::float16* bias_data = biases.data::value_type>(); + const ov::float16* scale_data = scales.data::value_type>(); + ov::Tensor zero_point_tensor(ov::element::u4, scale_bias_shape); + uint8_t* zero_point_data = static_cast(zero_point_tensor.data()); + for (size_t i = 0; i < zero_point_tensor.get_byte_size(); ++i) { + uint8_t bias1 = (uint8_t)std::round(-1.f * static_cast(bias_data[i * 2]) / static_cast(scale_data[i * 2])); + uint8_t bias2 = (uint8_t)std::round(-1.f * static_cast(bias_data[i * 2 + 1]) / static_cast(scale_data[i * 2 + 1])); + zero_point_data[i] = (bias2 << 4) | (bias1 & 0x0F); + } + + // CVS-166438: GGUF Q4_0 zp array (U4) with all same value (8) will be converted to single U4 scalar via ConvertU4WeightsZeroPointToScalar transformation. + // This corner case can be handled by CPU plugin properly, but will trigger compilation error on GPU plugin. + // Temporal WA by adding one small bias to keep zp array shape for GPU plugin, confirm no accuracy impact for final LLM generation results. + zero_point_data[0] += 1; + + auto zero_points_node = std::make_shared(zero_point_tensor); + auto zero_points_f16 = std::make_shared(zero_points_node, ov::element::f16); + + auto scales_f16 = std::make_shared(scales); + + // Perform dequantization + auto w_zp = std::make_shared( + weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY); + + auto w_zp_s = std::make_shared( + w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); + + // Reshape back to original shape + auto final_shape = std::make_shared( + ov::element::i64, ov::Shape{orig_weight_shape.size()}, orig_weight_shape); + + auto w_zp_s_r = std::make_shared( + w_zp_s, final_shape, false); + + return std::make_shared(w_zp_s_r, ov::element::f32); +} diff --git a/ggml/src/ggml-openvino/ggml-quant.hpp b/ggml/src/ggml-openvino/ggml-quant.hpp new file mode 100644 index 0000000000000..9c0dd89a95aee --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-quant.hpp @@ -0,0 +1,44 @@ +#include +#include +#include "ggml.h" + +void unpack_32_4(const uint8_t* data, uint8_t* dst); + +void extract_q4_0_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr); + +void extract_q4_1_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr); + +void extract_q8_0_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr); + +void unpack_256_4(const uint8_t* data, uint8_t* dst); + +void extract_q4_k_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr); + +void extract_q6_k_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr); + +static constexpr size_t GGML_QUANTIZATION_GROUP_SIZE = 32; + +ov::Output make_int8_weights(ov::Tensor& weight, + ov::Tensor& scales, + ov::Tensor& biases, + size_t group_size = GGML_QUANTIZATION_GROUP_SIZE); + +ov::Output make_int4_weights(ov::Tensor& weight, + ov::Tensor& scales, + ov::Tensor& biases, + size_t group_size = GGML_QUANTIZATION_GROUP_SIZE); From 976c85bd70945e599f8f7ec03aec14c9c7006581 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Wed, 6 Aug 2025 15:54:40 +0800 Subject: [PATCH 03/30] Quant models run with accuracy issue --- ggml/src/ggml-openvino/ggml-decoder.cpp | 20 ++++++++++++++++++- ggml/src/ggml-openvino/ggml-quant.cpp | 4 +++- .../ggml-openvino/openvino/op/get_rows.cpp | 11 ++++++++-- .../openvino/translate_session.cpp | 1 - ggml/src/ggml-openvino/openvino/utils.cpp | 2 ++ 5 files changed, 33 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index c2e164b808baa..a3e7059fa2147 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -22,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -415,6 +417,9 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor) auto node_shape = get_shape(tensor); auto ne_total = ggml_nelements(tensor); + OPENVINO_ASSERT(node_shape[0] == 1, "Got 3D weights, expect all weights to be 2D: ", tensor->name); + + // F16 and F32 case if (node_type != ov::element::dynamic) { ov::Tensor weights(node_type, node_shape); memcpy(weights.data(), tensor->data, ne_total * node_type.size()); @@ -426,6 +431,9 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor) return weight_node; } + // Quantized case + node_shape.erase(node_shape.begin()); + uint64_t weights_per_byte; if (tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_1 || tensor->type == GGML_TYPE_Q4_K) { weights_per_byte = 2; @@ -459,7 +467,7 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor) ov::Output weight_node; if (tensor->type == GGML_TYPE_Q4_0) { extract_q4_0_data(tensor, weights, scales, biases); - weight_node = make_int8_weights(weights, scales, biases, weights_per_block); + weight_node = make_int4_weights(weights, scales, biases, weights_per_block); } else if (tensor->type == GGML_TYPE_Q4_1) { extract_q4_1_data(tensor, weights, scales, biases); weight_node = make_int4_weights(weights, scales, biases, weights_per_block); @@ -474,7 +482,17 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor) extract_q4_k_data(tensor, weights, scales, biases); weight_node = make_int4_weights(weights, scales, biases, weights_per_block); } + + OPENVINO_ASSERT(weight_node.get_shape().size() == 2, "Weight should be 2D"); + // weight_node = std::make_shared( + // weight_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0})); + weight_node.get_node_shared_ptr()->set_friendly_name(tensor->name); + // GGML_LOG_DEBUG("Created weight node: %s %s %s%s\n", + // tensor->name, + // ggml_type_name(tensor->type), + // weight_node.get_element_type().get_type_name().c_str(), + // weight_node.get_partial_shape().to_string().c_str()); return weight_node.get_node_shared_ptr(); } diff --git a/ggml/src/ggml-openvino/ggml-quant.cpp b/ggml/src/ggml-openvino/ggml-quant.cpp index 4311ab138ea0d..14ef58a3f777a 100644 --- a/ggml/src/ggml-openvino/ggml-quant.cpp +++ b/ggml/src/ggml-openvino/ggml-quant.cpp @@ -1,4 +1,7 @@ +#include "ggml-quant.hpp" + #include +#include #include #include #include @@ -6,7 +9,6 @@ #include #include #include -#include #include "ggml.h" diff --git a/ggml/src/ggml-openvino/openvino/op/get_rows.cpp b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp index 36795fd43eabd..0de77da59ffc5 100644 --- a/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +++ b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -7,6 +6,7 @@ #include #include #include +#include #include "../node_context.hpp" #include "../op_table.hpp" @@ -31,11 +31,18 @@ OutputVector translate_get_rows(const NodeContext& context) { indices = process_view_input(context, 1); } - auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1}); + Output axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1}); if (indices.get_partial_shape()[1].get_length() == 1) { indices = std::make_shared(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1})); + if (data.get_partial_shape().rank() == 2) { + axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + } res = std::make_shared(data, indices, axis); + if (data.get_partial_shape().rank() == 2) { + res = + std::make_shared(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); + } } else { indices = std::make_shared(indices, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index 3e27a689d52ff..62804670414ef 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -212,7 +212,6 @@ std::shared_ptr TranslateSession::apply_transformations(std::shared_ptr(); - manager.register_pass(); if (!ggml_model_decoder->is_static()) { const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names(); diff --git a/ggml/src/ggml-openvino/openvino/utils.cpp b/ggml/src/ggml-openvino/openvino/utils.cpp index c4197ccc3abdc..ef5f51ebbc4a7 100644 --- a/ggml/src/ggml-openvino/openvino/utils.cpp +++ b/ggml/src/ggml-openvino/openvino/utils.cpp @@ -17,6 +17,8 @@ #include #include +#include "ggml-impl.h" + namespace ov { namespace frontend { namespace ggml { From 1f8e007f2f5478c97b76038369f58b7e0eb3b512 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 7 Aug 2025 14:25:20 +0800 Subject: [PATCH 04/30] Fix accuracy: disable cpu_repack --- docs/build.md | 2 +- ggml/src/ggml-openvino/ggml-decoder.cpp | 4 ++++ ggml/src/ggml-openvino/ggml-openvino.cpp | 3 +-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/build.md b/docs/build.md index c7e15a4e78482..e2ef8b4e08b5b 100644 --- a/docs/build.md +++ b/docs/build.md @@ -648,7 +648,7 @@ git switch dev_backend_openvino # Build with OpenVINO support source /opt/intel/openvino/setupvars.sh -cmake -B build/ReleaseOV -G Ninja -DCMAKE_BUILD_TYPE=Release -DGGML_OPENVINO=ON +cmake -B build/ReleaseOV -G Ninja -DCMAKE_BUILD_TYPE=Release -DGGML_OPENVINO=ON -DGGML_CPU_REPACK=OFF cmake --build build/ReleaseOV --config Release -j $(nproc) ``` diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index a3e7059fa2147..cd897e5f688bb 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -432,6 +432,10 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor) } // Quantized case + OPENVINO_ASSERT( + tensor->extra == nullptr, + "Unsupported weight tensor: " + std::string(tensor->name) + " Possibly this is a repacked quantized weights"); + node_shape.erase(node_shape.begin()); uint64_t weights_per_byte; diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index f81b1ee4834de..23a92c58ac8b8 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -420,8 +420,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con } static bool ggml_backend_openvino_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { - // TODO quantized weigts are cpu_repack_buffer_type which does not implement ggml_backend_buft_is_host - return ggml_backend_buft_is_host(buft) || strcmp(buft->device->iface.get_name(buft->device), "CPU") == 0; + return ggml_backend_buft_is_host(buft); GGML_UNUSED(dev); } From c77cefefa98f39f8f7a403259e4da821fe4d6405 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 7 Aug 2025 15:22:58 +0800 Subject: [PATCH 05/30] Fix CI; Disable test-backend-ops --- ci/run.sh | 2 +- ggml/src/ggml-openvino/ggml-decoder.cpp | 2 +- .../ggml-openvino/{ggml-quant.cpp => ggml-quants.cpp} | 10 +++++----- .../ggml-openvino/{ggml-quant.hpp => ggml-quants.hpp} | 0 4 files changed, 7 insertions(+), 7 deletions(-) rename ggml/src/ggml-openvino/{ggml-quant.cpp => ggml-quants.cpp} (98%) rename ggml/src/ggml-openvino/{ggml-quant.hpp => ggml-quants.hpp} (100%) diff --git a/ci/run.sh b/ci/run.sh index 052ee999ad848..baabddf5fde9a 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -103,7 +103,7 @@ if [ ! -z ${GG_BUILD_OPENVINO} ]; then echo "source /opt/intel/openvino/setupvars.sh" exit 1 fi - CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_OPENVINO=ON" + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_OPENVINO=ON -DGGML_CPU_REPACK=OFF" fi ## helpers diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index cd897e5f688bb..cde99f32883bf 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -32,7 +32,7 @@ #include "ggml-backend-impl.h" #include "ggml-backend.h" -#include "ggml-quant.hpp" +#include "ggml-quants.hpp" GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token, int context_size, int num_heads, int num_heads_kv, int head_size) : diff --git a/ggml/src/ggml-openvino/ggml-quant.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp similarity index 98% rename from ggml/src/ggml-openvino/ggml-quant.cpp rename to ggml/src/ggml-openvino/ggml-quants.cpp index 14ef58a3f777a..8d4fb141896f4 100644 --- a/ggml/src/ggml-openvino/ggml-quant.cpp +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -1,4 +1,4 @@ -#include "ggml-quant.hpp" +#include "ggml-quants.hpp" #include #include @@ -75,11 +75,11 @@ void extract_q8_0_data(const ggml_tensor* tensor, auto weights = static_cast(weights_arr.data()); auto scales = scales_arr.data::value_type>(); auto biases = biases_arr.data::value_type>(); - for (int64_t i = 0; i < scales_arr.get_size(); i++) { + for (size_t i = 0; i < scales_arr.get_size(); i++) { uint8_t* block_data = data + i * bytes_per_block; scales[i] = ov::float16::from_bits(*(uint16_t*)block_data); biases[i] = ov::float16(-128.f * static_cast(scales[i])); - for (int64_t j = 0; j < weights_per_block; ++j) { + for (size_t j = 0; j < weights_per_block; ++j) { uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes. // Original data is in int8_t, so we add a bias of -128 and invert the // first bit. @@ -128,7 +128,7 @@ void extract_q4_k_data(const ggml_tensor* tensor, // Extract qs1 and qs2 uint8_t* qs1 = block_data + 4; - uint8_t* qs2 = block_data + 16; + // uint8_t* qs2 = block_data + 16; scales[i * 8] = ov::float16(scale_scales * static_cast((*(qs1) & 0b111111))); scales[i * 8 + 1] = ov::float16(scale_scales * static_cast((*(qs1 + 1) & 0b111111))); @@ -170,7 +170,7 @@ void extract_q6_k_data(const ggml_tensor* tensor, auto scales = scales_arr.data::value_type>(); auto biases = biases_arr.data::value_type>(); // std::string name(tensor.name, tensor.namelen); - for (int64_t i = 0; i < n_super_block; i++) { + for (size_t i = 0; i < n_super_block; i++) { uint8_t* block_data = data + i * bytes_per_block; float scale_factor = diff --git a/ggml/src/ggml-openvino/ggml-quant.hpp b/ggml/src/ggml-openvino/ggml-quants.hpp similarity index 100% rename from ggml/src/ggml-openvino/ggml-quant.hpp rename to ggml/src/ggml-openvino/ggml-quants.hpp From 30239813a6a12ea4b65688773de1371cd812026b Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 8 Aug 2025 11:07:10 +0800 Subject: [PATCH 06/30] Fix Q4_1 --- ggml/src/ggml-openvino/ggml-quants.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp index 8d4fb141896f4..e969b0b54adfc 100644 --- a/ggml/src/ggml-openvino/ggml-quants.cpp +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -15,8 +15,8 @@ void unpack_32_4(const uint8_t* data, uint8_t* dst) { std::fill_n(dst, 16, 0); for (int j = 0; j < 16; ++j) { - uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes. - uint8_t y = (data[j + 2] >> 4); + uint8_t x = (data[j] & 0x0F); + uint8_t y = (data[j] >> 4); if (j % 2 != 0) { x <<= 4; y <<= 4; @@ -41,7 +41,7 @@ void extract_q4_0_data(const ggml_tensor* tensor, ov::parallel_for(scales_arr.get_size(), [&](size_t i) { scales[i] = ov::float16::from_bits(*((uint16_t*)(data + i * bytes_per_block))); biases[i] = ov::float16(-8.f * static_cast(scales[i])); - unpack_32_4(data + i * bytes_per_block, weights + i * 16); + unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); }); } @@ -58,8 +58,8 @@ void extract_q4_1_data(const ggml_tensor* tensor, auto biases = biases_arr.data::value_type>(); ov::parallel_for(scales_arr.get_size(), [&](size_t i) { scales[i] = ov::float16::from_bits(*((uint16_t*)(data + i * bytes_per_block))); - biases[i] = ov::float16::from_bits(*((uint16_t*)(data + i * bytes_per_block + 1))); - unpack_32_4(data + i * bytes_per_block, weights + i * 16); + biases[i] = ov::float16::from_bits(*((uint16_t*)(data + i * bytes_per_block + 2))); + unpack_32_4(data + i * bytes_per_block + 4, weights + i * 16); }); } From 7b8189ae225cd403e32a3433e280e1aabee2422c Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 8 Aug 2025 15:15:12 +0800 Subject: [PATCH 07/30] Fix test-thread-safety --- tests/CMakeLists.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d9e035df32cc6..860883bdce5db 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -185,9 +185,7 @@ llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) llama_build_and_test(test-regex-partial.cpp) -if (NOT GGML_OPENVINO) - llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4 -t 2) -endif() +llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4 -t 2) # this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135) if (NOT WIN32) From 0376a7ae32400e6e57a5493d35d9b44b74b36a14 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 12 Aug 2025 09:44:21 +0800 Subject: [PATCH 08/30] Fix test-backend-ops: Treat quantized tensors as weights --- ggml/src/ggml-openvino/ggml-decoder.cpp | 16 ++++++++++------ ggml/src/ggml-openvino/ggml-decoder.h | 5 +++-- ggml/src/ggml-openvino/ggml-openvino.cpp | 14 +++++++++++--- ggml/src/ggml-openvino/utils.cpp | 6 +++++- 4 files changed, 29 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index cde99f32883bf..b20bfd0c76f52 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -76,13 +76,15 @@ GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph, add_extra_inputs(); } -GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph) { +GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph, + std::map>& model_weights) { if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) { std::string filename = "cgraph.txt"; dump_cgraph(cgraph, filename); } m_cgraph = cgraph; + m_model_weights = model_weights; for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { auto* cur_node = cgraph->nodes[node_n]; if (cur_node->op == GGML_OP_NONE) { @@ -123,10 +125,12 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) { // Add model inputs and weights constants, if called for the whole graph if (naive) { - auto param_node = std::make_shared(get_ov_type(src), get_graph_input_shape(src)); - param_node->set_friendly_name(src_name); - param_node->output(0).get_tensor().set_names({src_name}); - m_model_inputs[src_name] = param_node; + if (m_model_weights.find(src_name) == m_model_weights.end()) { + auto param_node = std::make_shared(get_ov_type(src), get_graph_input_shape(src)); + param_node->set_friendly_name(src_name); + param_node->output(0).get_tensor().set_names({src_name}); + m_model_inputs[src_name] = param_node; + } } else if (!m_node && !src->view_src) { ggml_backend_buffer* buffer = src->buffer; @@ -381,7 +385,7 @@ std::map> GgmlOvDecoder::create_weight_no std::string src_name(src->name); if (!src->view_src) { ggml_backend_buffer* buffer = src->buffer; - if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { + if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS || ggml_is_quantized(src->type)) { bool should_create = false; { std::lock_guard lock(weights_mutex); diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index ae378273d32e0..df23c649f4f47 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -20,7 +20,7 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { int context_size, int num_heads, int num_heads_kv, int head_size); // Naive graph decoder - GgmlOvDecoder(struct ggml_cgraph* cgraph); + GgmlOvDecoder(struct ggml_cgraph* cgraph, std::map>& model_weights); virtual ov::Any get_attribute(const std::string& name) const override { return nullptr; @@ -115,6 +115,8 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { ov::PartialShape get_graph_input_shape(const ggml_tensor* src) const; + static void dump_cgraph(const struct ggml_cgraph* cgraph, std::string& filename); + static std::shared_ptr create_weight_node(ggml_tensor* tensor); static std::map> create_weight_nodes(struct ggml_cgraph* cgraph); @@ -126,7 +128,6 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { private: void set_input_output(ggml_tensor* node, bool naive = false); void add_extra_inputs(); - static void dump_cgraph(const struct ggml_cgraph* cgraph, std::string& filename); static std::vector get_shape(const ggml_tensor* tensor); static std::vector get_stride(const ggml_tensor* tensor); static ov::element::Type get_ov_type(const ggml_tensor* tensor); diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 23a92c58ac8b8..4b743be6884b0 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -403,14 +403,22 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con return false; } for (int i = 0; i < GGML_MAX_SRC; i++) { - if (supported_types.find(op->type) == supported_types.end()) { - GGML_LOG_WARN("OpenVINO backend does not support tensor type %s\n", ggml_type_name(op->type)); + auto* src = op->src[i]; + if (src == nullptr) { + break; + } + if (supported_types.find(src->type) == supported_types.end()) { + GGML_LOG_WARN("OpenVINO backend does not support tensor type %s\n", ggml_type_name(src->type)); return false; } - if (op->src[i] != nullptr && op->src[i]->ne[3] != 1) { + if (src->ne[3] != 1) { GGML_LOG_WARN("OpenVINO backend does not support tensors with ne[3] != 1\n"); return false; } + if (ggml_is_quantized(src->type) && src->ne[2] != 1) { + GGML_LOG_WARN("OpenVINO backend does not support 3D quantized tensors\n"); + return false; + } } if (is_op_unsupported_case(op)) { diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 473fa72f99fd5..43fa0c469d60a 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -281,10 +281,14 @@ enum ggml_status naive_compute(struct ggml_cgraph* cgraph, return GGML_STATUS_FAILED; } - auto decoder = std::make_shared(cgraph); + auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); + auto decoder = std::make_shared(cgraph, model_weights); auto input_model = std::make_shared(decoder); auto naive = true; auto model = ov::frontend::ggml::FrontEnd::convert(input_model, naive); + if (getenv("GGML_OPENVINO_DUMP_IR")) { + ov::serialize(model, "IR_naive.xml"); + } auto infer_request = core.compile_model(model, device, config).create_infer_request(); auto ov_params = model->get_parameters(); From e1a5f7ea3d9b6d963236f1e90d92d62e52e8d10d Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 19 Aug 2025 14:56:28 +0800 Subject: [PATCH 09/30] Add NPU Q4_0 support --- ggml/src/ggml-openvino/ggml-openvino.cpp | 28 +++++++++++++++--------- ggml/src/ggml-openvino/ggml-quants.cpp | 13 ++++++----- ggml/src/ggml-openvino/ggml-quants.hpp | 13 +++++++++++ 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 4b743be6884b0..a6ec1c64c2904 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -333,16 +333,24 @@ static bool is_op_unsupported_case(const ggml_tensor* op) { static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor* op) { GGML_ASSERT(dev->reg != nullptr); - static const std::set supported_types{GGML_TYPE_F32, - GGML_TYPE_F16, - GGML_TYPE_BF16, - GGML_TYPE_I64, - GGML_TYPE_I32, - GGML_TYPE_Q4_0, - GGML_TYPE_Q4_1, - GGML_TYPE_Q4_K, - GGML_TYPE_Q8_0, - GGML_TYPE_Q6_K}; + static std::set supported_types{GGML_TYPE_F32, + GGML_TYPE_F16, + GGML_TYPE_BF16, + GGML_TYPE_I64, + GGML_TYPE_I32, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + GGML_TYPE_Q4_K, + GGML_TYPE_Q8_0, + GGML_TYPE_Q6_K}; + + std::string device = std::string(getenv("GGML_OPENVINO_DEVICE")); + bool is_npu = device == "NPU"; + if (is_npu) { + // NPU has poor support for asymmetric quantization + supported_types.erase(GGML_TYPE_Q4_1); + supported_types.erase(GGML_TYPE_Q4_K); + } static const std::set supported_ops{GGML_OP_NONE, GGML_OP_ADD, diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp index e969b0b54adfc..97aa494ed85aa 100644 --- a/ggml/src/ggml-openvino/ggml-quants.cpp +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -230,6 +230,10 @@ ov::Output make_int8_weights(ov::Tensor& weight, ov::Tensor& scales, o } auto zero_point = std::make_shared(biases_u8); + float zp_value; + if (ov::op::util::get_single_value(zero_point, zp_value)) { + zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value}); + } // Quantization operations auto weights_f16 = std::make_shared(weights_node, ov::element::f16); @@ -287,12 +291,11 @@ ov::Output make_int4_weights(ov::Tensor& weight, ov::Tensor& scales, o zero_point_data[i] = (bias2 << 4) | (bias1 & 0x0F); } - // CVS-166438: GGUF Q4_0 zp array (U4) with all same value (8) will be converted to single U4 scalar via ConvertU4WeightsZeroPointToScalar transformation. - // This corner case can be handled by CPU plugin properly, but will trigger compilation error on GPU plugin. - // Temporal WA by adding one small bias to keep zp array shape for GPU plugin, confirm no accuracy impact for final LLM generation results. - zero_point_data[0] += 1; - auto zero_points_node = std::make_shared(zero_point_tensor); + float zp_value; + if (ov::op::util::get_single_value(zero_points_node, zp_value)) { + zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value}); + } auto zero_points_f16 = std::make_shared(zero_points_node, ov::element::f16); auto scales_f16 = std::make_shared(scales); diff --git a/ggml/src/ggml-openvino/ggml-quants.hpp b/ggml/src/ggml-openvino/ggml-quants.hpp index 9c0dd89a95aee..ae37b1618ed14 100644 --- a/ggml/src/ggml-openvino/ggml-quants.hpp +++ b/ggml/src/ggml-openvino/ggml-quants.hpp @@ -1,5 +1,7 @@ #include +#include #include + #include "ggml.h" void unpack_32_4(const uint8_t* data, uint8_t* dst); @@ -42,3 +44,14 @@ ov::Output make_int4_weights(ov::Tensor& weight, ov::Tensor& scales, ov::Tensor& biases, size_t group_size = GGML_QUANTIZATION_GROUP_SIZE); + +namespace ov { +namespace op { +namespace util { +// From /src/common/transformations/include/transformations/utils/utils.hpp +bool get_single_value(const std::shared_ptr& const_node, + float& value, + bool check_value_range = true); +} // namespace util +} // namespace op +} // namespace ov From 489d45384612e9d39447131764086cf1d1c6d23e Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 22 Aug 2025 15:00:38 +0800 Subject: [PATCH 10/30] NPU perf: eliminate zp --- .../openvino/pass/eliminate_zp.cpp | 116 ++++++++++++++++++ .../openvino/pass/eliminate_zp.hpp | 17 +++ .../openvino/translate_session.cpp | 2 + 3 files changed, 135 insertions(+) create mode 100644 ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp create mode 100644 ggml/src/ggml-openvino/openvino/pass/eliminate_zp.hpp diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp new file mode 100644 index 0000000000000..d2e5a040dd28f --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp @@ -0,0 +1,116 @@ +#include "eliminate_zp.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace pass { + +EliminateZeroPoints::EliminateZeroPoints() { + // Find pattern: + // (Multiply Any(scale) + // (Subtract (Convert Constant(data))) + // (Convert Constant(zero_point))) + // where zero_point is a scalar + // If data is u4 and zp value is 8 (q4_0), Replace the Subtract with an i4 Constant whose value is data - zp_val + // If data is u8 and zp value is 128 (q8_0) or 32 (q6_k), Replace the Subtract with an i8 Constant + + auto m_data_constant = ov::pass::pattern::wrap_type(); + auto m_data_convert = ov::pass::pattern::wrap_type({m_data_constant}); + + auto m_zp_constant = ov::pass::pattern::wrap_type(); + auto m_zp_convert = ov::pass::pattern::wrap_type({m_zp_constant}); + + auto m_subtract = ov::pass::pattern::wrap_type({m_data_convert, m_zp_convert}); + auto m_scale = ov::pass::pattern::any_input(); + auto m_multiply = ov::pass::pattern::wrap_type({m_scale, m_subtract}); + + const auto callback = [=](ov::pass::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + + auto multiply_node = std::dynamic_pointer_cast(pattern_map.at(m_multiply).get_node_shared_ptr()); + auto subtract_node = std::dynamic_pointer_cast(pattern_map.at(m_subtract).get_node_shared_ptr()); + auto data_constant = std::dynamic_pointer_cast(pattern_map.at(m_data_constant).get_node_shared_ptr()); + auto zp_constant = std::dynamic_pointer_cast(pattern_map.at(m_zp_constant).get_node_shared_ptr()); + + if (!multiply_node || !subtract_node || !data_constant || !zp_constant) { + return false; + } + + if (ov::shape_size(zp_constant->get_shape()) != 1) { + return false; + } + + auto data_type = data_constant->get_element_type(); + auto zp_data = zp_constant->cast_vector(); + + if (zp_data.empty()) { + return false; + } + + int zp_value = zp_data[0]; + + bool should_eliminate = false; + ov::element::Type target_type; + + if (data_type == ov::element::u4 && zp_value == 8) { + should_eliminate = true; + target_type = ov::element::i4; + } else if (data_type == ov::element::u8 && (zp_value == 128 || zp_value == 32)) { + should_eliminate = true; + target_type = ov::element::i8; + } + + if (!should_eliminate) { + return false; + } + + auto data_shape = data_constant->get_shape(); + size_t total_elements = ov::shape_size(data_shape); + + std::shared_ptr new_constant; + + if (data_type == ov::element::u4) { + auto data_values = data_constant->cast_vector(); + std::vector adjusted_values(total_elements); + + ov::parallel_for(total_elements, [&](size_t i) { + adjusted_values[i] = static_cast(static_cast(data_values[i]) - 8); + }); + + new_constant = std::make_shared(target_type, data_shape, adjusted_values); + } else if (data_type == ov::element::u8) { + auto data_values = data_constant->cast_vector(); + std::vector adjusted_values(total_elements); + + ov::parallel_for(total_elements, [&, zp_value](size_t i) { + adjusted_values[i] = static_cast(static_cast(data_values[i]) - zp_value); + }); + + new_constant = std::make_shared(target_type, data_shape, adjusted_values); + } + + auto new_convert = std::make_shared(new_constant, subtract_node->get_output_element_type(0)); + ov::replace_node(subtract_node, new_convert); + + return true; + }; + + register_matcher(std::make_shared(m_multiply, "ov::frontend::ggml::pass::EliminateZeroPoints"), + callback); +} + +} // namespace pass +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.hpp b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.hpp new file mode 100644 index 0000000000000..edd3cd718d9b0 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.hpp @@ -0,0 +1,17 @@ +#include "openvino/pass/matcher_pass.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace pass { + +class EliminateZeroPoints : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::EliminateZeroPoints") + EliminateZeroPoints(); +}; + +} // namespace pass +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index 62804670414ef..634fea40e923f 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -26,6 +26,7 @@ #include "ggml-openvino/openvino/node_context.hpp" #include "ggml-openvino/openvino/utils.hpp" #include "input_model.hpp" +#include "pass/eliminate_zp.hpp" #include "pass/fuse_to_sdpa.hpp" #include "pass/mark_decompression_convert_constant_folding.hpp" @@ -219,6 +220,7 @@ std::shared_ptr TranslateSession::apply_transformations(std::shared_ptr(kv_param_res_pairs); } + manager.register_pass(); manager.register_pass(); manager.run_passes(model); } From c319ce5ad91890a51459aef84995a7810d9dc37e Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 26 Aug 2025 15:55:06 +0800 Subject: [PATCH 11/30] NPU perf: Faster compilation --- IR.xml | 462 +++++++++++++++++++++++++++ ggml/src/ggml-openvino/utils.cpp.bak | 72 +++++ 2 files changed, 534 insertions(+) create mode 100644 IR.xml create mode 100644 ggml/src/ggml-openvino/utils.cpp.bak diff --git a/IR.xml b/IR.xml new file mode 100644 index 0000000000000..f5b1df8740a66 --- /dev/null +++ b/IR.xml @@ -0,0 +1,462 @@ + + + + + + + + 2 + 128 + 64 + + + + + + + + 1 + 1 + 32 + + + + + + + + 1 + 1 + 2 + + + + + + + + + + + + + + 2 + 128 + 64 + + + + + + 2 + 128 + 32 + + + 2 + 128 + 32 + + + + + + + + 1 + 1 + 32 + + + + + + + + 1 + 1 + 32 + + + 1 + 1 + 32 + + + + + 1 + 1 + 32 + + + + + + + + 1 + 1 + 2 + + + + + 1 + 1 + 2 + + + + + + + + 3 + + + + + + + 1 + 1 + 2 + + + 3 + + + + + 2 + 1 + 1 + + + + + + + + 1 + 1 + 32 + + + 2 + 1 + 1 + + + + + 2 + 1 + 32 + + + + + + + + 1 + + + + + + + + 2 + 1 + 32 + + + 1 + + + + + 2 + 1 + 32 + + + + + + + 2 + 1 + 32 + + + + + 2 + 1 + 32 + + + + + + + + + + + + + + 2 + 1 + 32 + + + + + + 2 + 1 + 32 + + + + + + + + 2 + 128 + 32 + + + 2 + 1 + 32 + + + + + 2 + 128 + 32 + + + + + + + 2 + 1 + 32 + + + + + 2 + 1 + 32 + + + + + + + + 2 + 1 + 32 + + + + + + 2 + 1 + 32 + + + + + + + + 2 + 128 + 32 + + + 2 + 1 + 32 + + + + + 2 + 128 + 32 + + + + + + + + 2 + 128 + 32 + + + 2 + 128 + 32 + + + + + 2 + 128 + 32 + + + + + + + + 2 + 128 + 32 + + + 2 + 1 + 32 + + + + + 2 + 128 + 32 + + + + + + + + 2 + 128 + 32 + + + 2 + 1 + 32 + + + + + 2 + 128 + 32 + + + + + + + + 2 + 128 + 32 + + + 2 + 128 + 32 + + + + + 2 + 128 + 32 + + + + + + + + 2 + 128 + 32 + + + 2 + 128 + 32 + + + + + 2 + 128 + 64 + + + + + + + 2 + 128 + 64 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/ggml/src/ggml-openvino/utils.cpp.bak b/ggml/src/ggml-openvino/utils.cpp.bak new file mode 100644 index 0000000000000..8fef1985f91ae --- /dev/null +++ b/ggml/src/ggml-openvino/utils.cpp.bak @@ -0,0 +1,72 @@ +void model_cut() { + ov::Core core; + std::shared_ptr model = + core.read_model("/home/zijun/dev/llama.cpp-ov/tmp/fold_graph/Model1_01_0x5555601c5ac0.xml"); + + ov::ParameterVector new_params; + + auto ops = model->get_ops(); + std::shared_ptr node_a; + std::shared_ptr node_b; + for (const auto& op : ops) { + if (op->get_friendly_name() == "Multiply_4636_ffn_norm-0") { + node_a = op; + } else if (op->get_friendly_name() == "Multiply_4645_ffn_gate_par-0") { + node_b = op; + } else if (op->get_friendly_name() == "Parameter_39914") { + auto param = std::dynamic_pointer_cast(op); + new_params.push_back(param); + } else if (op->get_friendly_name() == "Parameter_39915") { + auto param = std::dynamic_pointer_cast(op); + new_params.push_back(param); + } + } + + auto subgraph_input_tensor = node_a->output(0); + auto subgraph_output_tensor = node_b->output(0); + + auto new_input = std::make_shared(subgraph_input_tensor.get_element_type(), + subgraph_input_tensor.get_shape()); + new_input->set_friendly_name("subgraph_input"); + new_params.push_back(new_input); + + // Rewire: replace all consumers of original tensor with new input + subgraph_input_tensor.replace(new_input); + + auto result = std::make_shared(subgraph_output_tensor); + result->set_friendly_name("subgraph_output"); + + auto subgraph = std::make_shared(ov::ResultVector{result}, new_params, "trimmed_subgraph"); + + ov::serialize(subgraph, "/home/zijun/dev/llama.cpp-ov/tmp/subgraph.xml"); + + assert(false); +} + +void create_graph() { + // Input shapes: [256, 1, 1] + ov::Shape input_shape{256, 1, 1}; + + // Define input parameters + auto input0 = std::make_shared(ov::element::f32, input_shape); + auto input1 = std::make_shared(ov::element::f32, input_shape); + + // Concat on axis 2 -> shape becomes [256, 1, 2] + auto concat = std::make_shared(ov::OutputVector{input0, input1}, 2); + + // Target shape constant for reshape: [256, 2] + auto reshape_shape = ov::op::v0::Constant::create(ov::element::i64, {2}, {256, 2}); + + // special_zero = false + auto reshape = std::make_shared(concat, reshape_shape, false); + + // Define result node + auto result = std::make_shared(reshape); + + // Create model + auto model = std::make_shared(ov::ResultVector{result}, ov::ParameterVector{input0, input1}, "ReshapeConcatModel"); + + ov::serialize(subgraph, "/home/zijun/dev/llama.cpp-ov/tmp/subgraph3.xml"); + + exit(0); +} From 5b6418d76e4d8496d7e97c406c5a794c800d8332 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 29 Aug 2025 11:39:27 +0800 Subject: [PATCH 12/30] Dequantize q4_1 q4_k q6_k for NPU --- ggml/src/ggml-openvino/ggml-decoder.cpp | 25 +++++++++++++++++------- ggml/src/ggml-openvino/ggml-decoder.h | 5 +++-- ggml/src/ggml-openvino/ggml-openvino.cpp | 8 -------- ggml/src/ggml-openvino/utils.cpp | 6 +++++- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index b20bfd0c76f52..fef8648ebdac4 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -370,7 +370,8 @@ std::map GgmlOvDecoder::get_kv_param_res_names() const return kv_param_res_names; } -std::map> GgmlOvDecoder::create_weight_nodes(struct ggml_cgraph* cgraph) { +std::map> GgmlOvDecoder::create_weight_nodes( + struct ggml_cgraph* cgraph, std::set types_to_dequantize) { std::map> model_weights; static std::mutex weights_mutex; auto* nodes = cgraph->nodes; @@ -395,7 +396,7 @@ std::map> GgmlOvDecoder::create_weight_no } } if (should_create) { - auto weight_node = create_weight_node(src); + auto weight_node = create_weight_node(src, types_to_dequantize.count(src->type) > 0); weight_node->set_friendly_name(src_name); { std::lock_guard lock(weights_mutex); @@ -409,7 +410,7 @@ std::map> GgmlOvDecoder::create_weight_no return model_weights; } -std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor) { +std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor, bool to_dequantize) { std::set weight_types = { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K}; if (weight_types.find(tensor->type) == weight_types.end()) { @@ -422,15 +423,17 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor) auto ne_total = ggml_nelements(tensor); OPENVINO_ASSERT(node_shape[0] == 1, "Got 3D weights, expect all weights to be 2D: ", tensor->name); + node_shape.erase(node_shape.begin()); // F16 and F32 case if (node_type != ov::element::dynamic) { ov::Tensor weights(node_type, node_shape); memcpy(weights.data(), tensor->data, ne_total * node_type.size()); std::shared_ptr weight_node = std::make_shared(weights); - if (node_type == ov::element::f16) { - weight_node = std::make_shared(weight_node, ov::element::f32); - } + // Disabled because it triggers a bug in NPUW, no performance impact on CPU GPU + // if (node_type == ov::element::f16) { + // weight_node = std::make_shared(weight_node, ov::element::f32); + // } weight_node->set_friendly_name(tensor->name); return weight_node; } @@ -440,7 +443,15 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor) tensor->extra == nullptr, "Unsupported weight tensor: " + std::string(tensor->name) + " Possibly this is a repacked quantized weights"); - node_shape.erase(node_shape.begin()); + if (to_dequantize) { + std::vector weights_f32(ne_total); + ggml_get_type_traits(tensor->type)->to_float(tensor->data, weights_f32.data(), ggml_nelements(tensor)); + ov::Tensor weights(ov::element::f16, node_shape); + ggml_get_type_traits(GGML_TYPE_F16)->from_float_ref(weights_f32.data(), weights.data(), ggml_nelements(tensor)); + std::shared_ptr weight_node = std::make_shared(weights); + weight_node->set_friendly_name(tensor->name); + return weight_node; + } uint64_t weights_per_byte; if (tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_1 || tensor->type == GGML_TYPE_Q4_K) { diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index df23c649f4f47..b446841514794 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -117,8 +117,9 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { static void dump_cgraph(const struct ggml_cgraph* cgraph, std::string& filename); - static std::shared_ptr create_weight_node(ggml_tensor* tensor); - static std::map> create_weight_nodes(struct ggml_cgraph* cgraph); + static std::shared_ptr create_weight_node(ggml_tensor* tensor, bool to_dequantize); + static std::map> create_weight_nodes( + struct ggml_cgraph* cgraph, std::set types_to_dequantize = {}); const ggml_tensor* get_tensor_used_op(const ggml_tensor* tensor) const; const ggml_tensor* get_tensor_from_name(const std::string& name) const; diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index a6ec1c64c2904..60a2eb388ea1e 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -344,14 +344,6 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con GGML_TYPE_Q8_0, GGML_TYPE_Q6_K}; - std::string device = std::string(getenv("GGML_OPENVINO_DEVICE")); - bool is_npu = device == "NPU"; - if (is_npu) { - // NPU has poor support for asymmetric quantization - supported_types.erase(GGML_TYPE_Q4_1); - supported_types.erase(GGML_TYPE_Q4_K); - } - static const std::set supported_ops{GGML_OP_NONE, GGML_OP_ADD, GGML_OP_MUL, diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 43fa0c469d60a..e49d941da4ab1 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -130,7 +130,11 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c compile_end_time = conversion_end_time; } else { std::shared_ptr model; - auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); + std::set types_to_dequantize; + if (is_static) { + types_to_dequantize = {GGML_TYPE_Q4_1, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K}; + } + auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph, types_to_dequantize); if (is_static) { ggml_decoder = std::make_shared(cgraph, model_weights, is_static, true); From 623f8630588a26636ce7c71aa20c798f76c73d7a Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 2 Sep 2025 13:52:45 +0800 Subject: [PATCH 13/30] Add custom quant type: q8_1_c, q4_0_128 --- ggml/src/ggml-openvino/ggml-decoder.cpp | 44 ++---- ggml/src/ggml-openvino/ggml-decoder.h | 7 +- ggml/src/ggml-openvino/ggml-quants.cpp | 194 +++++++++++++++++++----- ggml/src/ggml-openvino/ggml-quants.hpp | 10 ++ ggml/src/ggml-openvino/utils.cpp | 16 +- 5 files changed, 203 insertions(+), 68 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index fef8648ebdac4..d00b78e891ee0 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -371,7 +372,7 @@ std::map GgmlOvDecoder::get_kv_param_res_names() const } std::map> GgmlOvDecoder::create_weight_nodes( - struct ggml_cgraph* cgraph, std::set types_to_dequantize) { + struct ggml_cgraph* cgraph, std::map types_to_requantize) { std::map> model_weights; static std::mutex weights_mutex; auto* nodes = cgraph->nodes; @@ -396,7 +397,10 @@ std::map> GgmlOvDecoder::create_weight_no } } if (should_create) { - auto weight_node = create_weight_node(src, types_to_dequantize.count(src->type) > 0); + auto requant_type = types_to_requantize.count(src->type) ? + std::optional(types_to_requantize.at(src->type)) : + std::nullopt; + auto weight_node = create_weight_node(src, requant_type); weight_node->set_friendly_name(src_name); { std::lock_guard lock(weights_mutex); @@ -410,7 +414,8 @@ std::map> GgmlOvDecoder::create_weight_no return model_weights; } -std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor, bool to_dequantize) { +std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor, + std::optional requant_type) { std::set weight_types = { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K}; if (weight_types.find(tensor->type) == weight_types.end()) { @@ -443,21 +448,15 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor, tensor->extra == nullptr, "Unsupported weight tensor: " + std::string(tensor->name) + " Possibly this is a repacked quantized weights"); - if (to_dequantize) { - std::vector weights_f32(ne_total); - ggml_get_type_traits(tensor->type)->to_float(tensor->data, weights_f32.data(), ggml_nelements(tensor)); - ov::Tensor weights(ov::element::f16, node_shape); - ggml_get_type_traits(GGML_TYPE_F16)->from_float_ref(weights_f32.data(), weights.data(), ggml_nelements(tensor)); - std::shared_ptr weight_node = std::make_shared(weights); - weight_node->set_friendly_name(tensor->name); - return weight_node; + if (requant_type.has_value()) { + return requantize(tensor, requant_type.value()); } - uint64_t weights_per_byte; + ov::element::Type weight_type; if (tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_1 || tensor->type == GGML_TYPE_Q4_K) { - weights_per_byte = 2; + weight_type = ov::element::u4; } else { // tensor.type == GGUF_TYPE_Q8_0 || tensor.type == GGUF_TYPE_Q6_K - weights_per_byte = 1; + weight_type = ov::element::u8; } uint64_t weights_per_block; @@ -474,15 +473,12 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor, " has incompatible last dim shape: ", node_shape.back()); - auto weights_shape = node_shape; - weights_shape.back() /= (weights_per_byte * 4); // means u32 type can store 8 q4 or 4 q8 - - ov::Tensor weights(ov::element::u32, weights_shape); - // For scales and bias + ov::Tensor weights(weight_type, node_shape); + // For scales and biases node_shape[node_shape.size() - 1] = node_shape[node_shape.size() - 1] / weights_per_block; - ov::Tensor scales(ov::element::f16, node_shape); ov::Tensor biases(ov::element::f16, node_shape); + ov::Output weight_node; if (tensor->type == GGML_TYPE_Q4_0) { extract_q4_0_data(tensor, weights, scales, biases); @@ -494,7 +490,6 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor, extract_q8_0_data(tensor, weights, scales, biases); weight_node = make_int8_weights(weights, scales, biases, weights_per_block); } else if (tensor->type == GGML_TYPE_Q6_K) { - // due to WA #2135, this case will not be used, extract_q6_k_data temporarily disabled. extract_q6_k_data(tensor, weights, scales, biases); weight_node = make_int8_weights(weights, scales, biases, weights_per_block); } else if (tensor->type == GGML_TYPE_Q4_K) { @@ -503,15 +498,8 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor, } OPENVINO_ASSERT(weight_node.get_shape().size() == 2, "Weight should be 2D"); - // weight_node = std::make_shared( - // weight_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0})); weight_node.get_node_shared_ptr()->set_friendly_name(tensor->name); - // GGML_LOG_DEBUG("Created weight node: %s %s %s%s\n", - // tensor->name, - // ggml_type_name(tensor->type), - // weight_node.get_element_type().get_type_name().c_str(), - // weight_node.get_partial_shape().to_string().c_str()); return weight_node.get_node_shared_ptr(); } diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index b446841514794..24e1d92dcfd68 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -4,8 +4,10 @@ #include #include #include +#include #include +#include "ggml-quants.hpp" #include "ggml.h" #include "openvino/decoder.hpp" @@ -117,9 +119,10 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { static void dump_cgraph(const struct ggml_cgraph* cgraph, std::string& filename); - static std::shared_ptr create_weight_node(ggml_tensor* tensor, bool to_dequantize); + static std::shared_ptr create_weight_node(ggml_tensor* tensor, + std::optional requant_type = std::nullopt); static std::map> create_weight_nodes( - struct ggml_cgraph* cgraph, std::set types_to_dequantize = {}); + struct ggml_cgraph* cgraph, std::map types_to_requantize = {}); const ggml_tensor* get_tensor_used_op(const ggml_tensor* tensor) const; const ggml_tensor* get_tensor_from_name(const std::string& name) const; diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp index 97aa494ed85aa..1603e65355274 100644 --- a/ggml/src/ggml-openvino/ggml-quants.cpp +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -1,15 +1,20 @@ #include "ggml-quants.hpp" #include +#include +#include #include #include +#include #include #include #include #include #include #include +#include +#include "ggml-impl.h" #include "ggml.h" void unpack_32_4(const uint8_t* data, uint8_t* dst) { @@ -203,20 +208,24 @@ void extract_q6_k_data(const ggml_tensor* tensor, // TODO Reorder for make_intX_weights ov::Output make_int8_weights(ov::Tensor& weight, ov::Tensor& scales, ov::Tensor& biases, size_t group_size) { - - // Reshape weight to (num_heads, -1, group_size) ov::Shape orig_shape = weight.get_shape(); - orig_shape[1] *= sizeof(uint32_t) / sizeof(uint8_t); - size_t num_groups = orig_shape[1] / group_size; // Expand dimensions for scales and biases auto scale_shape = scales.get_shape(); - scale_shape.push_back(1); - scales.set_shape(scale_shape); - biases.set_shape(scale_shape); + + ov::Shape packed_shape = {orig_shape[0], orig_shape[1] / group_size, group_size}; + + if (packed_shape[1] == 1) { + packed_shape.erase(packed_shape.begin() + 1); + } else { + scale_shape.push_back(1); + scales.set_shape(scale_shape); + biases.set_shape(scale_shape); + } // Create graph nodes - auto weights_node = std::make_shared(ov::element::u8, ov::Shape{orig_shape[0], num_groups, group_size}, static_cast(weight.data()), nullptr); + auto weights_node = std::make_shared( + ov::element::u8, packed_shape, static_cast(weight.data()), nullptr); weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; auto scales_f16 = std::make_shared(scales); ov::Tensor biases_u8(ov::element::u8, scale_shape); @@ -242,32 +251,24 @@ ov::Output make_int8_weights(ov::Tensor& weight, ov::Tensor& scales, o auto w_zp = std::make_shared( weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY ); - auto w_zp_s = std::make_shared( - w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY - ); - - // Reshape back to original dimensions - auto final_shape = std::make_shared( - ov::element::i64, ov::Shape{orig_shape.size()}, orig_shape - ); - auto w_zp_s_r = std::make_shared( - w_zp_s, final_shape, false - ); + ov::Output w_zp_s = + std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); + + if (packed_shape.size() != 2) { + // If not requantized channel-wise case, reshape back to original shape + auto final_shape = + std::make_shared(ov::element::i64, ov::Shape{orig_shape.size()}, orig_shape); + w_zp_s = std::make_shared(w_zp_s, final_shape, false); + } - return std::make_shared(w_zp_s_r, ov::element::f32); + return std::make_shared(w_zp_s, ov::element::f32); } ov::Output make_int4_weights(ov::Tensor& weight, ov::Tensor& scales, ov::Tensor& biases, size_t group_size) { - - // Convert weight to uint8 view and adjust shape ov::Shape orig_weight_shape = weight.get_shape(); - orig_weight_shape[1] *= sizeof(uint32_t) / sizeof(uint8_t) * 2; // Double number of columns for 4-bit representation // Expand dimensions for scales and biases ov::Shape scale_bias_shape = scales.get_shape(); - scale_bias_shape.push_back(1); // Add new axis at the end - scales.set_shape(scale_bias_shape); - biases.set_shape(scale_bias_shape); // Create INT4 weight tensor ov::Shape packed_shape = { @@ -276,8 +277,17 @@ ov::Output make_int4_weights(ov::Tensor& weight, ov::Tensor& scales, o group_size }; + // Requantized channel-wise case + if (packed_shape[1] == 1) { + packed_shape.erase(packed_shape.begin() + 1); + } else { + scale_bias_shape.push_back(1); + scales.set_shape(scale_bias_shape); + biases.set_shape(scale_bias_shape); + } + auto weights_node = std::make_shared(ov::element::u4, packed_shape, static_cast(weight.data()), nullptr); - weights_node->get_rt_info()["__gguf_tensor_holde"] = weight; + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; auto weights_f16 = std::make_shared(weights_node, ov::element::f16); // Pack zero points: two subsequent values into one @@ -304,15 +314,129 @@ ov::Output make_int4_weights(ov::Tensor& weight, ov::Tensor& scales, o auto w_zp = std::make_shared( weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY); - auto w_zp_s = std::make_shared( - w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); + ov::Output w_zp_s = + std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); + + if (packed_shape.size() != 2) { + // If not requantized channel-wise case, reshape back to original shape + auto final_shape = std::make_shared( + ov::element::i64, ov::Shape{orig_weight_shape.size()}, orig_weight_shape); + + w_zp_s = std::make_shared(w_zp_s, final_shape, false); + } + + return std::make_shared(w_zp_s, ov::element::f32); +} - // Reshape back to original shape - auto final_shape = std::make_shared( - ov::element::i64, ov::Shape{orig_weight_shape.size()}, orig_weight_shape); +std::shared_ptr requantize(const ggml_tensor* tensor, ExtraQuantType requant_type) { + std::vector weights_f32(tensor->ne[0] * tensor->ne[1]); + ggml_get_type_traits(tensor->type)->to_float(tensor->data, weights_f32.data(), ggml_nelements(tensor)); - auto w_zp_s_r = std::make_shared( - w_zp_s, final_shape, false); + std::shared_ptr weight_node; + ov::Shape node_shape = {(uint64_t) (tensor->ne[1]), (uint64_t) (tensor->ne[0])}; + + if (requant_type == ExtraQuantType::F16) { + ov::Tensor weights(ov::element::f16, node_shape); + ggml_get_type_traits(GGML_TYPE_F16)->from_float_ref(weights_f32.data(), weights.data(), ggml_nelements(tensor)); + std::shared_ptr weight_node = std::make_shared(weights); + weight_node->set_friendly_name(tensor->name); + return weight_node; + } - return std::make_shared(w_zp_s_r, ov::element::f32); + int64_t block_size = node_shape[1]; + if (requant_type == ExtraQuantType::Q4_0_128) { + block_size = 128; + } + auto scales_shape = ov::Shape{node_shape[0], node_shape[1] / block_size}; + + ov::Tensor weights; + ov::Tensor scales(ov::element::f16, scales_shape); + ov::Tensor bias(ov::element::f16, scales_shape); + + if (requant_type == ExtraQuantType::Q4_0_C) { + weights = ov::Tensor(ov::element::u4, node_shape); + quantize_q4_0(weights_f32.data(), weights, scales, bias, weights.get_size(), block_size); + weight_node = make_int4_weights(weights, scales, bias, block_size).get_node_shared_ptr(); + } else if (requant_type == ExtraQuantType::Q8_1_C) { + weights = ov::Tensor(ov::element::u8, node_shape); + quantize_q8_1(weights_f32.data(), weights, scales, bias, weights.get_size(), block_size); + weight_node = make_int8_weights(weights, scales, bias, block_size).get_node_shared_ptr(); + } else if (requant_type == ExtraQuantType::Q4_0_128) { + weights = ov::Tensor(ov::element::u4, node_shape); + quantize_q4_0(weights_f32.data(), weights, scales, bias, weights.get_size(), block_size); + weight_node = make_int4_weights(weights, scales, bias, block_size).get_node_shared_ptr(); + } + + weight_node->set_friendly_name(tensor->name); + return weight_node; +} + +void quantize_q4_0(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, + int64_t qk) { + assert(k % qk == 0); + const int nb = k / qk; + + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); + biases[i] = ov::float16(-8.f * d); + + for (int j = 0; j < qk / 2; ++j) { + const float x0 = x[i * qk + 2 * j] * id; + const float x1 = x[i * qk + 2 * j + 1] * id; + const uint8_t xi0 = MIN(15, (int8_t) (x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t) (x1 + 8.5f)); + weights[i * qk / 2 + j] = xi0 | (xi1 << 4); + } + } +} + +void quantize_q8_1(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, + int64_t qk) { + assert(k % qk == 0); + const int nb = k / qk; + + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); + for (int i = 0; i < nb; i++) { + float min = std::numeric_limits::max(); + float max = std::numeric_limits::lowest(); + + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + + const float d = (max - min) / ((1 << 8) - 1); + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); + biases[i] = ov::float16(min); + + for (int j = 0; j < qk; ++j) { + const float x0 = (x[i * qk + j] - min) * id; + const uint8_t xi0 = roundf(x0); + weights[i * qk + j] = xi0; + } + } } diff --git a/ggml/src/ggml-openvino/ggml-quants.hpp b/ggml/src/ggml-openvino/ggml-quants.hpp index ae37b1618ed14..fbae2aa1f43ef 100644 --- a/ggml/src/ggml-openvino/ggml-quants.hpp +++ b/ggml/src/ggml-openvino/ggml-quants.hpp @@ -1,3 +1,4 @@ +#pragma once #include #include #include @@ -45,6 +46,15 @@ ov::Output make_int4_weights(ov::Tensor& weight, ov::Tensor& biases, size_t group_size = GGML_QUANTIZATION_GROUP_SIZE); +enum class ExtraQuantType { F16, Q4_0_C, Q8_1_C, Q4_0_128 }; + +std::shared_ptr requantize(const ggml_tensor* tensor, ExtraQuantType requant_type); + +void quantize_q4_0(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, + int64_t qk); +void quantize_q8_1(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, + int64_t qk); + namespace ov { namespace op { namespace util { diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index e49d941da4ab1..3f728c242dd41 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -130,11 +130,21 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c compile_end_time = conversion_end_time; } else { std::shared_ptr model; - std::set types_to_dequantize; + std::map types_to_requantize; if (is_static) { - types_to_dequantize = {GGML_TYPE_Q4_1, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K}; + types_to_requantize = { + {GGML_TYPE_Q4_0, ExtraQuantType::Q4_0_128}, + {GGML_TYPE_Q4_1, ExtraQuantType::Q4_0_128}, + {GGML_TYPE_Q4_K, ExtraQuantType::Q4_0_128}, + {GGML_TYPE_Q6_K, ExtraQuantType::Q8_1_C }, + }; + } else if (device == "GPU") { + types_to_requantize = { + // CVS-166739 + {GGML_TYPE_Q6_K, ExtraQuantType::Q8_1_C}, + }; } - auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph, types_to_dequantize); + auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph, types_to_requantize); if (is_static) { ggml_decoder = std::make_shared(cgraph, model_weights, is_static, true); From ef4de4dbe81d1a706d61bf6fb3869f936e7969b2 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 2 Sep 2025 14:52:04 +0800 Subject: [PATCH 14/30] Set m_is_static=false as default in decoder --- ggml/src/ggml-openvino/ggml-decoder.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index 24e1d92dcfd68..4ba147da20d62 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -161,7 +161,7 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { int m_head_size; int32_t* m_rope_params; std::vector m_kv_names; - bool m_is_static; + bool m_is_static = false; bool m_is_first_token; }; From 1b8323ff0bea8f14e91778b69bfbfca57c577f19 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 2 Sep 2025 14:53:09 +0800 Subject: [PATCH 15/30] Simpilfy translation of get_rows --- .../ggml-openvino/openvino/op/get_rows.cpp | 26 ++++++------------- 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-openvino/openvino/op/get_rows.cpp b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp index 0de77da59ffc5..5e4c7d901ac32 100644 --- a/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +++ b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp @@ -3,10 +3,7 @@ #include #include #include -#include -#include #include -#include #include "../node_context.hpp" #include "../op_table.hpp" @@ -31,22 +28,15 @@ OutputVector translate_get_rows(const NodeContext& context) { indices = process_view_input(context, 1); } - Output axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1}); - if (indices.get_partial_shape()[1].get_length() == 1) { - indices = - std::make_shared(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1})); - if (data.get_partial_shape().rank() == 2) { - axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); - } - res = std::make_shared(data, indices, axis); - if (data.get_partial_shape().rank() == 2) { - res = - std::make_shared(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); - } - } else { - indices = - std::make_shared(indices, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); + // data[b,x,y] ind[1,b,x'] test-backend-ops case + // data[x,y] ind[1,1,x'] normal case + indices = std::make_shared(indices, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); + if (data.get_partial_shape().rank() == 3) { + auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1}); res = std::make_shared(data, indices, axis, 1); + } else { + auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + res = std::make_shared(data, indices, axis); } if (res.get_element_type() != context.get_output_type(0)) { From 6da5a225236915de511c107caf5e0e198ebc93d9 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Mon, 8 Sep 2025 16:52:58 +0800 Subject: [PATCH 16/30] Fix after rebasing --- ggml/src/ggml-openvino/openvino/op/mulmat.cpp | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-openvino/openvino/op/mulmat.cpp b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp index bfccc28163522..b4103378ebb1b 100644 --- a/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +++ b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp @@ -41,13 +41,8 @@ OutputVector translate_mulmat(const NodeContext& context) { B = process_view_input(context, 0); A = process_view_input(context, 1); } - - bool convert_out_type = false; - if (ov::op::util::is_constant(B.get_node()) && context.get_input_type(0) != context.get_input_type(1)) { - B = std::make_shared(B, context.get_input_type(1)); - } else if (context.get_input_type(0) != context.get_input_type(1)) { - A = std::make_shared(A, context.get_input_type(0)); - convert_out_type = true; + if (A.get_element_type() != B.get_element_type()) { + B = std::make_shared(context.get_input(0), context.get_input_type(1)); } auto B_shape = context.get_input_shape(0).to_shape(); @@ -82,12 +77,7 @@ OutputVector translate_mulmat(const NodeContext& context) { A = Z; } - if (convert_out_type) { - auto result_lp = std::make_shared(A, B, false, transpose_b); - res = std::make_shared(result_lp, context.get_output_type(0)); - } else { - res = std::make_shared(A, B, false, transpose_b); - } + res = std::make_shared(A, B, false, transpose_b); return rename_outputs_with_suffix({res}, context.get_name()); } From e37cdd4878125938f8908b3f84dc7e6402e0f8d7 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Wed, 10 Sep 2025 15:38:15 +0800 Subject: [PATCH 17/30] Improve debug util; Eliminate nop ReshapeReshape --- ggml/src/ggml-openvino/ggml-decoder.cpp | 27 +++++---- .../src/ggml-openvino/openvino/op/reshape.cpp | 7 ++- ggml/src/ggml-openvino/utils.cpp | 55 +++++++++++++++---- 3 files changed, 65 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index d00b78e891ee0..0dfc11e4905f6 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -154,22 +154,22 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) { // Add model outputs, if called for the whole graph if (naive) { - m_model_output_names.push_back(node->name); + m_model_output_names.push_back(node_name); } else if (!m_node) { + // Model outputs are tensors with GGML_TENSOR_FLAG_OUTPUT flag and kv_caches static std::set debug_output_names = {}; // Workaround: the final tensor "result_output" does not have GGML_TENSOR_FLAG_OUTPUT flag set in cgraph - if (node->buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY || node->flags & GGML_TENSOR_FLAG_OUTPUT || - std::string(node->name).find("result") == 0 || debug_output_names.count(node->name)) { - auto name = node->view_src ? std::string(node->view_src->name) : std::string(node->name); - if (node->buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY) { - assert(name.find("cache_k") == 0 || name.find("cache_v") == 0); + if (node->op == GGML_OP_SET_ROWS || node->flags & GGML_TENSOR_FLAG_OUTPUT || node_name.find("result") == 0 || + debug_output_names.count(node_name)) { + if (node->op == GGML_OP_SET_ROWS) { + assert(node_name.find("cache_k") == 0 || node_name.find("cache_v") == 0); + if (auto it = std::find(m_kv_names.begin(), m_kv_names.end(), node_name); it == m_kv_names.end()) { + m_kv_names.push_back(node_name); + } } - if (auto it = std::find(m_model_output_names.begin(), m_model_output_names.end(), name); + if (auto it = std::find(m_model_output_names.begin(), m_model_output_names.end(), node_name); it == m_model_output_names.end()) { - m_model_output_names.push_back(name); - } - if (auto it = std::find(m_kv_names.begin(), m_kv_names.end(), name); it == m_kv_names.end()) { - m_kv_names.push_back(name); + m_model_output_names.push_back(node_name); } } } @@ -177,7 +177,10 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) { if (m_node) { switch (node->op) { case GGML_OP_RESHAPE: { - if (node->ne[0] * node->ne[1] == node->src[0]->ne[0]) { + if (node->src[0]->op == GGML_OP_RESHAPE && node->src[0]->src[0]->ne[0] == node->ne[0] && + node->src[0]->src[0]->ne[1] == node->ne[1]) { + m_op_case = 4; + } else if (node->ne[0] * node->ne[1] == node->src[0]->ne[0]) { m_op_case = 1; } else if (node->src[0]->ne[0] * node->src[0]->ne[1] == node->ne[0]) { m_op_case = 2; diff --git a/ggml/src/ggml-openvino/openvino/op/reshape.cpp b/ggml/src/ggml-openvino/openvino/op/reshape.cpp index 4ef3833c90252..1ed6f4b880b0a 100644 --- a/ggml/src/ggml-openvino/openvino/op/reshape.cpp +++ b/ggml/src/ggml-openvino/openvino/op/reshape.cpp @@ -23,7 +23,8 @@ OutputVector translate_reshape(const NodeContext& context) { } int op_case = context.get_op_case(); - FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported RESHAPE case"); + FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3 || op_case == 4, + "Unsupported RESHAPE case"); auto output_shape = context.get_output_shape(0).to_shape(); std::shared_ptr new_shape_node; @@ -37,9 +38,11 @@ OutputVector translate_reshape(const NodeContext& context) { ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector{(int64_t)output_shape[0], -1, (int64_t)output_shape[2]}); - } else { + } else if (op_case == 3) { new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector{(int64_t) output_shape[0], -1, 1}); + } else if (op_case == 4) { + return {context.get_input(0).get_node_shared_ptr()->input_value(0)}; } auto res = std::make_shared(context.get_input(0), new_shape_node, false); return rename_outputs_with_suffix({res}, context.get_name()); diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 3f728c242dd41..588404df19d07 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include #include #include @@ -418,17 +420,50 @@ void print_output_tensor_info(const std::string& name, const ov::Tensor& tensor, std::map& output_dst) { std::cout << "Output name: " << name << ", Output shape: " << tensor.get_shape() << ", Address: " << output_dst[name] << std::endl; + + auto print_float_stats = [](const std::string& type_name, size_t size, auto get_value) { + if (size == 0) { + return; + } + + float first = get_value(0); + float min = first; + float max = first; + double sum = first; + + for (size_t i = 1; i < size; ++i) { + float v = get_value(i); + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + sum += v; + } + double mean = sum / size; + + std::cout << std::right << std::setw(6) << type_name << std::right << std::setw(12) << "First" << std::setw(12) + << "Min" << std::setw(12) << "Max" << std::setw(12) << "Mean" << std::endl; + std::cout << std::right << std::setw(6) << "" << std::right << std::setw(12) << first << std::setw(12) << min + << std::setw(12) << max << std::setw(12) << mean << std::endl; + }; + switch (tensor.get_element_type()) { - case ov::element::f32: - std::cout << *(tensor.data()) << std::endl; - std::cout << checksum(tensor.data(), tensor.get_byte_size()) << std::endl; - break; - case ov::element::f16: - std::cout << *(tensor.data()) << std::endl; - std::cout << checksum(tensor.data(), tensor.get_byte_size()) << std::endl; - break; - default: - break; + case ov::element::f32: { + const float* data = tensor.data(); + size_t size = tensor.get_size(); + print_float_stats("[f32]", size, [data](size_t i) { return data[i]; }); + break; + } + case ov::element::f16: { + const ov::float16* data = tensor.data(); + size_t size = tensor.get_size(); + print_float_stats("[f16]", size, [data](size_t i) { return static_cast(data[i]); }); + break; + } + default: + break; } } From b453d68dce31bdb83e4160a3d30580c0a14fd2db Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Wed, 10 Sep 2025 16:54:57 +0800 Subject: [PATCH 18/30] STYLE: make get_types_to_requant a function --- ggml/src/ggml-openvino/utils.cpp | 33 +++++++++++++++++--------------- ggml/src/ggml-openvino/utils.h | 2 ++ 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 588404df19d07..2438f2dd1191b 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -132,21 +132,7 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c compile_end_time = conversion_end_time; } else { std::shared_ptr model; - std::map types_to_requantize; - if (is_static) { - types_to_requantize = { - {GGML_TYPE_Q4_0, ExtraQuantType::Q4_0_128}, - {GGML_TYPE_Q4_1, ExtraQuantType::Q4_0_128}, - {GGML_TYPE_Q4_K, ExtraQuantType::Q4_0_128}, - {GGML_TYPE_Q6_K, ExtraQuantType::Q8_1_C }, - }; - } else if (device == "GPU") { - types_to_requantize = { - // CVS-166739 - {GGML_TYPE_Q6_K, ExtraQuantType::Q8_1_C}, - }; - } - auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph, types_to_requantize); + auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph, get_types_to_requant(device)); if (is_static) { ggml_decoder = std::make_shared(cgraph, model_weights, is_static, true); @@ -275,6 +261,23 @@ ov::AnyMap get_npu_prefill_config() { return config; } +std::map get_types_to_requant(const std::string& device) { + if (device == "NPU") { + return { + {GGML_TYPE_Q4_0, ExtraQuantType::Q4_0_128}, + {GGML_TYPE_Q4_1, ExtraQuantType::Q4_0_128}, + {GGML_TYPE_Q4_K, ExtraQuantType::Q4_0_128}, + {GGML_TYPE_Q6_K, ExtraQuantType::Q8_1_C }, + }; + } + if (device == "GPU") { + return { + // CVS-166739 + {GGML_TYPE_Q6_K, ExtraQuantType::Q8_1_C}, + }; + } +} + ov::AnyMap get_npu_generate_config() { ov::AnyMap config = get_npu_prefill_config(); config.emplace("NPUW_UNFOLD_IREQS", "YES"); diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h index f377fe9d2735d..42686c593b3ce 100644 --- a/ggml/src/ggml-openvino/utils.h +++ b/ggml/src/ggml-openvino/utils.h @@ -43,6 +43,8 @@ bool is_prefill(struct ggml_cgraph * cgraph); ov::AnyMap get_npu_prefill_config(); ov::AnyMap get_npu_generate_config(); +std::map get_types_to_requant(const std::string& device); + ov::Tensor get_ov_input_tensor(std::shared_ptr ggml_decoder, const std::string& param_name); bool is_naive(struct ggml_cgraph* cgraph); From 4ed251049432701e4610abc2ab9326ea0504795a Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 11 Sep 2025 14:34:17 +0800 Subject: [PATCH 19/30] Support BF16 model --- ggml/src/ggml-openvino/ggml-decoder.cpp | 10 ++++++++-- ggml/src/ggml-openvino/utils.cpp | 1 + 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 0dfc11e4905f6..0bdb9aa8971f6 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -419,8 +419,14 @@ std::map> GgmlOvDecoder::create_weight_no std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor, std::optional requant_type) { - std::set weight_types = { - GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K}; + std::set weight_types = {GGML_TYPE_F32, + GGML_TYPE_F16, + GGML_TYPE_BF16, + GGML_TYPE_Q8_0, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + GGML_TYPE_Q4_K, + GGML_TYPE_Q6_K}; if (weight_types.find(tensor->type) == weight_types.end()) { throw std::runtime_error("Unexpected weight tensor type: " + std::string(tensor->name) + " with type " + ggml_type_name(tensor->type)); diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 2438f2dd1191b..cf0a02c3ad671 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -276,6 +276,7 @@ std::map get_types_to_requant(const std::string& devi {GGML_TYPE_Q6_K, ExtraQuantType::Q8_1_C}, }; } + return {}; } ov::AnyMap get_npu_generate_config() { From 3eeb567e391a51ce48458a82d8828c6b1653d802 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 12 Sep 2025 11:42:02 +0800 Subject: [PATCH 20/30] Fix NPU compile --- ggml/src/ggml-openvino/utils.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index cf0a02c3ad671..c03ec1acbcf53 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -251,7 +251,6 @@ ov::AnyMap get_npu_prefill_config() { {"NPUW_DEVICES", "NPU" }, {"NPUW_FOLD", "YES" }, {"NPUW_WEIGHTS_BANK", "shared" }, - {"NPUW_SLICE_OUT", "YES" }, {"NPUW_FUNCALL_ASYNC", "YES" }, {"NPUW_FUNCALL_FOR_ALL", "YES" }, {"NPUW_DQ", "YES" }, From 3346a3395531304feecb65a7fa60916cf34d6a9e Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 12 Sep 2025 16:32:41 +0800 Subject: [PATCH 21/30] WA for npu 1st token acc issue --- ggml/src/ggml-openvino/utils.cpp | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index c03ec1acbcf53..7b696769fba84 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -218,7 +218,7 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c auto gguf_tensor_addrs = get_ggml_graph_output_dst(ggml_decoder); for (size_t i = 0; i < ov_output_names.size(); i++) { - auto result_name = ov_output_names[i]; + auto& result_name = ov_output_names[i]; const auto output_tensor = infer_request.get_output_tensor(i); std::memcpy(gguf_tensor_addrs[result_name], output_tensor.data(), output_tensor.get_byte_size()); @@ -243,20 +243,34 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c GGML_UNUSED(backend); } -ov::AnyMap get_npu_prefill_config() { - ov::AnyMap config = { +namespace { +ov::AnyMap get_npu_base_config() { + return { {"NPU_COMPILATION_MODE_PARAMS", "compute-layers-with-higher-precision=Sqrt,Power,ReduceMean,Add_RMSNorm" }, {"NPU_COMPILER_DYNAMIC_QUANTIZATION", "YES" }, {"NPU_USE_NPUW", "YES" }, {"NPUW_DEVICES", "NPU" }, {"NPUW_FOLD", "YES" }, {"NPUW_WEIGHTS_BANK", "shared" }, - {"NPUW_FUNCALL_ASYNC", "YES" }, {"NPUW_FUNCALL_FOR_ALL", "YES" }, {"NPUW_DQ", "YES" }, {"NPUW_DQ_FULL", "NO" }, {"NPUW_CACHE_DIR", getenv("GGML_OPENVINO_CACHE_DIR") ? getenv("GGML_OPENVINO_CACHE_DIR") : ""}, }; +} +} // namespace + +ov::AnyMap get_npu_prefill_config() { + auto config = get_npu_base_config(); + config.emplace("NPUW_FUNCALL_ASYNC", "NO"); + config.emplace("NPUW_ACC_CHECK", "YES"); + config.emplace("NPUW_ACC_DEVICE", "CPU"); + return config; +} + +ov::AnyMap get_npu_generate_config() { + auto config = get_npu_base_config(); + config.emplace("NPUW_FUNCALL_ASYNC", "YES"); return config; } @@ -266,7 +280,7 @@ std::map get_types_to_requant(const std::string& devi {GGML_TYPE_Q4_0, ExtraQuantType::Q4_0_128}, {GGML_TYPE_Q4_1, ExtraQuantType::Q4_0_128}, {GGML_TYPE_Q4_K, ExtraQuantType::Q4_0_128}, - {GGML_TYPE_Q6_K, ExtraQuantType::Q8_1_C }, + {GGML_TYPE_Q6_K, ExtraQuantType::F16 }, }; } if (device == "GPU") { @@ -278,12 +292,6 @@ std::map get_types_to_requant(const std::string& devi return {}; } -ov::AnyMap get_npu_generate_config() { - ov::AnyMap config = get_npu_prefill_config(); - config.emplace("NPUW_UNFOLD_IREQS", "YES"); - return config; -} - bool is_naive(struct ggml_cgraph* cgraph) { constexpr int naive_graph_size_threshold = 20; return cgraph->n_nodes < naive_graph_size_threshold; @@ -373,7 +381,7 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr ggml_decoder, cons } else if (const auto* op = ggml_decoder->get_tensor_used_op(ggml_decoder->get_tensor_from_name(param_name)); op && op->op == GGML_OP_SET_ROWS && is_static && is_first_token) { - input_tensor = ov::Tensor(ov::element::i64, ov::Shape{1}); + input_tensor = ov::Tensor(ov::element::i64, ov::Shape{1, 1, 1}); } else { input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name); } From fa237a10dc453825aa243674522d24c1e9c6e439 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 12 Sep 2025 16:51:46 +0800 Subject: [PATCH 22/30] Apply EliminateZP only for npu --- ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp | 1 + ggml/src/ggml-openvino/openvino/translate_session.cpp | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp index c36579910d48c..f38c0837d1374 100644 --- a/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +++ b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp @@ -19,6 +19,7 @@ namespace ggml { namespace pass { FuseToSDPA::FuseToSDPA() { + // Not maintained since FLASH_ATTN_EXT has replaced this pattern const auto m_k = ov::pass::pattern::any_input(); const auto m_q = ov::pass::pattern::any_input(); const auto m_qk = ov::pass::pattern::wrap_type({m_q, m_k}); diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index 634fea40e923f..3b8c30361a5e8 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -27,7 +27,6 @@ #include "ggml-openvino/openvino/utils.hpp" #include "input_model.hpp" #include "pass/eliminate_zp.hpp" -#include "pass/fuse_to_sdpa.hpp" #include "pass/mark_decompression_convert_constant_folding.hpp" namespace ov { @@ -220,8 +219,9 @@ std::shared_ptr TranslateSession::apply_transformations(std::shared_ptr(kv_param_res_pairs); } - manager.register_pass(); - manager.register_pass(); + if (ggml_model_decoder->is_static()) { + manager.register_pass(); + } manager.run_passes(model); } return model; From 9a85e532a5f6b18fa0745eb384280803558878f3 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Mon, 15 Sep 2025 11:13:59 +0800 Subject: [PATCH 23/30] Add GeGLU --- ggml/src/ggml-openvino/ggml-openvino.cpp | 37 ++++++++++---- .../ggml-openvino/openvino/op/glu_geglu.cpp | 50 +++++++++++++++++++ .../ggml-openvino/openvino/op/glu_swiglu.cpp | 7 +++ ggml/src/ggml-openvino/openvino/op_table.cpp | 1 + ggml/src/ggml-openvino/openvino/op_table.hpp | 1 + 5 files changed, 85 insertions(+), 11 deletions(-) create mode 100644 ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 60a2eb388ea1e..6da653716f7ed 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -249,17 +249,30 @@ static bool is_op_unsupported_case(const ggml_tensor* op) { const auto* op_params = op->op_params; memcpy(&scale, (const float*) op_params + 0, sizeof(float)); memcpy(&max_bias, (const float*) op_params + 1, sizeof(float)); - const uint32_t h = op->src[0]->ne[2]; - const uint32_t n_head = op->src[0]->ne[0]; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - const float slope = - (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + if (max_bias > 0) { + GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with max_bias > 0\n"); + return true; + } + } - if (slope != 1.0f) { - GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with slope != 1.0f\n"); + if (op->op == GGML_OP_FLASH_ATTN_EXT) { + if (op->src[4] != nullptr) { + GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with sinks\n"); + return true; + } + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + const auto* op_params = op->op_params; + memcpy(&scale, (const float*) op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float*) op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float*) op_params + 2, sizeof(float)); + if (max_bias > 0) { + GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with max_bias > 0\n"); + return true; + } + if (logit_softcap != 0) { + GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with logit_softcap != 0\n"); return true; } } @@ -357,7 +370,8 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con GGML_OP_ROPE, GGML_OP_RMS_NORM, GGML_OP_SCALE, - GGML_OP_SOFT_MAX, + // softmax is not updated due to replaced by flash_attn_ext + // GGML_OP_SOFT_MAX, GGML_OP_SET_ROWS, GGML_OP_FLASH_ATTN_EXT, GGML_OP_CPY}; @@ -366,6 +380,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con }; static const std::set supported_glu_ops{ GGML_GLU_OP_SWIGLU, + GGML_GLU_OP_GEGLU, }; switch (op->op) { diff --git a/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp b/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp new file mode 100644 index 0000000000000..4295bf7517c3c --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp @@ -0,0 +1,50 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_glu_geglu(const NodeContext& context) { + num_inputs_check(context, 1, 2); + + ov::Output src0; + ov::Output src1; + if (context.get_input_size() == 2) { + src0 = context.get_input(0); + src1 = context.get_input(1); + } else { + auto combined = context.get_input(0); + auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {2}); + auto split = std::make_shared(combined, split_axis, 2); + src0 = split->output(0); + src1 = split->output(1); + } + + int32_t* params = context.get_output_op_params(0); + const int32_t swapped = params[1]; + if (swapped) { + std::swap(src0, src1); + } + + auto gelu = std::make_shared(src0); + auto res = std::make_shared(gelu, src1); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp b/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp index 138ef650901fd..bef42fe4b70c0 100644 --- a/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +++ b/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp @@ -31,6 +31,13 @@ OutputVector translate_glu_swiglu(const NodeContext& context) { src0 = split->output(0); src1 = split->output(1); } + + int32_t* params = context.get_output_op_params(0); + const int32_t swapped = params[1]; + if (swapped) { + std::swap(src0, src1); + } + auto sigmoid = std::make_shared(src0); auto silu = std::make_shared(src0, sigmoid); auto res = std::make_shared(silu, src1); diff --git a/ggml/src/ggml-openvino/openvino/op_table.cpp b/ggml/src/ggml-openvino/openvino/op_table.cpp index ee55f84b96f80..e36e8f17cc94e 100644 --- a/ggml/src/ggml-openvino/openvino/op_table.cpp +++ b/ggml/src/ggml-openvino/openvino/op_table.cpp @@ -34,6 +34,7 @@ std::unordered_map get_supported_ops() { {"GGML_UNARY_OP_SILU", op::translate_unary_silu }, {"GGML_OP_VIEW", op::translate_view }, {"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu }, + {"GGML_GLU_OP_GEGLU", op::translate_glu_geglu }, {"GGML_OP_SET_ROWS", op::translate_set_rows }, {"GGML_OP_CPY", op::translate_cpy }, {"GGML_OP_FLASH_ATTN_EXT", op::translate_flash_attn_ext }, diff --git a/ggml/src/ggml-openvino/openvino/op_table.hpp b/ggml/src/ggml-openvino/openvino/op_table.hpp index faa61f5f6c8d2..5d4f0538604d1 100644 --- a/ggml/src/ggml-openvino/openvino/op_table.hpp +++ b/ggml/src/ggml-openvino/openvino/op_table.hpp @@ -25,6 +25,7 @@ GGML_OP_CONVERTER(translate_soft_max); GGML_OP_CONVERTER(translate_transpose); GGML_OP_CONVERTER(translate_view); GGML_OP_CONVERTER(translate_glu_swiglu); +GGML_OP_CONVERTER(translate_glu_geglu); GGML_OP_CONVERTER(translate_set_rows); GGML_OP_CONVERTER(translate_cpy); GGML_OP_CONVERTER(translate_flash_attn_ext); From 3d31fa6b3ad8ee1c403ac899e581dc72f21710d6 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Mon, 15 Sep 2025 15:56:03 +0800 Subject: [PATCH 24/30] Fix Hunyuan --- ggml/src/ggml-openvino/ggml-decoder.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 0bdb9aa8971f6..bc528e0cfb5a6 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -242,14 +242,17 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) { void GgmlOvDecoder::set_llm_params() { for (int i = 0; i < m_cgraph->n_nodes; i++) { auto* node = m_cgraph->nodes[i]; + std::string name = std::string(node->name); if (node->op == GGML_OP_VIEW && std::string(node->name) == "cache_k_l0 (view)") { auto* cache_k = node->src[0]; m_context_size = cache_k->ne[1]; - } else if (node->op == GGML_OP_ROPE && std::string(node->name) == "Qcur-0") { + } else if (node->op == GGML_OP_ROPE && + (name.find("Qcur-0") == 0 || std::string(node->src[0]->name).find("Qcur-0") == 0)) { m_head_size = node->ne[0]; m_num_heads = node->ne[1]; m_rope_params = node->op_params; - } else if (node->op == GGML_OP_ROPE && std::string(node->name) == "Kcur-0") { + } else if (node->op == GGML_OP_ROPE && + (name.find("Kcur-0") == 0 || std::string(node->src[0]->name).find("Kcur-0") == 0)) { m_num_heads_kv = node->ne[1]; } } From 500aeadc635529599b20aff49fd58174a443c852 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 16 Sep 2025 16:30:45 +0800 Subject: [PATCH 25/30] Support iSWA --- ggml/src/ggml-openvino/ggml-decoder.cpp | 103 ++++++++++++------ ggml/src/ggml-openvino/ggml-decoder.h | 13 ++- ggml/src/ggml-openvino/openvino/decoder.hpp | 2 + .../ggml-openvino/openvino/node_context.hpp | 13 +-- .../openvino/op/flash_attn_ext.cpp | 9 +- .../src/ggml-openvino/openvino/op/permute.cpp | 38 ++----- .../openvino/translate_session.cpp | 21 +++- ggml/src/ggml-openvino/utils.cpp | 2 +- src/llama-graph.cpp | 2 + 9 files changed, 124 insertions(+), 79 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index bc528e0cfb5a6..e3dd5e0c1dd91 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -30,17 +30,21 @@ #include #include #include +#include #include "ggml-backend-impl.h" #include "ggml-backend.h" #include "ggml-quants.hpp" GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token, - int context_size, int num_heads, int num_heads_kv, int head_size) : + int context_size, int context_size_swa, int num_heads, int num_heads_kv, int head_size, + const std::vector& swa_layers) : m_cgraph(cgraph), m_node(node), m_op_name(std::string(node->name)), m_context_size(context_size), + m_context_size_swa(context_size_swa), + m_swa_layers(swa_layers), m_num_heads(num_heads), m_num_heads_kv(num_heads_kv), m_head_size(head_size), @@ -204,11 +208,14 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) { if (node->src[0]->op != GGML_OP_VIEW) { m_op_case = 1; } else if (ggml_is_contiguous(node->src[0])) { - // Permute cache_k (view) - m_op_case = 2; - } else { - // Permute cache_v (view), deprecated, cache_v will also fall to case 2 - m_op_case = 3; + // Permute kv cache (view) + std::string src_name(node->view_src->name); + int layer = extract_layer_from_name(src_name); + if (!is_swa_layer(layer)) { + m_op_case = 2; + } else { + m_op_case = 3; + } } break; } @@ -239,13 +246,34 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) { } } +int extract_layer_from_name(const std::string& name) { + size_t pos1 = name.find("_l"); + assert(pos1 != std::string::npos); + pos1 += 2; + size_t pos2 = name.find(' ', pos1); + if (pos2 == std::string::npos) { + pos2 = name.length(); + } + std::string layer_str = name.substr(pos1, pos2 - pos1); + int layer = std::stoi(layer_str); + return layer; +} + void GgmlOvDecoder::set_llm_params() { for (int i = 0; i < m_cgraph->n_nodes; i++) { auto* node = m_cgraph->nodes[i]; std::string name = std::string(node->name); - if (node->op == GGML_OP_VIEW && std::string(node->name) == "cache_k_l0 (view)") { - auto* cache_k = node->src[0]; - m_context_size = cache_k->ne[1]; + if (node->op == GGML_OP_FLASH_ATTN_EXT) { + auto* cache_k = node->src[1]; + cache_k = cache_k->view_src ? cache_k->view_src : cache_k; + int layer = extract_layer_from_name(cache_k->name); + + if (std::string(node->src[3]->name).find("swa") != std::string::npos) { + m_swa_layers.push_back(layer); + m_context_size_swa = cache_k->ne[1]; + } else { + m_context_size = cache_k->ne[1]; + } } else if (node->op == GGML_OP_ROPE && (name.find("Qcur-0") == 0 || std::string(node->src[0]->name).find("Qcur-0") == 0)) { m_head_size = node->ne[0]; @@ -269,11 +297,11 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co input_shape = ov::PartialShape{1, 1, 1}; } } else { - input_shape = ov::PartialShape{1, 1, ov::Dimension(1, m_context_size)}; + input_shape = ov::PartialShape{1, 1, -1}; } } else if (name == "inp_out_ids" && !m_is_static) { - input_shape = ov::PartialShape{1, 1, ov::Dimension(1, m_context_size)}; - } else if (name == "KQ_mask") { + input_shape = ov::PartialShape{1, 1, -1}; + } else if (name.find("KQ_mask") == 0) { if (m_is_static) { if (m_is_first_token) { input_shape = ov::PartialShape{1, m_context_size, m_context_size}; @@ -281,13 +309,12 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co input_shape = ov::PartialShape{1, 1, m_context_size}; } } else { - auto max_mask_size = GGML_PAD(m_context_size, GGML_KQ_MASK_PAD); - input_shape = ov::PartialShape{1, ov::Dimension(1, max_mask_size), ov::Dimension(1, max_mask_size)}; + input_shape = ov::PartialShape{1, -1, -1}; } - } else if (name.find("cache_k") == 0) { - input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size}; - } else if (name.find("cache_v") == 0) { - input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size}; + } else if (name.find("cache_") == 0) { + int layer = extract_layer_from_name(name); + bool is_swa = is_swa_layer(layer); + input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size}; } else if (const auto* op = get_tensor_used_op(src); op && op->op == GGML_OP_SET_ROWS) { input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1}; } else if (src->op == GGML_OP_VIEW) { @@ -305,35 +332,35 @@ void GgmlOvDecoder::add_extra_inputs() { // see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding. // Not used for NPU int64_t attention_size = -1; + int64_t attention_size_swa = -1; for (const auto& node : m_nodes) { - if (node->op == GGML_OP_SOFT_MAX) { - auto* mask = node->src[1]; - if (std::string(mask->name).find("KQ_mask") != 0) { - throw std::runtime_error("Unexpected softmax node: " + std::string(mask->name)); - } - attention_size = mask->ne[0]; - break; - } if (node->op == GGML_OP_FLASH_ATTN_EXT) { auto* mask = node->src[3]; - if (std::string(mask->name).find("KQ_mask") != 0) { + std::string mask_name(mask->name); + if (mask_name.find("KQ_mask") != 0) { throw std::runtime_error("Unexpected flash attention node: " + std::string(mask->name)); } - attention_size = mask->ne[0]; + if (mask_name.find("swa") != std::string::npos) { + attention_size_swa = mask->ne[0]; + } else { + attention_size = mask->ne[0]; + } } } - { - std::string name = "attention_size"; + auto create_attention_size_input = [this](const std::string& name, int64_t size) { auto param_node = std::make_shared(ov::element::i64, ov::Shape{1}); param_node->set_friendly_name(name); param_node->output(0).get_tensor().set_names({name}); m_model_extra_inputs[name] = param_node; auto tensor = std::make_shared(ov::element::i64, ov::Shape{1}); - *tensor->data() = attention_size; + *tensor->data() = size; m_model_extra_input_values[name] = tensor; - } + }; + + create_attention_size_input("attention_size", attention_size); + create_attention_size_input("attention_size_swa", attention_size_swa); } const ggml_tensor* GgmlOvDecoder::get_tensor_used_op(const ggml_tensor* tensor) const { @@ -706,8 +733,16 @@ int32_t* GgmlOvDecoder::get_output_op_params(const std::string& name) const { void GgmlOvDecoder::visit_subgraph(std::function)> node_visitor) const { for (const auto& node : m_nodes) { - auto decoder = std::make_shared( - node, m_cgraph, m_is_static, m_is_first_token, m_context_size, m_num_heads, m_num_heads_kv, m_head_size); + auto decoder = std::make_shared(node, + m_cgraph, + m_is_static, + m_is_first_token, + m_context_size, + m_context_size_swa, + m_num_heads, + m_num_heads_kv, + m_head_size, + m_swa_layers); node_visitor(decoder); } } diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index 4ba147da20d62..35e79ecefc724 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -19,7 +19,8 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { // Node decoder, called in GgmlOvDecoder::visit_subgraph GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token, - int context_size, int num_heads, int num_heads_kv, int head_size); + int context_size, int context_size_swa, int num_heads, int num_heads_kv, int head_size, + const std::vector& swa_layers); // Naive graph decoder GgmlOvDecoder(struct ggml_cgraph* cgraph, std::map>& model_weights); @@ -101,6 +102,12 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { virtual int get_context_size() const override { return m_context_size; } + virtual int get_context_size_swa() const override { return m_context_size_swa; } + + virtual int is_swa_layer(int layer) const override { + return std::find(m_swa_layers.begin(), m_swa_layers.end(), layer) != m_swa_layers.end(); + } + virtual int get_num_heads() const override { return m_num_heads; } virtual int get_num_heads_kv() const override { return m_num_heads_kv; } @@ -156,6 +163,8 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { std::map> m_model_weights; std::vector m_model_output_names; int m_context_size; + int m_context_size_swa; + std::vector m_swa_layers; int m_num_heads; int m_num_heads_kv; int m_head_size; @@ -166,3 +175,5 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { }; void print_tensor_address_map(const struct ggml_cgraph* cgraph); + +int extract_layer_from_name(const std::string& name); diff --git a/ggml/src/ggml-openvino/openvino/decoder.hpp b/ggml/src/ggml-openvino/openvino/decoder.hpp index a3387ba3947a2..6f11ff1283e37 100644 --- a/ggml/src/ggml-openvino/openvino/decoder.hpp +++ b/ggml/src/ggml-openvino/openvino/decoder.hpp @@ -67,6 +67,8 @@ class GgmlDecoder : public DecoderBase { virtual bool is_static() const = 0; virtual bool is_first_token() const = 0; virtual int get_context_size() const = 0; + virtual int get_context_size_swa() const = 0; + virtual int is_swa_layer(int layer) const = 0; }; } // namespace ggml diff --git a/ggml/src/ggml-openvino/openvino/node_context.hpp b/ggml/src/ggml-openvino/openvino/node_context.hpp index cc1b5c03329c9..a64ae098ab3e9 100644 --- a/ggml/src/ggml-openvino/openvino/node_context.hpp +++ b/ggml/src/ggml-openvino/openvino/node_context.hpp @@ -2,6 +2,7 @@ #include #include +#include #include "decoder.hpp" @@ -30,6 +31,8 @@ class NodeContext : public frontend::NodeContext { return m_translate_session; } + const std::vector& get_input_names() const { return m_input_names; } + size_t get_input_size() const override { return m_decoder->get_input_size(); } @@ -101,15 +104,7 @@ class NodeContext : public frontend::NodeContext { return m_decoder->is_first_token(); } - int get_num_heads() const { return m_decoder->get_num_heads(); } - - int get_num_heads_kv() const { return m_decoder->get_num_heads_kv(); } - - int get_head_size() const { return m_decoder->get_head_size(); } - - int get_context_size() const { return m_decoder->get_context_size(); } - - private: +private: std::shared_ptr m_decoder; std::shared_ptr& m_tensor_map; TranslateSession* m_translate_session; diff --git a/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp b/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp index d97603d98a941..8b67778fb9373 100644 --- a/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +++ b/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include "../node_context.hpp" #include "../op_table.hpp" @@ -32,8 +33,12 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) { auto scale_node = std::make_shared(ov::element::f16, ov::Shape{}, std::vector{scale}); ov::Output mask_sliced; - if (context.has_input("KQ_mask_sliced")) { - mask_sliced = context.get_input("KQ_mask_sliced"); + std::string mask_name = "KQ_mask_sliced"; + if (context.get_input_names()[3].find("swa") != std::string::npos) { + mask_name = "KQ_mask_swa_sliced"; + } + if (context.has_input(mask_name)) { + mask_sliced = context.get_input(mask_name); } else { auto token_len = get_dimensions(q, {1}); auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); diff --git a/ggml/src/ggml-openvino/openvino/op/permute.cpp b/ggml/src/ggml-openvino/openvino/op/permute.cpp index fcb091016a4f1..086b1e4cdb172 100644 --- a/ggml/src/ggml-openvino/openvino/op/permute.cpp +++ b/ggml/src/ggml-openvino/openvino/op/permute.cpp @@ -29,43 +29,29 @@ OutputVector translate_permute(const NodeContext& context) { ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2})); } else { auto src = context.get_input(0); - auto attention_size = context.get_input("attention_size"); + Output attention_size; if (context.is_static()) { attention_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {INT_MAX}); + } else if (op_case == 2) { + attention_size = context.get_input("attention_size"); + } else { + attention_size = context.get_input("attention_size_swa"); } auto src_shape_ = context.get_input_shape(0).to_shape(); std::vector src_shape(src_shape_.begin(), src_shape_.end()); - std::shared_ptr src_reshaped; - if (op_case == 2) { - src_reshaped = std::make_shared( - src, - ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector{-1, src_shape[1], src_shape[2]}), - false); - } else { - src_reshaped = std::make_shared( - src, - ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector{src_shape[1], src_shape[0], -1}), - false); - } + auto src_reshaped = std::make_shared( + src, + ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector{-1, src_shape[1], src_shape[2]}), + false); auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); - auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); - std::shared_ptr slice_axis; - if (op_case == 2) { - slice_axis = zero; - } else { - slice_axis = two; - } - auto src_slice = std::make_shared(src_reshaped, zero, attention_size, one, slice_axis); + auto src_slice = std::make_shared(src_reshaped, zero, attention_size, one, zero); - if (op_case == 2) { - res = std::make_shared(src_slice, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2})); - } else { - res = src_slice; - } + res = std::make_shared(src_slice, + ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2})); } return rename_outputs_with_suffix({res}, context.get_name()); } diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index 3b8c30361a5e8..9c82fe5f850a0 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -78,13 +78,22 @@ void add_token_len(TensorMap& tensor_map) { } void add_sliced_mask(TensorMap& tensor_map) { - auto mask = tensor_map.at("KQ_mask").get_node_shared_ptr(); auto token_len = tensor_map.at("token_len").get_node_shared_ptr(); - auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); - auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); - std::shared_ptr mask_sliced = std::make_shared(mask, zero, token_len, one, one); - mask_sliced->set_friendly_name("KQ_mask_sliced"); - tensor_map.insert({"KQ_mask_sliced", mask_sliced->output(0)}); + + auto create_sliced_mask = [&](const std::string& mask_name, const std::string& sliced_name) { + if (tensor_map.find(mask_name) != tensor_map.end()) { + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto mask = tensor_map.at(mask_name).get_node_shared_ptr(); + std::shared_ptr mask_sliced = + std::make_shared(mask, zero, token_len, one, one); + mask_sliced->set_friendly_name(sliced_name); + tensor_map.insert({sliced_name, mask_sliced->output(0)}); + } + }; + + create_sliced_mask("KQ_mask", "KQ_mask_sliced"); + create_sliced_mask("KQ_mask_swa", "KQ_mask_swa_sliced"); } void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) { diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 7b696769fba84..8724404098c05 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -362,7 +362,7 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr ggml_decoder, cons input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name); } - } else if (param_name == "KQ_mask") { + } else if (param_name.find("KQ_mask") == 0) { size_t context_size = ggml_decoder->get_context_size(); const auto* input_tensor_ggml = ggml_decoder->get_input_ggml_tensor(param_name); if (is_first_token) { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 9ee31344b40a4..b2c7428d62a2c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1642,6 +1642,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); + ggml_set_name(inp->self_kq_mask, "KQ_mask"); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1656,6 +1657,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); + ggml_set_name(inp->self_kq_mask_swa, "KQ_mask_swa"); ggml_set_input(inp->self_kq_mask_swa); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; From a079242748ef359fbc5e5316a4a97ed0315fd0ce Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Wed, 17 Sep 2025 11:16:14 +0800 Subject: [PATCH 26/30] Fix NPU accuracy --- .../openvino/translate_session.cpp | 25 +++++++++++-------- ggml/src/ggml-openvino/utils.cpp | 5 +--- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index 9c82fe5f850a0..c37aa21602ff0 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -77,23 +77,28 @@ void add_token_len(TensorMap& tensor_map) { tensor_map.insert({"token_len", token_len->output(0)}); } -void add_sliced_mask(TensorMap& tensor_map) { +void add_sliced_mask(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) { auto token_len = tensor_map.at("token_len").get_node_shared_ptr(); - auto create_sliced_mask = [&](const std::string& mask_name, const std::string& sliced_name) { + auto create_sliced_mask = [&](const std::string& mask_name, const std::string& sliced_name, bool is_static) { if (tensor_map.find(mask_name) != tensor_map.end()) { - auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); - auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); auto mask = tensor_map.at(mask_name).get_node_shared_ptr(); - std::shared_ptr mask_sliced = - std::make_shared(mask, zero, token_len, one, one); - mask_sliced->set_friendly_name(sliced_name); + std::shared_ptr mask_sliced; + if (is_static) { + mask_sliced = mask; + } else { + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + mask_sliced = std::make_shared(mask, zero, token_len, one, one); + mask_sliced = std::make_shared(mask_sliced, ov::element::f16); + mask_sliced->set_friendly_name(sliced_name); + } tensor_map.insert({sliced_name, mask_sliced->output(0)}); } }; - create_sliced_mask("KQ_mask", "KQ_mask_sliced"); - create_sliced_mask("KQ_mask_swa", "KQ_mask_swa_sliced"); + create_sliced_mask("KQ_mask", "KQ_mask_sliced", ggml_model_decoder.is_static()); + create_sliced_mask("KQ_mask_swa", "KQ_mask_swa_sliced", ggml_model_decoder.is_static()); } void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) { @@ -117,7 +122,7 @@ void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) { // Create common patterns void preprocess(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) { add_token_len(tensor_map); - add_sliced_mask(tensor_map); + add_sliced_mask(tensor_map, ggml_model_decoder); add_rope_sin_cos(tensor_map, ggml_model_decoder); } diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 8724404098c05..db471636452ad 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -253,6 +253,7 @@ ov::AnyMap get_npu_base_config() { {"NPUW_FOLD", "YES" }, {"NPUW_WEIGHTS_BANK", "shared" }, {"NPUW_FUNCALL_FOR_ALL", "YES" }, + {"NPUW_FUNCALL_ASYNC", "YES" }, {"NPUW_DQ", "YES" }, {"NPUW_DQ_FULL", "NO" }, {"NPUW_CACHE_DIR", getenv("GGML_OPENVINO_CACHE_DIR") ? getenv("GGML_OPENVINO_CACHE_DIR") : ""}, @@ -262,15 +263,11 @@ ov::AnyMap get_npu_base_config() { ov::AnyMap get_npu_prefill_config() { auto config = get_npu_base_config(); - config.emplace("NPUW_FUNCALL_ASYNC", "NO"); - config.emplace("NPUW_ACC_CHECK", "YES"); - config.emplace("NPUW_ACC_DEVICE", "CPU"); return config; } ov::AnyMap get_npu_generate_config() { auto config = get_npu_base_config(); - config.emplace("NPUW_FUNCALL_ASYNC", "YES"); return config; } From b6c84afe8a55511ecc742e80612169f55dece957 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Wed, 17 Sep 2025 15:35:27 +0800 Subject: [PATCH 27/30] Fix ROPE accuracy when freq_scale != 1 --- ggml/src/ggml-openvino/ggml-openvino.cpp | 6 +----- ggml/src/ggml-openvino/openvino/utils.cpp | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 6da653716f7ed..683f768c5f170 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -319,12 +319,8 @@ static bool is_op_unsupported_case(const ggml_tensor* op) { return true; } float freq_scale; - memcpy(&freq_scale, op_params + 6, sizeof(float)); - if (freq_scale != 0.0f && freq_scale != 1.0f) { - GGML_LOG_WARN("OpenVINO backend does not support ROPE with freq_scale %f != 1.0f\n", freq_scale); - return true; - } float ext_factor; + memcpy(&freq_scale, op_params + 6, sizeof(float)); memcpy(&ext_factor, op_params + 7, sizeof(float)); if (ext_factor != 0.0f) { GGML_LOG_WARN("OpenVINO backend does not support ROPE with ext_factor %f != 0.0f\n", ext_factor); diff --git a/ggml/src/ggml-openvino/openvino/utils.cpp b/ggml/src/ggml-openvino/openvino/utils.cpp index ef5f51ebbc4a7..f70cb91a17fe0 100644 --- a/ggml/src/ggml-openvino/openvino/utils.cpp +++ b/ggml/src/ggml-openvino/openvino/utils.cpp @@ -140,7 +140,7 @@ std::pair, ov::Output> make_sin_cos(int32_t* rope_params, ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); std::vector factor(n_dims / 2); - factor[0] = freq_scale; + factor[0] = 1.0f; for (size_t i = 1; i < factor.size(); i++) { factor[i] = theta_scale * factor[i - 1]; } From b3eb6fb8dcf184ff2ac22d19368e7884d0829603 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Wed, 17 Sep 2025 16:50:54 +0800 Subject: [PATCH 28/30] Minor: not add attention_size_swa for non-swa model --- ggml/src/ggml-openvino/ggml-decoder.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index e3dd5e0c1dd91..8286052f8bf2e 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -360,7 +360,9 @@ void GgmlOvDecoder::add_extra_inputs() { }; create_attention_size_input("attention_size", attention_size); - create_attention_size_input("attention_size_swa", attention_size_swa); + if (attention_size_swa != -1) { + create_attention_size_input("attention_size_swa", attention_size_swa); + } } const ggml_tensor* GgmlOvDecoder::get_tensor_used_op(const ggml_tensor* tensor) const { From bb3353045cda50de66af4f6695549ffd5d453e28 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 19 Sep 2025 16:50:27 +0800 Subject: [PATCH 29/30] Minor refactor --- ggml/src/ggml-openvino/ggml-decoder.cpp | 10 ---------- ggml/src/ggml-openvino/utils.cpp | 5 +++++ 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 8286052f8bf2e..a5d9d6967fd92 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -65,11 +65,6 @@ GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph, print_tensor_address_map(cgraph); } - if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) { - std::string filename = "cgraph.txt"; - dump_cgraph(cgraph, filename); - } - set_llm_params(); for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { @@ -83,11 +78,6 @@ GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph, GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph, std::map>& model_weights) { - if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) { - std::string filename = "cgraph.txt"; - dump_cgraph(cgraph, filename); - } - m_cgraph = cgraph; m_model_weights = model_weights; for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index db471636452ad..07cbb2e437f6a 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -86,6 +86,11 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c }; } + if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) { + std::string filename = "cgraph.txt"; + GgmlOvDecoder::dump_cgraph(cgraph, filename); + } + if (is_naive(cgraph)) { return naive_compute(cgraph, core, device, config); } From 812590bc1516d745640ba7f2449ebe04f1892b29 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 23 Sep 2025 16:07:51 +0800 Subject: [PATCH 30/30] Add Q5_K to support phi-3-q4_k_m --- ggml/src/ggml-openvino/ggml-decoder.cpp | 8 +- ggml/src/ggml-openvino/ggml-openvino.cpp | 1 + ggml/src/ggml-openvino/ggml-quants.cpp | 143 ++++++++++++++++++----- ggml/src/ggml-openvino/ggml-quants.hpp | 5 + ggml/src/ggml-openvino/utils.cpp | 1 + 5 files changed, 124 insertions(+), 34 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index a5d9d6967fd92..38b0fa3db4f1c 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -448,6 +448,7 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_K, + GGML_TYPE_Q5_K, GGML_TYPE_Q6_K}; if (weight_types.find(tensor->type) == weight_types.end()) { throw std::runtime_error("Unexpected weight tensor type: " + std::string(tensor->name) + " with type " + @@ -486,12 +487,12 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor, ov::element::Type weight_type; if (tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_1 || tensor->type == GGML_TYPE_Q4_K) { weight_type = ov::element::u4; - } else { // tensor.type == GGUF_TYPE_Q8_0 || tensor.type == GGUF_TYPE_Q6_K + } else { // tensor.type == GGUF_TYPE_Q8_0 || tensor.type == GGUF_TYPE_Q6_K || tensor.type == GGUF_TYPE_Q5_K weight_type = ov::element::u8; } uint64_t weights_per_block; - // here we only consider sub block, q6k:16 q4k:32 + // here we only consider sub block, q6k:16 q4k:32 q5k:32 if (tensor->type == GGML_TYPE_Q6_K) { weights_per_block = 16; } else { @@ -526,6 +527,9 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor, } else if (tensor->type == GGML_TYPE_Q4_K) { extract_q4_k_data(tensor, weights, scales, biases); weight_node = make_int4_weights(weights, scales, biases, weights_per_block); + } else if (tensor->type == GGML_TYPE_Q5_K) { + extract_q5_k_data(tensor, weights, scales, biases); + weight_node = make_int8_weights(weights, scales, biases, weights_per_block); } OPENVINO_ASSERT(weight_node.get_shape().size() == 2, "Weight should be 2D"); diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 683f768c5f170..648acb4e35ede 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -350,6 +350,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_K, + GGML_TYPE_Q5_K, GGML_TYPE_Q8_0, GGML_TYPE_Q6_K}; diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp index 1603e65355274..9b8bfff072570 100644 --- a/ggml/src/ggml-openvino/ggml-quants.cpp +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -1,9 +1,17 @@ #include "ggml-quants.hpp" +#include +#include +#include +#include #include #include #include +#include +#include #include +#include +#include #include #include #include @@ -11,9 +19,12 @@ #include #include #include +#include #include #include +#include +#include "ggml-common.h" #include "ggml-impl.h" #include "ggml.h" @@ -38,10 +49,10 @@ void extract_q4_0_data(const ggml_tensor* tensor, ov::Tensor& scales_arr, ov::Tensor& biases_arr) { const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights - auto data = static_cast(tensor->data); - auto weights = static_cast(weights_arr.data()); - auto scales = scales_arr.data::value_type>(); - auto biases = biases_arr.data::value_type>(); + auto* data = static_cast(tensor->data); + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); ov::parallel_for(scales_arr.get_size(), [&](size_t i) { scales[i] = ov::float16::from_bits(*((uint16_t*)(data + i * bytes_per_block))); @@ -57,10 +68,10 @@ void extract_q4_1_data(const ggml_tensor* tensor, ov::Tensor& scales_arr, ov::Tensor& biases_arr) { const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights - auto data = static_cast(tensor->data); - auto weights = static_cast(weights_arr.data()); - auto scales = scales_arr.data::value_type>(); - auto biases = biases_arr.data::value_type>(); + auto* data = static_cast(tensor->data); + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); ov::parallel_for(scales_arr.get_size(), [&](size_t i) { scales[i] = ov::float16::from_bits(*((uint16_t*)(data + i * bytes_per_block))); biases[i] = ov::float16::from_bits(*((uint16_t*)(data + i * bytes_per_block + 2))); @@ -76,22 +87,22 @@ void extract_q8_0_data(const ggml_tensor* tensor, ov::Tensor& biases_arr) { const uint64_t weights_per_block = 32; const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights - auto data = static_cast(tensor->data); - auto weights = static_cast(weights_arr.data()); - auto scales = scales_arr.data::value_type>(); - auto biases = biases_arr.data::value_type>(); - for (size_t i = 0; i < scales_arr.get_size(); i++) { + auto* data = static_cast(tensor->data); + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); + + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { uint8_t* block_data = data + i * bytes_per_block; - scales[i] = ov::float16::from_bits(*(uint16_t*)block_data); + scales[i] = ov::float16::from_bits(*(uint16_t*) block_data); biases[i] = ov::float16(-128.f * static_cast(scales[i])); for (size_t j = 0; j < weights_per_block; ++j) { uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes. - // Original data is in int8_t, so we add a bias of -128 and invert the - // first bit. + // Original data is in int8_t, so we add a bias of -128 and invert the first bit. x ^= 1 << 7; weights[i * weights_per_block + j] = x; } - } + }); } void unpack_256_4(const uint8_t* data, uint8_t* dst) { @@ -117,12 +128,11 @@ void extract_q4_k_data(const ggml_tensor* tensor, ov::Tensor& scales_arr, ov::Tensor& biases_arr) { const uint64_t bytes_per_block = 2 + 2 + 12 + 128; - // TODO tensor->nb[3] const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; - auto data = static_cast(tensor->data); - auto weights = static_cast(weights_arr.data()); - auto scales = scales_arr.data::value_type>(); - auto biases = biases_arr.data::value_type>(); + auto* data = static_cast(tensor->data); + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); ov::parallel_for(n_super_block, [&](size_t i) { uint8_t* block_data = data + i * bytes_per_block; @@ -170,28 +180,26 @@ void extract_q6_k_data(const ggml_tensor* tensor, ov::Tensor& biases_arr) { const uint64_t bytes_per_block = 128 + 64 + 16 + 2; const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; - auto data = static_cast(tensor->data); - auto weights = static_cast(weights_arr.data()); - auto scales = scales_arr.data::value_type>(); - auto biases = biases_arr.data::value_type>(); - // std::string name(tensor.name, tensor.namelen); - for (size_t i = 0; i < n_super_block; i++) { + auto* data = static_cast(tensor->data); + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); + + ov::parallel_for(n_super_block, [&](size_t i) { uint8_t* block_data = data + i * bytes_per_block; float scale_factor = - static_cast(ov::float16::from_bits(*((uint16_t*)block_data + 104))); // (128+64+16)/2 + static_cast(ov::float16::from_bits(*((uint16_t*) block_data + 104))); // (128+64+16)/2 for (size_t j = 0; j < 16; j++) { scales[j + i * 16] = - ov::float16(scale_factor * static_cast(*((int8_t*)(block_data + 128 + 64 + j)))); + ov::float16(scale_factor * static_cast(*((int8_t*) (block_data + 128 + 64 + j)))); biases[j + i * 16] = ov::float16(-32.f * static_cast(scales[j + i * 16])); } - // Extract ql and qh uint8_t* ql = block_data; uint8_t* qh = block_data + 128; - // Extract weights for (int64_t j = 0; j < 32; ++j) { weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4); weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4); @@ -202,9 +210,80 @@ void extract_q6_k_data(const ggml_tensor* tensor, weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4); weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4); } + }); +} + +static inline void get_scale_min_k4(int j, const uint8_t* q, uint8_t* d, uint8_t* m) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } else { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); } } +void extract_q5_k_data(const ggml_tensor* tensor, ov::Tensor& weights_arr, ov::Tensor& scales_arr, + ov::Tensor& biases_arr) { + const uint64_t bytes_per_block = 4 + 12 + 32 + 128; + const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; + auto* data = static_cast(tensor->data); + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); + + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t* block_data = data + i * bytes_per_block; + + const float d = static_cast(ov::float16::from_bits(*((uint16_t*) block_data))); + const float min = static_cast(ov::float16::from_bits(*((uint16_t*) block_data + 1))); + + const uint8_t* scales_data = block_data + 4; // 12 bytes of scales + const uint8_t* qh = block_data + 4 + 12; // 32 bytes of high bits + const uint8_t* ql = block_data + 4 + 12 + 32; // 128 bytes of low bits + + int is = 0; + uint8_t u1 = 1; + uint8_t u2 = 2; + + // Process 2 blocks in one iteration + for (int j = 0; j < 256; j += 64) { // 256 = QK_K, so 4 iterations of 64 + uint8_t sc; + uint8_t m; + + // Get scale and min for first 32 elements + get_scale_min_k4(is + 0, scales_data, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + // Get scale and min for second 32 elements + get_scale_min_k4(is + 1, scales_data, &sc, &m); + const float d2 = d * sc; + const float m2 = min * m; + + scales[i * 8 + is] = ov::float16(d1); + biases[i * 8 + is] = ov::float16(-m1); + scales[i * 8 + is + 1] = ov::float16(d2); + biases[i * 8 + is + 1] = ov::float16(-m2); + + // Extract weights for first 32 elements (matching deq formula exactly) + for (int l = 0; l < 32; ++l) { + weights[i * 256 + j + l] = (ql[l] & 0xF) + ((qh[l] & u1) ? 16 : 0); + } + + // Extract weights for second 32 elements + for (int l = 0; l < 32; ++l) { + weights[i * 256 + j + l + 32] = (ql[l] >> 4) + ((qh[l] & u2) ? 16 : 0); + } + + ql += 32; + is += 2; + u1 <<= 2; + u2 <<= 2; + } + }); +} + // TODO Reorder for make_intX_weights ov::Output make_int8_weights(ov::Tensor& weight, ov::Tensor& scales, ov::Tensor& biases, size_t group_size) { diff --git a/ggml/src/ggml-openvino/ggml-quants.hpp b/ggml/src/ggml-openvino/ggml-quants.hpp index fbae2aa1f43ef..5496785eb1fbd 100644 --- a/ggml/src/ggml-openvino/ggml-quants.hpp +++ b/ggml/src/ggml-openvino/ggml-quants.hpp @@ -29,6 +29,11 @@ void extract_q4_k_data(const ggml_tensor* tensor, ov::Tensor& scales_arr, ov::Tensor& biases_arr); +void extract_q5_k_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr); + void extract_q6_k_data(const ggml_tensor* tensor, ov::Tensor& weights_arr, ov::Tensor& scales_arr, diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 07cbb2e437f6a..e9084cf387f7a 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -283,6 +283,7 @@ std::map get_types_to_requant(const std::string& devi {GGML_TYPE_Q4_1, ExtraQuantType::Q4_0_128}, {GGML_TYPE_Q4_K, ExtraQuantType::Q4_0_128}, {GGML_TYPE_Q6_K, ExtraQuantType::F16 }, + {GGML_TYPE_Q5_K, ExtraQuantType::F16 }, }; } if (device == "GPU") {