From 1642a0427b3c8e8f28c9c9ed84f4fae0858d7e8a Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 2 Feb 2024 13:20:23 +0100 Subject: [PATCH 1/2] Avoid recursing into already substituted expressions --- dask_expr/_core.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 72e8cc1ba..68a5a7bb9 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -442,7 +442,7 @@ def __dask_graph__(self): def dask(self): return self.__dask_graph__() - def substitute(self, old, new) -> Expr: + def substitute(self, old, new, _seen=None) -> Expr: """Substitute a specific term within the expression Note that replacing non-`Expr` terms may produce @@ -461,7 +461,10 @@ def substitute(self, old, new) -> Expr: >>> (df + 10).substitute(10, 20) df + 20 """ - + if _seen is None: + _seen = set() + if self._name in _seen: + return self # Check if we are replacing a literal if isinstance(old, Expr): substitute_literal = False @@ -508,6 +511,8 @@ def substitute(self, old, new) -> Expr: if update: # Only recreate if something changed return type(self)(*new_exprs) + else: + _seen.add(self._name) return self def substitute_parameters(self, substitutions: dict) -> Expr: From bbddff71408ec351283d27378464298f72e373b2 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 2 Feb 2024 13:31:25 +0100 Subject: [PATCH 2/2] move to private method --- dask_expr/_core.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 68a5a7bb9..51716c973 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -442,7 +442,7 @@ def __dask_graph__(self): def dask(self): return self.__dask_graph__() - def substitute(self, old, new, _seen=None) -> Expr: + def substitute(self, old, new) -> Expr: """Substitute a specific term within the expression Note that replacing non-`Expr` terms may produce @@ -461,8 +461,9 @@ def substitute(self, old, new, _seen=None) -> Expr: >>> (df + 10).substitute(10, 20) df + 20 """ - if _seen is None: - _seen = set() + return self._substitute(old, new, _seen=set()) + + def _substitute(self, old, new, _seen): if self._name in _seen: return self # Check if we are replacing a literal @@ -479,7 +480,7 @@ def substitute(self, old, new, _seen=None) -> Expr: update = False for operand in self.operands: if isinstance(operand, Expr): - val = operand.substitute(old, new) + val = operand._substitute(old, new, _seen) if operand._name != val._name: update = True new_exprs.append(val) @@ -494,7 +495,7 @@ def substitute(self, old, new, _seen=None) -> Expr: # do so for the `Fused.exprs` operand. val = [] for op in operand: - val.append(op.substitute(old, new)) + val.append(op._substitute(old, new, _seen)) if val[-1]._name != op._name: update = True new_exprs.append(val)