diff --git a/tf_keras/saving/saving_api.py b/tf_keras/saving/saving_api.py index 1b3933c18..fca15f981 100644 --- a/tf_keras/saving/saving_api.py +++ b/tf_keras/saving/saving_api.py @@ -24,6 +24,7 @@ from tf_keras.saving import saving_lib from tf_keras.saving.legacy import save as legacy_sm_saving_lib +from tf_keras.saving.legacy import saving_utils from tf_keras.utils import io_utils try: @@ -75,8 +76,7 @@ class SupportWriteToRemote: supports remoted saved model out of the box. """ - def __init__(self, filepath, overwrite=True, save_format=None): - save_format = get_save_format(filepath, save_format=save_format) + def __init__(self, filepath, overwrite, save_format): self.overwrite = overwrite if saving_lib.is_remote_path(filepath) and save_format != "tf": self.temp_directory = tempfile.TemporaryDirectory() @@ -191,14 +191,14 @@ def save_model(model, filepath, overwrite=True, save_format=None, **kwargs): when loading the model. See the `custom_objects` argument in `tf.keras.saving.load_model`. """ + save_format = get_save_format(filepath, save_format) + # Supports remote paths via a temporary file with SupportWriteToRemote( filepath, overwrite=overwrite, save_format=save_format, ) as local_filepath: - save_format = get_save_format(filepath, save_format) - # Deprecation warnings if save_format == "h5": warnings.warn( @@ -307,8 +307,12 @@ def load_model( def save_weights(model, filepath, overwrite=True, **kwargs): + save_format = get_save_weights_format(filepath) + # Supports remote paths via a temporary file - with SupportWriteToRemote(filepath, overwrite=overwrite) as local_filepath: + with SupportWriteToRemote( + filepath, overwrite=overwrite, save_format=save_format + ) as local_filepath: if str(local_filepath).endswith(".weights.h5"): # If file exists and should not be overwritten. try: @@ -385,3 +389,12 @@ def get_save_format(filepath, save_format): return "tf" else: return "h5" + + +def get_save_weights_format(filepath): + filepath = io_utils.path_to_string(filepath) + filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath) + if filepath_is_h5: + return "h5" + else: + return "tf" diff --git a/tf_keras/saving/saving_lib_test.py b/tf_keras/saving/saving_lib_test.py index 41a301694..73b8f4fca 100644 --- a/tf_keras/saving/saving_lib_test.py +++ b/tf_keras/saving/saving_lib_test.py @@ -533,11 +533,14 @@ def test_metadata(self): self.assertIn("keras_version", metadata) self.assertIn("date_saved", metadata) - def test_gfile_copy_called(self): - temp_filepath = Path( - os.path.join(self.get_temp_dir(), "my_model.keras") + def test_save_keras_gfile_copy_called(self): + path = Path(os.path.join(self.get_temp_dir(), "my_model.keras")) + model = keras.Sequential( + [ + keras.Input(shape=(1, 1)), + keras.layers.Dense(4), + ] ) - model = CompileOverridingModel() with mock.patch( "re.match", autospec=True ) as mock_re_match, mock.patch.object( @@ -545,14 +548,82 @@ def test_gfile_copy_called(self): ) as mock_gfile_copy: # Check regex matching mock_re_match.return_value = True - model.save(temp_filepath, save_format="keras_v3") + model.save(path, save_format="keras_v3") mock_re_match.assert_called() - self.assertIn(str(temp_filepath), mock_re_match.call_args.args) + self.assertIn(str(path), mock_re_match.call_args.args) # Check gfile copied with filepath specified as destination - self.assertEqual( - str(temp_filepath), str(mock_gfile_copy.call_args.args[1]) - ) + mock_gfile_copy.assert_called() + self.assertEqual(str(path), str(mock_gfile_copy.call_args.args[1])) + + def test_save_tf_gfile_copy_not_called(self): + path = Path(os.path.join(self.get_temp_dir(), "my_model.keras")) + model = keras.Sequential( + [ + keras.Input(shape=(1, 1)), + keras.layers.Dense(4), + ] + ) + with mock.patch( + "re.match", autospec=True + ) as mock_re_match, mock.patch.object( + tf.io.gfile, "copy" + ) as mock_gfile_copy: + # Check regex matching + mock_re_match.return_value = True + model.save(path, save_format="tf") + mock_re_match.assert_called() + self.assertIn(str(path), mock_re_match.call_args.args) + + # Check gfile.copy was not used. + mock_gfile_copy.assert_not_called() + + def test_save_weights_h5_gfile_copy_called(self): + path = Path(os.path.join(self.get_temp_dir(), "my_model.weights.h5")) + model = keras.Sequential( + [ + keras.Input(shape=(1, 1)), + keras.layers.Dense(4), + ] + ) + model(tf.constant([[1.0]])) + with mock.patch( + "re.match", autospec=True + ) as mock_re_match, mock.patch.object( + tf.io.gfile, "copy" + ) as mock_gfile_copy: + # Check regex matching + mock_re_match.return_value = True + model.save_weights(path) + mock_re_match.assert_called() + self.assertIn(str(path), mock_re_match.call_args.args) + + # Check gfile copied with filepath specified as destination + mock_gfile_copy.assert_called() + self.assertEqual(str(path), str(mock_gfile_copy.call_args.args[1])) + + def test_save_weights_tf_gfile_copy_not_called(self): + path = Path(os.path.join(self.get_temp_dir(), "my_model.ckpt")) + model = keras.Sequential( + [ + keras.Input(shape=(1, 1)), + keras.layers.Dense(4), + ] + ) + model(tf.constant([[1.0]])) + with mock.patch( + "re.match", autospec=True + ) as mock_re_match, mock.patch.object( + tf.io.gfile, "copy" + ) as mock_gfile_copy: + # Check regex matching + mock_re_match.return_value = True + model.save_weights(path) + mock_re_match.assert_called() + self.assertIn(str(path), mock_re_match.call_args.args) + + # Check gfile.copy was not used. + mock_gfile_copy.assert_not_called() def test_load_model_api_endpoint(self): temp_filepath = Path(os.path.join(self.get_temp_dir(), "mymodel.keras"))