Skip to content

Commit 199bc28

Browse files
Fiddle-Config Teamcopybara-github
authored andcommitted
Pax-specific flattened config codegen (MVP, will be building on it further soon).
This will help flatten nested experiment hierarchies, making configs for baseline models easier to read. PiperOrigin-RevId: 543513463
1 parent a2e585a commit 199bc28

File tree

6 files changed

+65
-16
lines changed

6 files changed

+65
-16
lines changed

fiddle/_src/codegen/auto_config/code_ir.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ class AttributeExpression(CodegenNode):
121121
base: Any # Wrapped expression, can involve VariableReference's
122122
attribute: str
123123

124+
def __hash__(self):
125+
# Currently, some Pax (https://github.com/google/paxml) codegen involves
126+
# having AttributeExpression's as dict keys, as those keys are rewritten to
127+
# expressions. This function allows for that, but one shouldn't generally
128+
# assume equality as object identity here.
129+
return id(self)
130+
124131

125132
@dataclasses.dataclass
126133
class ArgFactoryExpr(CodegenNode):

fiddle/_src/codegen/auto_config/experimental_top_level_api_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,9 @@ def test_sub_fixtures_with_shared_nodes(self, api: str):
192192
# the MoveComplexNodesToVariables pass is run.
193193
num_lines = len(code.splitlines())
194194
if complexity is None:
195-
self.assertLessEqual(num_lines, 25)
195+
self.assertLessEqual(num_lines, 22)
196196
else:
197-
self.assertGreater(num_lines, 25)
197+
self.assertGreater(num_lines, 22)
198198

199199
matches = re.findall(r"def\ (?P<name>[\w_]+)\(", code)
200200
self.assertEqual(

fiddle/_src/codegen/auto_config/ir_printer.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def traverse(value, state: daglish.State) -> str:
7979
+ ", ".join(f'"{key}": {value}' for key, value in value.items())
8080
+ "}"
8181
)
82-
elif isinstance(value, code_ir.VariableReference):
82+
elif isinstance(value, code_ir.BaseNameReference):
8383
return value.name.value
8484
elif isinstance(value, code_ir.AttributeExpression):
8585
base_obj = state.call(value.base, daglish.Attr("base"))
@@ -90,6 +90,22 @@ def traverse(value, state: daglish.State) -> str:
9090
elif isinstance(value, code_ir.WithTagsCall):
9191
sub_value = state.map_children(value).expression
9292
return f"WithTagsCall[{sub_value}]"
93+
elif isinstance(value, code_ir.SymbolOrFixtureCall):
94+
symbol_expression = state.call(
95+
value.symbol_expression, daglish.Attr("symbol_expression")
96+
)
97+
positional_arg_expressions = state.call(
98+
value.positional_arg_expressions,
99+
daglish.Attr("positional_arg_expressions"),
100+
)
101+
arg_expressions = state.call(
102+
value.arg_expressions, daglish.Attr("arg_expressions")
103+
)
104+
return (
105+
f"call:<{symbol_expression}"
106+
f"(*[{positional_arg_expressions}],"
107+
f" **{arg_expressions})>"
108+
)
93109
elif isinstance(value, code_ir.Name):
94110
return value.value
95111
elif isinstance(value, type):

fiddle/_src/codegen/auto_config/ir_printer_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,20 @@ def test_format_attributes(self):
7373
attr = code_ir.AttributeExpression(self_var, "foo")
7474
self.assertEqual(ir_printer.format_expr(attr), "self.foo")
7575

76+
def test_format_calls(self):
77+
call = code_ir.SymbolOrFixtureCall(
78+
symbol_expression=code_ir.Name("foo"),
79+
positional_arg_expressions=[code_ir.Name("bar")],
80+
arg_expressions={"baz": code_ir.Name("qux")},
81+
)
82+
self.assertEqual(
83+
ir_printer.format_expr(call), 'call:<foo(*[[bar]], **{"baz": qux})>'
84+
)
85+
86+
def test_format_module_reference(self):
87+
module_reference = code_ir.ModuleReference(code_ir.Name("foo"))
88+
self.assertEqual(ir_printer.format_expr(module_reference), "foo")
89+
7690
def test_format_simple_ir(self):
7791
task = test_fixtures.simple_ir()
7892
code = "\n".join(ir_printer.format_fn(task.top_level_call.fn))

fiddle/_src/codegen/auto_config/ir_to_cst.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def _prepare_args_helper(
154154
except:
155155
print(f"\n\nERROR CONVERTING: {value!r}")
156156
print(f"\n\nTYPE: {type(value)}")
157+
print(f"\n\nPATH: {daglish.path_str(state.current_path)}")
157158
raise
158159

159160
return daglish.MemoizedTraversal.run(traverse, expr)
@@ -198,7 +199,7 @@ def code_for_fn(
198199
),
199200
]
200201
)
201-
if fn.parameters:
202+
if fn.parameters and len(fn.parameters) > 1:
202203
whitespace_before_params = cst.ParenthesizedWhitespace(
203204
cst.TrailingWhitespace(),
204205
indent=True,

fiddle/_src/codegen/codegen_diff.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import functools
2020
import re
2121
import types
22-
from typing import Any, Callable, Dict, List, Set, Tuple
22+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
2323

2424
from fiddle import daglish
2525
from fiddle import diffing
@@ -31,10 +31,13 @@
3131
import libcst as cst
3232

3333

34-
def fiddler_from_diff(diff: diffing.Diff,
35-
old: Any = None,
36-
func_name: str = 'fiddler',
37-
param_name: str = 'cfg'):
34+
def fiddler_from_diff(
35+
diff: diffing.Diff,
36+
old: Any = None,
37+
func_name: str = 'fiddler',
38+
param_name: str = 'cfg',
39+
import_manager: Optional[import_manager_lib.ImportManager] = None,
40+
):
3841
"""Returns the CST for a fiddler function that applies the changes in `diff`.
3942
4043
The returned `cst.Module` consists of a set of `import` statements for any
@@ -66,18 +69,26 @@ def fiddler_from_diff(diff: diffing.Diff,
6669
all referenced paths.
6770
func_name: The name for the fiddler function.
6871
param_name: The name for the parameter to the fiddler function.
72+
import_manager: Existing import manager. Usually set to None, but if you are
73+
integrating this with other code generation tasks, it can be nice to
74+
share.
6975
7076
Returns:
7177
An `cst.Module` object. You can convert this to a string using
7278
`result.code`.
7379
"""
74-
# Create a namespace to keep track of variables that we add. Reserve the
75-
# names of the param & func.
76-
namespace = namespace_lib.Namespace()
77-
namespace.add(param_name)
78-
namespace.add(func_name)
79-
80-
import_manager = import_manager_lib.ImportManager(namespace)
80+
if import_manager is None:
81+
# Create a namespace to keep track of variables that we add. Reserve the
82+
# names of the param & func.
83+
namespace = namespace_lib.Namespace()
84+
namespace.add(param_name)
85+
namespace.add(func_name)
86+
87+
import_manager = import_manager_lib.ImportManager(namespace)
88+
else:
89+
namespace = import_manager.namespace
90+
namespace.add(param_name)
91+
namespace.add(func_name)
8192

8293
# Get a list of paths that are referenced by the diff.
8394
used_paths = _find_used_paths(diff)

0 commit comments

Comments
 (0)