diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 72e8cc1ba..51716c973 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -461,7 +461,11 @@ def substitute(self, old, new) -> Expr: >>> (df + 10).substitute(10, 20) df + 20 """ + return self._substitute(old, new, _seen=set()) + def _substitute(self, old, new, _seen): + if self._name in _seen: + return self # Check if we are replacing a literal if isinstance(old, Expr): substitute_literal = False @@ -476,7 +480,7 @@ def substitute(self, old, new) -> Expr: update = False for operand in self.operands: if isinstance(operand, Expr): - val = operand.substitute(old, new) + val = operand._substitute(old, new, _seen) if operand._name != val._name: update = True new_exprs.append(val) @@ -491,7 +495,7 @@ def substitute(self, old, new) -> Expr: # do so for the `Fused.exprs` operand. val = [] for op in operand: - val.append(op.substitute(old, new)) + val.append(op._substitute(old, new, _seen)) if val[-1]._name != op._name: update = True new_exprs.append(val) @@ -508,6 +512,8 @@ def substitute(self, old, new) -> Expr: if update: # Only recreate if something changed return type(self)(*new_exprs) + else: + _seen.add(self._name) return self def substitute_parameters(self, substitutions: dict) -> Expr: