diff --git a/bigframes/core/bigframe_node.py b/bigframes/core/bigframe_node.py index e14b48a7ec..c71f2136fc 100644 --- a/bigframes/core/bigframe_node.py +++ b/bigframes/core/bigframe_node.py @@ -330,12 +330,32 @@ def top_down( """ Perform a top-down transformation of the BigFrameNode tree. """ + results: Dict[BigFrameNode, BigFrameNode] = {} + # Each stack entry is (node, t_node). t_node is None until transform(node) is called. + stack: list[tuple[BigFrameNode, typing.Optional[BigFrameNode]]] = [(self, None)] - @functools.cache - def recursive_transform(node: BigFrameNode) -> BigFrameNode: - return transform(node).transform_children(recursive_transform) + while stack: + node, t_node = stack[-1] + + if t_node is None: + if node in results: + stack.pop() + continue + t_node = transform(node) + stack[-1] = (node, t_node) + + all_done = True + for child in reversed(t_node.child_nodes): + if child not in results: + stack.append((child, None)) + all_done = False + break + + if all_done: + results[node] = t_node.transform_children(lambda x: results[x]) + stack.pop() - return recursive_transform(self) + return results[self] def bottom_up( self: BigFrameNode,