diff --git a/src/tirx/transform/common_subexpr_elim.cc b/src/tirx/transform/common_subexpr_elim.cc index 38925dc25a8d..e5230812bc4f 100644 --- a/src/tirx/transform/common_subexpr_elim.cc +++ b/src/tirx/transform/common_subexpr_elim.cc @@ -49,6 +49,9 @@ * - It is not a leaf (Var, IntImm, FloatImm, StringImm). * - It does not contain Call or BufferLoad (side-effects / memory dependence). * - It is not Ramp or Broadcast (hardware-specific vector ops). + * - It is not a wholly-constant expression (no Var anywhere in the tree). + * Constant subtrees provide no CSE benefit — the constant folder + * collapses them — and lifting them only adds noise like `cse_v = 1`. * * Scope tree * ---------- @@ -263,6 +266,10 @@ class CSEPlanner : public StmtExprVisitor { * - Not a Call or BufferLoad (side effects / memory dependence). * - Not Ramp or Broadcast (hardware-specific vector construction). * - Does not transitively contain any forbidden node. + * - Is not a wholly-constant expression (contains no Var in its tree). + * Compound expressions like `Cast(int32, 1)` or `Min(1, 2)` pass the + * leaf-only checks above but are still compile-time constants — the + * constant folder will collapse them, so hoisting only adds noise. * * \param expr The expression to check. * \return true if the expression can participate in CSE. @@ -275,6 +282,13 @@ class CSEPlanner : public StmtExprVisitor { if (IsForbiddenNode(expr)) return false; if (expr.as() || expr.as()) return false; if (CheckContains::ExprContains(expr, IsForbiddenNode)) return false; + // Reject wholly-constant expressions (no Var anywhere in the tree). + // BufferLoad is already filtered above by IsForbiddenNode, so + // "contains no Var" is sufficient to declare the expression a + // compile-time constant. Hoisting it adds noise; the constant + // folder will collapse it. + auto contains_var = [](const PrimExpr& e) { return e.as() != nullptr; }; + if (!CheckContains::ExprContains(expr, contains_var)) return false; return true; } diff --git a/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py b/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py index 8786720a2522..bb5d1357b142 100644 --- a/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py @@ -713,6 +713,41 @@ def test_let_floordiv_pattern(): assert "cse_v" not in script, f"CSE incorrectly extracted from Let body:\n{script}" +# ===================================================================== +# T22: Wholly-constant Cast is not lifted +# `Cast(int32, 1)` is a compound expression whose leaves are constants +# only. The constant folder collapses it; CSE adds noise only. +# ===================================================================== +def test_no_lift_constant_expression(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main(B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32): + B[i1] = T.Cast("int32", 1) + i1 + B[i2] = T.Cast("int32", 1) + i2 + + after = tvm.tirx.transform.CommonSubexprElim()(Before) + tvm.ir.assert_structural_equal(after, Before) + assert "cse_v" not in after["main"].script() + + +# ===================================================================== +# T23: Wholly-constant Min is not lifted +# `T.min(1, 2)` is a compound constant; CSE must not lift it. +# ===================================================================== +def test_no_lift_constant_min(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main(B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32): + B[i1] = T.min(1, 2) + i1 + B[i2] = T.min(1, 2) + i2 + + after = tvm.tirx.transform.CommonSubexprElim()(Before) + tvm.ir.assert_structural_equal(after, Before) + assert "cse_v" not in after["main"].script() + + if __name__ == "__main__": test_basic() test_if_single_branch() @@ -735,3 +770,5 @@ def test_let_floordiv_pattern(): test_let_value_cse() test_nested_let_no_extraction() test_let_floordiv_pattern() + test_no_lift_constant_expression() + test_no_lift_constant_min()