diff --git a/fiddle/_src/codegen/codegen_diff.py b/fiddle/_src/codegen/codegen_diff.py index bf43ddc7..05922706 100644 --- a/fiddle/_src/codegen/codegen_diff.py +++ b/fiddle/_src/codegen/codegen_diff.py @@ -16,10 +16,11 @@ """Library for converting generating fiddlers from diffs.""" import collections +import dataclasses import functools import re import types -from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple from fiddle import daglish from fiddle import diffing @@ -31,12 +32,49 @@ import libcst as cst +@dataclasses.dataclass(frozen=True) +class ObjectToName: + prefix: str + path: daglish.Path + + def __hash__(self): + return id(self) + + +def assign_explicit_names(all_to_name: List[ObjectToName]) -> List[str]: + """Returns suggested names for a list of objects.""" + return [ + to_name.prefix + _path_to_name(to_name.path) for to_name in all_to_name + ] + + +def assign_short_names(all_to_name: List[ObjectToName]) -> List[str]: + """Returns suggested names for a list of objects.""" + name_to_paths = {} + for to_name in all_to_name: + sub_path = to_name.path[-1:] + name_to_paths.setdefault( + to_name.prefix + _path_to_name(sub_path), [] + ).append(to_name) + + result_as_dict = {} + for name, group in name_to_paths.items(): + if len(group) == 1: + result_as_dict[group[0]] = name + else: + for to_name in group: + sub_path = to_name.path[-2:] + result_as_dict[to_name] = to_name.prefix + _path_to_name(sub_path) + return [result_as_dict[to_name] for to_name in all_to_name] + + def fiddler_from_diff( diff: diffing.Diff, old: Any = None, func_name: str = 'fiddler', param_name: str = 'cfg', import_manager: Optional[import_manager_lib.ImportManager] = None, + variable_naming: Literal['explicit', 'short'] = 'explicit', ): """Returns the CST for a fiddler function that applies the changes in `diff`. @@ -72,6 +110,8 @@ def fiddler_from_diff( import_manager: Existing import manager. Usually set to None, but if you are integrating this with other code generation tasks, it can be nice to share. + variable_naming: Whether to create intermediate variables with long, + explicit names, or just capture the last elements of a path. Returns: An `cst.Module` object. You can convert this to a string using @@ -97,18 +137,26 @@ def fiddler_from_diff( # ancestors) will be replaced by a change in the diff. If we don't have an # `old` structure, then we pessimistically assume that we need to create # variables for all used paths. - moved_value_names = {} + moved_values_to_name = [] if old is not None: modified_paths = set([change.target for change in diff.changes]) _add_path_aliases(modified_paths, old) for path in sorted(used_paths, key=daglish.path_str): if any(path[:i] in modified_paths for i in range(len(path) + 1)): - moved_value_names[path] = namespace.get_new_name( - _path_to_name(path), f'moved_{param_name}_') + moved_values_to_name.append(ObjectToName(f'moved_{param_name}_', path)) else: for path in sorted(used_paths, key=daglish.path_str): - moved_value_names[path] = namespace.get_new_name( - _path_to_name(path), f'original_{param_name}_') + moved_values_to_name.append(ObjectToName(f'original_{param_name}_', path)) + + if variable_naming == 'explicit': + initial_names = assign_explicit_names(moved_values_to_name) + else: + initial_names = assign_short_names(moved_values_to_name) + + moved_value_names = { + to_name.path: namespace.get_new_name(name, prefix='') + for to_name, name in zip(moved_values_to_name, initial_names) + } # Add variables for new shared values added by the diff. new_shared_value_names = [