From ebc99e2724014cede35058d35bc2e4ee7d5468ce Mon Sep 17 00:00:00 2001 From: Jiawei Xia Date: Fri, 25 Jul 2025 12:16:18 -0700 Subject: [PATCH] Adds a function to check the inputs have default values in orbax savedmodel. PiperOrigin-RevId: 787197816 --- checkpoint/CHANGELOG.md | 1 + export/orbax/export/utils.py | 10 ++++++++++ export/orbax/export/utils_test.py | 18 ++++++++++++++++++ 3 files changed, 29 insertions(+) diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 22a123f55..e3fa14dd7 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add `assert_tensor_spec_with_default` for input signature testing - Add support for loading SafeTensors checkpoints - #v1 Add `is_orbax_checkpoint()` method for validation checks diff --git a/export/orbax/export/utils.py b/export/orbax/export/utils.py index e71d3dc4a..09ebaaea0 100644 --- a/export/orbax/export/utils.py +++ b/export/orbax/export/utils.py @@ -89,6 +89,16 @@ def __post_init__(self): ] +def assert_tensor_spec_with_default( + input_signature: PyTree, +) -> PyTree: + """Asserts that the input signature is a TensorSpecWithDefault.""" + def check_fn(x): + assert isinstance(x, TensorSpecWithDefault), f'x: {x}' + return x + return jax.tree_util.tree_map(check_fn, input_signature) + + def remove_signature_defaults(input_signature: PyTree) -> PyTree: """Removes TensorSpecWithDefault from an input_signature.""" diff --git a/export/orbax/export/utils_test.py b/export/orbax/export/utils_test.py index 8199f818e..8a1f4763b 100644 --- a/export/orbax/export/utils_test.py +++ b/export/orbax/export/utils_test.py @@ -65,6 +65,24 @@ def test_missing_default(self): ): utils.with_default_args(lambda x: x[0] + x[1], input_signature) + def test_assert_tensor_spec_with_default(self): + input_signature = [ + TensorSpecWithDefault( + tf.TensorSpec([None], tf.int32), + np.asarray([1, 2]), + ) + ] + utils.assert_tensor_spec_with_default(input_signature) + + input_signature_bad_type = [ + tf.TensorSpec([None], tf.int32), + ] + with self.assertRaisesRegex( + AssertionError, + 'x: TensorSpec', + ): + utils.assert_tensor_spec_with_default(input_signature_bad_type) + def test_with_default_args_nested(self): def f(required_arg, optional_args): return (