Skip to content

Commit 8fc9258

Browse files
Fiddle-Config Teamcopybara-github
authored andcommitted
Generate experiment diffs from baselines.
PiperOrigin-RevId: 549813869
1 parent 01780f0 commit 8fc9258

File tree

1 file changed

+54
-6
lines changed

1 file changed

+54
-6
lines changed

fiddle/_src/codegen/codegen_diff.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
"""Library for converting generating fiddlers from diffs."""
1717

1818
import collections
19+
import dataclasses
1920
import functools
2021
import re
2122
import types
22-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
23+
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple
2324

2425
from fiddle import daglish
2526
from fiddle import diffing
@@ -31,12 +32,49 @@
3132
import libcst as cst
3233

3334

35+
@dataclasses.dataclass(frozen=True)
36+
class ObjectToName:
37+
prefix: str
38+
path: daglish.Path
39+
40+
def __hash__(self):
41+
return id(self)
42+
43+
44+
def assign_explicit_names(all_to_name: List[ObjectToName]) -> List[str]:
45+
"""Returns suggested names for a list of objects."""
46+
return [
47+
to_name.prefix + _path_to_name(to_name.path) for to_name in all_to_name
48+
]
49+
50+
51+
def assign_short_names(all_to_name: List[ObjectToName]) -> List[str]:
52+
"""Returns suggested names for a list of objects."""
53+
name_to_paths = {}
54+
for to_name in all_to_name:
55+
sub_path = to_name.path[-1:]
56+
name_to_paths.setdefault(
57+
to_name.prefix + _path_to_name(sub_path), []
58+
).append(to_name)
59+
60+
result_as_dict = {}
61+
for name, group in name_to_paths.items():
62+
if len(group) == 1:
63+
result_as_dict[group[0]] = name
64+
else:
65+
for to_name in group:
66+
sub_path = to_name.path[-2:]
67+
result_as_dict[to_name] = to_name.prefix + _path_to_name(sub_path)
68+
return [result_as_dict[to_name] for to_name in all_to_name]
69+
70+
3471
def fiddler_from_diff(
3572
diff: diffing.Diff,
3673
old: Any = None,
3774
func_name: str = 'fiddler',
3875
param_name: str = 'cfg',
3976
import_manager: Optional[import_manager_lib.ImportManager] = None,
77+
variable_naming: Literal['explicit', 'short'] = 'explicit',
4078
):
4179
"""Returns the CST for a fiddler function that applies the changes in `diff`.
4280
@@ -72,6 +110,8 @@ def fiddler_from_diff(
72110
import_manager: Existing import manager. Usually set to None, but if you are
73111
integrating this with other code generation tasks, it can be nice to
74112
share.
113+
variable_naming: Whether to create intermediate variables with long,
114+
explicit names, or just capture the last elements of a path.
75115
76116
Returns:
77117
An `cst.Module` object. You can convert this to a string using
@@ -97,18 +137,26 @@ def fiddler_from_diff(
97137
# ancestors) will be replaced by a change in the diff. If we don't have an
98138
# `old` structure, then we pessimistically assume that we need to create
99139
# variables for all used paths.
100-
moved_value_names = {}
140+
moved_values_to_name = []
101141
if old is not None:
102142
modified_paths = set([change.target for change in diff.changes])
103143
_add_path_aliases(modified_paths, old)
104144
for path in sorted(used_paths, key=daglish.path_str):
105145
if any(path[:i] in modified_paths for i in range(len(path) + 1)):
106-
moved_value_names[path] = namespace.get_new_name(
107-
_path_to_name(path), f'moved_{param_name}_')
146+
moved_values_to_name.append(ObjectToName(f'moved_{param_name}_', path))
108147
else:
109148
for path in sorted(used_paths, key=daglish.path_str):
110-
moved_value_names[path] = namespace.get_new_name(
111-
_path_to_name(path), f'original_{param_name}_')
149+
moved_values_to_name.append(ObjectToName(f'original_{param_name}_', path))
150+
151+
if variable_naming == 'explicit':
152+
initial_names = assign_explicit_names(moved_values_to_name)
153+
else:
154+
initial_names = assign_short_names(moved_values_to_name)
155+
156+
moved_value_names = {
157+
to_name.path: namespace.get_new_name(name, prefix='')
158+
for to_name, name in zip(moved_values_to_name, initial_names)
159+
}
112160

113161
# Add variables for new shared values added by the diff.
114162
new_shared_value_names = [

0 commit comments

Comments
 (0)