Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/tirx/transform/common_subexpr_elim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
* - It is not a leaf (Var, IntImm, FloatImm, StringImm).
* - It does not contain Call or BufferLoad (side-effects / memory dependence).
* - It is not Ramp or Broadcast (hardware-specific vector ops).
* - It is not a wholly-constant expression (no Var anywhere in the tree).
* Constant subtrees provide no CSE benefit — the constant folder
* collapses them — and lifting them only adds noise like `cse_v = 1`.
*
* Scope tree
* ----------
Expand Down Expand Up @@ -263,6 +266,10 @@ class CSEPlanner : public StmtExprVisitor {
* - Not a Call or BufferLoad (side effects / memory dependence).
* - Not Ramp or Broadcast (hardware-specific vector construction).
* - Does not transitively contain any forbidden node.
* - Is not a wholly-constant expression (contains no Var in its tree).
* Compound expressions like `Cast(int32, 1)` or `Min(1, 2)` pass the
* leaf-only checks above but are still compile-time constants — the
* constant folder will collapse them, so hoisting only adds noise.
*
* \param expr The expression to check.
* \return true if the expression can participate in CSE.
Expand All @@ -275,6 +282,13 @@ class CSEPlanner : public StmtExprVisitor {
if (IsForbiddenNode(expr)) return false;
if (expr.as<RampNode>() || expr.as<BroadcastNode>()) return false;
if (CheckContains::ExprContains(expr, IsForbiddenNode)) return false;
// Reject wholly-constant expressions (no Var anywhere in the tree).
// BufferLoad is already filtered above by IsForbiddenNode, so
// "contains no Var" is sufficient to declare the expression a
// compile-time constant. Hoisting it adds noise; the constant
// folder will collapse it.
auto contains_var = [](const PrimExpr& e) { return e.as<VarNode>() != nullptr; };
if (!CheckContains::ExprContains(expr, contains_var)) return false;
Comment on lines 282 to +291
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The IsEligible predicate now performs two separate full traversals of the expression tree (lines 284 and 291) for every node visited during the CSE planning phase. This leads to $O(N^2)$ complexity relative to the expression depth, which can be significant for large generated TIR programs.

Furthermore, the check at line 282 is redundant because CheckContains::ExprContains at line 284 already checks the root node. Additionally, since IsEligible is only invoked from RecordExpr within specific VisitExpr_ overrides (which exclude Call and BufferLoad), the root-level check at line 282 will always be false.

Consider combining the 'forbidden node' and 'contains var' checks into a single pass or, ideally, caching these properties during the bottom-up traversal of CSEPlanner to maintain linear complexity.

return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,41 @@ def test_let_floordiv_pattern():
assert "cse_v" not in script, f"CSE incorrectly extracted from Let body:\n{script}"


# =====================================================================
# T22: Wholly-constant Cast is not lifted
# `Cast(int32, 1)` is a compound expression whose leaves are constants
# only. The constant folder collapses it; CSE adds noise only.
# =====================================================================
def test_no_lift_constant_expression():
@tvm.script.ir_module
class Before:
@T.prim_func
def main(B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32):
B[i1] = T.Cast("int32", 1) + i1
B[i2] = T.Cast("int32", 1) + i2

after = tvm.tirx.transform.CommonSubexprElim()(Before)
tvm.ir.assert_structural_equal(after, Before)
assert "cse_v" not in after["main"].script()


# =====================================================================
# T23: Wholly-constant Min is not lifted
# `T.min(1, 2)` is a compound constant; CSE must not lift it.
# =====================================================================
def test_no_lift_constant_min():
@tvm.script.ir_module
class Before:
@T.prim_func
def main(B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32):
B[i1] = T.min(1, 2) + i1
B[i2] = T.min(1, 2) + i2

after = tvm.tirx.transform.CommonSubexprElim()(Before)
tvm.ir.assert_structural_equal(after, Before)
assert "cse_v" not in after["main"].script()


if __name__ == "__main__":
test_basic()
test_if_single_branch()
Expand All @@ -735,3 +770,5 @@ def test_let_floordiv_pattern():
test_let_value_cse()
test_nested_let_no_extraction()
test_let_floordiv_pattern()
test_no_lift_constant_expression()
test_no_lift_constant_min()
Loading