diff --git a/sealir/tests/test_cost_extraction.py b/sealir/tests/test_cost_extraction.py index 4fd3855..f6f35e4 100644 --- a/sealir/tests/test_cost_extraction.py +++ b/sealir/tests/test_cost_extraction.py @@ -1,5 +1,3 @@ -import json - from egglog import ( EGraph, Expr, @@ -9,15 +7,14 @@ i64Like, rewrite, ruleset, + default_cost_model, + get_callable_args, + get_callable_fn, + BaseExpr, + greedy_dag_cost_model, ) from sealir.eqsat.rvsdg_eqsat import GraphRoot -from sealir.eqsat.rvsdg_extract import ( - CostModel, - EGraphJsonDict, - Extraction, - get_graph_root, -) class Term(Expr): @@ -44,20 +41,17 @@ def Loop(body_expr: Term) -> Term: ... def Repeat(body_expr: Term, ntime: i64Like) -> Term: ... -class MyCostModel(CostModel): - def get_cost_function(self, nodename, op, ty, cost, children): - match op, tuple(children): - case "Pow", _: - cost = 10 - case "Loop", (expr,): - return self.get_scaled(self_cost=13, multipliers=[23]) - case "Repeat", (expr, ntimes): - - def equ(expr, _, ntimes): - return expr * ntimes - - return self.get_equation(equ, constants=dict(ntimes=ntimes)) - return self.get_simple(cost) +@greedy_dag_cost_model +def cost_model(egraph: EGraph, expr: BaseExpr, child_costs: list[int]) -> int: + callable = get_callable_fn(expr) + if callable == Pow: + return sum(child_costs, start=10) + if callable == Loop: + return 13 + child_costs[0] * 24 + match get_callable_args(expr, Repeat): + case (_, i64(n)): + return sum(child_costs, start=child_costs[0] * n) + return default_cost_model(egraph, expr, child_costs) @ruleset @@ -66,45 +60,13 @@ def simplify_pow(x: Term, i: i64): yield rewrite(Pow(x, 1)).to(x) -def _extraction(egraph, cost_model=None): - # TODO move this into rvsdg_extract - gdct: EGraphJsonDict = json.loads( - egraph._serialize( - n_inline_leaves=0, split_primitive_outputs=False - ).to_json() - ) - [root] = get_graph_root(gdct) - root_eclass = gdct["nodes"][root]["eclass"] - - cost_model = cost_model or CostModel() - _extraction = Extraction(gdct, root_eclass, cost_model) - cost, exgraph = _extraction.choose() - return cost, exgraph - - -def _flatten_multidigraph(graph): - def format_node(node): - # Get all successors (children) of the node - children = [v for u, v in graph.out_edges(node)] - if not children: - return node - return tuple([node, *(format_node(child) for child in children)]) - - # Get all nodes with no incoming edges (roots) - roots = [n for n in graph.nodes if graph.in_degree(n) == 0] - return [format_node(root) for root in roots] - - def test_cost_duplicated_term(): A = Term("A") B = Term("B") - expr = Add(A, Add(B, A)) + expr = GraphRoot(Add(A, Add(B, A))) egraph = EGraph() - egraph.register(GraphRoot(expr)) - - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - [extracted] = _flatten_multidigraph(exgraph) - print("extracted:", extracted) + egraph.register(expr) + res, cost = egraph.extract(expr, cost_model=cost_model, include_cost=True) # t1 = function-0-Term___init__(primitive-String-2684354568) # 1 ------------^ # 1 -------------------------------^ @@ -116,40 +78,23 @@ def test_cost_duplicated_term(): # 1 --------^ # 1 ------------^ # = 7 - assert cost == 7 - - match extracted: - case (graphroot, (add1, term1, (add2, term2, term3))): - pass - case _: - assert False, f"failed to match: {extracted}" - assert term1[0].endswith("Term___init__") - assert term2[0].endswith("Term___init__") - assert term3[0].endswith("Term___init__") - assert term1 == term3 - assert term2 != term3 - assert add1 != add2 - assert add1 != add2 - assert graphroot.endswith("GraphRoot") - assert add1.endswith("Add") - assert add2.endswith("Add") + assert cost.total == 7 + assert res == expr def test_simplify_pow_2(): - A = Term("A") - expr = Pow(A, 2) + expr = GraphRoot(Pow(A, 2)) egraph = EGraph() - egraph.register(GraphRoot(expr)) + egraph.register(expr) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - assert cost == 14 + res, cost = egraph.extract(expr, cost_model=cost_model, include_cost=True) + assert res == expr + assert cost.total == 14 egraph.run(simplify_pow.saturate()) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - [extracted] = _flatten_multidigraph(exgraph) + res, cost = egraph.extract(expr, cost_model=cost_model, include_cost=True) - print("extracted:", extracted) # t1 = function-1-Term___init__(primitive-String-2684354568) # 1 ------------^ # 1 ----------------------------------^ @@ -157,33 +102,23 @@ def test_simplify_pow_2(): # 1 --^ # 1 --------^ # = 4 - assert cost == 4 - match extracted: - case (graphroot, (mul1, term1, term2)): - pass - case _: - assert False, f"failed to match: {extracted}" - assert term1[0].endswith("Term___init__") - assert term2[0].endswith("Term___init__") - assert term1 == term2 - assert graphroot.endswith("GraphRoot") - assert mul1.endswith("Mul") + assert res == GraphRoot(Mul(A, A)) + assert cost.total == 4 def test_simplify_pow_3(): A = Term("A") - expr = Pow(A, 3) + expr = GraphRoot(Pow(A, 3)) egraph = EGraph() - egraph.register(GraphRoot(expr)) + egraph.register(expr) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - assert cost == 14 + res, cost = egraph.extract(expr, cost_model=cost_model, include_cost=True) + assert res == expr + assert cost.total == 14 egraph.run(simplify_pow.saturate()) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - [extracted] = _flatten_multidigraph(exgraph) + res, cost = egraph.extract(expr, cost_model=cost_model, include_cost=True) - print("extracted:", extracted) # t1 = function-1-Term___init__(primitive-String-2684354568) # 1 ------------^ # 1 ----------------------------------^ @@ -192,70 +127,33 @@ def test_simplify_pow_3(): # 1 --------^ # 1 ----------------^ # = 5 - assert cost == 5 - match extracted: - case (graphroot, (mul1, term1, (mul2, term2, term3))): - pass - case _: - assert False, f"failed to match: {extracted}" - assert term1[0].endswith("Term___init__") - assert term1 == term2 - assert term1 == term3 - assert graphroot.endswith("GraphRoot") - assert mul1.endswith("Mul") - assert mul2.endswith("Mul") - - -def test_simple_cost_func(): - scf = CostModel().get_simple(self_cost=10) - cost = scf.compute(7, 8, 9) - # cost is just self_cost - assert cost == 10 - - -def test_scaled_cost_func(): - ccf = CostModel().get_scaled(10, [2, 3, 4]) - cost = ccf.compute(7, 8, 9) - assert cost == 10 + (2 * 7) + (3 * 8) + (4 * 9) + assert res == GraphRoot(Mul(A, Mul(A, A))) + assert cost.total == 5 def test_loop_multiplier(): A = Term("A") - expr = Loop(Pow(A, 3)) + expr = GraphRoot(Loop(Pow(A, 3))) egraph = EGraph() - egraph.register(GraphRoot(expr)) + egraph.register(expr) egraph.run(simplify_pow.saturate()) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - [extracted] = _flatten_multidigraph(exgraph) - print("extracted", extracted) + res, cost = egraph.extract(expr, cost_model=cost_model, include_cost=True) # Pow(A, 3) = 4 (see test_simplify_pow_3) # Loop(A, 3) = (4 * 23 + 13) + 4 # ^^^^^^^^^^^ cost from multiplier and self_cost # ^^^^ cost from DAG of children # GraphRoot(...) = 1 - assert cost == (4 * 23 + 13) + 4 + 1 - - match extracted: - case (graphroot, (loop, (mul1, term1, (mul2, term2, term3)))): - pass - case _: - assert False, f"failed to match: {extracted}" - assert term1[0].endswith("Term___init__") - assert term1 == term2 - assert term1 == term3 - assert graphroot.endswith("GraphRoot") - assert mul1.endswith("Mul") - assert mul2.endswith("Mul") - assert loop.endswith("Loop") + assert cost.total == (4 * 23 + 13) + 4 + 1 + assert res == GraphRoot(Loop(Mul(A, Mul(A, A)))) def test_const_factor(): A = Term("A") - expr = Repeat(A, 13) + expr = GraphRoot(Repeat(A, 13)) egraph = EGraph() egraph.register(GraphRoot(expr)) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - [extracted] = _flatten_multidigraph(exgraph) + extracted, cost = egraph.extract(expr, cost_model=cost_model, include_cost=True) + assert extracted == expr dagcost = 2 + 1 + 1 # ^ Term("A") # ^ literal 13 @@ -263,13 +161,4 @@ def test_const_factor(): repeatcost = 2 * 13 # ^ Term("A") # ^ ntimes - assert cost == repeatcost + dagcost - match extracted: - case (graphroot, (repeat, term, literal)): - pass - case _: - assert False, f"failed to match: {extracted}" - assert term[0].endswith("Term___init__") - assert literal.startswith("primitive-i64-") - assert graphroot.endswith("GraphRoot") - assert repeat.endswith("Repeat") + assert cost.total == repeatcost + dagcost