diff --git a/fiddle/_src/codegen/auto_config/code_ir.py b/fiddle/_src/codegen/auto_config/code_ir.py index 0d152a4b..858982b4 100644 --- a/fiddle/_src/codegen/auto_config/code_ir.py +++ b/fiddle/_src/codegen/auto_config/code_ir.py @@ -121,6 +121,13 @@ class AttributeExpression(CodegenNode): base: Any # Wrapped expression, can involve VariableReference's attribute: str + def __hash__(self): + # Currently, some Pax (https://github.com/google/paxml) codegen involves + # having AttributeExpression's as dict keys, as those keys are rewritten to + # expressions. This function allows for that, but one shouldn't generally + # assume equality as object identity here. + return id(self) + @dataclasses.dataclass class ArgFactoryExpr(CodegenNode): diff --git a/fiddle/_src/codegen/auto_config/experimental_top_level_api_test.py b/fiddle/_src/codegen/auto_config/experimental_top_level_api_test.py index 8399bdf6..bcda458d 100644 --- a/fiddle/_src/codegen/auto_config/experimental_top_level_api_test.py +++ b/fiddle/_src/codegen/auto_config/experimental_top_level_api_test.py @@ -192,9 +192,9 @@ def test_sub_fixtures_with_shared_nodes(self, api: str): # the MoveComplexNodesToVariables pass is run. num_lines = len(code.splitlines()) if complexity is None: - self.assertLessEqual(num_lines, 25) + self.assertLessEqual(num_lines, 22) else: - self.assertGreater(num_lines, 25) + self.assertGreater(num_lines, 22) matches = re.findall(r"def\ (?P[\w_]+)\(", code) self.assertEqual( diff --git a/fiddle/_src/codegen/auto_config/ir_printer.py b/fiddle/_src/codegen/auto_config/ir_printer.py index 9aae519f..d032f997 100644 --- a/fiddle/_src/codegen/auto_config/ir_printer.py +++ b/fiddle/_src/codegen/auto_config/ir_printer.py @@ -79,7 +79,7 @@ def traverse(value, state: daglish.State) -> str: + ", ".join(f'"{key}": {value}' for key, value in value.items()) + "}" ) - elif isinstance(value, code_ir.VariableReference): + elif isinstance(value, code_ir.BaseNameReference): return value.name.value elif isinstance(value, code_ir.AttributeExpression): base_obj = state.call(value.base, daglish.Attr("base")) @@ -90,6 +90,22 @@ def traverse(value, state: daglish.State) -> str: elif isinstance(value, code_ir.WithTagsCall): sub_value = state.map_children(value).expression return f"WithTagsCall[{sub_value}]" + elif isinstance(value, code_ir.SymbolOrFixtureCall): + symbol_expression = state.call( + value.symbol_expression, daglish.Attr("symbol_expression") + ) + positional_arg_expressions = state.call( + value.positional_arg_expressions, + daglish.Attr("positional_arg_expressions"), + ) + arg_expressions = state.call( + value.arg_expressions, daglish.Attr("arg_expressions") + ) + return ( + f"call:<{symbol_expression}" + f"(*[{positional_arg_expressions}]," + f" **{arg_expressions})>" + ) elif isinstance(value, code_ir.Name): return value.value elif isinstance(value, type): diff --git a/fiddle/_src/codegen/auto_config/ir_printer_test.py b/fiddle/_src/codegen/auto_config/ir_printer_test.py index 9b2b691f..427b646b 100644 --- a/fiddle/_src/codegen/auto_config/ir_printer_test.py +++ b/fiddle/_src/codegen/auto_config/ir_printer_test.py @@ -73,6 +73,20 @@ def test_format_attributes(self): attr = code_ir.AttributeExpression(self_var, "foo") self.assertEqual(ir_printer.format_expr(attr), "self.foo") + def test_format_calls(self): + call = code_ir.SymbolOrFixtureCall( + symbol_expression=code_ir.Name("foo"), + positional_arg_expressions=[code_ir.Name("bar")], + arg_expressions={"baz": code_ir.Name("qux")}, + ) + self.assertEqual( + ir_printer.format_expr(call), 'call:' + ) + + def test_format_module_reference(self): + module_reference = code_ir.ModuleReference(code_ir.Name("foo")) + self.assertEqual(ir_printer.format_expr(module_reference), "foo") + def test_format_simple_ir(self): task = test_fixtures.simple_ir() code = "\n".join(ir_printer.format_fn(task.top_level_call.fn)) diff --git a/fiddle/_src/codegen/auto_config/ir_to_cst.py b/fiddle/_src/codegen/auto_config/ir_to_cst.py index 779d6a57..f4d076dc 100644 --- a/fiddle/_src/codegen/auto_config/ir_to_cst.py +++ b/fiddle/_src/codegen/auto_config/ir_to_cst.py @@ -154,6 +154,7 @@ def _prepare_args_helper( except: print(f"\n\nERROR CONVERTING: {value!r}") print(f"\n\nTYPE: {type(value)}") + print(f"\n\nPATH: {daglish.path_str(state.current_path)}") raise return daglish.MemoizedTraversal.run(traverse, expr) @@ -198,7 +199,7 @@ def code_for_fn( ), ] ) - if fn.parameters: + if fn.parameters and len(fn.parameters) > 1: whitespace_before_params = cst.ParenthesizedWhitespace( cst.TrailingWhitespace(), indent=True, diff --git a/fiddle/_src/codegen/codegen_diff.py b/fiddle/_src/codegen/codegen_diff.py index 317d7738..bf43ddc7 100644 --- a/fiddle/_src/codegen/codegen_diff.py +++ b/fiddle/_src/codegen/codegen_diff.py @@ -19,7 +19,7 @@ import functools import re import types -from typing import Any, Callable, Dict, List, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Set, Tuple from fiddle import daglish from fiddle import diffing @@ -31,10 +31,13 @@ import libcst as cst -def fiddler_from_diff(diff: diffing.Diff, - old: Any = None, - func_name: str = 'fiddler', - param_name: str = 'cfg'): +def fiddler_from_diff( + diff: diffing.Diff, + old: Any = None, + func_name: str = 'fiddler', + param_name: str = 'cfg', + import_manager: Optional[import_manager_lib.ImportManager] = None, +): """Returns the CST for a fiddler function that applies the changes in `diff`. The returned `cst.Module` consists of a set of `import` statements for any @@ -66,18 +69,26 @@ def fiddler_from_diff(diff: diffing.Diff, all referenced paths. func_name: The name for the fiddler function. param_name: The name for the parameter to the fiddler function. + import_manager: Existing import manager. Usually set to None, but if you are + integrating this with other code generation tasks, it can be nice to + share. Returns: An `cst.Module` object. You can convert this to a string using `result.code`. """ - # Create a namespace to keep track of variables that we add. Reserve the - # names of the param & func. - namespace = namespace_lib.Namespace() - namespace.add(param_name) - namespace.add(func_name) - - import_manager = import_manager_lib.ImportManager(namespace) + if import_manager is None: + # Create a namespace to keep track of variables that we add. Reserve the + # names of the param & func. + namespace = namespace_lib.Namespace() + namespace.add(param_name) + namespace.add(func_name) + + import_manager = import_manager_lib.ImportManager(namespace) + else: + namespace = import_manager.namespace + namespace.add(param_name) + namespace.add(func_name) # Get a list of paths that are referenced by the diff. used_paths = _find_used_paths(diff)