Skip to content

Align Int4Tensor implementation details with the design of Float8Tensor #2687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 110 additions & 35 deletions test/quantization/quantize_/workflows/int4/test_int4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,21 @@

import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)

from torchao.quantization import (
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization import Int4WeightOnlyConfig, quantize_
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_90,
)
from torchao.testing.utils import TorchAOIntegrationTestCase
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_90


@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
class TestInt4Tensor(TestCase):
class TestInt4Tensor(TorchAOIntegrationTestCase):
def setUp(self):
self.config = Int4WeightOnlyConfig(
group_size=128,
Expand Down Expand Up @@ -61,50 +57,46 @@ def test_slice(self):
quantize_(dummy, self.config)
weight1 = dummy.weight.narrow(0, 0, 64)
weight2 = dummy.weight.narrow(1, 0, 128)
self.assertEqual(weight1._data, dummy.weight._data.narrow(0, 0, 64))
self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64))
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(1, 0, 64))
self.assertEqual(weight2._data, dummy.weight._data.narrow(1, 0, 64))
self.assertEqual(weight1.zero_point, dummy.weight.zero_point.narrow(1, 0, 64))
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 64))
self.assertEqual(weight2.scale, dummy.weight.scale.narrow(0, 0, 1))
self.assertEqual(weight2.zero_point, dummy.weight.zero_point.narrow(0, 0, 1))

# check for sliced weight, before and after float8 quantization
# does not differ too much
input = torch.randn(2, 256, dtype=dtype, device=device)
res_ref = dummy1(input)
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
dummy.weight = torch.nn.Parameter(weight1.contiguous(), requires_grad=False)
res = dummy(input)
assert compute_error(res, res_ref) > 20

input = torch.randn(2, 128, dtype=dtype, device=device)
res_ref = dummy2(input)
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
dummy.weight = torch.nn.Parameter(weight2.contiguous(), requires_grad=False)
res = dummy(input)
assert compute_error(res, res_ref) > 15

def test_slice_and_copy_(self):
def test_slice_preserves_aliasing(self):
config = self.config
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
l.weight = torch.nn.Parameter(
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
)
quantize_(l, self.config)
quantize_(l, config)
param = l.weight
param_data = param.data
param_data = param_data.narrow(0, 0, 512)
assert param.data._data.data_ptr() == param_data._data.data_ptr()
# Making sure the aliasing is preserved in sliced quantized Tensor
assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr()
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr()
orig_value = param.data._data[0][0].item()

# dummy_l has random input (shouldn't be 0)
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
quantize_(dummy_l, self.config)
quantized = dummy_l.weight
quantized = quantized.narrow(0, 0, 512)

param_data.copy_(quantized)

# making sure param.data is updated
assert param.data._data[0][0] != orig_value
def test_slice_and_copy_similar_to_vllm(self):
self._test_slice_and_copy_similar_to_vllm(self.config)

@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
def test_bmm(self):
class M(torch.nn.Module):
def __init__(self, weight):
Expand All @@ -126,20 +118,103 @@ def forward(self, x):
quantized = m(input)
self.assertTrue(compute_error(original, quantized) > 18)

def test_to_device(self):
@parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 64, 256),
((2, 32, 128), 64, 256),
],
)
def test_to_device(self, sizes):
config = self.config
M, N, K = sizes
dtype = torch.bfloat16
for device in self.GPU_DEVICES:
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
input_tensor = torch.randn(*M, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, dtype=dtype)
quantize_(linear, config)
linear.to(device)
linear(input_tensor)

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
linear = torch.nn.Linear(K, N, dtype=dtype)
quantize_(linear, config)
linear.to(device=device)
linear(input_tensor)

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
linear = torch.nn.Linear(K, N, dtype=dtype)
quantize_(linear, config)
linear.to(device)
linear(input_tensor)

@parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 64, 256),
((2, 32, 128), 64, 256),
],
)
def test_cat(self, sizes):
config = self.config
dtype = torch.bfloat16
device = "cuda"
M, N, K = sizes
linear1 = torch.nn.Linear(K, N, dtype=dtype, device=device)
linear2 = torch.nn.Linear(K, N, dtype=dtype, device=device)
input_cat1 = torch.randn(*M, K, dtype=dtype, device=device)

cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
dummy_linear1 = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device)

dummy_linear1.weight = torch.nn.Parameter(cat_weight1)
quantize_(dummy_linear1, config)

quantize_(linear1, config)
quantize_(linear2, config)

cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
self.assertTrue(cat_qweight1.shape, (2 * N, K))
self.assertEqual(
dummy_linear1.weight.qdata,
cat_qweight1.qdata,
)
self.assertEqual(
dummy_linear1.weight.scale,
cat_qweight1.scale,
)
self.assertEqual(
dummy_linear1.weight.zero_point,
cat_qweight1.zero_point,
)

# making sure cat_qweight1 can be used for inference
dummy_linear1.weight = torch.nn.Parameter(cat_qweight1, requires_grad=False)
dummy_linear1(input_cat1)

# align the scale and zero_point before concatenation
linear2.weight.scale = linear1.weight.scale
linear2.weight.zero_point = linear1.weight.zero_point
cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
self.assertTrue(cat_qweight2.shape, (N, 2 * K))
ref_data = torch.cat(
[
linear1.weight.qdata,
linear2.weight.qdata,
],
dim=1,
)
ref_scale = linear1.weight.scale
ref_zero_point = linear1.weight.zero_point
self.assertEqual(cat_qweight2.qdata, ref_data)
self.assertEqual(cat_qweight2.scale, ref_scale)
self.assertEqual(cat_qweight2.zero_point, ref_zero_point)

def test_moe_weight_reshape_ops(self):
self._test_moe_weight_reshape_ops(self.config)


instantiate_parametrized_tests(TestInt4Tensor)

if __name__ == "__main__":
run_tests()
14 changes: 10 additions & 4 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, qdata, attr, device=None):
self.qdata = qdata
self.attr = attr

l = torch.nn.Linear(1, 1)
l = torch.nn.Linear(2, 3)
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
lp_tensor = l.weight
# test __tensor_flatten__ and __tensor_unflatten__
Expand Down Expand Up @@ -107,18 +107,24 @@ def __init__(self, qdata, attr, device=None):
# explicitly testing aten.alias
lp_tensor = torch.ops.aten.alias(lp_tensor)
lp_tensor = lp_tensor.clone()
# making qdata not contiguous
lp_tensor.qdata = lp_tensor.qdata.transpose(0, 1).contiguous()
lp_tensor.qdata = lp_tensor.qdata.transpose(0, 1)
self.assertFalse(lp_tensor.qdata.is_contiguous())
lp_tensor = lp_tensor.contiguous()
# making sure contiguous call works
self.assertTrue(lp_tensor.qdata.is_contiguous())

# copy_
another_tensor = torch.nn.Linear(1, 1).weight
another_tensor = torch.nn.Linear(2, 3).weight
# attribute has to be the same
another_lp_tensor = MyTensor(another_tensor, "attr")
# initially tensor values are not the same
self.assertNotEqual(lp_tensor.qdata[0], another_lp_tensor.qdata[0])
self.assertNotEqual(lp_tensor.qdata[0][0], another_lp_tensor.qdata[0][0])
lp_tensor.copy_(another_lp_tensor)
self.assertEqual(lp_tensor.attr, "attr")
# after copy_, the tensor values should match
self.assertEqual(lp_tensor.qdata[0], another_lp_tensor.qdata[0])
self.assertEqual(lp_tensor.qdata[0][0], another_lp_tensor.qdata[0][0])


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,7 @@ def _int4_weight_only_quantize_tensor(weight, config):
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])

if config.VERSION == 2:
block_size = list(block_size)
if packing_format == PackingFormat.PRESHUFFLED:
new_weight = Int4PreshuffledTensor.from_float(
weight,
Expand All @@ -1168,7 +1169,7 @@ def _int4_weight_only_quantize_tensor(weight, config):
)
return new_weight
elif packing_format == PackingFormat.PLAIN:
new_weight = Int4Tensor.from_float(
new_weight = Int4Tensor.from_hp(
weight,
block_size,
)
Expand Down Expand Up @@ -2212,7 +2213,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
activation_dtype=torch.bfloat16,
)
else:
weight = Int4Tensor.from_float(
weight = Int4Tensor.from_hp(
module.weight,
config.block_size,
)
Expand Down
Loading
Loading