Skip to content

Commit b050f05

Browse files
Fiddle-Config Teamcopybara-github
authored andcommitted
Correctly pass the old config to the model sharding fiddler generation, which makes the output significantly less verbose. Also uses the new short namer in case intermediate variables are needed.
PiperOrigin-RevId: 549703159
1 parent 01780f0 commit b050f05

File tree

1 file changed

+54
-6
lines changed

1 file changed

+54
-6
lines changed

fiddle/_src/codegen/codegen_diff.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
"""Library for converting generating fiddlers from diffs."""
1717

1818
import collections
19+
import dataclasses
1920
import functools
2021
import re
2122
import types
22-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
23+
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple
2324

2425
from fiddle import daglish
2526
from fiddle import diffing
@@ -31,12 +32,49 @@
3132
import libcst as cst
3233

3334

35+
@dataclasses.dataclass(frozen=True)
36+
class ObjectToName:
37+
prefix: str
38+
path: daglish.Path
39+
40+
def __hash__(self):
41+
return id(self)
42+
43+
44+
def assign_explicit_names(all_to_name: List[ObjectToName]) -> List[str]:
45+
"""Returns suggested names for a list of objects."""
46+
return [
47+
to_name.prefix + _path_to_name(to_name.path) for to_name in all_to_name
48+
]
49+
50+
51+
def assign_short_names(all_to_name: List[ObjectToName]) -> List[str]:
52+
"""Returns suggested names for a list of objects."""
53+
name_to_paths = {}
54+
for to_name in all_to_name:
55+
sub_path = to_name.path[-1:]
56+
name_to_paths.setdefault(
57+
to_name.prefix + _path_to_name(sub_path), []
58+
).append(to_name)
59+
60+
result_as_dict = {}
61+
for name, group in name_to_paths.items():
62+
if len(group) == 1:
63+
result_as_dict[group[0]] = name
64+
else:
65+
for to_name in group:
66+
sub_path = to_name.path[-2:]
67+
result_as_dict[to_name] = to_name.prefix + _path_to_name(sub_path)
68+
return [result_as_dict[to_name] for to_name in all_to_name]
69+
70+
3471
def fiddler_from_diff(
3572
diff: diffing.Diff,
3673
old: Any = None,
3774
func_name: str = 'fiddler',
3875
param_name: str = 'cfg',
3976
import_manager: Optional[import_manager_lib.ImportManager] = None,
77+
variable_naming: Literal['explicit', 'short'] = 'explicit',
4078
):
4179
"""Returns the CST for a fiddler function that applies the changes in `diff`.
4280
@@ -72,6 +110,8 @@ def fiddler_from_diff(
72110
import_manager: Existing import manager. Usually set to None, but if you are
73111
integrating this with other code generation tasks, it can be nice to
74112
share.
113+
variable_naming: Whether to create intermediate variables with long,
114+
explicit names, or just capture the last elements of a path.
75115
76116
Returns:
77117
An `cst.Module` object. You can convert this to a string using
@@ -97,18 +137,26 @@ def fiddler_from_diff(
97137
# ancestors) will be replaced by a change in the diff. If we don't have an
98138
# `old` structure, then we pessimistically assume that we need to create
99139
# variables for all used paths.
100-
moved_value_names = {}
140+
moved_values_to_name = []
101141
if old is not None:
102142
modified_paths = set([change.target for change in diff.changes])
103143
_add_path_aliases(modified_paths, old)
104144
for path in sorted(used_paths, key=daglish.path_str):
105145
if any(path[:i] in modified_paths for i in range(len(path) + 1)):
106-
moved_value_names[path] = namespace.get_new_name(
107-
_path_to_name(path), f'moved_{param_name}_')
146+
moved_values_to_name.append(ObjectToName(f'moved_{param_name}_', path))
108147
else:
109148
for path in sorted(used_paths, key=daglish.path_str):
110-
moved_value_names[path] = namespace.get_new_name(
111-
_path_to_name(path), f'original_{param_name}_')
149+
moved_values_to_name.append(ObjectToName(f'original_{param_name}_', path))
150+
151+
if variable_naming == 'explicit':
152+
initial_names = assign_explicit_names(moved_values_to_name)
153+
else:
154+
initial_names = assign_short_names(moved_values_to_name)
155+
156+
moved_value_names = {
157+
to_name.path: namespace.get_new_name(name, prefix='')
158+
for to_name, name in zip(moved_values_to_name, initial_names)
159+
}
112160

113161
# Add variables for new shared values added by the diff.
114162
new_shared_value_names = [

0 commit comments

Comments
 (0)