Skip to content

Commit c882b09

Browse files
Fiddle-Config Teamcopybara-github
authored andcommitted
Validator that checks for custom user objects in configs. Custom user objects break or don't work well with serialization/visualization/codegen.
PiperOrigin-RevId: 550975396
1 parent 7b034aa commit c882b09

File tree

6 files changed

+349
-8
lines changed

6 files changed

+349
-8
lines changed

fiddle/_src/daglish.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def initial_state(self) -> State:
438438
The initial state has a reference to the traversal, the empty path, and no
439439
parent states.
440440
"""
441-
return State(self, (), self.root_obj, _parent=None) # pytype: disable=attribute-error
441+
return State(self, (), self.root_obj, parent=None)
442442

443443
@classmethod
444444
def begin(cls, fn: Callable[..., Any], root_obj: Any) -> State:
@@ -484,8 +484,9 @@ class State:
484484
current_path: A path that can be followed to the current object. In the case
485485
of shared objects, there will be other paths to the current object, and
486486
often these are determined by a somewhat arbitrary DAG traversal order.
487+
parent: The parent state.
487488
"""
488-
__slots__ = ("traversal", "current_path", "_value", "_parent")
489+
__slots__ = ("traversal", "current_path", "_value", "parent")
489490

490491
traversal: Traversal
491492
current_path: Path
@@ -494,7 +495,7 @@ class State:
494495
# accessors.
495496
_value: Any # pylint: disable=invalid-name
496497

497-
_parent: Optional[State]
498+
parent: Optional[State]
498499

499500
@property
500501
def _object_id(self) -> int:
@@ -504,12 +505,23 @@ def _object_id(self) -> int:
504505
def _is_memoizable(self) -> bool:
505506
return is_memoizable(self._value)
506507

508+
@property
509+
def original_value(self):
510+
"""Original value constructed with this state.
511+
512+
Generally please don't use this value, it's much more clear to use the
513+
first argument of your `traverse(value, state)` function, especially since
514+
for post-order traversals, you'll often write `value =
515+
state.map_children(value)`.
516+
"""
517+
return self._value
518+
507519
@property
508520
def ancestors_inclusive(self) -> Iterable[State]:
509521
"""Gets ancestors, including the current state."""
510522
yield self
511-
if self._parent is not None:
512-
yield from self._parent.ancestors_inclusive
523+
if self.parent is not None:
524+
yield from self.parent.ancestors_inclusive
513525

514526
def get_all_paths(self, allow_caching: bool = True) -> List[Path]:
515527
"""Gets all paths to the current value.

fiddle/_src/history.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,20 @@ class Location:
5555
line_number: int
5656
function_name: Optional[str]
5757

58-
def __str__(self) -> str:
58+
def format(self, max_filename_parts: Optional[int] = None):
59+
filename = self.filename
60+
if max_filename_parts is not None:
61+
filename_parts = filename.split(os.path.sep)
62+
if len(filename_parts) > max_filename_parts:
63+
filename = os.path.sep.join(
64+
["...", *filename_parts[-max_filename_parts:]]
65+
)
5966
if self.function_name is None:
60-
return f"{self.filename}:{self.line_number}"
61-
return f"{self.filename}:{self.line_number}:{self.function_name}"
67+
return f"{filename}:{self.line_number}"
68+
return f"{filename}:{self.line_number}:{self.function_name}"
69+
70+
def __str__(self) -> str:
71+
return self.format()
6272

6373
def __deepcopy__(self, memo):
6474
del memo # unused

fiddle/_src/history_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,22 @@ def test_location_formatting(self):
4141
function_name="make_config")
4242
self.assertEqual(str(location), "my_other_file.py:321:make_config")
4343

44+
def test_location_formatting_concise(self):
45+
location = history.Location(
46+
filename="foo/bar/baz/my_file.py", line_number=123, function_name=None
47+
)
48+
self.assertEqual(location.format(3), ".../bar/baz/my_file.py:123")
49+
self.assertEqual(location.format(2), ".../baz/my_file.py:123")
50+
location = history.Location(
51+
filename="foo/bar/baz/my_other_file.py",
52+
line_number=321,
53+
function_name="make_config",
54+
)
55+
self.assertEqual(
56+
location.format(4), "foo/bar/baz/my_other_file.py:321:make_config"
57+
)
58+
self.assertEqual(location.format(1), ".../my_other_file.py:321:make_config")
59+
4460
def test_entry_simple(self):
4561
entry = history.new_value("x", 1)
4662
self.assertEqual(entry.param_name, "x")
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# coding=utf-8
2+
# Copyright 2022 The Fiddle-Config Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Checks that custom objects/instances are not present in a configuration."""
17+
18+
import dataclasses
19+
from typing import Any, List, Optional
20+
21+
import fiddle as fdl
22+
from fiddle import daglish
23+
from fiddle import history
24+
25+
26+
def _concise_history(entries: List[history.HistoryEntry]) -> str:
27+
"""Returns a concise string for a history entry."""
28+
if not entries:
29+
return "(no history)"
30+
set_or_deleted = (
31+
"Set in "
32+
if entries[-1].new_value is not history.DELETED
33+
else "Deleted in "
34+
)
35+
return set_or_deleted + entries[-1].location.format(max_filename_parts=3)
36+
37+
38+
def _get_history_from_state(
39+
state: daglish.State,
40+
) -> Optional[List[history.HistoryEntry]]:
41+
"""Returns the history from a Buildable for a state."""
42+
while state.current_path and state.parent is not None:
43+
attr = state.current_path[-1]
44+
state = state.parent
45+
if isinstance(state.original_value, fdl.Buildable):
46+
assert isinstance(attr, daglish.Attr)
47+
entries = state.original_value.__argument_history__[attr.name]
48+
return entries if entries else None
49+
return None
50+
51+
52+
def get_config_errors(config: Any) -> List[str]:
53+
"""Returns a list of errors found in the given config.
54+
55+
Args:
56+
config: Fiddle config object, or nested structure of configs.
57+
"""
58+
errors = []
59+
60+
def history_str(state):
61+
return ", " + _concise_history(_get_history_from_state(state))
62+
63+
def traverse(value, state: daglish.State):
64+
path_str = daglish.path_str(state.current_path)
65+
if isinstance(value, tuple) and hasattr(type(value), "_fields"):
66+
errors.append(f"Found namedtuple at {path_str}{history_str(state)}")
67+
elif dataclasses.is_dataclass(value):
68+
errors.append(f"Found dataclass at {path_str}{history_str(state)}")
69+
elif (not state.is_traversable(value)) and not daglish.is_unshareable(
70+
value
71+
):
72+
errors.append(f"Found {type(value)} at {path_str}{history_str(state)}")
73+
return state.map_children(value)
74+
75+
daglish.MemoizedTraversal.run(traverse, config)
76+
return errors
77+
78+
79+
def check_no_custom_objects(config: Any) -> None:
80+
"""Checks that no custom objects are present in the given config.
81+
82+
Args:
83+
config: Fiddle config object, or nested structure of configs.
84+
85+
Raises:
86+
ValueError: If the configuration contains custom objects.
87+
"""
88+
errors = get_config_errors(config)
89+
if errors:
90+
raise ValueError(
91+
"Custom objects were found in the config. In general, you should "
92+
"be able to convert these to fdl.Config's. Custom objects:\n "
93+
+ "\n ".join(errors),
94+
)
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# coding=utf-8
2+
# Copyright 2022 The Fiddle-Config Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for no_custom_objects."""
17+
18+
import dataclasses
19+
import typing
20+
21+
from absl.testing import absltest
22+
import fiddle as fdl
23+
from fiddle import daglish
24+
from fiddle._src.testing.example import fake_encoder_decoder
25+
from fiddle._src.validation import no_custom_objects
26+
27+
28+
def foo(x):
29+
return x
30+
31+
32+
class MyNamedtuple(typing.NamedTuple):
33+
a: int
34+
b: str
35+
36+
37+
@dataclasses.dataclass(frozen=True)
38+
class MyDataclass:
39+
a: int
40+
b: str
41+
42+
43+
class NoCustomObjectsTest(absltest.TestCase):
44+
45+
def test_empty_history(self):
46+
self.assertEqual(
47+
no_custom_objects._concise_history(entries=[]), "(no history)"
48+
)
49+
50+
def test_concise_history(self):
51+
config = fake_encoder_decoder.fixture.as_buildable()
52+
config.encoder.attention.dtype = "float64"
53+
with self.subTest("overridden"):
54+
history = no_custom_objects._concise_history(
55+
entries=config.encoder.attention.__argument_history__["dtype"]
56+
)
57+
self.assertRegex(history, r"Set in .+:\d+:test_concise_history")
58+
with self.subTest("not_overridden"):
59+
history = no_custom_objects._concise_history(
60+
entries=config.encoder.__argument_history__["attention"]
61+
)
62+
self.assertRegex(history, r"Set in .+fake_encoder.+:\d+:fixture")
63+
with self.subTest("deleted"):
64+
del config.encoder.attention.dtype
65+
history = no_custom_objects._concise_history(
66+
entries=config.encoder.attention.__argument_history__["dtype"]
67+
)
68+
self.assertRegex(history, r"Deleted in .+:\d+:test_concise_history")
69+
70+
def test_get_history_from_state(self):
71+
config = fdl.Config(foo, {"a": {"b": 1}})
72+
traversal = daglish.MemoizedTraversal(NotImplemented, config) # pytype: disable=wrong-arg-types
73+
state = traversal.initial_state()
74+
state = daglish.State(
75+
state.traversal,
76+
(*state.current_path, daglish.Attr("x")),
77+
config.x,
78+
state,
79+
)
80+
state = daglish.State(
81+
state.traversal,
82+
(*state.current_path, daglish.Key("a")),
83+
config.x["a"],
84+
state,
85+
)
86+
history1 = no_custom_objects._get_history_from_state(state=state)
87+
state = daglish.State(
88+
state.traversal,
89+
(*state.current_path, daglish.Key("b")),
90+
config.x["a"]["b"],
91+
state,
92+
)
93+
history2 = no_custom_objects._get_history_from_state(state=state)
94+
self.assertNotEmpty(history2)
95+
self.assertIs(
96+
history1,
97+
history2,
98+
msg=(
99+
"Since dictionaries do not have state, _get_history_from_state"
100+
" should traverse to an ancestor to find the history."
101+
),
102+
)
103+
104+
def test_get_config_errors_empty(self):
105+
config = fake_encoder_decoder.fixture.as_buildable()
106+
self.assertEmpty(no_custom_objects.get_config_errors(config=config))
107+
108+
def test_get_config_errors_namedtuple(self):
109+
config = fake_encoder_decoder.fixture.as_buildable()
110+
config.encoder.attention = MyNamedtuple(1, "a")
111+
errors = no_custom_objects.get_config_errors(config=config)
112+
self.assertLen(errors, 1)
113+
self.assertRegex(
114+
errors[0],
115+
r"Found.*namedtuple.*at \.encoder\.attention.*Set"
116+
r" in.*:\d+:test_get_config_errors_namedtuple",
117+
)
118+
119+
def test_get_config_errors_dataclass(self):
120+
config = fake_encoder_decoder.fixture.as_buildable()
121+
config.encoder.attention = MyDataclass(1, "a")
122+
errors = no_custom_objects.get_config_errors(config=config)
123+
self.assertLen(errors, 1)
124+
self.assertRegex(
125+
errors[0],
126+
r"Found.*dataclass.*at \.encoder\.attention.*Set"
127+
r" in.*:\d+:test_get_config_errors_dataclass",
128+
)
129+
130+
def test_get_config_errors_not_empty(self):
131+
config = fake_encoder_decoder.fixture.as_buildable()
132+
config.encoder.attention = object()
133+
errors = no_custom_objects.get_config_errors(config=config)
134+
self.assertLen(errors, 1)
135+
self.assertRegex(
136+
errors[0],
137+
r"Found.*object.*at \.encoder\.attention.*Set"
138+
r" in.*:\d+:test_get_config_errors_not_empty",
139+
)
140+
141+
def test_check_no_custom_objects_okay(self):
142+
config = fake_encoder_decoder.fixture.as_buildable()
143+
no_custom_objects.check_no_custom_objects(config)
144+
145+
def test_check_no_custom_objects_error(self):
146+
config = fake_encoder_decoder.fixture.as_buildable()
147+
config.encoder.attention = object()
148+
config.decoder.self_attention = object()
149+
with self.assertRaisesRegex(
150+
ValueError,
151+
r"Custom objects were found.*Custom objects:\n Found.*object.*at"
152+
r" \.encoder\.attention, Set in"
153+
r" .*:\d+:test_check_no_custom_objects_error\n Found.*object.*at"
154+
r" \.decoder\.self_attention, Set in"
155+
r" .*:\d+:test_check_no_custom_objects_error",
156+
):
157+
no_custom_objects.check_no_custom_objects(config=config)
158+
159+
def test_no_history_custom_objects_error(self):
160+
config = {
161+
"encoder_attention": object(),
162+
"decoder_self_attention": object(),
163+
}
164+
with self.assertRaisesRegex(
165+
ValueError,
166+
r"Custom objects were found.*Custom objects:"
167+
r"\n Found.*object.*at \['encoder_attention'\],.*no history.*"
168+
r"\n Found.*object.*at \['decoder_self_attention'\],.*no history.*",
169+
):
170+
no_custom_objects.check_no_custom_objects(config=config)
171+
172+
173+
if __name__ == "__main__":
174+
absltest.main()

0 commit comments

Comments
 (0)