Skip to content

Commit ebc99e2

Browse files
JW1992Orbax Authors
authored andcommitted
Adds a function to check the inputs have default values in orbax savedmodel.
PiperOrigin-RevId: 787197816
1 parent 21946f3 commit ebc99e2

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

checkpoint/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111

12+
- Add `assert_tensor_spec_with_default` for input signature testing
1213
- Add support for loading SafeTensors checkpoints
1314
- #v1 Add `is_orbax_checkpoint()` method for validation checks
1415

export/orbax/export/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,16 @@ def __post_init__(self):
8989
]
9090

9191

92+
def assert_tensor_spec_with_default(
93+
input_signature: PyTree,
94+
) -> PyTree:
95+
"""Asserts that the input signature is a TensorSpecWithDefault."""
96+
def check_fn(x):
97+
assert isinstance(x, TensorSpecWithDefault), f'x: {x}'
98+
return x
99+
return jax.tree_util.tree_map(check_fn, input_signature)
100+
101+
92102
def remove_signature_defaults(input_signature: PyTree) -> PyTree:
93103
"""Removes TensorSpecWithDefault from an input_signature."""
94104

export/orbax/export/utils_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,24 @@ def test_missing_default(self):
6565
):
6666
utils.with_default_args(lambda x: x[0] + x[1], input_signature)
6767

68+
def test_assert_tensor_spec_with_default(self):
69+
input_signature = [
70+
TensorSpecWithDefault(
71+
tf.TensorSpec([None], tf.int32),
72+
np.asarray([1, 2]),
73+
)
74+
]
75+
utils.assert_tensor_spec_with_default(input_signature)
76+
77+
input_signature_bad_type = [
78+
tf.TensorSpec([None], tf.int32),
79+
]
80+
with self.assertRaisesRegex(
81+
AssertionError,
82+
'x: TensorSpec',
83+
):
84+
utils.assert_tensor_spec_with_default(input_signature_bad_type)
85+
6886
def test_with_default_args_nested(self):
6987
def f(required_arg, optional_args):
7088
return (

0 commit comments

Comments
 (0)