Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions fiddle/_src/codegen/codegen_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
"""Library for converting generating fiddlers from diffs."""

import collections
import dataclasses
import functools
import re
import types
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple

from fiddle import daglish
from fiddle import diffing
Expand All @@ -31,12 +32,49 @@
import libcst as cst


@dataclasses.dataclass(frozen=True)
class ObjectToName:
prefix: str
path: daglish.Path

def __hash__(self):
return id(self)


def assign_explicit_names(all_to_name: List[ObjectToName]) -> List[str]:
"""Returns suggested names for a list of objects."""
return [
to_name.prefix + _path_to_name(to_name.path) for to_name in all_to_name
]


def assign_short_names(all_to_name: List[ObjectToName]) -> List[str]:
"""Returns suggested names for a list of objects."""
name_to_paths = {}
for to_name in all_to_name:
sub_path = to_name.path[-1:]
name_to_paths.setdefault(
to_name.prefix + _path_to_name(sub_path), []
).append(to_name)

result_as_dict = {}
for name, group in name_to_paths.items():
if len(group) == 1:
result_as_dict[group[0]] = name
else:
for to_name in group:
sub_path = to_name.path[-2:]
result_as_dict[to_name] = to_name.prefix + _path_to_name(sub_path)
return [result_as_dict[to_name] for to_name in all_to_name]


def fiddler_from_diff(
diff: diffing.Diff,
old: Any = None,
func_name: str = 'fiddler',
param_name: str = 'cfg',
import_manager: Optional[import_manager_lib.ImportManager] = None,
variable_naming: Literal['explicit', 'short'] = 'explicit',
):
"""Returns the CST for a fiddler function that applies the changes in `diff`.

Expand Down Expand Up @@ -72,6 +110,8 @@ def fiddler_from_diff(
import_manager: Existing import manager. Usually set to None, but if you are
integrating this with other code generation tasks, it can be nice to
share.
variable_naming: Whether to create intermediate variables with long,
explicit names, or just capture the last elements of a path.

Returns:
An `cst.Module` object. You can convert this to a string using
Expand All @@ -97,18 +137,26 @@ def fiddler_from_diff(
# ancestors) will be replaced by a change in the diff. If we don't have an
# `old` structure, then we pessimistically assume that we need to create
# variables for all used paths.
moved_value_names = {}
moved_values_to_name = []
if old is not None:
modified_paths = set([change.target for change in diff.changes])
_add_path_aliases(modified_paths, old)
for path in sorted(used_paths, key=daglish.path_str):
if any(path[:i] in modified_paths for i in range(len(path) + 1)):
moved_value_names[path] = namespace.get_new_name(
_path_to_name(path), f'moved_{param_name}_')
moved_values_to_name.append(ObjectToName(f'moved_{param_name}_', path))
else:
for path in sorted(used_paths, key=daglish.path_str):
moved_value_names[path] = namespace.get_new_name(
_path_to_name(path), f'original_{param_name}_')
moved_values_to_name.append(ObjectToName(f'original_{param_name}_', path))

if variable_naming == 'explicit':
initial_names = assign_explicit_names(moved_values_to_name)
else:
initial_names = assign_short_names(moved_values_to_name)

moved_value_names = {
to_name.path: namespace.get_new_name(name, prefix='')
for to_name, name in zip(moved_values_to_name, initial_names)
}

# Add variables for new shared values added by the diff.
new_shared_value_names = [
Expand Down