@@ -1359,10 +1359,32 @@ def save_own_variables(self, store):
1359
1359
Args:
1360
1360
store: Dict where the state of the model will be saved.
1361
1361
"""
1362
+ if not getattr (self , "_is_quantized" , False ):
1363
+ all_vars = self ._trainable_variables + self ._non_trainable_variables
1364
+ for i , v in enumerate (all_vars ):
1365
+ store [f"{ i } " ] = v
1366
+ return
1367
+
1368
+ # Case: quantized layer
1369
+ quantized_vars = self ._get_quantized_variables ()
1370
+ for i , v in enumerate (quantized_vars ):
1371
+ store [f"quantized_{ i } " ] = v
1372
+
1373
+ # Save non-quantized variables
1362
1374
all_vars = self ._trainable_variables + self ._non_trainable_variables
1363
- for i , v in enumerate (all_vars ):
1375
+ non_quantized_vars = [
1376
+ v for v in all_vars if v not in quantized_vars and v .trainable
1377
+ ]
1378
+ for i , v in enumerate (non_quantized_vars ):
1364
1379
store [f"{ i } " ] = v
1365
1380
1381
+ def _get_quantized_variables (self ):
1382
+ quantized_vars = []
1383
+ for v in self ._trainable_variables + self ._non_trainable_variables :
1384
+ if not backend .is_float_dtype (v .dtype ):
1385
+ quantized_vars .append (v )
1386
+ return quantized_vars
1387
+
1366
1388
def load_own_variables (self , store ):
1367
1389
"""Loads the state of the layer.
1368
1390
@@ -1372,6 +1394,10 @@ def load_own_variables(self, store):
1372
1394
Args:
1373
1395
store: Dict from which the state of the model will be loaded.
1374
1396
"""
1397
+ if any (key .startswith ("quantized_" ) for key in store .keys ()):
1398
+ self ._load_quantized_variables (store )
1399
+ return
1400
+
1375
1401
all_vars = self ._trainable_variables + self ._non_trainable_variables
1376
1402
if len (store .keys ()) != len (all_vars ):
1377
1403
if len (all_vars ) == 0 and not self .built :
@@ -1407,6 +1433,19 @@ def load_own_variables(self, store):
1407
1433
for i , v in enumerate (all_vars ):
1408
1434
v .assign (store [f"{ i } " ])
1409
1435
1436
+ def _load_quantized_variables (self , store ):
1437
+ quantized_vars = self ._get_quantized_variables ()
1438
+ for i , v in enumerate (quantized_vars ):
1439
+ v .assign (store [f"quantized_{ i } " ])
1440
+
1441
+ # Load non-quantized variables
1442
+ all_vars = self ._trainable_variables + self ._non_trainable_variables
1443
+ non_quantized_vars = [
1444
+ v for v in all_vars if v not in quantized_vars and v .trainable
1445
+ ]
1446
+ for i , v in enumerate (non_quantized_vars ):
1447
+ v .assign (store [f"{ i } " ])
1448
+
1410
1449
def _track_variable (self , variable ):
1411
1450
if variable .trainable :
1412
1451
self ._tracker .add_to_store ("trainable_variables" , variable )
0 commit comments