diff --git a/auto_round/data_type/mxfp.py b/auto_round/data_type/mxfp.py index 862ff0a9a..9a7315a11 100644 --- a/auto_round/data_type/mxfp.py +++ b/auto_round/data_type/mxfp.py @@ -50,13 +50,17 @@ def quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding="even"): if ebits != 0: private_exp = floor_ste(torch.log2(torch.abs(tensor) + (tensor == 0).type(tensor.dtype))) # The minimum representable exponent for 8 exp bits is -126 - min_exp = -(2 ** (ebits - 1)) + 2 + min_exp = -(2.0 ** float(ebits - 1)) + 2 private_exp = private_exp.clip(min=min_exp) else: private_exp = None # Scale up so appropriate number of mbits are in the integer portion of the number - tensor = tensor * (2 ** (mbits - 2)) if private_exp is None else tensor / (2**private_exp) * (2 ** (mbits - 2)) + tensor = ( + tensor * (2.0 ** float(mbits - 2)) + if private_exp is None + else tensor / (2.0 ** private_exp.float()) * (2.0 ** float(mbits - 2)) + ) if mantissa_rounding == "even": abs_tensor = torch.abs(tensor) mask_tensor = ((abs_tensor - 0.5) % 2 == torch.zeros_like(abs_tensor)).type(tensor.dtype) @@ -71,7 +75,11 @@ def quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding="even"): raise ValueError("mantissa_rounding only supports even, nearest or floor.") # Undo scaling - tensor = tensor / (2 ** (mbits - 2)) if private_exp is None else tensor / (2 ** (mbits - 2)) * (2**private_exp) + tensor = ( + tensor / (2.0 ** float(mbits - 2)) + if private_exp is None + else tensor / (2.0 ** float(mbits - 2)) * (2.0 ** private_exp.float()) + ) tensor = torch.clamp(tensor, min=-max_norm, max=max_norm) return tensor @@ -114,10 +122,10 @@ def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, mantissa_roundin # shared_exp = torch.log2(shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype)) shared_exp = torch.where(max_val == 0, torch.ones_like(max_val), torch.log2(max_val)) shared_exp = floor_ste(shared_exp) - scale_emax = 2 ** (8 - 1) - 1 + scale_emax = 2.0 ** float(8 - 1) - 1 shared_exp = (shared_exp - emax).clamp(min=-scale_emax, max=scale_emax) - scale = torch.pow(2, shared_exp) + scale = torch.pow(2.0, shared_exp.float()) tensor = tensor / scale + v tensor = torch.clamp(tensor, min=-max_norm, max=max_norm) tensor = quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding) @@ -165,10 +173,10 @@ def quant_mx_rceil( # shared_exp = torch.log2(shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype)) shared_exp = torch.where(max_val == 0, torch.ones_like(max_val), ceil_ste(torch.log2(max_val / max_norm))) - scale_emax = 2 ** (8 - 1) - 1 + scale_emax = 2.0 ** float(8 - 1) - 1 shared_exp = shared_exp.clamp(min=-scale_emax, max=scale_emax) - scale = torch.pow(2, shared_exp) + scale = torch.pow(2.0, shared_exp.float()) tensor = tensor / scale + v tensor = torch.clamp(tensor, min=-max_norm, max=max_norm) tensor = quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding) diff --git a/test/test_cpu/test_alg_ext.py b/test/test_cpu/test_alg_ext.py index f5a7d306b..b0c909bd3 100644 --- a/test/test_cpu/test_alg_ext.py +++ b/test/test_cpu/test_alg_ext.py @@ -20,6 +20,13 @@ def test_alg_ext(self): ar = AutoRound(model_name, scheme="gguf:q4_k_s", iters=1, nsamples=1, enable_alg_ext=True) ar.quantize() + from auto_round.auto_scheme import AutoScheme + + scheme = AutoScheme(options=["mxfp4", "mxfp8"], avg_bits=5.5, ignore_scale_zp_bits=True) + model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B" + ar = AutoRound(model_name, scheme=scheme, iters=1, nsamples=1, enable_alg_ext=True, enable_torch_compile=True) + ar.quantize() + def test_alg_ext_import(self): from auto_round.alg_ext import wrapper_autoround diff --git a/test/test_cuda/test_alg_ext.py b/test/test_cuda/test_alg_ext.py index 4bc493316..c83d6f3b4 100644 --- a/test/test_cuda/test_alg_ext.py +++ b/test/test_cuda/test_alg_ext.py @@ -46,11 +46,29 @@ def test_cli(self): python_path = sys.executable res = os.system( - f"cd ../.. && CUDA_VISIBLE_DEVICES=0 {python_path} -m auto_round --model {model_name} --device auto --enable_alg_ext --avg_bits 2 --options=W2A16,W4A16 --ignore_scale_zp_bits" + f"cd ../.. && CUDA_VISIBLE_DEVICES=0 {python_path} -m auto_round --model {model_name} --iters 1 --device auto --enable_alg_ext --avg_bits 2 --options=W2A16,W4A16 --ignore_scale_zp_bits" ) if res > 0 or res == -1: assert False, "cmd line test fail, please have a check" + res = os.system( + f"cd ../.. && CUDA_VISIBLE_DEVICES=0 {python_path} -m auto_round --model {model_name} --iters 1 --device auto --enable_alg_ext --avg_bits 5.5 --options=mxfp4,mxfp8 --ignore_scale_zp_bits --enable_torch_compile" + ) + if res > 0 or res == -1: + assert False, "cmd line test fail, please have a check" + + def test_all_support_dtype(self): + from auto_round.auto_scheme import AutoScheme + + model_name = "/models/Qwen3-0.6B" + for scheme in ["MXFP4", "NVFP4", "W2A16G64", "gguf:q2_k_s,gguf:q4_k_s"]: + avg_bits = 2 if scheme == "W2A16G64" else 4 + scheme = AutoScheme(options=scheme, avg_bits=avg_bits, ignore_scale_zp_bits=True) + ar = AutoRound( + model_name, scheme=scheme, iters=1, nsamples=1, enable_alg_ext=True, enable_torch_compile=True + ) + ar.quantize() + if __name__ == "__main__": unittest.main()