Skip to content

Add all fbgemm kernel Tensors into Int4WeightOnlyConfig and Float8DynamicActivationInt4WeightConfig #2474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_ref_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Inference APIs for quantize\_
:nosignatures:

Int4WeightOnlyConfig
Float8ActivationInt4WeightConfig
Float8DynamicActivationFloat8WeightConfig
Float8WeightOnlyConfig
Float8StaticActivationFloat8WeightConfig
Expand Down
5 changes: 4 additions & 1 deletion test/integration/test_loading_deprecated_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
class TestLoadingDeprecatedCheckpoint(TestCase):
@common_utils.parametrize("model_name_and_version", _MODEL_NAME_AND_VERSIONS)
def test_load_model_and_run(self, model_name_and_version):
"""Test that we print correct warning message when loading a deprecated checkpoint"""
"""Test that we print correct warning message when loading a deprecated checkpoint
and making sure the deprecated checkpoints can still be loaded
"""
# Load and quantize model
model_name, version = model_name_and_version
with warnings.catch_warnings(record=True) as caught_warnings:
Expand All @@ -41,6 +43,7 @@ def test_load_model_and_run(self, model_name_and_version):
for w in caught_warnings
), "Didn't get expected warning message for version mismatch"

# TODO: generalize when we test more checkpoints
assert any(
"Models quantized with version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated"
in str(w.message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
run_tests,
)

from torchao.float8.config import e4m3_dtype
from torchao.quantization import (
FbgemmConfig,
Float8ActivationInt4WeightConfig,
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
Expand All @@ -27,44 +27,16 @@
is_sm_at_least_90,
)

if TORCH_VERSION_AT_LEAST_2_8:
BF16_ACT_CONFIG = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
preshuffle=True,
)

BF16_ACT_BMM_CONFIG = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
preshuffle=True,
)

FP8_ACT_CONFIG = FbgemmConfig(
input_dtype=e4m3_dtype,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
preshuffle=True,
)

FP8_ACT_BMM_CONFIG = FbgemmConfig(
input_dtype=e4m3_dtype,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
preshuffle=True,
)

else:
BF16_ACT_CONFIG = None
BF16_ACT_BMM_CONFIG = None
FP8_ACT_CONFIG = None
FP8_ACT_BMM_CONFIG = None
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
group_size=128,
packing_format="preshuffled",
VERSION=2,
)

FP8_ACT_CONFIG = Float8ActivationInt4WeightConfig(
group_size=128,
packing_format="preshuffled",
)


@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
Expand All @@ -90,7 +62,7 @@ def test_linear(self, config):

# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
@parametrize("bmm_config", [FP8_ACT_CONFIG, BF16_ACT_CONFIG])
def test_bmm(self, bmm_config):
class M(torch.nn.Module):
def __init__(self, weight):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)

from torchao.quantization import (
FbgemmConfig,
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
Expand All @@ -26,19 +26,12 @@
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
class TestFbgemmInt4Tensor(TestCase):
class TestInt4Tensor(TestCase):
def setUp(self):
self.config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
)
self.bmm_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
self.config = Int4WeightOnlyConfig(
group_size=128,
packing_format="plain",
VERSION=2,
)
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []

Expand Down Expand Up @@ -68,13 +61,9 @@ def test_slice(self):
quantize_(dummy, self.config)
weight1 = dummy.weight.narrow(0, 0, 64)
weight2 = dummy.weight.narrow(1, 0, 128)
self.assertEqual(
weight1.packed_weight, dummy.weight.packed_weight.narrow(0, 0, 64)
)
self.assertEqual(weight1._data, dummy.weight._data.narrow(0, 0, 64))
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(1, 0, 64))
self.assertEqual(
weight2.packed_weight, dummy.weight.packed_weight.narrow(1, 0, 64)
)
self.assertEqual(weight2._data, dummy.weight._data.narrow(1, 0, 64))
self.assertEqual(weight2.scale, dummy.weight.scale.narrow(0, 0, 1))

# check for sliced weight, before and after float8 quantization
Expand All @@ -100,12 +89,10 @@ def test_slice_and_copy_(self):
param = l.weight
param_data = param.data
param_data = param_data.narrow(0, 0, 512)
assert (
param.data.packed_weight.data_ptr() == param_data.packed_weight.data_ptr()
)
assert param.data._data.data_ptr() == param_data._data.data_ptr()
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr()
orig_value = param.data.packed_weight[0][0].item()
orig_value = param.data._data[0][0].item()

# dummy_l has random input (shouldn't be 0)
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
Expand All @@ -116,7 +103,7 @@ def test_slice_and_copy_(self):
param_data.copy_(quantized)

# making sure param.data is updated
assert param.data.packed_weight[0][0] != orig_value
assert param.data._data[0][0] != orig_value

def test_bmm(self):
class M(torch.nn.Module):
Expand All @@ -135,7 +122,7 @@ def forward(self, x):
original = m(input)
# we need to transpose the weight first for bmm
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
quantize_(m, self.config, filter_fn=lambda x, fqn: True)
quantized = m(input)
self.assertTrue(compute_error(original, quantized) > 18)

Expand Down
3 changes: 0 additions & 3 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
to_affine_quantized_intx_static,
)
from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8
from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4
from .floatx import (
CutlassSemiSparseLayout,
Float8Layout,
Expand Down Expand Up @@ -64,8 +63,6 @@
"PackedLinearInt8DynamicActivationIntxWeightLayout",
"to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight",
"Int4XPULayout",
"to_fbgemm_int4",
"FbgemmInt4Tensor",
"to_fbgemm_fp8",
"FbgemmFp8Tensor",
"Int8DynamicActInt4WeightCPULayout",
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .quant_api import (
CutlassInt4PackedLayout,
FbgemmConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the plan for FbgemmConfig? Looks like it was added only ~1.5 months ago but it's technically public API. Do we know if anyone's using it already? I don't think it's released yet so wonder if it's OK to just remove it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll remove it, it is used in some internal script but we'll update these as well

Float8ActivationInt4WeightConfig,
Float8DynamicActivationFloat8SemiSparseWeightConfig,
Float8DynamicActivationFloat8WeightConfig,
Float8MMConfig,
Expand Down Expand Up @@ -90,6 +91,7 @@
from .quantize_.workflows import (
Float8Tensor,
Int4PreshuffledTensor,
Int4Tensor,
)
from .smoothquant import (
SmoothFakeDynamicallyQuantizedLinear,
Expand Down Expand Up @@ -141,6 +143,7 @@
"Int8DynamicActivationInt8WeightConfig",
"Int8DynamicActivationIntxWeightConfig",
"Int4WeightOnlyConfig",
"Float8ActivationInt4WeightConfig",
"Int8WeightOnlyConfig",
"Float8WeightOnlyConfig",
"Float8DynamicActivationFloat8WeightConfig",
Expand All @@ -154,6 +157,7 @@
"ModuleFqnToConfig",
"FbgemmConfig",
# tensor subclasses
"Int4Tensor",
"Int4PreshuffledTensor",
"Float8Tensor",
# smooth quant - subject to change
Expand Down
73 changes: 70 additions & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
to_affine_quantized_floatx_static,
to_affine_quantized_intx,
to_fbgemm_fp8,
to_fbgemm_int4,
to_marlinqqq_quantized_intx,
)
from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import (
Expand All @@ -71,10 +70,12 @@
from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size
from torchao.quantization.quantize_.common import (
KernelPreference,
PackingFormat,
)
from torchao.quantization.quantize_.workflows import (
Float8Tensor,
Int4PreshuffledTensor,
Int4Tensor,
QuantizeTensorToFloat8Kwargs,
)
from torchao.quantization.transform_module import (
Expand Down Expand Up @@ -1119,6 +1120,7 @@ class Int4WeightOnlyConfig(AOBaseConfig):
`zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE]
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values.
`preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT
`packing_format`: the packing format for int4 tensor, available from VERSION 2 and above
"""

group_size: int = 128
Expand All @@ -1127,6 +1129,9 @@ class Int4WeightOnlyConfig(AOBaseConfig):
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE
set_inductor_config: bool = True
preserve_zero: Optional[bool] = None
# only used in VERSION >= 2
packing_format: PackingFormat = PackingFormat.PLAIN
VERSION: int = 1


# for BC
Expand All @@ -1144,15 +1149,36 @@ def _int4_weight_only_quantize_tensor(weight, config):
layout = config.layout
use_hqq = config.use_hqq
zero_point_domain = config.zero_point_domain
packing_format = config.packing_format

if weight.shape[-1] % group_size != 0:
logger.info(
f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
)
return weight

block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])

if config.VERSION == 2:
if packing_format == PackingFormat.PRESHUFFLED:
new_weight = Int4PreshuffledTensor.from_float(
weight,
block_size,
activation_dtype=torch.bfloat16,
)
return new_weight
elif packing_format == PackingFormat.PLAIN:
new_weight = Int4Tensor.from_float(
weight,
block_size,
)
return new_weight
else:
raise ValueError(f"Unsupported packing format: {packing_format}")

assert config.VERSION == 1

mapping_type = MappingType.ASYMMETRIC
block_size = tuple([1 for _ in range(weight.dim() - 1)] + [group_size])
target_dtype = torch.int32
quant_min = 0
quant_max = 15
Expand Down Expand Up @@ -1224,6 +1250,46 @@ def _int4_weight_only_transform(
return module


@dataclass
class Float8ActivationInt4WeightConfig(AOBaseConfig):
"""Configuration for apply float8 dynamic per row quantization and int4
per group weight quantization to linear

Args:
`group_size`: group size for groupwise quantization for weight
`packing_format`: how the weight is packed, only preshuffled is supported
"""

group_size: int = 128
packing_format: PackingFormat = "preshuffled"


@register_quantize_module_handler(Float8ActivationInt4WeightConfig)
def _float8_activation_int4_weight_transform(
module: torch.nn.Module, config: Float8ActivationInt4WeightConfig
) -> torch.nn.Module:
assert hasattr(module, "weight"), (
"applying int8 weight only quant requires module to have weight attribute"
+ " but {module} does not have one"
)
group_size = config.group_size
packing_format = config.packing_format

assert packing_format == "preshuffled", (
f"only preshuffled packing_format supported right now, got: {packing_format}"
)
weight = module.weight
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
new_weight = Int4PreshuffledTensor.from_float(
module.weight,
block_size,
activation_dtype=torch.float8_e4m3fn,
)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


@dataclass
class Int8WeightOnlyConfig(AOBaseConfig):
"""
Expand Down Expand Up @@ -1677,6 +1743,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
# TODO(future PR): this should really throw an exception instead of silently
# not doing what the user asked
return weight

if isinstance(weight_granularity, PerRow):
assert weight.dtype == torch.bfloat16, (
"PerRow quantization only works for bfloat16 precision input weight"
Expand Down Expand Up @@ -2145,7 +2212,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
activation_dtype=torch.bfloat16,
)
else:
weight = to_fbgemm_int4(
weight = Int4Tensor.from_float(
module.weight,
config.block_size,
)
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/quantize_/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .kernel_preference import KernelPreference
from .packing_format import PackingFormat
from .quantize_tensor_kwargs import (
QuantizeTensorKwargs,
_choose_quant_func_and_quantize_tensor,
Expand All @@ -7,5 +8,6 @@
__all__ = [
"QuantizeTensorKwargs",
"KernelPreference",
"PackingFormat",
"_choose_quant_func_and_quantize_tensor",
]
Loading
Loading