|
4 | 4 | import keras
|
5 | 5 | import numpy as np
|
6 | 6 | import pytest
|
| 7 | +from absl.testing import parameterized |
7 | 8 |
|
8 | 9 | from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier
|
9 | 10 | from keras_hub.src.models.causal_lm import CausalLM
|
@@ -107,48 +108,100 @@ def test_summary_without_preprocessor(self):
|
107 | 108 | model.summary(print_fn=lambda x, line_break=False: summary.append(x))
|
108 | 109 | self.assertNotRegex("\n".join(summary), "Preprocessor:")
|
109 | 110 |
|
110 |
| - # @pytest.mark.large |
111 |
| - def test_save_to_preset_with_quantization(self): |
| 111 | + @pytest.mark.large |
| 112 | + @parameterized.named_parameters( |
| 113 | + { |
| 114 | + "testcase_name": "load_with_quantized_weights", |
| 115 | + "load_weights": True, |
| 116 | + "dtype_override": None, |
| 117 | + "expected_dtype": "int8", |
| 118 | + }, |
| 119 | + { |
| 120 | + "testcase_name": "override_dtype_without_loading_weights", |
| 121 | + "load_weights": False, |
| 122 | + "dtype_override": "float32", |
| 123 | + "expected_dtype": "float32", |
| 124 | + }, |
| 125 | + ) |
| 126 | + def test_quantized_preset_loading_and_saving( |
| 127 | + self, load_weights, dtype_override, expected_dtype |
| 128 | + ): |
| 129 | + # Create, quantize, and save the model preset. |
112 | 130 | save_dir = self.get_temp_dir()
|
113 | 131 | task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2)
|
114 | 132 | task.quantize(mode="int8")
|
115 | 133 | task.save_to_preset(save_dir)
|
116 | 134 |
|
117 |
| - # Check existence of files. |
| 135 | + # Verify that all necessary files were created. |
118 | 136 | path = pathlib.Path(save_dir)
|
119 | 137 | self.assertTrue(os.path.exists(path / CONFIG_FILE))
|
120 | 138 | self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE))
|
121 | 139 | self.assertTrue(os.path.exists(path / METADATA_FILE))
|
122 | 140 | self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE))
|
123 | 141 | self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE))
|
124 | 142 |
|
125 |
| - # Check the task config (`task.json`). |
| 143 | + # Verify the contents of the task config file. |
126 | 144 | task_config = load_json(save_dir, TASK_CONFIG_FILE)
|
127 |
| - self.assertTrue("build_config" not in task_config) |
128 |
| - self.assertTrue("compile_config" not in task_config) |
129 |
| - self.assertTrue("backbone" in task_config["config"]) |
130 |
| - self.assertTrue("preprocessor" in task_config["config"]) |
131 |
| - |
132 |
| - # Check the preset directory task class. |
| 145 | + self.assertNotIn("build_config", task_config) |
| 146 | + self.assertNotIn("compile_config", task_config) |
| 147 | + self.assertIn("backbone", task_config["config"]) |
| 148 | + self.assertIn("preprocessor", task_config["config"]) |
133 | 149 | self.assertEqual(BertTextClassifier, check_config_class(task_config))
|
134 | 150 |
|
135 |
| - # Try loading the model from preset directory. |
136 |
| - restored_task = TextClassifier.from_preset(save_dir, num_classes=2) |
| 151 | + # Restore the task from the preset using parameterized arguments. |
| 152 | + restored_task = TextClassifier.from_preset( |
| 153 | + save_dir, |
| 154 | + num_classes=2, |
| 155 | + load_weights=load_weights, |
| 156 | + dtype=dtype_override, |
| 157 | + ) |
137 | 158 |
|
138 |
| - # Validate dtypes for quantized layers are in lower precision. |
| 159 | + # Check that the layers have the expected data type. |
139 | 160 | for layer in restored_task._flatten_layers():
|
140 | 161 | if isinstance(layer, keras.layers.Dense) and layer.name != "logits":
|
141 | 162 | self.assertEqual(
|
142 | 163 | layer.kernel.dtype,
|
143 |
| - "int8", |
144 |
| - f"{layer.name=} should be in lower precision (int8)", |
| 164 | + expected_dtype, |
| 165 | + f"Layer '{layer.name}' kernel " |
| 166 | + "should have dtype '{expected_dtype}'", |
145 | 167 | )
|
146 | 168 |
|
147 |
| - # Test whether inference works. |
| 169 | + # Ensure inference runs without errors. |
148 | 170 | data = ["the quick brown fox.", "the slow brown fox."]
|
149 |
| - |
150 | 171 | _ = restored_task.predict(data)
|
151 | 172 |
|
| 173 | + @pytest.mark.large |
| 174 | + def test_load_quantized_preset_with_dtype_override(self): |
| 175 | + save_dir = self.get_temp_dir() |
| 176 | + task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2) |
| 177 | + task.quantize(mode="int8") |
| 178 | + task.save_to_preset(save_dir) |
| 179 | + |
| 180 | + # Check existence of files. |
| 181 | + path = pathlib.Path(save_dir) |
| 182 | + self.assertTrue(os.path.exists(path / CONFIG_FILE)) |
| 183 | + self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE)) |
| 184 | + self.assertTrue(os.path.exists(path / METADATA_FILE)) |
| 185 | + self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE)) |
| 186 | + self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE)) |
| 187 | + |
| 188 | + # Check the task config (`task.json`). |
| 189 | + task_config = load_json(save_dir, TASK_CONFIG_FILE) |
| 190 | + self.assertTrue("build_config" not in task_config) |
| 191 | + self.assertTrue("compile_config" not in task_config) |
| 192 | + self.assertTrue("backbone" in task_config["config"]) |
| 193 | + self.assertTrue("preprocessor" in task_config["config"]) |
| 194 | + |
| 195 | + # Check the preset directory task class. |
| 196 | + self.assertEqual(BertTextClassifier, check_config_class(task_config)) |
| 197 | + |
| 198 | + # Loading the model in full-precision should cause an error during |
| 199 | + # initialization. The serialized quantized layers include additional |
| 200 | + # quantization specific weights (kernel_scale, etc.) which the |
| 201 | + # full-precision layer is not aware about and can't handle. |
| 202 | + with self.assertRaises(ValueError): |
| 203 | + TextClassifier.from_preset(save_dir, num_classes=2, dtype="float32") |
| 204 | + |
152 | 205 | @pytest.mark.large
|
153 | 206 | def test_save_to_preset(self):
|
154 | 207 | save_dir = self.get_temp_dir()
|
|
0 commit comments