Skip to content

Commit 430d7b9

Browse files
implements priority-based dtype resolution + tests
1 parent 84f26d2 commit 430d7b9

File tree

3 files changed

+124
-23
lines changed

3 files changed

+124
-23
lines changed

keras_hub/src/models/backbone.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,21 +91,24 @@ def get_config(self):
9191
}
9292

9393
# Add quantization support by utilizing `DTypePolicyMap`
94+
dtype = None
9495
try:
9596
if isinstance(
9697
self.dtype_policy, keras.dtype_policies.DTypePolicyMap
9798
):
98-
config.update({"dtype": self.dtype_policy})
99+
dtype = self.dtype_policy
99100
else:
100101
policy_map = keras.dtype_policies.DTypePolicyMap()
101102
for layer in self._flatten_layers():
102103
if layer.quantization_mode is not None:
103104
policy_map[layer.path] = layer.dtype_policy
104105
if len(policy_map) > 0:
105-
config.update({"dtype": policy_map})
106+
dtype = policy_map
106107
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
107108
except AttributeError:
108109
pass
110+
111+
config.update({"dtype": dtype})
109112
return config
110113

111114
@classmethod

keras_hub/src/models/task_test.py

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import keras
55
import numpy as np
66
import pytest
7+
from absl.testing import parameterized
78

89
from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier
910
from keras_hub.src.models.causal_lm import CausalLM
@@ -107,48 +108,100 @@ def test_summary_without_preprocessor(self):
107108
model.summary(print_fn=lambda x, line_break=False: summary.append(x))
108109
self.assertNotRegex("\n".join(summary), "Preprocessor:")
109110

110-
# @pytest.mark.large
111-
def test_save_to_preset_with_quantization(self):
111+
@pytest.mark.large
112+
@parameterized.named_parameters(
113+
{
114+
"testcase_name": "load_with_quantized_weights",
115+
"load_weights": True,
116+
"dtype_override": None,
117+
"expected_dtype": "int8",
118+
},
119+
{
120+
"testcase_name": "override_dtype_without_loading_weights",
121+
"load_weights": False,
122+
"dtype_override": "float32",
123+
"expected_dtype": "float32",
124+
},
125+
)
126+
def test_quantized_preset_loading_and_saving(
127+
self, load_weights, dtype_override, expected_dtype
128+
):
129+
# Create, quantize, and save the model preset.
112130
save_dir = self.get_temp_dir()
113131
task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2)
114132
task.quantize(mode="int8")
115133
task.save_to_preset(save_dir)
116134

117-
# Check existence of files.
135+
# Verify that all necessary files were created.
118136
path = pathlib.Path(save_dir)
119137
self.assertTrue(os.path.exists(path / CONFIG_FILE))
120138
self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE))
121139
self.assertTrue(os.path.exists(path / METADATA_FILE))
122140
self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE))
123141
self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE))
124142

125-
# Check the task config (`task.json`).
143+
# Verify the contents of the task config file.
126144
task_config = load_json(save_dir, TASK_CONFIG_FILE)
127-
self.assertTrue("build_config" not in task_config)
128-
self.assertTrue("compile_config" not in task_config)
129-
self.assertTrue("backbone" in task_config["config"])
130-
self.assertTrue("preprocessor" in task_config["config"])
131-
132-
# Check the preset directory task class.
145+
self.assertNotIn("build_config", task_config)
146+
self.assertNotIn("compile_config", task_config)
147+
self.assertIn("backbone", task_config["config"])
148+
self.assertIn("preprocessor", task_config["config"])
133149
self.assertEqual(BertTextClassifier, check_config_class(task_config))
134150

135-
# Try loading the model from preset directory.
136-
restored_task = TextClassifier.from_preset(save_dir, num_classes=2)
151+
# Restore the task from the preset using parameterized arguments.
152+
restored_task = TextClassifier.from_preset(
153+
save_dir,
154+
num_classes=2,
155+
load_weights=load_weights,
156+
dtype=dtype_override,
157+
)
137158

138-
# Validate dtypes for quantized layers are in lower precision.
159+
# Check that the layers have the expected data type.
139160
for layer in restored_task._flatten_layers():
140161
if isinstance(layer, keras.layers.Dense) and layer.name != "logits":
141162
self.assertEqual(
142163
layer.kernel.dtype,
143-
"int8",
144-
f"{layer.name=} should be in lower precision (int8)",
164+
expected_dtype,
165+
f"Layer '{layer.name}' kernel "
166+
"should have dtype '{expected_dtype}'",
145167
)
146168

147-
# Test whether inference works.
169+
# Ensure inference runs without errors.
148170
data = ["the quick brown fox.", "the slow brown fox."]
149-
150171
_ = restored_task.predict(data)
151172

173+
@pytest.mark.large
174+
def test_load_quantized_preset_with_dtype_override(self):
175+
save_dir = self.get_temp_dir()
176+
task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2)
177+
task.quantize(mode="int8")
178+
task.save_to_preset(save_dir)
179+
180+
# Check existence of files.
181+
path = pathlib.Path(save_dir)
182+
self.assertTrue(os.path.exists(path / CONFIG_FILE))
183+
self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE))
184+
self.assertTrue(os.path.exists(path / METADATA_FILE))
185+
self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE))
186+
self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE))
187+
188+
# Check the task config (`task.json`).
189+
task_config = load_json(save_dir, TASK_CONFIG_FILE)
190+
self.assertTrue("build_config" not in task_config)
191+
self.assertTrue("compile_config" not in task_config)
192+
self.assertTrue("backbone" in task_config["config"])
193+
self.assertTrue("preprocessor" in task_config["config"])
194+
195+
# Check the preset directory task class.
196+
self.assertEqual(BertTextClassifier, check_config_class(task_config))
197+
198+
# Loading the model in full-precision should cause an error during
199+
# initialization. The serialized quantized layers include additional
200+
# quantization specific weights (kernel_scale, etc.) which the
201+
# full-precision layer is not aware about and can't handle.
202+
with self.assertRaises(ValueError):
203+
TextClassifier.from_preset(save_dir, num_classes=2, dtype="float32")
204+
152205
@pytest.mark.large
153206
def test_save_to_preset(self):
154207
save_dir = self.get_temp_dir()

keras_hub/src/utils/preset_utils.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -687,10 +687,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
687687
)
688688
# We found a `task.json` with a complete config for our class.
689689
# Forward backbone args.
690-
if "config" in self.config and "dtype" in self.config["config"]:
691-
# Forward the serialized dtype from the config. This is critical for
692-
# restoring quantized models, which rely on a `DTypePolicyMap`.
693-
kwargs["dtype"] = self.config["config"]["dtype"]
690+
kwargs["dtype"] = self._resolve_dtype(self.config, kwargs)
694691
backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs)
695692
if "backbone" in task_config["config"]:
696693
backbone_config = task_config["config"]["backbone"]["config"]
@@ -712,6 +709,54 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
712709
self._load_backbone_weights(task.backbone)
713710
return task
714711

712+
def _resolve_dtype(self, config, kwargs):
713+
"""Resolves the Model's dtype based on the provided config and kwargs.
714+
715+
The data type is resolved based on the following priority:
716+
1. If a user specified dtype is passed, use that.
717+
2. If no user specified dtype is passed, and the save dtype is castable
718+
to the current keras default dtype convert weights on load (float type
719+
to float type).
720+
3. If not user specified dtype is passed, and the save dtype is not
721+
castable to the current default dtype (quantized dtypes). Load the
722+
saved types verbatim.
723+
724+
Args:
725+
config: The model configuration.
726+
kwargs: Additional keyword arguments, potentially including `dtype`.
727+
728+
Returns:
729+
The resolved dtype.
730+
"""
731+
# 1. If a user specified dtype is passed, use that.
732+
if "dtype" in kwargs and kwargs["dtype"] is not None:
733+
return kwargs["dtype"]
734+
735+
saved_dtype = config.get("config", {}).get("dtype")
736+
737+
# If there's no saved dtype, we don't need to do anything.
738+
if saved_dtype is None:
739+
return None
740+
741+
# If the saved dtype is a string (e.g. "float32"), check if it is a
742+
# floating point type.
743+
is_float = isinstance(
744+
saved_dtype, str
745+
) and keras.backend.is_float_dtype(saved_dtype)
746+
if is_float:
747+
# 2. If the saved dtype is a float, we can safely cast to the
748+
# default backend float type.
749+
logging.info(
750+
"No dtype specified during loading. "
751+
f"Using {keras.backend.floatx()} as default. "
752+
"This may result in type casting."
753+
)
754+
return keras.backend.floatx()
755+
else:
756+
# 3. Otherwise, the dtype is a complex object (e.g. a
757+
# DTypePolicyMap for quantization), and should be used as is.
758+
return saved_dtype
759+
715760
def load_preprocessor(
716761
self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs
717762
):

0 commit comments

Comments
 (0)