From 3796f20220376434a5b68a32d84078145366a658 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Fri, 5 Dec 2025 13:52:13 +0530 Subject: [PATCH 1/8] Introduces customizable quantization API --- keras/src/layers/core/dense.py | 66 +++++-- keras/src/layers/core/dense_test.py | 58 ++++++ keras/src/layers/core/einsum_dense.py | 73 +++++--- keras/src/layers/core/einsum_dense_test.py | 62 +++++++ keras/src/layers/core/embedding.py | 43 +++-- keras/src/layers/core/embedding_test.py | 40 +++++ keras/src/layers/core/reversible_embedding.py | 96 +++++++--- .../layers/core/reversible_embedding_test.py | 53 ++++++ keras/src/models/model.py | 20 +-- keras/src/quantizers/__init__.py | 13 +- keras/src/quantizers/gptq_config.py | 8 +- keras/src/quantizers/gptq_test.py | 19 +- keras/src/quantizers/quantization_config.py | 166 ++++++++++++++++++ .../quantizers/quantization_config_test.py | 106 +++++++++++ keras/src/quantizers/quantizers.py | 9 +- 15 files changed, 721 insertions(+), 111 deletions(-) create mode 100644 keras/src/quantizers/quantization_config.py create mode 100644 keras/src/quantizers/quantization_config_test.py diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index 7066a1a6dd9a..87fd606d1151 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -11,6 +11,8 @@ from keras.src.api_export import keras_export from keras.src.layers.input_spec import InputSpec from keras.src.layers.layer import Layer +from keras.src.quantizers.quantization_config import QuantizationConfig +from keras.src.quantizers.quantization_config import validate_and_resolve_config from keras.src.quantizers.quantizers import dequantize_with_sz_map @@ -378,9 +380,9 @@ def variable_serialization_spec(self): def quantized_build(self, kernel_shape, mode, config=None): if mode == "int8": - self._int8_build(kernel_shape) + self._int8_build(kernel_shape, config) elif mode == "int4": - self._int4_build(kernel_shape) + self._int4_build(kernel_shape, config) elif mode == "float8": self._float8_build() elif mode == "gptq": @@ -389,8 +391,13 @@ def quantized_build(self, kernel_shape, mode, config=None): raise self._quantization_mode_error(mode) self._is_quantized = True - def _int8_build(self, kernel_shape): - self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) + def _int8_build(self, kernel_shape, config=None): + self.inputs_quantizer = ( + QuantizationConfig.activation_quantizer_or_default( + config, quantizers.AbsMaxQuantizer(axis=-1) + ) + ) + self._kernel = self.add_weight( name="kernel", shape=kernel_shape, @@ -489,7 +496,7 @@ def _gptq_call(self, inputs, training=False): y = self.activation(y) return y - def _int4_build(self, kernel_shape): + def _int4_build(self, kernel_shape, config=None): """Build variables for int4 quantization. `kernel_shape` is the *original* float32 kernel shape @@ -498,8 +505,10 @@ def _int4_build(self, kernel_shape): int8 byte. """ # Per-channel int8 quantizer for the last axis (features). - self.inputs_quantizer = quantizers.AbsMaxQuantizer( - axis=-1, + self.inputs_quantizer = ( + QuantizationConfig.activation_quantizer_or_default( + config, quantizers.AbsMaxQuantizer(axis=-1) + ) ) input_dim, output_dim = kernel_shape packed_rows = (input_dim + 1) // 2 # ceil for odd dims @@ -588,7 +597,15 @@ def grad_fn(*args, upstream=None): inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel)) return (inputs_grad, None, None) - inputs, inputs_scale = self.inputs_quantizer(inputs) + if self.inputs_quantizer: + inputs, inputs_scale = self.inputs_quantizer(inputs) + else: + # Weight-only quantization: inputs are not quantized + # We still need inputs_scale for the formula: + # x = x / (inputs_scale * kernel_scale) + # If inputs are not quantized, inputs_scale should be 1. + inputs_scale = ops.ones((1,), dtype=self.compute_dtype) + x = ops.matmul(inputs, kernel) # De-scale outputs x = ops.cast(x, self.compute_dtype) @@ -639,7 +656,10 @@ def grad_fn(*args, upstream=None): inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel)) return (inputs_grad, None, None) - inputs, inputs_scale = self.inputs_quantizer(inputs) + if self.inputs_quantizer: + inputs, inputs_scale = self.inputs_quantizer(inputs) + else: + inputs_scale = ops.ones((1,), dtype=self.compute_dtype) x = ops.matmul(inputs, unpacked_kernel) x = ops.cast(x, self.compute_dtype) x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) @@ -759,25 +779,33 @@ def quantize(self, mode, type_check=True, config=None): if type_check and (type(self) is not Dense): raise self._not_implemented_error(self.quantize) + config = validate_and_resolve_config(mode, config) + mode = config.mode + kernel_shape = self._kernel.shape if mode == "int8": - kernel_value, kernel_scale = quantizers.abs_max_quantize( - self._kernel, axis=0, to_numpy=True + weight_quantizer = QuantizationConfig.weight_quantizer_or_default( + config, quantizers.AbsMaxQuantizer(axis=0) + ) + kernel_value, kernel_scale = weight_quantizer( + self._kernel, to_numpy=True ) kernel_scale = ops.squeeze(kernel_scale, axis=0) del self._kernel # Build variables for int8 mode - self.quantized_build(kernel_shape, mode) + self.quantized_build(kernel_shape, mode, config) self._kernel.assign(kernel_value) self.kernel_scale.assign(kernel_scale) elif mode == "int4": # 1. Quantize to int4 values (still int8 dtype, range [-8,7]) - kernel_value_int4, kernel_scale = quantizers.abs_max_quantize( - self._kernel, - axis=0, - value_range=(-8, 7), - dtype="int8", - to_numpy=True, + weight_quantizer = QuantizationConfig.weight_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer( + axis=0, value_range=(-8, 7), output_dtype="int8" + ), + ) + kernel_value_int4, kernel_scale = weight_quantizer( + self._kernel, to_numpy=True ) kernel_scale = ops.squeeze(kernel_scale, axis=0) # 2. Pack two int4 values into a single int8 byte. @@ -785,7 +813,7 @@ def quantize(self, mode, type_check=True, config=None): del self._kernel # Build variables using the original kernel shape; _int4_build will # compute the packed shape internally. - self.quantized_build(kernel_shape, mode) + self.quantized_build(kernel_shape, mode, config) # Assign packed values. self._kernel.assign(packed_kernel_value) self.kernel_scale.assign(kernel_scale) diff --git a/keras/src/layers/core/dense_test.py b/keras/src/layers/core/dense_test.py index 6550367a3faf..a4a25e60e081 100644 --- a/keras/src/layers/core/dense_test.py +++ b/keras/src/layers/core/dense_test.py @@ -17,9 +17,67 @@ from keras.src import testing from keras.src.backend.common import keras_tensor from keras.src.quantizers.gptq_config import GPTQConfig +from keras.src.quantizers.quantization_config import Int4QuantizationConfig +from keras.src.quantizers.quantization_config import Int8QuantizationConfig +from keras.src.quantizers.quantizers import AbsMaxQuantizer class DenseTest(testing.TestCase): + @parameterized.named_parameters( + ("int8", "int8", {"axis": 0}, {"axis": -1}), + ( + "int4", + "int4", + {"axis": 0, "value_range": (-8, 7), "output_dtype": "int8"}, + {"axis": -1}, + ), + ("int8_weight_only", "int8", {"axis": 0}, None), + ) + def test_dense_quantize_config( + self, mode, weight_quantizer_args, activation_quantizer_args + ): + """Test Dense quantization with QuantizationConfig.""" + layer = layers.Dense(units=32) + layer.build((None, 8)) + + weight_quantizer = AbsMaxQuantizer(**weight_quantizer_args) + if activation_quantizer_args is not None: + activation_quantizer = AbsMaxQuantizer(**activation_quantizer_args) + else: + activation_quantizer = None + + if mode == "int8": + config = Int8QuantizationConfig( + weight_quantizer=weight_quantizer, + activation_quantizer=activation_quantizer, + ) + elif mode == "int4": + config = Int4QuantizationConfig( + weight_quantizer=weight_quantizer, + activation_quantizer=activation_quantizer, + ) + + layer.quantize(mode, config=config) + + if activation_quantizer_args is not None: + # Verify inputs_quantizer is set correctly + self.assertIsInstance(layer.inputs_quantizer, AbsMaxQuantizer) + self.assertEqual(layer.inputs_quantizer.axis, (-1,)) + else: + # Verify inputs_quantizer is None + self.assertIsNone(layer.inputs_quantizer) + + # Verify call works + x = np.random.random((2, 8)).astype("float32") + y = layer(x) + self.assertEqual(y.shape, (2, 32)) + + if mode == "int4": + # Verify kernel is int8 (packed int4) + self.assertEqual( + backend.standardize_dtype(layer._kernel.dtype), "int8" + ) + @pytest.mark.requires_trainable_backend def test_dense_basics(self): # 2D case, no bias. diff --git a/keras/src/layers/core/einsum_dense.py b/keras/src/layers/core/einsum_dense.py index 23d98fe3ec04..1676cdac70f8 100644 --- a/keras/src/layers/core/einsum_dense.py +++ b/keras/src/layers/core/einsum_dense.py @@ -15,6 +15,7 @@ from keras.src.api_export import keras_export from keras.src.layers.input_spec import InputSpec from keras.src.layers.layer import Layer +from keras.src.quantizers.quantization_config import QuantizationConfig from keras.src.quantizers.quantizers import dequantize_with_sz_map @@ -444,9 +445,9 @@ def variable_serialization_spec(self): def quantized_build(self, kernel_shape, mode, config=None): if mode == "int8": - self._int8_build(kernel_shape) + self._int8_build(kernel_shape, config) elif mode == "int4": - self._int4_build(kernel_shape) + self._int4_build(kernel_shape, config) elif mode == "float8": self._float8_build() elif mode == "gptq": @@ -455,10 +456,13 @@ def quantized_build(self, kernel_shape, mode, config=None): raise self._quantization_mode_error(mode) self._is_quantized = True - def _int8_build(self, kernel_shape): + def _int8_build(self, kernel_shape, config=None): self._set_quantization_info() - self.inputs_quantizer = quantizers.AbsMaxQuantizer( - axis=self._input_reduced_axes + self.inputs_quantizer = ( + QuantizationConfig.activation_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer(axis=self._input_reduced_axes), + ) ) self._kernel = self.add_weight( name="kernel", @@ -591,7 +595,7 @@ def _gptq_call(self, inputs, training=False): y = self.activation(y) return y - def _int4_build(self, kernel_shape): + def _int4_build(self, kernel_shape, config=None): """Build variables for int4 quantization. The packed int4 kernel stores two int4 values within a single int8 @@ -603,8 +607,11 @@ def _int4_build(self, kernel_shape): self._set_quantization_info() # Quantizer for the inputs (per the reduced axes) - self.inputs_quantizer = quantizers.AbsMaxQuantizer( - axis=self._input_reduced_axes + self.inputs_quantizer = ( + QuantizationConfig.activation_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer(axis=self._input_reduced_axes), + ) ) # Choose the axis to perform int4 packing - use the first reduced axis @@ -727,10 +734,16 @@ def grad_fn(*args, upstream=None): ) return (inputs_grad, None, None) - inputs, inputs_scale = self.inputs_quantizer(inputs) + if self.inputs_quantizer: + inputs, inputs_scale = self.inputs_quantizer(inputs) + # Align `inputs_scale` axes with the output + # for correct broadcasting + inputs_scale = self._adjust_scale_for_quant( + inputs_scale, "input" + ) + else: + inputs_scale = ops.ones((1,), dtype=self.compute_dtype) x = ops.einsum(self.equation, inputs, kernel) - # Deal with `inputs_scale` - inputs_scale = self._adjust_scale_for_quant(inputs_scale, "input") # De-scale outputs x = ops.cast(x, self.compute_dtype) x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) @@ -803,14 +816,20 @@ def grad_fn(*args, upstream=None): return (inputs_grad, None, None) # Quantize inputs per `self.inputs_quantizer`. - inputs_q, inputs_scale = self.inputs_quantizer(inputs) + if self.inputs_quantizer: + inputs_q, inputs_scale = self.inputs_quantizer(inputs) + # Align `inputs_scale` axes with the output + # for correct broadcasting + inputs_scale = self._adjust_scale_for_quant( + inputs_scale, "input" + ) + else: + inputs_q = inputs + inputs_scale = ops.ones((1,), dtype=self.compute_dtype) # Compute einsum on quantized inputs and unpacked int4 kernel. x = ops.einsum(self.equation, inputs_q, unpacked_kernel) - # Align `inputs_scale` axes with the output for correct broadcasting - inputs_scale = self._adjust_scale_for_quant(inputs_scale, "input") - # De-scale outputs. x = ops.cast(x, self.compute_dtype) x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) @@ -938,19 +957,27 @@ def quantize(self, mode, type_check=True, config=None): if mode == "int8": # Quantize `self._kernel` to int8 and compute corresponding scale - kernel_value, kernel_scale = quantizers.abs_max_quantize( - self._kernel, axis=self._kernel_reduced_axes, to_numpy=True + weight_quantizer = QuantizationConfig.weight_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer(axis=self._kernel_reduced_axes), + ) + kernel_value, kernel_scale = weight_quantizer( + self._kernel, to_numpy=True ) kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel") del self._kernel elif mode == "int4": # Quantize to int4 values (stored in int8 dtype, range [-8, 7]) - kernel_value_int4, kernel_scale = quantizers.abs_max_quantize( - self._kernel, - axis=self._kernel_reduced_axes, - value_range=(-8, 7), - dtype="int8", - to_numpy=True, + weight_quantizer = QuantizationConfig.weight_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer( + axis=self._kernel_reduced_axes, + value_range=(-8, 7), + output_dtype="int8", + ), + ) + kernel_value_int4, kernel_scale = weight_quantizer( + self._kernel, to_numpy=True ) kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel") diff --git a/keras/src/layers/core/einsum_dense_test.py b/keras/src/layers/core/einsum_dense_test.py index 92496f5f9d7a..e1d0308088b6 100644 --- a/keras/src/layers/core/einsum_dense_test.py +++ b/keras/src/layers/core/einsum_dense_test.py @@ -16,9 +16,71 @@ from keras.src import saving from keras.src import testing from keras.src.quantizers.gptq_config import GPTQConfig +from keras.src.quantizers.quantization_config import Int4QuantizationConfig +from keras.src.quantizers.quantization_config import Int8QuantizationConfig +from keras.src.quantizers.quantizers import AbsMaxQuantizer class EinsumDenseTest(testing.TestCase): + @parameterized.named_parameters( + ("int8", "int8", {"axis": 0}, {"axis": -1}), + ( + "int4", + "int4", + {"axis": 0, "value_range": (-8, 7), "output_dtype": "int8"}, + {"axis": -1}, + ), + ("int8_weight_only", "int8", {"axis": 0}, None), + ) + def test_einsum_dense_quantize( + self, mode, weight_quantizer_args, activation_quantizer_args + ): + """Test EinsumDense quantization with QuantizationConfig.""" + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + layer.build((None, 3)) + + weight_quantizer = AbsMaxQuantizer(**weight_quantizer_args) + if activation_quantizer_args is not None: + activation_quantizer = AbsMaxQuantizer(**activation_quantizer_args) + else: + activation_quantizer = None + + if mode == "int8": + config = Int8QuantizationConfig( + weight_quantizer=weight_quantizer, + activation_quantizer=activation_quantizer, + ) + elif mode == "int4": + config = Int4QuantizationConfig( + weight_quantizer=weight_quantizer, + activation_quantizer=activation_quantizer, + ) + + layer.quantize(mode, config=config) + + if activation_quantizer_args is not None: + # Verify inputs_quantizer is set correctly + self.assertIsInstance(layer.inputs_quantizer, AbsMaxQuantizer) + self.assertEqual(layer.inputs_quantizer.axis, (-1,)) + else: + # Verify inputs_quantizer is None + self.assertIsNone(layer.inputs_quantizer) + + # Verify call works + x = np.random.random((2, 3)).astype("float32") + y = layer(x) + self.assertEqual(y.shape, (2, 8, 32)) + + if mode == "int4": + # Verify kernel is int8 (packed int4) + self.assertEqual( + backend.standardize_dtype(layer._kernel.dtype), "int8" + ) + @parameterized.named_parameters( { "testcase_name": "_1d_end_weight", diff --git a/keras/src/layers/core/embedding.py b/keras/src/layers/core/embedding.py index c1cb3b6b0117..f4ee2d4e52d9 100644 --- a/keras/src/layers/core/embedding.py +++ b/keras/src/layers/core/embedding.py @@ -10,6 +10,8 @@ from keras.src.api_export import keras_export from keras.src.backend import KerasTensor from keras.src.layers.layer import Layer +from keras.src.quantizers.quantization_config import QuantizationConfig +from keras.src.quantizers.quantization_config import validate_and_resolve_config @keras_export("keras.layers.Embedding") @@ -315,16 +317,16 @@ def variable_serialization_spec(self): ], } - def quantized_build(self, embeddings_shape, mode): + def quantized_build(self, embeddings_shape, mode, config=None): if mode == "int8": - self._int8_build(embeddings_shape) + self._int8_build(embeddings_shape, config) elif mode == "int4": - self._int4_build(embeddings_shape) + self._int4_build(embeddings_shape, config) else: raise self._quantization_mode_error(mode) self._is_quantized = True - def _int8_build(self, embeddings_shape): + def _int8_build(self, embeddings_shape, config=None): self._embeddings = self.add_weight( name="embeddings", shape=embeddings_shape, @@ -342,7 +344,7 @@ def _int8_build(self, embeddings_shape): trainable=False, ) - def _int4_build(self, embeddings_shape): + def _int4_build(self, embeddings_shape, config=None): input_dim, output_dim = embeddings_shape packed_rows = (output_dim + 1) // 2 # ceil for odd dims @@ -412,26 +414,37 @@ def quantize(self, mode, type_check=True, config=None): if type_check and (type(self) is not Embedding): raise self._not_implemented_error(self.quantize) + config = validate_and_resolve_config(mode, config) + mode = config.mode + embeddings_shape = (self.input_dim, self.output_dim) if mode == "int8": # Quantize `self._embeddings` to int8 and compute corresponding # scale. - embeddings_value, embeddings_scale = quantizers.abs_max_quantize( - self._embeddings, axis=-1, to_numpy=True + weight_quantizer = QuantizationConfig.weight_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer(axis=-1), + ) + embeddings_value, embeddings_scale = weight_quantizer( + self._embeddings, to_numpy=True ) embeddings_scale = ops.squeeze(embeddings_scale, axis=-1) del self._embeddings - self.quantized_build(embeddings_shape, mode) + self.quantized_build(embeddings_shape, mode, config) self._embeddings.assign(embeddings_value) self.embeddings_scale.assign(embeddings_scale) elif mode == "int4": # Quantize to int4 values (stored in int8 dtype, range [-8, 7]). - embeddings_value, embeddings_scale = quantizers.abs_max_quantize( - self._embeddings, - axis=-1, - value_range=(-8, 7), - dtype="int8", - to_numpy=True, + weight_quantizer = QuantizationConfig.weight_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer( + axis=-1, + value_range=(-8, 7), + output_dtype="int8", + ), + ) + embeddings_value, embeddings_scale = weight_quantizer( + self._embeddings, to_numpy=True ) embeddings_scale = ops.squeeze(embeddings_scale, axis=-1) # 2. Pack two int4 values into a single int8 byte. @@ -439,7 +452,7 @@ def quantize(self, mode, type_check=True, config=None): embeddings_value, axis=-1 ) del self._embeddings - self.quantized_build(embeddings_shape, mode) + self.quantized_build(embeddings_shape, mode, config) self._embeddings.assign(packed_embeddings_value) self.embeddings_scale.assign(embeddings_scale) else: diff --git a/keras/src/layers/core/embedding_test.py b/keras/src/layers/core/embedding_test.py index 68b4ca1d9c15..4337fda7cd6a 100644 --- a/keras/src/layers/core/embedding_test.py +++ b/keras/src/layers/core/embedding_test.py @@ -12,10 +12,50 @@ from keras.src import ops from keras.src import quantizers from keras.src import saving +from keras.src.quantizers.quantization_config import Int4QuantizationConfig +from keras.src.quantizers.quantization_config import Int8QuantizationConfig +from keras.src.quantizers.quantizers import AbsMaxQuantizer from keras.src.testing import test_case class EmbeddingTest(test_case.TestCase): + @parameterized.named_parameters( + ("int8", "int8", {"axis": -1}), + ( + "int4", + "int4", + {"axis": -1, "value_range": (-8, 7), "output_dtype": "int8"}, + ), + ("int8_custom", "int8", {"axis": -1}), + ) + def test_embedding_quantize_config(self, mode, weight_quantizer_args): + """Test Embedding quantization with QuantizationConfig.""" + layer = layers.Embedding(input_dim=10, output_dim=6) + layer.build((None,)) + + weight_quantizer = AbsMaxQuantizer(**weight_quantizer_args) + if mode == "int8": + config = Int8QuantizationConfig( + weight_quantizer=weight_quantizer, activation_quantizer=None + ) + elif mode == "int4": + config = Int4QuantizationConfig( + weight_quantizer=weight_quantizer, activation_quantizer=None + ) + + layer.quantize(mode, config=config) + + # Verify weights are quantized + self.assertEqual( + backend.standardize_dtype(layer._embeddings.dtype), "int8" + ) + self.assertTrue(hasattr(layer, "embeddings_scale")) + + # Verify call works + x = np.random.randint(0, 10, size=(2, 3)) + y = layer(x) + self.assertEqual(y.shape, (2, 3, 6)) + @pytest.mark.requires_trainable_backend def test_embedding_basics(self): self.run_layer_test( diff --git a/keras/src/layers/core/reversible_embedding.py b/keras/src/layers/core/reversible_embedding.py index ae8ea8f4c4f7..41b0f88b0aea 100644 --- a/keras/src/layers/core/reversible_embedding.py +++ b/keras/src/layers/core/reversible_embedding.py @@ -6,6 +6,8 @@ from keras.src import quantizers from keras.src.api_export import keras_export from keras.src.backend import KerasTensor +from keras.src.quantizers.quantization_config import QuantizationConfig +from keras.src.quantizers.quantization_config import validate_and_resolve_config @keras_export("keras.layers.ReversibleEmbedding") @@ -172,20 +174,25 @@ def variable_serialization_spec(self): variable_spec.append("reverse_embeddings_scale") return _spec - def quantized_build(self, embeddings_shape, mode): + def quantized_build(self, embeddings_shape, mode, config=None): if mode == "int8": - self._int8_build(embeddings_shape) + self._int8_build(embeddings_shape, config) elif mode == "int4": - self._int4_build(embeddings_shape) + self._int4_build(embeddings_shape, config) else: raise self._quantization_mode_error(mode) self._is_quantized = True - def _int8_build(self, embeddings_shape): + def _int8_build(self, embeddings_shape, config=None): if embeddings_shape is None: embeddings_shape = (self.input_dim, self.output_dim) super()._int8_build(embeddings_shape=embeddings_shape) - self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) + + self.inputs_quantizer = ( + QuantizationConfig.activation_quantizer_or_default( + config, quantizers.AbsMaxQuantizer(axis=-1) + ) + ) if not self.tie_weights: self.reverse_embeddings = self.add_weight( name="reverse_embeddings", @@ -201,11 +208,16 @@ def _int8_build(self, embeddings_shape): trainable=False, ) - def _int4_build(self, embeddings_shape): + def _int4_build(self, embeddings_shape, config=None): if embeddings_shape is None: embeddings_shape = (self.input_dim, self.output_dim) - super()._int4_build(embeddings_shape=embeddings_shape) - self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) + super()._int4_build(embeddings_shape=embeddings_shape, config=config) + + self.inputs_quantizer = ( + QuantizationConfig.activation_quantizer_or_default( + config, quantizers.AbsMaxQuantizer(axis=-1) + ) + ) if not self.tie_weights: packed_rows = (self.output_dim + 1) // 2 # ceil for odd dims self.reverse_embeddings = self.add_weight( @@ -232,7 +244,10 @@ def _int8_call(self, inputs, reverse=False): else: kernel = self.reverse_embeddings scale = self.reverse_embeddings_scale - inputs, inputs_scale = self.inputs_quantizer(inputs) + if self.inputs_quantizer: + inputs, inputs_scale = self.inputs_quantizer(inputs) + else: + inputs_scale = ops.ones((1,), dtype=self.compute_dtype) logits = ops.matmul(inputs, kernel) # De-scale outputs logits = ops.cast(logits, self.compute_dtype) @@ -258,7 +273,10 @@ def _int4_call(self, inputs, reverse=False): unpacked_embeddings = quantizers.unpack_int4( embeddings, self.output_dim, axis=0 ) - inputs, inputs_scale = self.inputs_quantizer(inputs) + if self.inputs_quantizer: + inputs, inputs_scale = self.inputs_quantizer(inputs) + else: + inputs_scale = ops.ones((1,), dtype=self.compute_dtype) logits = ops.matmul(inputs, unpacked_embeddings) # De-scale outputs logits = ops.cast(logits, self.compute_dtype) @@ -272,30 +290,40 @@ def _int4_call(self, inputs, reverse=False): return logits def quantize(self, mode, type_check=True, config=None): - del config if type_check and type(self) is not ReversibleEmbedding: raise self._not_implemented_error(self.quantize) + config = validate_and_resolve_config(mode, config) + mode = config.mode + embeddings_shape = (self.input_dim, self.output_dim) if mode == "int8": # Quantize `self._embeddings` to int8 and compute corresponding # scale. - embeddings_value, embeddings_scale = quantizers.abs_max_quantize( - self._embeddings, axis=-1, to_numpy=True + weight_quantizer = QuantizationConfig.weight_quantizer_or_default( + config, quantizers.AbsMaxQuantizer(axis=-1) + ) + embeddings_value, embeddings_scale = weight_quantizer( + self._embeddings, to_numpy=True ) embeddings_scale = ops.squeeze(embeddings_scale, axis=-1) del self._embeddings if not self.tie_weights: + reverse_weight_quantizer = ( + QuantizationConfig.weight_quantizer_or_default( + config, quantizers.AbsMaxQuantizer(axis=0) + ) + ) reverse_embeddings_value, reverse_embeddings_scale = ( - quantizers.abs_max_quantize( - self.reverse_embeddings, axis=0, to_numpy=True + reverse_weight_quantizer( + self.reverse_embeddings, to_numpy=True ) ) reverse_embeddings_scale = ops.squeeze( reverse_embeddings_scale, axis=0 ) del self.reverse_embeddings - self.quantized_build(embeddings_shape, mode) + self.quantized_build(embeddings_shape, mode, config) self._embeddings.assign(embeddings_value) self.embeddings_scale.assign(embeddings_scale) if not self.tie_weights: @@ -303,12 +331,16 @@ def quantize(self, mode, type_check=True, config=None): self.reverse_embeddings_scale.assign(reverse_embeddings_scale) elif mode == "int4": # Quantize to int4 values (stored in int8 dtype, range [-8, 7]). - embeddings_value, embeddings_scale = quantizers.abs_max_quantize( - self._embeddings, - axis=-1, - value_range=(-8, 7), - dtype="int8", - to_numpy=True, + weight_quantizer = QuantizationConfig.weight_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer( + axis=-1, + value_range=(-8, 7), + output_dtype="int8", + ), + ) + embeddings_value, embeddings_scale = weight_quantizer( + self._embeddings, to_numpy=True ) embeddings_scale = ops.squeeze(embeddings_scale, axis=-1) # 2. Pack two int4 values into a single int8 byte. @@ -317,13 +349,19 @@ def quantize(self, mode, type_check=True, config=None): ) del self._embeddings if not self.tie_weights: + reverse_weight_quantizer = ( + QuantizationConfig.weight_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer( + axis=0, + value_range=(-8, 7), + output_dtype="int8", + ), + ) + ) reverse_embeddings_value, reverse_embeddings_scale = ( - quantizers.abs_max_quantize( - self.reverse_embeddings, - axis=0, - value_range=(-8, 7), - dtype="int8", - to_numpy=True, + reverse_weight_quantizer( + self.reverse_embeddings, to_numpy=True ) ) reverse_embeddings_scale = ops.squeeze( @@ -334,7 +372,7 @@ def quantize(self, mode, type_check=True, config=None): reverse_embeddings_value, axis=0 ) del self.reverse_embeddings - self.quantized_build(embeddings_shape, mode) + self.quantized_build(embeddings_shape, mode, config) self._embeddings.assign(packed_embeddings_value) self.embeddings_scale.assign(embeddings_scale) if not self.tie_weights: diff --git a/keras/src/layers/core/reversible_embedding_test.py b/keras/src/layers/core/reversible_embedding_test.py index 043c734aea01..95822ea45a2d 100644 --- a/keras/src/layers/core/reversible_embedding_test.py +++ b/keras/src/layers/core/reversible_embedding_test.py @@ -9,11 +9,64 @@ from keras.src import models from keras.src import ops from keras.src import saving +from keras.src.quantizers.quantization_config import Int4QuantizationConfig +from keras.src.quantizers.quantization_config import Int8QuantizationConfig +from keras.src.quantizers.quantizers import AbsMaxQuantizer from keras.src.testing import test_case from keras.src.testing.test_utils import named_product class ReversibleEmbeddingTest(test_case.TestCase): + @parameterized.named_parameters( + ("int8", "int8", {"axis": -1}, {"axis": -1}), + ( + "int4", + "int4", + {"axis": -1, "value_range": (-8, 7), "output_dtype": "int8"}, + {"axis": -1}, + ), + ("int8_weight_only", "int8", {"axis": -1}, None), + ) + def test_reversible_embedding_quantize( + self, mode, weight_quantizer_args, activation_quantizer_args + ): + """Test ReversibleEmbedding quantization with QuantizationConfig.""" + layer = layers.ReversibleEmbedding( + input_dim=10, output_dim=6, tie_weights=True + ) + layer.build((None,)) + + weight_quantizer = AbsMaxQuantizer(**weight_quantizer_args) + if activation_quantizer_args is not None: + activation_quantizer = AbsMaxQuantizer(**activation_quantizer_args) + else: + activation_quantizer = None + + if mode == "int8": + config = Int8QuantizationConfig( + weight_quantizer=weight_quantizer, + activation_quantizer=activation_quantizer, + ) + elif mode == "int4": + config = Int4QuantizationConfig( + weight_quantizer=weight_quantizer, + activation_quantizer=activation_quantizer, + ) + + layer.quantize(mode, config=config) + + if activation_quantizer_args is not None: + # Verify inputs_quantizer is set correctly + self.assertIsInstance(layer.inputs_quantizer, AbsMaxQuantizer) + else: + # Verify inputs_quantizer is None + self.assertIsNone(layer.inputs_quantizer) + + # Verify reverse call works + x = np.random.random((2, 6)).astype("float32") + y = layer(x, reverse=True) + self.assertEqual(y.shape, (2, 10)) + @parameterized.named_parameters( ("tie_weights", True), ("untie_weights", False), diff --git a/keras/src/models/model.py b/keras/src/models/model.py index 37f4b3bef7ef..2c30ef580221 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -9,8 +9,8 @@ from keras.src.api_export import keras_export from keras.src.layers.layer import Layer from keras.src.models.variable_mapping import map_saveable_variables -from keras.src.quantizers.gptq_config import GPTQConfig from keras.src.quantizers.gptq_core import gptq_quantize +from keras.src.quantizers.quantization_config import validate_and_resolve_config from keras.src.quantizers.utils import should_quantize_layer from keras.src.saving import saving_api from keras.src.trainers import trainer as base_trainer @@ -424,7 +424,7 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs): **kwargs, ) - def get_quantization_layer_structure(self, mode): + def get_quantization_layer_structure(self, mode=None): """Returns the quantization structure for the model. This method is intended to be overridden by model authors to provide @@ -464,8 +464,6 @@ def quantize(self, mode, config=None, filters=None, **kwargs): layers which match the filter conditions will be quantized. """ - from keras.src.dtype_policies import QUANTIZATION_MODES - # Validate inputs. type_check = kwargs.pop("type_check", True) if kwargs: @@ -488,18 +486,8 @@ def quantize(self, mode, config=None, filters=None, **kwargs): f"{type(filters)}" ) - if mode == "gptq": - if not isinstance(config, GPTQConfig): - raise ValueError( - "Mode 'gptq' requires a valid `config` argument of type " - f"`GPTQConfig`. Received: {type(config)}" - ) - elif config is not None: - # All other modes must not receive a config - raise ValueError( - f"The `config` argument is only supported for 'gptq' mode, " - f"but received mode='{mode}' and a non-None config." - ) + config = validate_and_resolve_config(mode, config) + mode = config.mode graph_modified = False for layer in self._flatten_layers(): diff --git a/keras/src/quantizers/__init__.py b/keras/src/quantizers/__init__.py index 586530204588..1e80a9cb7dc3 100644 --- a/keras/src/quantizers/__init__.py +++ b/keras/src/quantizers/__init__.py @@ -1,6 +1,10 @@ import inspect from keras.src.api_export import keras_export +from keras.src.quantizers.quantization_config import Float8QuantizationConfig +from keras.src.quantizers.quantization_config import Int4QuantizationConfig +from keras.src.quantizers.quantization_config import Int8QuantizationConfig +from keras.src.quantizers.quantization_config import QuantizationConfig from keras.src.quantizers.quantizers import AbsMaxQuantizer from keras.src.quantizers.quantizers import Quantizer from keras.src.quantizers.quantizers import abs_max_quantize @@ -13,7 +17,14 @@ from keras.src.saving import serialization_lib from keras.src.utils.naming import to_snake_case -ALL_OBJECTS = {Quantizer, AbsMaxQuantizer} +ALL_OBJECTS = { + Quantizer, + AbsMaxQuantizer, + QuantizationConfig, + Int8QuantizationConfig, + Int4QuantizationConfig, + Float8QuantizationConfig, +} ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} ALL_OBJECTS_DICT.update( {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS} diff --git a/keras/src/quantizers/gptq_config.py b/keras/src/quantizers/gptq_config.py index edcb465ce4c2..0ea159b56548 100644 --- a/keras/src/quantizers/gptq_config.py +++ b/keras/src/quantizers/gptq_config.py @@ -1,8 +1,9 @@ from keras.src.api_export import keras_export +from keras.src.quantizers.quantization_config import QuantizationConfig @keras_export("keras.quantizers.GPTQConfig") -class GPTQConfig: +class GPTQConfig(QuantizationConfig): """Configuration class for the GPTQ (Gradient-based Post-Training Quantization) algorithm. @@ -154,6 +155,7 @@ def __init__( activation_order: bool = False, quantization_layer_structure: dict = None, ): + super().__init__() if weight_bits not in [2, 3, 4, 8]: raise ValueError( f"Unsupported weight_bits {weight_bits}. " @@ -183,6 +185,10 @@ def __init__( self.activation_order = activation_order self.quantization_layer_structure = quantization_layer_structure + @property + def mode(self): + return "gptq" + def dtype_policy_string(self): """Returns the dtype policy string for this configuration. diff --git a/keras/src/quantizers/gptq_test.py b/keras/src/quantizers/gptq_test.py index d6fe0048ac3f..d53022724bb6 100644 --- a/keras/src/quantizers/gptq_test.py +++ b/keras/src/quantizers/gptq_test.py @@ -14,6 +14,7 @@ from keras.src.quantizers.gptq import _stable_permutation from keras.src.quantizers.gptq import gptq_quantize_matrix from keras.src.quantizers.gptq_config import GPTQConfig +from keras.src.quantizers.quantization_config import QuantizationConfig from keras.src.quantizers.quantizers import dequantize_with_sz_map from keras.src.quantizers.quantizers import dequantize_with_zero_point from keras.src.quantizers.quantizers import quantize_with_zero_point @@ -621,18 +622,26 @@ def test_quantize_gptq_combinations(self, dataset, config): @parameterized.named_parameters( { - "testcase_name": "gptq_with_invalid_config", + "testcase_name": "gptq_with_invalid_config_type", "mode": "gptq", "config": {"weight_bits": 4}, "expected_exception": ValueError, + "error_msg": "Argument `config` must be an instance of " + "`QuantizationConfig`", + }, + { + "testcase_name": "gptq_with_none_config", + "mode": "gptq", + "config": None, + "expected_exception": ValueError, "error_msg": "Mode 'gptq' requires a valid `config`", }, { - "testcase_name": "non_gptq_with_unsupported_config", - "mode": "int8", - "config": GPTQConfig(dataset=["a"], tokenizer=lambda x: x), + "testcase_name": "gptq_with_base_quantization_config", + "mode": "gptq", + "config": QuantizationConfig(), "expected_exception": ValueError, - "error_msg": "only supported for 'gptq'", + "error_msg": "Mode 'gptq' requires a valid `config`", }, { "testcase_name": "gptq_missing_structure", diff --git a/keras/src/quantizers/quantization_config.py b/keras/src/quantizers/quantization_config.py new file mode 100644 index 000000000000..357d806c6c9c --- /dev/null +++ b/keras/src/quantizers/quantization_config.py @@ -0,0 +1,166 @@ +from keras.src.api_export import keras_export +from keras.src.dtype_policies import QUANTIZATION_MODES +from keras.src.saving import serialization_lib + + +@keras_export("keras.quantizers.QuantizationConfig") +class QuantizationConfig: + def __init__(self, weight_quantizer=None, activation_quantizer=None): + self.weight_quantizer = weight_quantizer + self.activation_quantizer = activation_quantizer + + @property + def mode(self): + raise NotImplementedError + + def get_config(self): + return { + "weight_quantizer": serialization_lib.serialize_keras_object( + self.weight_quantizer + ), + "activation_quantizer": serialization_lib.serialize_keras_object( + self.activation_quantizer + ), + } + + @classmethod + def from_config(cls, config): + weight_quantizer = serialization_lib.deserialize_keras_object( + config.get("weight_quantizer") + ) + activation_quantizer = serialization_lib.deserialize_keras_object( + config.get("activation_quantizer") + ) + return cls( + weight_quantizer=weight_quantizer, + activation_quantizer=activation_quantizer, + ) + + @staticmethod + def weight_quantizer_or_default(config, default): + if config and config.weight_quantizer: + return config.weight_quantizer + return default + + @staticmethod + def activation_quantizer_or_default(config, default): + if config and config.activation_quantizer: + return config.activation_quantizer + elif config and config.activation_quantizer is None: + return None + return default + + +@keras_export("keras.quantizers.Int8QuantizationConfig") +class Int8QuantizationConfig(QuantizationConfig): + def __init__(self, weight_quantizer=None, activation_quantizer="default"): + from keras.src.quantizers.quantizers import AbsMaxQuantizer + + if activation_quantizer == "default": + activation_quantizer = AbsMaxQuantizer(axis=-1) + super().__init__(weight_quantizer, activation_quantizer) + if self.weight_quantizer: + if hasattr(self.weight_quantizer, "value_range"): + if self.weight_quantizer.value_range != (-127, 127): + raise ValueError( + "Int8QuantizationConfig requires a weight_quantizer " + "with value_range=(-127, 127). Received: " + f"value_range={self.weight_quantizer.value_range}" + ) + + @property + def mode(self): + return "int8" + + +@keras_export("keras.quantizers.Int4QuantizationConfig") +class Int4QuantizationConfig(QuantizationConfig): + def __init__(self, weight_quantizer=None, activation_quantizer="default"): + from keras.src.quantizers.quantizers import AbsMaxQuantizer + + if activation_quantizer == "default": + activation_quantizer = AbsMaxQuantizer(axis=-1) + super().__init__(weight_quantizer, activation_quantizer) + if self.weight_quantizer: + if hasattr(self.weight_quantizer, "value_range"): + if self.weight_quantizer.value_range != (-8, 7): + raise ValueError( + "Int4QuantizationConfig requires a weight_quantizer " + "with value_range=(-8, 7). Received: " + f"value_range={self.weight_quantizer.value_range}" + ) + + @property + def mode(self): + return "int4" + + +@keras_export("keras.quantizers.Float8QuantizationConfig") +class Float8QuantizationConfig(QuantizationConfig): + def __init__(self, weight_quantizer=None, activation_quantizer=None): + super().__init__(weight_quantizer, activation_quantizer) + + @property + def mode(self): + return "float8" + + +def validate_and_resolve_config(mode, config, name=None): + # 1. Backwards Compatibility: Handle string shortcuts + if isinstance(config, str): + mode = config + config = None + + # 2. Resolve "mode" into a Config object + if config is None: + if mode == "int8": + config = Int8QuantizationConfig() + elif mode == "int4": + config = Int4QuantizationConfig() + elif mode == "float8": + config = Float8QuantizationConfig() + elif mode == "gptq": + raise ValueError( + "For GPTQ, you must pass a GPTQConfig object explicitly." + ) + else: + if mode is not None: + raise ValueError( + f"Invalid quantization mode. Received: mode={mode}" + ) + raise ValueError( + "You must provide either `mode` or `config` to `quantize`." + ) + else: + if not isinstance(config, QuantizationConfig): + raise ValueError( + "Argument `config` must be an instance of " + "`QuantizationConfig`. " + f"Received: config={config} (of type {type(config)})" + ) + + # 3. Validation: Prevent contradictions + if mode is not None and config.mode != mode: + raise ValueError( + f"Contradictory arguments: mode='{mode}' but " + f"config.mode='{config.mode}'" + ) + + # 4. Execution + mode = config.mode # Ensure mode is consistent + if mode not in QUANTIZATION_MODES: + raise ValueError( + "Invalid quantization mode. " + f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}" + ) + + if mode == "gptq": + from keras.src.quantizers.gptq_config import GPTQConfig + + if not isinstance(config, GPTQConfig): + raise ValueError( + "Mode 'gptq' requires a valid `config` argument of type " + f"`GPTQConfig`. Received: {type(config)}" + ) + + return config diff --git a/keras/src/quantizers/quantization_config_test.py b/keras/src/quantizers/quantization_config_test.py new file mode 100644 index 000000000000..f7c94c6cc2e4 --- /dev/null +++ b/keras/src/quantizers/quantization_config_test.py @@ -0,0 +1,106 @@ +from keras.src import testing +from keras.src.quantizers.quantization_config import Int4QuantizationConfig +from keras.src.quantizers.quantization_config import Int8QuantizationConfig +from keras.src.quantizers.quantization_config import QuantizationConfig +from keras.src.quantizers.quantization_config import validate_and_resolve_config +from keras.src.quantizers.quantizers import AbsMaxQuantizer + + +class QuantizationConfigTest(testing.TestCase): + def test_base_quantization_config(self): + config = QuantizationConfig() + with self.assertRaises(NotImplementedError): + _ = config.mode + + def test_int8_quantization_config_valid(self): + config = Int8QuantizationConfig() + self.assertEqual(config.mode, "int8") + self.assertIsNone(config.weight_quantizer) + + # Valid weight quantizer + q = AbsMaxQuantizer(axis=0, value_range=(-127, 127)) + config = Int8QuantizationConfig(weight_quantizer=q) + self.assertEqual(config.weight_quantizer, q) + + def test_int8_quantization_config_invalid(self): + # Invalid value_range + q = AbsMaxQuantizer(axis=0, value_range=(-8, 7)) + with self.assertRaisesRegex(ValueError, "value_range"): + Int8QuantizationConfig(weight_quantizer=q) + + def test_int4_quantization_config_valid(self): + config = Int4QuantizationConfig() + self.assertEqual(config.mode, "int4") + self.assertIsNone(config.weight_quantizer) + + # Valid weight quantizer + q = AbsMaxQuantizer(axis=0, value_range=(-8, 7)) + config = Int4QuantizationConfig(weight_quantizer=q) + self.assertEqual(config.weight_quantizer, q) + + def test_int4_quantization_config_invalid(self): + # Invalid value_range + q = AbsMaxQuantizer(axis=0, value_range=(-127, 127)) + with self.assertRaisesRegex(ValueError, "value_range"): + Int4QuantizationConfig(weight_quantizer=q) + + def test_quantization_config_serialization(self): + config = Int8QuantizationConfig( + weight_quantizer=AbsMaxQuantizer(axis=0), + activation_quantizer=AbsMaxQuantizer(axis=-1), + ) + serialized = config.get_config() + deserialized = Int8QuantizationConfig.from_config(serialized) + self.assertIsInstance(deserialized, Int8QuantizationConfig) + self.assertIsInstance(deserialized.weight_quantizer, AbsMaxQuantizer) + self.assertIsInstance( + deserialized.activation_quantizer, AbsMaxQuantizer + ) + self.assertEqual(deserialized.weight_quantizer.axis, (0,)) + self.assertEqual(deserialized.activation_quantizer.axis, (-1,)) + + def test_validate_and_resolve_config(self): + # 1. String mode + config = validate_and_resolve_config("int8", None) + self.assertIsInstance(config, Int8QuantizationConfig) + self.assertEqual(config.mode, "int8") + + config = validate_and_resolve_config("int4", None) + self.assertIsInstance(config, Int4QuantizationConfig) + self.assertEqual(config.mode, "int4") + + # 2. Config object + config_in = Int8QuantizationConfig() + config_out = validate_and_resolve_config(None, config_in) + self.assertIs(config_out, config_in) + + # 3. Mode + Config (matching) + config_in = Int8QuantizationConfig() + config_out = validate_and_resolve_config("int8", config_in) + self.assertIs(config_out, config_in) + + # 4. Mode + Config (mismatch) + config_in = Int8QuantizationConfig() + with self.assertRaisesRegex(ValueError, "Contradictory arguments"): + validate_and_resolve_config("int4", config_in) + + # 5. Invalid mode + with self.assertRaisesRegex(ValueError, "Invalid quantization mode"): + validate_and_resolve_config("invalid_mode", None) + + # 6. GPTQ without config + with self.assertRaisesRegex(ValueError, "must pass a GPTQConfig"): + validate_and_resolve_config("gptq", None) + + # 7. Contradictory config + with self.assertRaisesRegex(ValueError, "Contradictory arguments"): + validate_and_resolve_config("gptq", Int8QuantizationConfig()) + + # 8. GPTQ with invalid config type (but correct mode) + class FakeGPTQConfig(QuantizationConfig): + @property + def mode(self): + return "gptq" + + with self.assertRaisesRegex(ValueError, "requires a valid `config`"): + validate_and_resolve_config("gptq", FakeGPTQConfig()) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index d9ef671b6fc9..708a143504c9 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -117,9 +117,14 @@ def __init__( self.value_range = value_range self.epsilon = epsilon - def __call__(self, x): + def __call__(self, x, to_numpy=False): quantized_x, scale = abs_max_quantize( - x, self.axis, self.value_range, self.output_dtype, self.epsilon + x, + self.axis, + self.value_range, + self.output_dtype, + self.epsilon, + to_numpy, ) return quantized_x, scale From 8b22843a4a2072c815ac77dfb92f4edd2e00b48a Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Mon, 8 Dec 2025 10:08:11 +0530 Subject: [PATCH 2/8] mixed precision einsum fix for torch + fixed tf/jax tests --- keras/src/layers/core/einsum_dense.py | 57 ++++++++++++++++----- keras/src/layers/core/einsum_dense_test.py | 6 +++ keras/src/quantizers/gptq_test.py | 8 +-- keras/src/quantizers/quantization_config.py | 5 +- 4 files changed, 58 insertions(+), 18 deletions(-) diff --git a/keras/src/layers/core/einsum_dense.py b/keras/src/layers/core/einsum_dense.py index 1676cdac70f8..110b96efa096 100644 --- a/keras/src/layers/core/einsum_dense.py +++ b/keras/src/layers/core/einsum_dense.py @@ -6,6 +6,7 @@ import numpy as np from keras.src import activations +from keras.src import backend from keras.src import constraints from keras.src import dtype_policies from keras.src import initializers @@ -741,12 +742,27 @@ def grad_fn(*args, upstream=None): inputs_scale = self._adjust_scale_for_quant( inputs_scale, "input" ) + x = ops.einsum(self.equation, inputs, kernel) + # De-scale outputs + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) else: - inputs_scale = ops.ones((1,), dtype=self.compute_dtype) - x = ops.einsum(self.equation, inputs, kernel) - # De-scale outputs - x = ops.cast(x, self.compute_dtype) - x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + # Weight-only quantization: dequantize kernel and use float + # einsum. This is a workaround for PyTorch's einsum which + # doesn't support mixed-precision inputs (float input, + # int8 kernel). + if backend.backend() == "torch": + kernel_scale = self._adjust_scale_for_dequant(kernel_scale) + float_kernel = ops.divide( + ops.cast(kernel, dtype=self.compute_dtype), + kernel_scale, + ) + x = ops.einsum(self.equation, inputs, float_kernel) + else: + x = ops.einsum(self.equation, inputs, kernel) + # De-scale outputs + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, kernel_scale) return x, grad_fn x = einsum_with_inputs_gradient( @@ -823,16 +839,29 @@ def grad_fn(*args, upstream=None): inputs_scale = self._adjust_scale_for_quant( inputs_scale, "input" ) + x = ops.einsum(self.equation, inputs_q, unpacked_kernel) + # De-scale outputs + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) else: - inputs_q = inputs - inputs_scale = ops.ones((1,), dtype=self.compute_dtype) - - # Compute einsum on quantized inputs and unpacked int4 kernel. - x = ops.einsum(self.equation, inputs_q, unpacked_kernel) - - # De-scale outputs. - x = ops.cast(x, self.compute_dtype) - x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + # Weight-only quantization: dequantize kernel and use float + # einsum. This is a workaround for PyTorch's einsum which + # doesn't support mixed-precision inputs (float input, + # int4 kernel). + if backend.backend() == "torch": + # Align `kernel_scale` to the same layout as + # `unpacked_kernel`. + kernel_scale = self._adjust_scale_for_dequant(kernel_scale) + float_kernel = ops.divide( + ops.cast(unpacked_kernel, dtype=self.compute_dtype), + kernel_scale, + ) + x = ops.einsum(self.equation, inputs, float_kernel) + else: + x = ops.einsum(self.equation, inputs, unpacked_kernel) + # De-scale outputs + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, kernel_scale) return x, grad_fn x = einsum_with_inputs_gradient( diff --git a/keras/src/layers/core/einsum_dense_test.py b/keras/src/layers/core/einsum_dense_test.py index e1d0308088b6..4f7dfef9fd5b 100644 --- a/keras/src/layers/core/einsum_dense_test.py +++ b/keras/src/layers/core/einsum_dense_test.py @@ -31,6 +31,12 @@ class EinsumDenseTest(testing.TestCase): {"axis": -1}, ), ("int8_weight_only", "int8", {"axis": 0}, None), + ( + "int4_weight_only", + "int4", + {"axis": 0, "value_range": (-8, 7), "output_dtype": "int8"}, + None, + ), ) def test_einsum_dense_quantize( self, mode, weight_quantizer_args, activation_quantizer_args diff --git a/keras/src/quantizers/gptq_test.py b/keras/src/quantizers/gptq_test.py index d53022724bb6..a69479b8f9ff 100644 --- a/keras/src/quantizers/gptq_test.py +++ b/keras/src/quantizers/gptq_test.py @@ -634,14 +634,16 @@ def test_quantize_gptq_combinations(self, dataset, config): "mode": "gptq", "config": None, "expected_exception": ValueError, - "error_msg": "Mode 'gptq' requires a valid `config`", + "error_msg": "For GPTQ, you must pass a GPTQConfig " + "object explicitly.", }, { "testcase_name": "gptq_with_base_quantization_config", "mode": "gptq", "config": QuantizationConfig(), - "expected_exception": ValueError, - "error_msg": "Mode 'gptq' requires a valid `config`", + "expected_exception": NotImplementedError, + "error_msg": "Do not instantiate " + "QuantizationConfig directly.", }, { "testcase_name": "gptq_missing_structure", diff --git a/keras/src/quantizers/quantization_config.py b/keras/src/quantizers/quantization_config.py index 357d806c6c9c..ebf593e54aed 100644 --- a/keras/src/quantizers/quantization_config.py +++ b/keras/src/quantizers/quantization_config.py @@ -11,7 +11,10 @@ def __init__(self, weight_quantizer=None, activation_quantizer=None): @property def mode(self): - raise NotImplementedError + raise NotImplementedError( + "Subclasses must implement this property. Do not instantiate " + "QuantizationConfig directly." + ) def get_config(self): return { From c3195f770f78f291a60e95e32ac10220f253d881 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Mon, 8 Dec 2025 10:23:57 +0530 Subject: [PATCH 3/8] fixed minor errors + api export --- keras/api/_tf_keras/keras/quantizers/__init__.py | 12 ++++++++++++ keras/api/quantizers/__init__.py | 12 ++++++++++++ keras/src/models/model.py | 1 + keras/src/quantizers/gptq_test.py | 3 +-- 4 files changed, 26 insertions(+), 2 deletions(-) diff --git a/keras/api/_tf_keras/keras/quantizers/__init__.py b/keras/api/_tf_keras/keras/quantizers/__init__.py index 299e467ac1bb..205183264c03 100644 --- a/keras/api/_tf_keras/keras/quantizers/__init__.py +++ b/keras/api/_tf_keras/keras/quantizers/__init__.py @@ -8,6 +8,18 @@ from keras.src.quantizers import get as get from keras.src.quantizers import serialize as serialize from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig +from keras.src.quantizers.quantization_config import ( + Float8QuantizationConfig as Float8QuantizationConfig, +) +from keras.src.quantizers.quantization_config import ( + Int4QuantizationConfig as Int4QuantizationConfig, +) +from keras.src.quantizers.quantization_config import ( + Int8QuantizationConfig as Int8QuantizationConfig, +) +from keras.src.quantizers.quantization_config import ( + QuantizationConfig as QuantizationConfig, +) from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer from keras.src.quantizers.quantizers import Quantizer as Quantizer from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize diff --git a/keras/api/quantizers/__init__.py b/keras/api/quantizers/__init__.py index 299e467ac1bb..205183264c03 100644 --- a/keras/api/quantizers/__init__.py +++ b/keras/api/quantizers/__init__.py @@ -8,6 +8,18 @@ from keras.src.quantizers import get as get from keras.src.quantizers import serialize as serialize from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig +from keras.src.quantizers.quantization_config import ( + Float8QuantizationConfig as Float8QuantizationConfig, +) +from keras.src.quantizers.quantization_config import ( + Int4QuantizationConfig as Int4QuantizationConfig, +) +from keras.src.quantizers.quantization_config import ( + Int8QuantizationConfig as Int8QuantizationConfig, +) +from keras.src.quantizers.quantization_config import ( + QuantizationConfig as QuantizationConfig, +) from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer from keras.src.quantizers.quantizers import Quantizer as Quantizer from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize diff --git a/keras/src/models/model.py b/keras/src/models/model.py index 2c30ef580221..5b671401ee5f 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -7,6 +7,7 @@ from keras.src import backend from keras.src import utils from keras.src.api_export import keras_export +from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES from keras.src.layers.layer import Layer from keras.src.models.variable_mapping import map_saveable_variables from keras.src.quantizers.gptq_core import gptq_quantize diff --git a/keras/src/quantizers/gptq_test.py b/keras/src/quantizers/gptq_test.py index a69479b8f9ff..a2af07c27155 100644 --- a/keras/src/quantizers/gptq_test.py +++ b/keras/src/quantizers/gptq_test.py @@ -642,8 +642,7 @@ def test_quantize_gptq_combinations(self, dataset, config): "mode": "gptq", "config": QuantizationConfig(), "expected_exception": NotImplementedError, - "error_msg": "Do not instantiate " - "QuantizationConfig directly.", + "error_msg": "Do not instantiate QuantizationConfig directly.", }, { "testcase_name": "gptq_missing_structure", From 255fe0e2d6deba89b34d3ab7750d11272f5abc8d Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Mon, 8 Dec 2025 10:36:26 +0530 Subject: [PATCH 4/8] Removed redundant matmuls + added docs --- keras/src/layers/core/dense.py | 18 ++++---- keras/src/quantizers/quantization_config.py | 48 ++++++++++++++++++++- 2 files changed, 54 insertions(+), 12 deletions(-) diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index 87fd606d1151..cbe7bc59e0f9 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -597,19 +597,15 @@ def grad_fn(*args, upstream=None): inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel)) return (inputs_grad, None, None) + output_scale = kernel_scale if self.inputs_quantizer: inputs, inputs_scale = self.inputs_quantizer(inputs) - else: - # Weight-only quantization: inputs are not quantized - # We still need inputs_scale for the formula: - # x = x / (inputs_scale * kernel_scale) - # If inputs are not quantized, inputs_scale should be 1. - inputs_scale = ops.ones((1,), dtype=self.compute_dtype) + output_scale = ops.multiply(output_scale, inputs_scale) x = ops.matmul(inputs, kernel) # De-scale outputs x = ops.cast(x, self.compute_dtype) - x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + x = ops.divide(x, output_scale) return x, grad_fn x = matmul_with_inputs_gradient( @@ -656,13 +652,15 @@ def grad_fn(*args, upstream=None): inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel)) return (inputs_grad, None, None) + output_scale = kernel_scale + if self.inputs_quantizer: inputs, inputs_scale = self.inputs_quantizer(inputs) - else: - inputs_scale = ops.ones((1,), dtype=self.compute_dtype) + output_scale = ops.multiply(output_scale, inputs_scale) + x = ops.matmul(inputs, unpacked_kernel) x = ops.cast(x, self.compute_dtype) - x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + x = ops.divide(x, output_scale) return x, grad_fn x = matmul_with_inputs_gradient( diff --git a/keras/src/quantizers/quantization_config.py b/keras/src/quantizers/quantization_config.py index ebf593e54aed..362bc866e67d 100644 --- a/keras/src/quantizers/quantization_config.py +++ b/keras/src/quantizers/quantization_config.py @@ -5,6 +5,16 @@ @keras_export("keras.quantizers.QuantizationConfig") class QuantizationConfig: + """Base class for quantization configs. + + Subclasses must implement the `mode` property and the `get_config` and + `from_config` class methods. + + Args: + weight_quantizer: Quantizer for weights. + activation_quantizer: Quantizer for activations. + """ + def __init__(self, weight_quantizer=None, activation_quantizer=None): self.weight_quantizer = weight_quantizer self.activation_quantizer = activation_quantizer @@ -56,6 +66,14 @@ def activation_quantizer_or_default(config, default): @keras_export("keras.quantizers.Int8QuantizationConfig") class Int8QuantizationConfig(QuantizationConfig): + """Int8 quantization config. + + Args: + weight_quantizer: Quantizer for weights. + activation_quantizer: Quantizer for activations. If "default", uses + AbsMaxQuantizer with axis=-1. + """ + def __init__(self, weight_quantizer=None, activation_quantizer="default"): from keras.src.quantizers.quantizers import AbsMaxQuantizer @@ -78,6 +96,14 @@ def mode(self): @keras_export("keras.quantizers.Int4QuantizationConfig") class Int4QuantizationConfig(QuantizationConfig): + """Int4 quantization config. + + Args: + weight_quantizer: Quantizer for weights. + activation_quantizer: Quantizer for activations. If "default", uses + AbsMaxQuantizer with axis=-1. + """ + def __init__(self, weight_quantizer=None, activation_quantizer="default"): from keras.src.quantizers.quantizers import AbsMaxQuantizer @@ -100,8 +126,15 @@ def mode(self): @keras_export("keras.quantizers.Float8QuantizationConfig") class Float8QuantizationConfig(QuantizationConfig): - def __init__(self, weight_quantizer=None, activation_quantizer=None): - super().__init__(weight_quantizer, activation_quantizer) + """FP8 quantization config. + + FP8 mixed-precision training does not support user defined quantizers. + This config is only used to indicate that FP8 mixed-precision training + should be used. + """ + + def __init__(self): + super().__init__(None, None) @property def mode(self): @@ -109,6 +142,17 @@ def mode(self): def validate_and_resolve_config(mode, config, name=None): + """Validate and resolve quantization config. + + This function validates the quantization config and resolves the mode. + If mode is not provided, it is inferred from the config. + If config is not provided, a default config is inferred from the mode. + + Args: + mode: Quantization mode. + config: Quantization config. + name: Name of the quantization config. + """ # 1. Backwards Compatibility: Handle string shortcuts if isinstance(config, str): mode = config From e4b891d9f5e2e84834129c9c9470743a68ad07f7 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Wed, 10 Dec 2025 13:02:18 +0530 Subject: [PATCH 5/8] minor cleanup + docstring improvements --- keras/src/models/model.py | 69 ++++++++++++++++++--- keras/src/quantizers/quantization_config.py | 8 +-- 2 files changed, 65 insertions(+), 12 deletions(-) diff --git a/keras/src/models/model.py b/keras/src/models/model.py index 5b671401ee5f..7e78fd049de7 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -448,23 +448,76 @@ def get_quantization_layer_structure(self, mode=None): del mode # Unused. return None - def quantize(self, mode, config=None, filters=None, **kwargs): + def quantize(self, mode=None, config=None, filters=None, **kwargs): """Quantize the weights of the model. Note that the model must be built first before calling this method. - `quantize` will recursively call `quantize(mode)` in all layers and + `quantize` will recursively call `quantize(...)` in all layers and will be skipped if the layer doesn't implement the function. + This method can be called by passing a `mode` string, which uses the + default configuration for that mode. Alternatively, a `config` object + can be passed to customize the behavior of the quantization (e.g. to + use specific quantizers for weights or activations). + Args: - mode: The mode of the quantization. Supported modes are: 'int4', - 'int8', 'float8', 'gptq'. + mode: The mode of the quantization. Supported modes are: + `"int8"`, `"int4"`, `"float8"`, `"gptq"`. This is + optional if `config` is provided. config: The configuration object specifying additional - quantization options for supported modes. + quantization options. This argument allows to configure + the weight and activation quantizers. be an instance of + `keras.quantizers.QuantizationConfig`. filters: Optional filters to apply to the quantization. Can be a - regex string, a list of regex strings, or a callable. Only the - layers which match the filter conditions will be quantized. - """ + regex string, a list of regex strings, or a callable. Only the + layers which match the filter conditions will be quantized. + **kwargs: Additional keyword arguments. + + Example: + + Quantize a model to int8 with default configuration: + ```python + # Build the model + >>> model = keras.Sequential([ + keras.Input(shape=(10,)), + keras.layers.Dense(10), + ]) + >>> model.build((None, 10)) + + # Quantize with default int8 config + >>> model.quantize("int8") + ``` + + Quantize a model to int8 with a custom configuration: + + ```python + from keras.quantizers import Int8QuantizationConfig + from keras.quantizers import AbsMaxQuantizer + + # Build the model + >>> model = keras.Sequential([ + keras.Input(shape=(10,)), + keras.layers.Dense(10), + ]) + >>> model.build((None, 10)) + + # Create a custom config + >>> config = Int8QuantizationConfig( + weight_quantizer=AbsMaxQuantizer( + axis=0, + value_range=(-127, 127) + ), + activation_quantizer=AbsMaxQuantizer( + axis=-1, + value_range=(-127, 127) + ), + ) + + # Quantize with custom config + >>> model.quantize(config=config) + ``` + """ # Validate inputs. type_check = kwargs.pop("type_check", True) if kwargs: diff --git a/keras/src/quantizers/quantization_config.py b/keras/src/quantizers/quantization_config.py index 362bc866e67d..c9f1053025ab 100644 --- a/keras/src/quantizers/quantization_config.py +++ b/keras/src/quantizers/quantization_config.py @@ -141,7 +141,7 @@ def mode(self): return "float8" -def validate_and_resolve_config(mode, config, name=None): +def validate_and_resolve_config(mode, config): """Validate and resolve quantization config. This function validates the quantization config and resolves the mode. @@ -151,7 +151,6 @@ def validate_and_resolve_config(mode, config, name=None): Args: mode: Quantization mode. config: Quantization config. - name: Name of the quantization config. """ # 1. Backwards Compatibility: Handle string shortcuts if isinstance(config, str): @@ -193,8 +192,9 @@ def validate_and_resolve_config(mode, config, name=None): f"config.mode='{config.mode}'" ) - # 4. Execution - mode = config.mode # Ensure mode is consistent + # Ensure mode is consistent + mode = config.mode + if mode not in QUANTIZATION_MODES: raise ValueError( "Invalid quantization mode. " From fe61e8fd5ea4561e1fcd897d5fac1a003457f46d Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Mon, 15 Dec 2025 13:46:05 +0530 Subject: [PATCH 6/8] address comments --- keras/src/models/model.py | 18 +++++----- keras/src/quantizers/gptq_test.py | 4 +-- keras/src/quantizers/quantization_config.py | 39 ++++++++++++--------- 3 files changed, 34 insertions(+), 27 deletions(-) diff --git a/keras/src/models/model.py b/keras/src/models/model.py index 7e78fd049de7..14dc639dee2d 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -479,14 +479,14 @@ def quantize(self, mode=None, config=None, filters=None, **kwargs): ```python # Build the model - >>> model = keras.Sequential([ + model = keras.Sequential([ keras.Input(shape=(10,)), keras.layers.Dense(10), - ]) - >>> model.build((None, 10)) + ]) + model.build((None, 10)) # Quantize with default int8 config - >>> model.quantize("int8") + model.quantize("int8") ``` Quantize a model to int8 with a custom configuration: @@ -496,14 +496,14 @@ def quantize(self, mode=None, config=None, filters=None, **kwargs): from keras.quantizers import AbsMaxQuantizer # Build the model - >>> model = keras.Sequential([ + model = keras.Sequential([ keras.Input(shape=(10,)), keras.layers.Dense(10), - ]) - >>> model.build((None, 10)) + ]) + model.build((None, 10)) # Create a custom config - >>> config = Int8QuantizationConfig( + config = Int8QuantizationConfig( weight_quantizer=AbsMaxQuantizer( axis=0, value_range=(-127, 127) @@ -515,7 +515,7 @@ def quantize(self, mode=None, config=None, filters=None, **kwargs): ) # Quantize with custom config - >>> model.quantize(config=config) + model.quantize(config=config) ``` """ # Validate inputs. diff --git a/keras/src/quantizers/gptq_test.py b/keras/src/quantizers/gptq_test.py index a2af07c27155..1d0598df3a44 100644 --- a/keras/src/quantizers/gptq_test.py +++ b/keras/src/quantizers/gptq_test.py @@ -634,8 +634,8 @@ def test_quantize_gptq_combinations(self, dataset, config): "mode": "gptq", "config": None, "expected_exception": ValueError, - "error_msg": "For GPTQ, you must pass a GPTQConfig " - "object explicitly.", + "error_msg": "For GPTQ, you must pass a `GPTQConfig` object " + "in the `config` argument.", }, { "testcase_name": "gptq_with_base_quantization_config", diff --git a/keras/src/quantizers/quantization_config.py b/keras/src/quantizers/quantization_config.py index c9f1053025ab..b62d3dbc6c88 100644 --- a/keras/src/quantizers/quantization_config.py +++ b/keras/src/quantizers/quantization_config.py @@ -51,16 +51,14 @@ def from_config(cls, config): @staticmethod def weight_quantizer_or_default(config, default): - if config and config.weight_quantizer: + if config is not None and config.weight_quantizer is not None: return config.weight_quantizer return default @staticmethod def activation_quantizer_or_default(config, default): - if config and config.activation_quantizer: + if config is not None: return config.activation_quantizer - elif config and config.activation_quantizer is None: - return None return default @@ -80,7 +78,7 @@ def __init__(self, weight_quantizer=None, activation_quantizer="default"): if activation_quantizer == "default": activation_quantizer = AbsMaxQuantizer(axis=-1) super().__init__(weight_quantizer, activation_quantizer) - if self.weight_quantizer: + if self.weight_quantizer is not None: if hasattr(self.weight_quantizer, "value_range"): if self.weight_quantizer.value_range != (-127, 127): raise ValueError( @@ -110,7 +108,7 @@ def __init__(self, weight_quantizer=None, activation_quantizer="default"): if activation_quantizer == "default": activation_quantizer = AbsMaxQuantizer(axis=-1) super().__init__(weight_quantizer, activation_quantizer) - if self.weight_quantizer: + if self.weight_quantizer is not None: if hasattr(self.weight_quantizer, "value_range"): if self.weight_quantizer.value_range != (-8, 7): raise ValueError( @@ -152,12 +150,14 @@ def validate_and_resolve_config(mode, config): mode: Quantization mode. config: Quantization config. """ - # 1. Backwards Compatibility: Handle string shortcuts + # 1. Backwards Compatibility: Handle string shortcuts. if isinstance(config, str): mode = config config = None - # 2. Resolve "mode" into a Config object + _validate_mode(mode) + + # 2. Resolve "mode" into a Config object. if config is None: if mode == "int8": config = Int8QuantizationConfig() @@ -167,7 +167,8 @@ def validate_and_resolve_config(mode, config): config = Float8QuantizationConfig() elif mode == "gptq": raise ValueError( - "For GPTQ, you must pass a GPTQConfig object explicitly." + "For GPTQ, you must pass a `GPTQConfig` object in the " + "`config` argument." ) else: if mode is not None: @@ -185,21 +186,18 @@ def validate_and_resolve_config(mode, config): f"Received: config={config} (of type {type(config)})" ) - # 3. Validation: Prevent contradictions + # 3. Validation: Prevent contradictions. if mode is not None and config.mode != mode: raise ValueError( f"Contradictory arguments: mode='{mode}' but " f"config.mode='{config.mode}'" ) - # Ensure mode is consistent + # Ensure mode is consistent. mode = config.mode - if mode not in QUANTIZATION_MODES: - raise ValueError( - "Invalid quantization mode. " - f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}" - ) + # Ensure the mode derived from the config is valid. + _validate_mode(mode) if mode == "gptq": from keras.src.quantizers.gptq_config import GPTQConfig @@ -211,3 +209,12 @@ def validate_and_resolve_config(mode, config): ) return config + + +def _validate_mode(mode): + """Validates quantization mode.""" + if mode is not None and mode not in QUANTIZATION_MODES: + raise ValueError( + "Invalid quantization mode. " + f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}" + ) From 45bcb7f2777e7a7d234765f15bb5ecdebfd6a1a3 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Mon, 15 Dec 2025 14:14:16 +0530 Subject: [PATCH 7/8] refactor validation --- keras/src/quantizers/quantization_config.py | 15 +++++++++++---- .../quantizers/quantization_config_test.py | 19 ++++++++++++++++--- keras/src/quantizers/quantizers.py | 7 +++++++ 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/keras/src/quantizers/quantization_config.py b/keras/src/quantizers/quantization_config.py index b62d3dbc6c88..6ef8e4e041a9 100644 --- a/keras/src/quantizers/quantization_config.py +++ b/keras/src/quantizers/quantization_config.py @@ -79,12 +79,12 @@ def __init__(self, weight_quantizer=None, activation_quantizer="default"): activation_quantizer = AbsMaxQuantizer(axis=-1) super().__init__(weight_quantizer, activation_quantizer) if self.weight_quantizer is not None: - if hasattr(self.weight_quantizer, "value_range"): - if self.weight_quantizer.value_range != (-127, 127): + if hasattr(self.weight_quantizer, "output_dtype"): + if self.weight_quantizer.output_dtype != "int8": raise ValueError( "Int8QuantizationConfig requires a weight_quantizer " - "with value_range=(-127, 127). Received: " - f"value_range={self.weight_quantizer.value_range}" + "with output_dtype='int8'. Received: " + f"output_dtype={self.weight_quantizer.output_dtype}" ) @property @@ -116,6 +116,13 @@ def __init__(self, weight_quantizer=None, activation_quantizer="default"): "with value_range=(-8, 7). Received: " f"value_range={self.weight_quantizer.value_range}" ) + if hasattr(self.weight_quantizer, "output_dtype"): + if self.weight_quantizer.output_dtype != "int8": + raise ValueError( + "Int4QuantizationConfig requires a weight_quantizer " + "with output_dtype='int8'. Received: " + f"output_dtype={self.weight_quantizer.output_dtype}" + ) @property def mode(self): diff --git a/keras/src/quantizers/quantization_config_test.py b/keras/src/quantizers/quantization_config_test.py index f7c94c6cc2e4..cd6803401ead 100644 --- a/keras/src/quantizers/quantization_config_test.py +++ b/keras/src/quantizers/quantization_config_test.py @@ -24,9 +24,8 @@ def test_int8_quantization_config_valid(self): def test_int8_quantization_config_invalid(self): # Invalid value_range - q = AbsMaxQuantizer(axis=0, value_range=(-8, 7)) with self.assertRaisesRegex(ValueError, "value_range"): - Int8QuantizationConfig(weight_quantizer=q) + AbsMaxQuantizer(axis=0, value_range=(-256, 256)) def test_int4_quantization_config_valid(self): config = Int4QuantizationConfig() @@ -89,7 +88,7 @@ def test_validate_and_resolve_config(self): validate_and_resolve_config("invalid_mode", None) # 6. GPTQ without config - with self.assertRaisesRegex(ValueError, "must pass a GPTQConfig"): + with self.assertRaisesRegex(ValueError, "must pass a `GPTQConfig`"): validate_and_resolve_config("gptq", None) # 7. Contradictory config @@ -104,3 +103,17 @@ def mode(self): with self.assertRaisesRegex(ValueError, "requires a valid `config`"): validate_and_resolve_config("gptq", FakeGPTQConfig()) + + def test_int8_quantization_config_output_dtype_mismatch(self): + # Invalid output_dtype + q = AbsMaxQuantizer( + axis=0, value_range=(-127, 127), output_dtype="int16" + ) + with self.assertRaisesRegex(ValueError, "output_dtype='int8'"): + Int8QuantizationConfig(weight_quantizer=q) + + def test_int4_quantization_config_output_dtype_mismatch(self): + # Invalid output_dtype + q = AbsMaxQuantizer(axis=0, value_range=(-8, 7), output_dtype="int16") + with self.assertRaisesRegex(ValueError, "output_dtype='int8'"): + Int4QuantizationConfig(weight_quantizer=q) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 708a143504c9..f274600b6e9c 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -116,6 +116,13 @@ def __init__( self.axis = tuple(axis) self.value_range = value_range self.epsilon = epsilon + if output_dtype == "int8": + if value_range[0] < -128 or value_range[1] > 127: + raise ValueError( + f"Quantizer with output_dtype='int8' requires value_range " + f"to be within the interval [-128, 127]. Received: " + f"value_range={value_range}" + ) def __call__(self, x, to_numpy=False): quantized_x, scale = abs_max_quantize( From 7619fd6bf0b10b34666ddb2d6bb3844778c5428d Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Mon, 15 Dec 2025 14:19:28 +0530 Subject: [PATCH 8/8] make mode optional --- keras/src/layers/core/dense.py | 2 +- keras/src/layers/core/einsum_dense.py | 2 +- keras/src/layers/core/embedding.py | 2 +- keras/src/layers/core/reversible_embedding.py | 2 +- keras/src/layers/layer.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index cbe7bc59e0f9..4f78b73cd324 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -772,7 +772,7 @@ def grad(*args, upstream=None, variables=None): x = self.activation(x) return x - def quantize(self, mode, type_check=True, config=None): + def quantize(self, mode=None, type_check=True, config=None): # Prevent quantization of the subclasses if type_check and (type(self) is not Dense): raise self._not_implemented_error(self.quantize) diff --git a/keras/src/layers/core/einsum_dense.py b/keras/src/layers/core/einsum_dense.py index 110b96efa096..546fea67aad9 100644 --- a/keras/src/layers/core/einsum_dense.py +++ b/keras/src/layers/core/einsum_dense.py @@ -975,7 +975,7 @@ def grad(*args, upstream=None, variables=None): x = self.activation(x) return x - def quantize(self, mode, type_check=True, config=None): + def quantize(self, mode=None, type_check=True, config=None): # Prevent quantization of the subclasses if type_check and (type(self) is not EinsumDense): raise self._not_implemented_error(self.quantize) diff --git a/keras/src/layers/core/embedding.py b/keras/src/layers/core/embedding.py index f4ee2d4e52d9..2534a5c59b26 100644 --- a/keras/src/layers/core/embedding.py +++ b/keras/src/layers/core/embedding.py @@ -409,7 +409,7 @@ def _int4_call(self, inputs, training=None): ) return outputs - def quantize(self, mode, type_check=True, config=None): + def quantize(self, mode=None, type_check=True, config=None): # Prevent quantization of the subclasses. if type_check and (type(self) is not Embedding): raise self._not_implemented_error(self.quantize) diff --git a/keras/src/layers/core/reversible_embedding.py b/keras/src/layers/core/reversible_embedding.py index 41b0f88b0aea..2f4a24f16a9e 100644 --- a/keras/src/layers/core/reversible_embedding.py +++ b/keras/src/layers/core/reversible_embedding.py @@ -289,7 +289,7 @@ def _int4_call(self, inputs, reverse=False): ) return logits - def quantize(self, mode, type_check=True, config=None): + def quantize(self, mode=None, type_check=True, config=None): if type_check and type(self) is not ReversibleEmbedding: raise self._not_implemented_error(self.quantize) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 5922cf54e423..bff56c2f3525 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1277,7 +1277,7 @@ def _clear_losses(self): def quantized_build(self, input_shape, mode): raise self._not_implemented_error(self.quantized_build) - def quantize(self, mode, type_check=True, config=None): + def quantize(self, mode=None, type_check=True, config=None): raise self._not_implemented_error(self.quantize) def _check_quantize_args(self, mode, compute_dtype):