|
15 | 15 | get_loaded_modules, |
16 | 16 | get_functions, |
17 | 17 | ) |
| 18 | + |
| 19 | +def remove_top_level_returns(code): |
| 20 | + """ |
| 21 | + Removes return statements from top-level functions while leaving nested functions intact. |
| 22 | + """ |
| 23 | + class TopLevelReturnRemover(ast.NodeTransformer): |
| 24 | + def __init__(self): |
| 25 | + super().__init__() |
| 26 | + self.in_top_level = False |
| 27 | + |
| 28 | + def visit_FunctionDef(self, node): |
| 29 | + # If entering a top-level function, mark it |
| 30 | + if not self.in_top_level: |
| 31 | + self.in_top_level = True |
| 32 | + node.body = [stmt for stmt in node.body if not isinstance(stmt, ast.Return)] |
| 33 | + self.in_top_level = False |
| 34 | + else: |
| 35 | + # Leave nested functions intact |
| 36 | + self.generic_visit(node) |
| 37 | + return node |
| 38 | + |
| 39 | + # Parse, transform, and unparse the code |
| 40 | + tree = ast.parse(code) |
| 41 | + transformer = TopLevelReturnRemover() |
| 42 | + transformed_tree = transformer.visit(tree) |
| 43 | + return ast.unparse(transformed_tree) |
| 44 | + |
18 | 45 | def parse_cell(func): |
19 | 46 | """ |
20 | 47 | Inspect the function to detect: |
@@ -45,9 +72,7 @@ def parse_cell(func): |
45 | 72 | def filter_return_statements(node): |
46 | 73 | """Recursively remove trivial returns from function bodies.""" |
47 | 74 | if isinstance(node, ast.FunctionDef): |
48 | | - node.body = [filter_return_statements(subnode) for subnode in node.body if not (isinstance(subnode, ast.Return) and not subnode.value)] |
49 | | - elif isinstance(node, ast.Module): |
50 | | - node.body = [filter_return_statements(subnode) for subnode in node.body] |
| 75 | + node.body = [filter_return_statements(subnode) for subnode in node.body if not (isinstance(subnode, ast.Return))] |
51 | 76 | return node |
52 | 77 |
|
53 | 78 | # Remove trivial returns from the function body |
|
0 commit comments