Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions tf_keras/saving/saving_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from tf_keras.saving import saving_lib
from tf_keras.saving.legacy import save as legacy_sm_saving_lib
from tf_keras.saving.legacy import saving_utils
from tf_keras.utils import io_utils

try:
Expand Down Expand Up @@ -75,8 +76,7 @@ class SupportWriteToRemote:
supports remoted saved model out of the box.
"""

def __init__(self, filepath, overwrite=True, save_format=None):
save_format = get_save_format(filepath, save_format=save_format)
def __init__(self, filepath, overwrite, save_format):
self.overwrite = overwrite
if saving_lib.is_remote_path(filepath) and save_format != "tf":
self.temp_directory = tempfile.TemporaryDirectory()
Expand Down Expand Up @@ -191,14 +191,14 @@ def save_model(model, filepath, overwrite=True, save_format=None, **kwargs):
when loading the model. See the `custom_objects` argument in
`tf.keras.saving.load_model`.
"""
save_format = get_save_format(filepath, save_format)

# Supports remote paths via a temporary file
with SupportWriteToRemote(
filepath,
overwrite=overwrite,
save_format=save_format,
) as local_filepath:
save_format = get_save_format(filepath, save_format)

# Deprecation warnings
if save_format == "h5":
warnings.warn(
Expand Down Expand Up @@ -307,8 +307,12 @@ def load_model(


def save_weights(model, filepath, overwrite=True, **kwargs):
save_format = get_save_weights_format(filepath)

# Supports remote paths via a temporary file
with SupportWriteToRemote(filepath, overwrite=overwrite) as local_filepath:
with SupportWriteToRemote(
filepath, overwrite=overwrite, save_format=save_format
) as local_filepath:
if str(local_filepath).endswith(".weights.h5"):
# If file exists and should not be overwritten.
try:
Expand Down Expand Up @@ -385,3 +389,12 @@ def get_save_format(filepath, save_format):
return "tf"
else:
return "h5"


def get_save_weights_format(filepath):
filepath = io_utils.path_to_string(filepath)
filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath)
if filepath_is_h5:
return "h5"
else:
return "tf"
89 changes: 80 additions & 9 deletions tf_keras/saving/saving_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,26 +533,97 @@ def test_metadata(self):
self.assertIn("keras_version", metadata)
self.assertIn("date_saved", metadata)

def test_gfile_copy_called(self):
temp_filepath = Path(
os.path.join(self.get_temp_dir(), "my_model.keras")
def test_save_keras_gfile_copy_called(self):
path = Path(os.path.join(self.get_temp_dir(), "my_model.keras"))
model = keras.Sequential(
[
keras.Input(shape=(1, 1)),
keras.layers.Dense(4),
]
)
model = CompileOverridingModel()
with mock.patch(
"re.match", autospec=True
) as mock_re_match, mock.patch.object(
tf.io.gfile, "copy"
) as mock_gfile_copy:
# Check regex matching
mock_re_match.return_value = True
model.save(temp_filepath, save_format="keras_v3")
model.save(path, save_format="keras_v3")
mock_re_match.assert_called()
self.assertIn(str(temp_filepath), mock_re_match.call_args.args)
self.assertIn(str(path), mock_re_match.call_args.args)

# Check gfile copied with filepath specified as destination
self.assertEqual(
str(temp_filepath), str(mock_gfile_copy.call_args.args[1])
)
mock_gfile_copy.assert_called()
self.assertEqual(str(path), str(mock_gfile_copy.call_args.args[1]))

def test_save_tf_gfile_copy_not_called(self):
path = Path(os.path.join(self.get_temp_dir(), "my_model.keras"))
model = keras.Sequential(
[
keras.Input(shape=(1, 1)),
keras.layers.Dense(4),
]
)
with mock.patch(
"re.match", autospec=True
) as mock_re_match, mock.patch.object(
tf.io.gfile, "copy"
) as mock_gfile_copy:
# Check regex matching
mock_re_match.return_value = True
model.save(path, save_format="tf")
mock_re_match.assert_called()
self.assertIn(str(path), mock_re_match.call_args.args)

# Check gfile.copy was not used.
mock_gfile_copy.assert_not_called()

def test_save_weights_h5_gfile_copy_called(self):
path = Path(os.path.join(self.get_temp_dir(), "my_model.weights.h5"))
model = keras.Sequential(
[
keras.Input(shape=(1, 1)),
keras.layers.Dense(4),
]
)
model(tf.constant([[1.0]]))
with mock.patch(
"re.match", autospec=True
) as mock_re_match, mock.patch.object(
tf.io.gfile, "copy"
) as mock_gfile_copy:
# Check regex matching
mock_re_match.return_value = True
model.save_weights(path)
mock_re_match.assert_called()
self.assertIn(str(path), mock_re_match.call_args.args)

# Check gfile copied with filepath specified as destination
mock_gfile_copy.assert_called()
self.assertEqual(str(path), str(mock_gfile_copy.call_args.args[1]))

def test_save_weights_tf_gfile_copy_not_called(self):
path = Path(os.path.join(self.get_temp_dir(), "my_model.ckpt"))
model = keras.Sequential(
[
keras.Input(shape=(1, 1)),
keras.layers.Dense(4),
]
)
model(tf.constant([[1.0]]))
with mock.patch(
"re.match", autospec=True
) as mock_re_match, mock.patch.object(
tf.io.gfile, "copy"
) as mock_gfile_copy:
# Check regex matching
mock_re_match.return_value = True
model.save_weights(path)
mock_re_match.assert_called()
self.assertIn(str(path), mock_re_match.call_args.args)

# Check gfile.copy was not used.
mock_gfile_copy.assert_not_called()

def test_load_model_api_endpoint(self):
temp_filepath = Path(os.path.join(self.get_temp_dir(), "mymodel.keras"))
Expand Down
Loading