Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions auto_round/data_type/mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,17 @@ def quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding="even"):
if ebits != 0:
private_exp = floor_ste(torch.log2(torch.abs(tensor) + (tensor == 0).type(tensor.dtype)))
# The minimum representable exponent for 8 exp bits is -126
min_exp = -(2 ** (ebits - 1)) + 2
min_exp = -(2.0 ** float(ebits - 1)) + 2
private_exp = private_exp.clip(min=min_exp)
else:
private_exp = None

# Scale up so appropriate number of mbits are in the integer portion of the number
tensor = tensor * (2 ** (mbits - 2)) if private_exp is None else tensor / (2**private_exp) * (2 ** (mbits - 2))
tensor = (
tensor * (2.0 ** float(mbits - 2))
if private_exp is None
else tensor / (2.0 ** private_exp.float()) * (2.0 ** float(mbits - 2))
)
if mantissa_rounding == "even":
abs_tensor = torch.abs(tensor)
mask_tensor = ((abs_tensor - 0.5) % 2 == torch.zeros_like(abs_tensor)).type(tensor.dtype)
Expand All @@ -71,7 +75,11 @@ def quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding="even"):
raise ValueError("mantissa_rounding only supports even, nearest or floor.")

# Undo scaling
tensor = tensor / (2 ** (mbits - 2)) if private_exp is None else tensor / (2 ** (mbits - 2)) * (2**private_exp)
tensor = (
tensor / (2.0 ** float(mbits - 2))
if private_exp is None
else tensor / (2.0 ** float(mbits - 2)) * (2.0 ** private_exp.float())
)

tensor = torch.clamp(tensor, min=-max_norm, max=max_norm)
return tensor
Expand Down Expand Up @@ -114,10 +122,10 @@ def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, mantissa_roundin
# shared_exp = torch.log2(shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype))
shared_exp = torch.where(max_val == 0, torch.ones_like(max_val), torch.log2(max_val))
shared_exp = floor_ste(shared_exp)
scale_emax = 2 ** (8 - 1) - 1
scale_emax = 2.0 ** float(8 - 1) - 1
shared_exp = (shared_exp - emax).clamp(min=-scale_emax, max=scale_emax)

scale = torch.pow(2, shared_exp)
scale = torch.pow(2.0, shared_exp.float())
tensor = tensor / scale + v
tensor = torch.clamp(tensor, min=-max_norm, max=max_norm)
tensor = quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding)
Expand Down Expand Up @@ -165,10 +173,10 @@ def quant_mx_rceil(

# shared_exp = torch.log2(shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype))
shared_exp = torch.where(max_val == 0, torch.ones_like(max_val), ceil_ste(torch.log2(max_val / max_norm)))
scale_emax = 2 ** (8 - 1) - 1
scale_emax = 2.0 ** float(8 - 1) - 1
shared_exp = shared_exp.clamp(min=-scale_emax, max=scale_emax)

scale = torch.pow(2, shared_exp)
scale = torch.pow(2.0, shared_exp.float())
tensor = tensor / scale + v
tensor = torch.clamp(tensor, min=-max_norm, max=max_norm)
tensor = quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding)
Expand Down
7 changes: 7 additions & 0 deletions test/test_cpu/test_alg_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ def test_alg_ext(self):
ar = AutoRound(model_name, scheme="gguf:q4_k_s", iters=1, nsamples=1, enable_alg_ext=True)
ar.quantize()

from auto_round.auto_scheme import AutoScheme

scheme = AutoScheme(options=["mxfp4", "mxfp8"], avg_bits=5.5, ignore_scale_zp_bits=True)
model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B"
ar = AutoRound(model_name, scheme=scheme, iters=1, nsamples=1, enable_alg_ext=True, enable_torch_compile=True)
ar.quantize()

def test_alg_ext_import(self):
from auto_round.alg_ext import wrapper_autoround

Expand Down
20 changes: 19 additions & 1 deletion test/test_cuda/test_alg_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,29 @@ def test_cli(self):
python_path = sys.executable

res = os.system(
f"cd ../.. && CUDA_VISIBLE_DEVICES=0 {python_path} -m auto_round --model {model_name} --device auto --enable_alg_ext --avg_bits 2 --options=W2A16,W4A16 --ignore_scale_zp_bits"
f"cd ../.. && CUDA_VISIBLE_DEVICES=0 {python_path} -m auto_round --model {model_name} --iters 1 --device auto --enable_alg_ext --avg_bits 2 --options=W2A16,W4A16 --ignore_scale_zp_bits"
)
if res > 0 or res == -1:
assert False, "cmd line test fail, please have a check"

res = os.system(
f"cd ../.. && CUDA_VISIBLE_DEVICES=0 {python_path} -m auto_round --model {model_name} --iters 1 --device auto --enable_alg_ext --avg_bits 5.5 --options=mxfp4,mxfp8 --ignore_scale_zp_bits --enable_torch_compile"
)
if res > 0 or res == -1:
assert False, "cmd line test fail, please have a check"

def test_all_support_dtype(self):
from auto_round.auto_scheme import AutoScheme

model_name = "/models/Qwen3-0.6B"
for scheme in ["MXFP4", "NVFP4", "W2A16G64", "gguf:q2_k_s,gguf:q4_k_s"]:
avg_bits = 2 if scheme == "W2A16G64" else 4
scheme = AutoScheme(options=scheme, avg_bits=avg_bits, ignore_scale_zp_bits=True)
ar = AutoRound(
model_name, scheme=scheme, iters=1, nsamples=1, enable_alg_ext=True, enable_torch_compile=True
)
ar.quantize()


if __name__ == "__main__":
unittest.main()