Skip to content

Commit 9de5486

Browse files
author
Red Giuliano
committed
added return statements and requirements for functions
1 parent 9c2de41 commit 9de5486

File tree

1 file changed

+60
-18
lines changed

1 file changed

+60
-18
lines changed

zt_backend/utils/pyfile_parser.py

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import os
55
import importlib.util
66
from collections import OrderedDict, defaultdict
7+
import astroid
78
from pathlib import Path
89
from uuid import uuid4
910
from zt_backend.models.notebook import Notebook, CodeCell
10-
11+
from zt_backend.runner.code_cell_parser import get_imports, get_defined_names,get_loaded_names,get_loaded_modules,get_functions
1112
def parse_cell(func):
1213
"""
1314
Inspect the function to detect:
@@ -38,7 +39,7 @@ def parse_cell(func):
3839
def filter_return_statements(node):
3940
"""Recursively remove trivial returns from function bodies."""
4041
if isinstance(node, ast.FunctionDef):
41-
node.body = [filter_return_statements(subnode) for subnode in node.body if not (isinstance(subnode, ast.Return) and not subnode.value)]
42+
node.body = [filter_return_statements(subnode) for subnode in node.body if not (isinstance(subnode, ast.Return))]
4243
elif isinstance(node, ast.Module):
4344
node.body = [filter_return_statements(subnode) for subnode in node.body]
4445
return node
@@ -132,6 +133,63 @@ def load_notebook_from_file(file_path, notebook_variable_name="notebook"):
132133
raise AttributeError(f"{notebook_variable_name} not found in {file_path}")
133134

134135

136+
def build_cell_code_block(fn_name, cell_obj, def_line):
137+
"""
138+
If the cell is e.g. sql, we'll put " zt.sql(\"\"\"...\")" in the file.
139+
Meanwhile, in the notebook data structure, .code contains only the raw query.
140+
"""
141+
return_line= " return\n"
142+
143+
if cell_obj.cellType == "code":
144+
# 1) Attempt to parse & extract loaded/defined names
145+
try:
146+
module = astroid.parse(cell_obj.code)
147+
all_imports = get_imports(module)
148+
function_names, function_args = get_functions(module)
149+
defined_names = get_defined_names(module) + function_names
150+
loaded_names = (
151+
get_loaded_names(module, defined_names)
152+
+ get_loaded_modules(module, all_imports)
153+
)
154+
# Optionally ensure uniqueness while preserving order
155+
seen = set()
156+
loaded_names_ordered = []
157+
for name in loaded_names:
158+
if name not in seen:
159+
seen.add(name)
160+
loaded_names_ordered.append(name)
161+
loaded_names = loaded_names_ordered
162+
except Exception as e:
163+
print(f"[Warning] Could not parse code in cell '{fn_name}': {e}")
164+
loaded_names = []
165+
defined_names = []
166+
167+
# 2) Build function signature (optional arguments)
168+
if loaded_names:
169+
signature_line = f"def {fn_name}({', '.join(loaded_names)}):"
170+
else:
171+
signature_line = f"def {fn_name}():"
172+
173+
def_line=signature_line
174+
175+
if defined_names:
176+
return_line = f" return {', '.join(defined_names)}\n"
177+
178+
179+
180+
lines = [def_line] # e.g. "def cell_0():"
181+
if cell_obj.cellType in ["markdown", "sql", "text"]:
182+
# Write zt.sql("""some code""") in the file
183+
lines.append(f" zt.{cell_obj.cellType}(\"\"\"{cell_obj.code}\"\"\")")
184+
else:
185+
# For "code" cells, we just insert them line by line
186+
for raw in cell_obj.code.splitlines():
187+
lines.append(f" {raw}")
188+
lines.append(return_line)
189+
return lines
190+
191+
192+
135193
def update_notebook_file(filepath, notebook_obj):
136194
"""
137195
Update or create a Python file to match the given Notebook’s cell definitions
@@ -151,22 +209,6 @@ def maybe_add_blank_line(line_list):
151209
if line_list and line_list[-1].strip(): # last line not empty
152210
line_list.append("\n")
153211

154-
# Helper to build code block for each cell (in the file).
155-
def build_cell_code_block(fn_name, cell_obj, def_line):
156-
"""
157-
If the cell is e.g. sql, we'll put " zt.sql(\"\"\"...\")" in the file.
158-
Meanwhile, in the notebook data structure, .code contains only the raw query.
159-
"""
160-
lines = [def_line] # e.g. "def cell_0():"
161-
if cell_obj.cellType in ["markdown", "sql", "text"]:
162-
# Write zt.sql("""some code""") in the file
163-
lines.append(f" zt.{cell_obj.cellType}(\"\"\"{cell_obj.code}\"\"\")")
164-
else:
165-
# For "code" cells, we just insert them line by line
166-
for raw in cell_obj.code.splitlines():
167-
lines.append(f" {raw}")
168-
lines.append(" return")
169-
return lines
170212

171213
# 1) Read existing lines or init
172214
try:

0 commit comments

Comments
 (0)