From 4b610c71a7166a10660a116f4ed0b2bde39a69a7 Mon Sep 17 00:00:00 2001 From: David Fan Date: Thu, 10 Apr 2025 17:45:36 +0000 Subject: [PATCH 1/6] k quant Signed-off-by: David Fan --- .../adaptor/ox_utils/weight_only.py | 175 +++++++++++++++++- 1 file changed, 172 insertions(+), 3 deletions(-) diff --git a/neural_compressor/adaptor/ox_utils/weight_only.py b/neural_compressor/adaptor/ox_utils/weight_only.py index f6e575fd9f3..d5743170de2 100644 --- a/neural_compressor/adaptor/ox_utils/weight_only.py +++ b/neural_compressor/adaptor/ox_utils/weight_only.py @@ -246,6 +246,168 @@ def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ra return q_weight, scale, zero_point +def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32): + """Quantize tensor per group based on k quant. + Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c + + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 4. + + Returns: + output: quantized weight + scale: scale + zero_point: zero point + """ + data = np.reshape(data, (-1, group_size)).astype(np.float32) # (nb, group_size) + maxq = 2**num_bits - 1 + minq = 0 + sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1) + av_x = np.sqrt(sum_x2 / group_size) # (nb, 1) + weights = np.add(av_x, np.abs(data)) # (nb, group_size) + rmin = np.min(data, axis=1, keepdims=True) # (nb, 1) + rmax = np.max(data, axis=1, keepdims=True) # (nb, 1) + sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1) + sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) + iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + mask = rmin != rmax + iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask]) + scale = 1 / iscale + quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) + diff = scale * quant_data + rmin - data # (nb, group_size) + best_mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + nstep = 20 + rdelta = 0.1 + # nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1 + rrmin = -1 + for is_ in range(nstep): + iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0] + mask = rmin != rmax + iscale_new[mask] = factor / (rmax[mask] - rmin[mask]) + quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) + mul_weights_quant_data_new = weights * quant_data_new + sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) + D = np.subtract(sum_w * sum_l2, sum_l ** 2) # (nb, 1) + + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) + + diff = this_scale * quant_data_new + this_min - data # (nb, group_size) + mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + + mad_1 = np.array(mad) + best_mad_1 = np.array(best_mad) + idx_to_replace = np.where(mad_1 < best_mad_1)[0] + quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :] + best_mad[idx_to_replace] = mad[idx_to_replace] + scale[idx_to_replace] = this_scale[idx_to_replace] + rmin[idx_to_replace] = this_min[idx_to_replace] + + zero_point = np.clip((( - rmin) / scale).round(), 0, maxq).astype("uint8") + scale = scale.astype(np.float64) + q_weight = np.empty_like(data, dtype=scale.dtype) + np.divide(data, scale, out=q_weight) + np.add(q_weight, zero_point, out=q_weight) + np.round(q_weight, out=q_weight) + np.clip(q_weight, minq, maxq, out=q_weight) + + return q_weight, scale, zero_point + +def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32): + """Quantize tensor per group based on k quant. + Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c + + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 4. + + Returns: + output: quantized weight + scale: scale + zero_point: zero point + """ + try: + import cupy as cp + import torch + if torch.cuda.is_available(): + data = cp.asarray(data) + data = data.reshape((-1, group_size)).astype(np.float32) # (nb, group_size) + nb = data.shape[0] + maxq = 2**num_bits - 1 + minq = 0 + sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1) + av_x = np.sqrt(sum_x2 / group_size) # (nb, 1) + weights = np.add(av_x, np.abs(data)) # (nb, group_size) + rmin = np.min(data, axis=1, keepdims=True) # (nb, 1) + rmax = np.max(data, axis=1, keepdims=True) # (nb, 1) + sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1) + sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) + iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + mask = rmin != rmax + iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask]) + scale = 1 / iscale + quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) + diff = scale * quant_data + rmin - data # (nb, group_size) + best_mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + nstep = 20 + rdelta = 0.1 + rrmin = -1 + for is_ in range(nstep): + iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0] + mask = rmin != rmax + iscale_new[mask] = factor / (rmax[mask] - rmin[mask]) + quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) + mul_weights_quant_data_new = weights * quant_data_new + sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) + D = np.subtract(sum_w * sum_l2, sum_l ** 2) # (nb, 1) + + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) + + diff = this_scale * quant_data_new + this_min - data # (nb, group_size) + mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + + mad_1 = cp.array(mad) + best_mad_1 = cp.array(best_mad) + idx_to_replace = np.where(mad_1 < best_mad_1)[0] + quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :] + best_mad[idx_to_replace] = mad[idx_to_replace] + scale[idx_to_replace] = this_scale[idx_to_replace] + rmin[idx_to_replace] = this_min[idx_to_replace] + + zero_point = np.clip((( - rmin) / scale).round(), 0, maxq).astype("uint8") + scale = scale.astype(np.float64) + q_weight = np.empty_like(data, dtype=scale.dtype) + np.divide(data, scale, out=q_weight) + np.add(q_weight, zero_point, out=q_weight) + np.round(q_weight, out=q_weight) + np.clip(q_weight, minq, maxq, out=q_weight) + + return q_weight.get(), scale.get(), zero_point.get() + else: + logger.warning("Try to use k-quant quantization on CUDA. However, CUDA is not available." \ + "Fall back to k-quant quantization on CPU.") + return quant_tensor_k_quant_cpu( + data, num_bits, group_size + ) + except ImportError: + logger.info( + "Now we are using k-quant quantization on cpu, which is time consuming." \ + "Please consider install cupy to speed up on CUDA. See https://cupy.dev/" \ + "Please also install torch to check CUDA availablity." + ) + return quant_tensor_k_quant_cpu( + data, num_bits, group_size + ) + def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0): """Quant dequant tensor per group. @@ -299,6 +461,7 @@ def rtn_quantize( ratios={}, accuracy_level=0, providers=["CPUExecutionProvider"], + algorithm="rtn", ): """Quant the model with round to nearst method. @@ -372,9 +535,15 @@ def rtn_quantize( ): # pragma: no cover # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP - q_weight, scale, zp = quant_tensor( - weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1) - ) + if algorithm == "k_quant": + q_weight, scale, zp = quant_tensor_k_quant_cuda( + weight.T, num_bits, group_size + ) + else: + q_weight, scale, zp = quant_tensor( + weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1) + ) + q_matmul_node, new_inits = make_matmul_weight_only_node( node=node, weight_shape=org_w_shape, From 6015feb2f1e6f74e7f36a7a9731bc45df2a72a8a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Apr 2025 17:55:26 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../adaptor/ox_utils/weight_only.py | 117 +++++++++--------- 1 file changed, 58 insertions(+), 59 deletions(-) diff --git a/neural_compressor/adaptor/ox_utils/weight_only.py b/neural_compressor/adaptor/ox_utils/weight_only.py index d5743170de2..34ef17e28ac 100644 --- a/neural_compressor/adaptor/ox_utils/weight_only.py +++ b/neural_compressor/adaptor/ox_utils/weight_only.py @@ -246,6 +246,7 @@ def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ra return q_weight, scale, zero_point + def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32): """Quantize tensor per group based on k quant. Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c @@ -260,44 +261,44 @@ def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32): scale: scale zero_point: zero point """ - data = np.reshape(data, (-1, group_size)).astype(np.float32) # (nb, group_size) + data = np.reshape(data, (-1, group_size)).astype(np.float32) # (nb, group_size) maxq = 2**num_bits - 1 minq = 0 - sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1) - av_x = np.sqrt(sum_x2 / group_size) # (nb, 1) - weights = np.add(av_x, np.abs(data)) # (nb, group_size) - rmin = np.min(data, axis=1, keepdims=True) # (nb, 1) - rmax = np.max(data, axis=1, keepdims=True) # (nb, 1) - sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1) - sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) - iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1) + av_x = np.sqrt(sum_x2 / group_size) # (nb, 1) + weights = np.add(av_x, np.abs(data)) # (nb, group_size) + rmin = np.min(data, axis=1, keepdims=True) # (nb, 1) + rmax = np.max(data, axis=1, keepdims=True) # (nb, 1) + sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1) + sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) + iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) mask = rmin != rmax iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask]) scale = 1 / iscale - quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) - diff = scale * quant_data + rmin - data # (nb, group_size) - best_mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) + diff = scale * quant_data + rmin - data # (nb, group_size) + best_mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) nstep = 20 rdelta = 0.1 # nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1 rrmin = -1 for is_ in range(nstep): - iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0] mask = rmin != rmax iscale_new[mask] = factor / (rmax[mask] - rmin[mask]) - quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) + quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) mul_weights_quant_data_new = weights * quant_data_new - sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) - sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) - sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) - D = np.subtract(sum_w * sum_l2, sum_l ** 2) # (nb, 1) + sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) + D = np.subtract(sum_w * sum_l2, sum_l**2) # (nb, 1) - this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) - this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) - diff = this_scale * quant_data_new + this_min - data # (nb, group_size) - mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + diff = this_scale * quant_data_new + this_min - data # (nb, group_size) + mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) mad_1 = np.array(mad) best_mad_1 = np.array(best_mad) @@ -307,7 +308,7 @@ def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32): scale[idx_to_replace] = this_scale[idx_to_replace] rmin[idx_to_replace] = this_min[idx_to_replace] - zero_point = np.clip((( - rmin) / scale).round(), 0, maxq).astype("uint8") + zero_point = np.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8") scale = scale.astype(np.float64) q_weight = np.empty_like(data, dtype=scale.dtype) np.divide(data, scale, out=q_weight) @@ -317,6 +318,7 @@ def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32): return q_weight, scale, zero_point + def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32): """Quantize tensor per group based on k quant. Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c @@ -334,46 +336,47 @@ def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32): try: import cupy as cp import torch + if torch.cuda.is_available(): data = cp.asarray(data) - data = data.reshape((-1, group_size)).astype(np.float32) # (nb, group_size) + data = data.reshape((-1, group_size)).astype(np.float32) # (nb, group_size) nb = data.shape[0] maxq = 2**num_bits - 1 minq = 0 - sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1) - av_x = np.sqrt(sum_x2 / group_size) # (nb, 1) - weights = np.add(av_x, np.abs(data)) # (nb, group_size) - rmin = np.min(data, axis=1, keepdims=True) # (nb, 1) - rmax = np.max(data, axis=1, keepdims=True) # (nb, 1) - sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1) - sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) - iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1) + av_x = np.sqrt(sum_x2 / group_size) # (nb, 1) + weights = np.add(av_x, np.abs(data)) # (nb, group_size) + rmin = np.min(data, axis=1, keepdims=True) # (nb, 1) + rmax = np.max(data, axis=1, keepdims=True) # (nb, 1) + sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1) + sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) + iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) mask = rmin != rmax iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask]) scale = 1 / iscale - quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) - diff = scale * quant_data + rmin - data # (nb, group_size) - best_mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) + diff = scale * quant_data + rmin - data # (nb, group_size) + best_mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) nstep = 20 rdelta = 0.1 rrmin = -1 for is_ in range(nstep): - iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0] mask = rmin != rmax iscale_new[mask] = factor / (rmax[mask] - rmin[mask]) - quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) + quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) mul_weights_quant_data_new = weights * quant_data_new - sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) - sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) - sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) - D = np.subtract(sum_w * sum_l2, sum_l ** 2) # (nb, 1) + sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) + D = np.subtract(sum_w * sum_l2, sum_l**2) # (nb, 1) - this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) - this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) - diff = this_scale * quant_data_new + this_min - data # (nb, group_size) - mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + diff = this_scale * quant_data_new + this_min - data # (nb, group_size) + mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) mad_1 = cp.array(mad) best_mad_1 = cp.array(best_mad) @@ -383,7 +386,7 @@ def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32): scale[idx_to_replace] = this_scale[idx_to_replace] rmin[idx_to_replace] = this_min[idx_to_replace] - zero_point = np.clip((( - rmin) / scale).round(), 0, maxq).astype("uint8") + zero_point = np.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8") scale = scale.astype(np.float64) q_weight = np.empty_like(data, dtype=scale.dtype) np.divide(data, scale, out=q_weight) @@ -393,20 +396,18 @@ def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32): return q_weight.get(), scale.get(), zero_point.get() else: - logger.warning("Try to use k-quant quantization on CUDA. However, CUDA is not available." \ - "Fall back to k-quant quantization on CPU.") - return quant_tensor_k_quant_cpu( - data, num_bits, group_size + logger.warning( + "Try to use k-quant quantization on CUDA. However, CUDA is not available." + "Fall back to k-quant quantization on CPU." ) + return quant_tensor_k_quant_cpu(data, num_bits, group_size) except ImportError: logger.info( - "Now we are using k-quant quantization on cpu, which is time consuming." \ - "Please consider install cupy to speed up on CUDA. See https://cupy.dev/" \ - "Please also install torch to check CUDA availablity." - ) - return quant_tensor_k_quant_cpu( - data, num_bits, group_size + "Now we are using k-quant quantization on cpu, which is time consuming." + "Please consider install cupy to speed up on CUDA. See https://cupy.dev/" + "Please also install torch to check CUDA availability." ) + return quant_tensor_k_quant_cpu(data, num_bits, group_size) def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0): @@ -536,9 +537,7 @@ def rtn_quantize( # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP if algorithm == "k_quant": - q_weight, scale, zp = quant_tensor_k_quant_cuda( - weight.T, num_bits, group_size - ) + q_weight, scale, zp = quant_tensor_k_quant_cuda(weight.T, num_bits, group_size) else: q_weight, scale, zp = quant_tensor( weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1) From c3318cf09acef2374ce222bac5f464c0f5e40d76 Mon Sep 17 00:00:00 2001 From: David Fan Date: Thu, 10 Apr 2025 17:45:36 +0000 Subject: [PATCH 3/6] k quant Signed-off-by: David Fan --- .../adaptor/ox_utils/weight_only.py | 174 +++++++++++++++++- 1 file changed, 171 insertions(+), 3 deletions(-) diff --git a/neural_compressor/adaptor/ox_utils/weight_only.py b/neural_compressor/adaptor/ox_utils/weight_only.py index f6e575fd9f3..2925b53fb7d 100644 --- a/neural_compressor/adaptor/ox_utils/weight_only.py +++ b/neural_compressor/adaptor/ox_utils/weight_only.py @@ -246,6 +246,167 @@ def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ra return q_weight, scale, zero_point +def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32): + """Quantize tensor per group based on k quant. + Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c + + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 32. + + Returns: + output: quantized weight + scale: scale + zero_point: zero point + """ + data = np.reshape(data, (-1, group_size)).astype(np.float32) # (nb, group_size) + maxq = 2**num_bits - 1 + minq = 0 + sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1) + av_x = np.sqrt(sum_x2 / group_size) # (nb, 1) + weights = np.add(av_x, np.abs(data)) # (nb, group_size) + rmin = np.min(data, axis=1, keepdims=True) # (nb, 1) + rmax = np.max(data, axis=1, keepdims=True) # (nb, 1) + sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1) + sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) + iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + mask = rmin != rmax + iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask]) + scale = 1 / iscale + quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) + diff = scale * quant_data + rmin - data # (nb, group_size) + best_mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + nstep = 20 + rdelta = 0.1 + # nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1 + rrmin = -1 + for is_ in range(nstep): + iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0] + mask = rmin != rmax + iscale_new[mask] = factor / (rmax[mask] - rmin[mask]) + quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) + mul_weights_quant_data_new = weights * quant_data_new + sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) + D = np.subtract(sum_w * sum_l2, sum_l ** 2) # (nb, 1) + + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) + + diff = this_scale * quant_data_new + this_min - data # (nb, group_size) + mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + + mad_1 = np.array(mad) + best_mad_1 = np.array(best_mad) + idx_to_replace = np.where(mad_1 < best_mad_1)[0] + quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :] + best_mad[idx_to_replace] = mad[idx_to_replace] + scale[idx_to_replace] = this_scale[idx_to_replace] + rmin[idx_to_replace] = this_min[idx_to_replace] + + zero_point = np.clip((( - rmin) / scale).round(), 0, maxq).astype("uint8") + scale = scale.astype(np.float64) + q_weight = np.empty_like(data, dtype=scale.dtype) + np.divide(data, scale, out=q_weight) + np.add(q_weight, zero_point, out=q_weight) + np.round(q_weight, out=q_weight) + np.clip(q_weight, minq, maxq, out=q_weight) + + return q_weight, scale, zero_point + +def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32): + """Quantize tensor per group based on k quant. + Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c + + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 4. + + Returns: + output: quantized weight + scale: scale + zero_point: zero point + """ + try: + import cupy as cp + import torch + if torch.cuda.is_available(): + data = cp.asarray(data) + data = data.reshape((-1, group_size)).astype(cp.float32) # nb = data.shape[0], (nb, group_size) + maxq = 2**num_bits - 1 + minq = 0 + sum_x2 = cp.sum(data**2, axis=1, keepdims=True) # (nb, 1) + av_x = cp.sqrt(sum_x2 / group_size) # (nb, 1) + weights = cp.add(av_x, cp.abs(data)) # (nb, group_size) + rmin = cp.min(data, axis=1, keepdims=True) # (nb, 1) + rmax = cp.max(data, axis=1, keepdims=True) # (nb, 1) + sum_w = cp.sum(weights, axis=1, keepdims=True) # (nb, 1) + sum_x = cp.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) + iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + mask = rmin != rmax + iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask]) + scale = 1 / iscale + quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) + diff = scale * quant_data + rmin - data # (nb, group_size) + best_mad = cp.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + nstep = 20 + rdelta = 0.1 + rrmin = -1 + for is_ in range(nstep): + iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0] + mask = rmin != rmax + iscale_new[mask] = factor / (rmax[mask] - rmin[mask]) + quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) + mul_weights_quant_data_new = weights * quant_data_new + sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) + D = cp.subtract(sum_w * sum_l2, sum_l ** 2) # (nb, 1) + + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) + + diff = this_scale * quant_data_new + this_min - data # (nb, group_size) + mad = cp.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + + mad_1 = cp.array(mad) + best_mad_1 = cp.array(best_mad) + idx_to_replace = cp.where(mad_1 < best_mad_1)[0] + quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :] + best_mad[idx_to_replace] = mad[idx_to_replace] + scale[idx_to_replace] = this_scale[idx_to_replace] + rmin[idx_to_replace] = this_min[idx_to_replace] + + zero_point = cp.clip((( - rmin) / scale).round(), 0, maxq).astype("uint8") + scale = scale.astype(cp.float64) + q_weight = cp.empty_like(data, dtype=scale.dtype) + cp.divide(data, scale, out=q_weight) + cp.add(q_weight, zero_point, out=q_weight) + cp.round(q_weight, out=q_weight) + cp.clip(q_weight, minq, maxq, out=q_weight) + + return q_weight.get(), scale.get(), zero_point.get() + else: + logger.warning("Try to use k-quant quantization on CUDA. However, CUDA is not available." \ + "Fall back to k-quant quantization on CPU.") + return quant_tensor_k_quant_cpu( + data, num_bits, group_size + ) + except ImportError: + logger.info( + "Now we are using k-quant quantization on cpu, which is time consuming." \ + "Please consider install cupy to speed up on CUDA. See https://cupy.dev/" \ + "Please also install torch to check CUDA availablity." + ) + return quant_tensor_k_quant_cpu( + data, num_bits, group_size + ) + def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0): """Quant dequant tensor per group. @@ -299,6 +460,7 @@ def rtn_quantize( ratios={}, accuracy_level=0, providers=["CPUExecutionProvider"], + algorithm="rtn", ): """Quant the model with round to nearst method. @@ -372,9 +534,15 @@ def rtn_quantize( ): # pragma: no cover # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP - q_weight, scale, zp = quant_tensor( - weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1) - ) + if algorithm == "k_quant": + q_weight, scale, zp = quant_tensor_k_quant_cuda( + weight.T, num_bits, group_size + ) + else: + q_weight, scale, zp = quant_tensor( + weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1) + ) + q_matmul_node, new_inits = make_matmul_weight_only_node( node=node, weight_shape=org_w_shape, From 1b3518a1da685689d42123676e5526e42297f368 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 12 Apr 2025 15:30:09 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../adaptor/ox_utils/weight_only.py | 113 +++++++++--------- 1 file changed, 57 insertions(+), 56 deletions(-) diff --git a/neural_compressor/adaptor/ox_utils/weight_only.py b/neural_compressor/adaptor/ox_utils/weight_only.py index 10d09600225..6a99cb2bd32 100644 --- a/neural_compressor/adaptor/ox_utils/weight_only.py +++ b/neural_compressor/adaptor/ox_utils/weight_only.py @@ -246,6 +246,7 @@ def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ra return q_weight, scale, zero_point + def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32): """Quantize tensor per group based on k quant. Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c @@ -260,44 +261,44 @@ def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32): scale: scale zero_point: zero point """ - data = np.reshape(data, (-1, group_size)).astype(np.float32) # (nb, group_size) + data = np.reshape(data, (-1, group_size)).astype(np.float32) # (nb, group_size) maxq = 2**num_bits - 1 minq = 0 - sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1) - av_x = np.sqrt(sum_x2 / group_size) # (nb, 1) - weights = np.add(av_x, np.abs(data)) # (nb, group_size) - rmin = np.min(data, axis=1, keepdims=True) # (nb, 1) - rmax = np.max(data, axis=1, keepdims=True) # (nb, 1) - sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1) - sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) - iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1) + av_x = np.sqrt(sum_x2 / group_size) # (nb, 1) + weights = np.add(av_x, np.abs(data)) # (nb, group_size) + rmin = np.min(data, axis=1, keepdims=True) # (nb, 1) + rmax = np.max(data, axis=1, keepdims=True) # (nb, 1) + sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1) + sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) + iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) mask = rmin != rmax iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask]) scale = 1 / iscale - quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) - diff = scale * quant_data + rmin - data # (nb, group_size) - best_mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) + diff = scale * quant_data + rmin - data # (nb, group_size) + best_mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) nstep = 20 rdelta = 0.1 # nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1 rrmin = -1 for is_ in range(nstep): - iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0] mask = rmin != rmax iscale_new[mask] = factor / (rmax[mask] - rmin[mask]) - quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) + quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) mul_weights_quant_data_new = weights * quant_data_new - sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) - sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) - sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) - D = np.subtract(sum_w * sum_l2, sum_l ** 2) # (nb, 1) + sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) + D = np.subtract(sum_w * sum_l2, sum_l**2) # (nb, 1) - this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) - this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) - diff = this_scale * quant_data_new + this_min - data # (nb, group_size) - mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + diff = this_scale * quant_data_new + this_min - data # (nb, group_size) + mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) mad_1 = np.array(mad) best_mad_1 = np.array(best_mad) @@ -307,7 +308,7 @@ def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32): scale[idx_to_replace] = this_scale[idx_to_replace] rmin[idx_to_replace] = this_min[idx_to_replace] - zero_point = np.clip((( - rmin) / scale).round(), 0, maxq).astype("uint8") + zero_point = np.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8") scale = scale.astype(np.float64) q_weight = np.empty_like(data, dtype=scale.dtype) np.divide(data, scale, out=q_weight) @@ -317,6 +318,7 @@ def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32): return q_weight, scale, zero_point + def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32): """Quantize tensor per group based on k quant. Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c @@ -334,45 +336,46 @@ def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32): try: import cupy as cp import torch + if torch.cuda.is_available(): data = cp.asarray(data) - data = data.reshape((-1, group_size)).astype(cp.float32) # nb = data.shape[0], (nb, group_size) + data = data.reshape((-1, group_size)).astype(cp.float32) # nb = data.shape[0], (nb, group_size) maxq = 2**num_bits - 1 minq = 0 - sum_x2 = cp.sum(data**2, axis=1, keepdims=True) # (nb, 1) - av_x = cp.sqrt(sum_x2 / group_size) # (nb, 1) - weights = cp.add(av_x, cp.abs(data)) # (nb, group_size) - rmin = cp.min(data, axis=1, keepdims=True) # (nb, 1) - rmax = cp.max(data, axis=1, keepdims=True) # (nb, 1) - sum_w = cp.sum(weights, axis=1, keepdims=True) # (nb, 1) - sum_x = cp.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) - iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + sum_x2 = cp.sum(data**2, axis=1, keepdims=True) # (nb, 1) + av_x = cp.sqrt(sum_x2 / group_size) # (nb, 1) + weights = cp.add(av_x, cp.abs(data)) # (nb, group_size) + rmin = cp.min(data, axis=1, keepdims=True) # (nb, 1) + rmax = cp.max(data, axis=1, keepdims=True) # (nb, 1) + sum_w = cp.sum(weights, axis=1, keepdims=True) # (nb, 1) + sum_x = cp.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) + iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) mask = rmin != rmax iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask]) scale = 1 / iscale - quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) - diff = scale * quant_data + rmin - data # (nb, group_size) - best_mad = cp.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) + diff = scale * quant_data + rmin - data # (nb, group_size) + best_mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) nstep = 20 rdelta = 0.1 rrmin = -1 for is_ in range(nstep): - iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0] mask = rmin != rmax iscale_new[mask] = factor / (rmax[mask] - rmin[mask]) - quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) + quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) mul_weights_quant_data_new = weights * quant_data_new - sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) - sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) - sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) - D = cp.subtract(sum_w * sum_l2, sum_l ** 2) # (nb, 1) + sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) + D = cp.subtract(sum_w * sum_l2, sum_l**2) # (nb, 1) - this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) - this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) - diff = this_scale * quant_data_new + this_min - data # (nb, group_size) - mad = cp.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + diff = this_scale * quant_data_new + this_min - data # (nb, group_size) + mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) mad_1 = cp.array(mad) best_mad_1 = cp.array(best_mad) @@ -382,7 +385,7 @@ def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32): scale[idx_to_replace] = this_scale[idx_to_replace] rmin[idx_to_replace] = this_min[idx_to_replace] - zero_point = cp.clip((( - rmin) / scale).round(), 0, maxq).astype("uint8") + zero_point = cp.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8") scale = scale.astype(cp.float64) q_weight = cp.empty_like(data, dtype=scale.dtype) cp.divide(data, scale, out=q_weight) @@ -392,20 +395,18 @@ def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32): return q_weight.get(), scale.get(), zero_point.get() else: - logger.warning("Try to use k-quant quantization on CUDA. However, CUDA is not available." \ - "Fall back to k-quant quantization on CPU.") - return quant_tensor_k_quant_cpu( - data, num_bits, group_size + logger.warning( + "Try to use k-quant quantization on CUDA. However, CUDA is not available." + "Fall back to k-quant quantization on CPU." ) + return quant_tensor_k_quant_cpu(data, num_bits, group_size) except ImportError: logger.info( - "Now we are using k-quant quantization on cpu, which is time consuming." \ - "Please consider install cupy to speed up on CUDA. See https://cupy.dev/" \ - "Please also install torch to check CUDA availablity." - ) - return quant_tensor_k_quant_cpu( - data, num_bits, group_size + "Now we are using k-quant quantization on cpu, which is time consuming." + "Please consider install cupy to speed up on CUDA. See https://cupy.dev/" + "Please also install torch to check CUDA availability." ) + return quant_tensor_k_quant_cpu(data, num_bits, group_size) def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0): From 0a1a0d47c0eb38d8d78a2eb16ce2e899b7625c89 Mon Sep 17 00:00:00 2001 From: David Fan Date: Sat, 12 Apr 2025 15:15:26 +0000 Subject: [PATCH 5/6] test --- .../adaptor/ox_utils/weight_only.py | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/neural_compressor/adaptor/ox_utils/weight_only.py b/neural_compressor/adaptor/ox_utils/weight_only.py index f6e575fd9f3..79a80d57728 100644 --- a/neural_compressor/adaptor/ox_utils/weight_only.py +++ b/neural_compressor/adaptor/ox_utils/weight_only.py @@ -40,7 +40,7 @@ ONNXRT1161_VERSION = Version("1.16.1") -def get_blob_size(group_size, has_zp): # pragma: no cover +def get_blob_size(group_size, num_bits, has_zp): # pragma: no cover """Get blob_size. Args: @@ -48,11 +48,11 @@ def get_blob_size(group_size, has_zp): # pragma: no cover has_zp (bool): whether zero_point is None """ if Version(ort.__version__) > ONNXRT1161_VERSION: - blob_size = group_size // 2 + blob_size = group_size * num_bits // 8 elif has_zp: - blob_size = group_size // 2 + 4 + 1 + blob_size = group_size * num_bits // 8 + 4 + 1 else: - blob_size = group_size // 2 + 4 + blob_size = group_size * num_bits // 8 + 4 return blob_size @@ -86,7 +86,7 @@ def make_matmul_weight_only_node( matmul_weight_only_node: MatMulFpQ4 or MatMulNBits node new_inits: initializers of the new node """ - blob_size = get_blob_size(group_size, zero_point is not None) + blob_size = get_blob_size(group_size, num_bits, zero_point is not None) packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8") q_weight_name = node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)) input_names = [node.input[0], q_weight_name] @@ -97,8 +97,16 @@ def make_matmul_weight_only_node( op_type = "MatMulNBits" # pack quantized weight - q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4 - packed[:, :] = q_weight_pairs[:, :blob_size] + if num_bits == 4: + q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4 + packed[:, :] = q_weight_pairs[:, :blob_size] + elif num_bits == 8: + packed = q_weight + else: + logger.error( + "MatMulNBits does not have kernel support for num_bits = {}.".format(num_bits) + ) + packed = np.reshape(packed, (-1, k_blocks, blob_size)) # build scale tensor @@ -362,7 +370,8 @@ def rtn_quantize( weight = pad_tensor(weight, group_size, k_blocks) - satisfy_MatMulNBits_condition = Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4 + enable_MatMulNBits_8bits = True + satisfy_MatMulNBits_condition = (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (enable_MatMulNBits_8bits and num_bits == 8) satisfy_MatMulFpQ4_condition = ( Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32 ) From 04409050e43b8c229b57a8d4bad48305ff26dfff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 26 Apr 2025 18:59:47 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../adaptor/ox_utils/weight_only.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/neural_compressor/adaptor/ox_utils/weight_only.py b/neural_compressor/adaptor/ox_utils/weight_only.py index 011bd2c6685..7439d2c8272 100644 --- a/neural_compressor/adaptor/ox_utils/weight_only.py +++ b/neural_compressor/adaptor/ox_utils/weight_only.py @@ -103,9 +103,7 @@ def make_matmul_weight_only_node( elif num_bits == 8: packed = q_weight else: - logger.error( - "MatMulNBits does not have kernel support for num_bits = {}.".format(num_bits) - ) + logger.error("MatMulNBits does not have kernel support for num_bits = {}.".format(num_bits)) packed = np.reshape(packed, (-1, k_blocks, blob_size)) @@ -273,44 +271,44 @@ def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32): scale: scale zero_point: zero point """ - data = np.reshape(data, (-1, group_size)).astype(np.float32) # nb = data.shape[0], (nb, group_size) + data = np.reshape(data, (-1, group_size)).astype(np.float32) # nb = data.shape[0], (nb, group_size) maxq = 2**num_bits - 1 minq = 0 - sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1) - av_x = np.sqrt(sum_x2 / group_size) # (nb, 1) - weights = np.add(av_x, np.abs(data)) # (nb, group_size) - rmin = np.min(data, axis=1, keepdims=True) # (nb, 1) - rmax = np.max(data, axis=1, keepdims=True) # (nb, 1) - sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1) - sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) - iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1) + av_x = np.sqrt(sum_x2 / group_size) # (nb, 1) + weights = np.add(av_x, np.abs(data)) # (nb, group_size) + rmin = np.min(data, axis=1, keepdims=True) # (nb, 1) + rmax = np.max(data, axis=1, keepdims=True) # (nb, 1) + sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1) + sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) + iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) mask = rmin != rmax iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask]) scale = 1 / iscale - quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) - diff = scale * quant_data + rmin - data # (nb, group_size) - best_mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) + diff = scale * quant_data + rmin - data # (nb, group_size) + best_mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) nstep = 20 rdelta = 0.1 # nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1 rrmin = -1 for is_ in range(nstep): - iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0] mask = rmin != rmax iscale_new[mask] = factor / (rmax[mask] - rmin[mask]) - quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) + quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) mul_weights_quant_data_new = weights * quant_data_new - sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) - sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) - sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) - D = np.subtract(sum_w * sum_l2, sum_l ** 2) # (nb, 1) + sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) + D = np.subtract(sum_w * sum_l2, sum_l**2) # (nb, 1) - this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) - this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) - diff = this_scale * quant_data_new + this_min - data # (nb, group_size) - mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1) + diff = this_scale * quant_data_new + this_min - data # (nb, group_size) + mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) mad_1 = np.array(mad) best_mad_1 = np.array(best_mad) @@ -539,7 +537,9 @@ def rtn_quantize( weight = pad_tensor(weight, group_size, k_blocks) enable_MatMulNBits_8bits = True - satisfy_MatMulNBits_condition = (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (enable_MatMulNBits_8bits and num_bits == 8) + satisfy_MatMulNBits_condition = (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or ( + enable_MatMulNBits_8bits and num_bits == 8 + ) satisfy_MatMulFpQ4_condition = ( Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32 )