Skip to content

Commit 886bfbe

Browse files
fix quantization save and load error
1 parent 129e3d7 commit 886bfbe

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

keras/src/layers/layer.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1359,10 +1359,32 @@ def save_own_variables(self, store):
13591359
Args:
13601360
store: Dict where the state of the model will be saved.
13611361
"""
1362+
if not getattr(self, "_is_quantized", False):
1363+
all_vars = self._trainable_variables + self._non_trainable_variables
1364+
for i, v in enumerate(all_vars):
1365+
store[f"{i}"] = v
1366+
return
1367+
1368+
# Case: quantized layer
1369+
quantized_vars = self._get_quantized_variables()
1370+
for i, v in enumerate(quantized_vars):
1371+
store[f"quantized_{i}"] = v
1372+
1373+
# Save non-quantized variables
13621374
all_vars = self._trainable_variables + self._non_trainable_variables
1363-
for i, v in enumerate(all_vars):
1375+
non_quantized_vars = [
1376+
v for v in all_vars if v not in quantized_vars and v.trainable
1377+
]
1378+
for i, v in enumerate(non_quantized_vars):
13641379
store[f"{i}"] = v
13651380

1381+
def _get_quantized_variables(self):
1382+
quantized_vars = []
1383+
for v in self._trainable_variables + self._non_trainable_variables:
1384+
if not backend.is_float_dtype(v.dtype):
1385+
quantized_vars.append(v)
1386+
return quantized_vars
1387+
13661388
def load_own_variables(self, store):
13671389
"""Loads the state of the layer.
13681390
@@ -1372,6 +1394,10 @@ def load_own_variables(self, store):
13721394
Args:
13731395
store: Dict from which the state of the model will be loaded.
13741396
"""
1397+
if any(key.startswith("quantized_") for key in store.keys()):
1398+
self._load_quantized_variables(store)
1399+
return
1400+
13751401
all_vars = self._trainable_variables + self._non_trainable_variables
13761402
if len(store.keys()) != len(all_vars):
13771403
if len(all_vars) == 0 and not self.built:
@@ -1407,6 +1433,19 @@ def load_own_variables(self, store):
14071433
for i, v in enumerate(all_vars):
14081434
v.assign(store[f"{i}"])
14091435

1436+
def _load_quantized_variables(self, store):
1437+
quantized_vars = self._get_quantized_variables()
1438+
for i, v in enumerate(quantized_vars):
1439+
v.assign(store[f"quantized_{i}"])
1440+
1441+
# Load non-quantized variables
1442+
all_vars = self._trainable_variables + self._non_trainable_variables
1443+
non_quantized_vars = [
1444+
v for v in all_vars if v not in quantized_vars and v.trainable
1445+
]
1446+
for i, v in enumerate(non_quantized_vars):
1447+
v.assign(store[f"{i}"])
1448+
14101449
def _track_variable(self, variable):
14111450
if variable.trainable:
14121451
self._tracker.add_to_store("trainable_variables", variable)

keras/src/layers/layer_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import pickle
23
from unittest import mock
34

@@ -12,6 +13,7 @@
1213
from keras.src import metrics
1314
from keras.src import models
1415
from keras.src import ops
16+
from keras.src import saving
1517
from keras.src import testing
1618
from keras.src.backend.common import global_state
1719
from keras.src.backend.common.remat import RematScope
@@ -1758,3 +1760,16 @@ def call(self, x):
17581760
# foo_mode omitted -> foo_mode defaults to False -> no change
17591761
y2 = model(sample_input)
17601762
self.assertAllClose(y2, sample_input)
1763+
1764+
def test_quantized_model_save_and_load(self):
1765+
inputs = layers.Input(shape=(None,))
1766+
x = layers.Embedding(input_dim=10, output_dim=10)(inputs)
1767+
x = layers.Dense(10)(x)
1768+
model = models.Model(inputs=inputs, outputs=x)
1769+
path = os.path.join(self.get_temp_dir(), "quantized_model.keras")
1770+
model.quantize(mode="int8")
1771+
model.save(path)
1772+
1773+
quantized_model = saving.load_model(path)
1774+
1775+
self.assertTrue(quantized_model.built)

0 commit comments

Comments
 (0)