Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 46 additions & 157 deletions sealir/tests/test_cost_extraction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json

from egglog import (
EGraph,
Expr,
Expand All @@ -9,15 +7,14 @@
i64Like,
rewrite,
ruleset,
default_cost_model,
get_callable_args,
get_callable_fn,
BaseExpr,
greedy_dag_cost_model,
)

from sealir.eqsat.rvsdg_eqsat import GraphRoot
from sealir.eqsat.rvsdg_extract import (
CostModel,
EGraphJsonDict,
Extraction,
get_graph_root,
)


class Term(Expr):
Expand All @@ -44,20 +41,17 @@ def Loop(body_expr: Term) -> Term: ...
def Repeat(body_expr: Term, ntime: i64Like) -> Term: ...


class MyCostModel(CostModel):
def get_cost_function(self, nodename, op, ty, cost, children):
match op, tuple(children):
case "Pow", _:
cost = 10
case "Loop", (expr,):
return self.get_scaled(self_cost=13, multipliers=[23])
case "Repeat", (expr, ntimes):

def equ(expr, _, ntimes):
return expr * ntimes

return self.get_equation(equ, constants=dict(ntimes=ntimes))
return self.get_simple(cost)
@greedy_dag_cost_model
def cost_model(egraph: EGraph, expr: BaseExpr, child_costs: list[int]) -> int:
callable = get_callable_fn(expr)
if callable == Pow:
return sum(child_costs, start=10)
if callable == Loop:
return 13 + child_costs[0] * 24
match get_callable_args(expr, Repeat):
case (_, i64(n)):
return sum(child_costs, start=child_costs[0] * n)
return default_cost_model(egraph, expr, child_costs)


@ruleset
Expand All @@ -66,45 +60,13 @@ def simplify_pow(x: Term, i: i64):
yield rewrite(Pow(x, 1)).to(x)


def _extraction(egraph, cost_model=None):
# TODO move this into rvsdg_extract
gdct: EGraphJsonDict = json.loads(
egraph._serialize(
n_inline_leaves=0, split_primitive_outputs=False
).to_json()
)
[root] = get_graph_root(gdct)
root_eclass = gdct["nodes"][root]["eclass"]

cost_model = cost_model or CostModel()
_extraction = Extraction(gdct, root_eclass, cost_model)
cost, exgraph = _extraction.choose()
return cost, exgraph


def _flatten_multidigraph(graph):
def format_node(node):
# Get all successors (children) of the node
children = [v for u, v in graph.out_edges(node)]
if not children:
return node
return tuple([node, *(format_node(child) for child in children)])

# Get all nodes with no incoming edges (roots)
roots = [n for n in graph.nodes if graph.in_degree(n) == 0]
return [format_node(root) for root in roots]


def test_cost_duplicated_term():
A = Term("A")
B = Term("B")
expr = Add(A, Add(B, A))
expr = GraphRoot(Add(A, Add(B, A)))
egraph = EGraph()
egraph.register(GraphRoot(expr))

cost, exgraph = _extraction(egraph, cost_model=MyCostModel())
[extracted] = _flatten_multidigraph(exgraph)
print("extracted:", extracted)
egraph.register(expr)
res, cost = egraph.extract(expr, cost_model=cost_model, include_cost=True)
# t1 = function-0-Term___init__(primitive-String-2684354568)
# 1 ------------^
# 1 -------------------------------^
Expand All @@ -116,74 +78,47 @@ def test_cost_duplicated_term():
# 1 --------^
# 1 ------------^
# = 7
assert cost == 7

match extracted:
case (graphroot, (add1, term1, (add2, term2, term3))):
pass
case _:
assert False, f"failed to match: {extracted}"
assert term1[0].endswith("Term___init__")
assert term2[0].endswith("Term___init__")
assert term3[0].endswith("Term___init__")
assert term1 == term3
assert term2 != term3
assert add1 != add2
assert add1 != add2
assert graphroot.endswith("GraphRoot")
assert add1.endswith("Add")
assert add2.endswith("Add")
assert cost.total == 7
assert res == expr


def test_simplify_pow_2():

A = Term("A")
expr = Pow(A, 2)
expr = GraphRoot(Pow(A, 2))
egraph = EGraph()
egraph.register(GraphRoot(expr))
egraph.register(expr)

cost, exgraph = _extraction(egraph, cost_model=MyCostModel())
assert cost == 14
res, cost = egraph.extract(expr, cost_model=cost_model, include_cost=True)
assert res == expr
assert cost.total == 14

egraph.run(simplify_pow.saturate())
cost, exgraph = _extraction(egraph, cost_model=MyCostModel())
[extracted] = _flatten_multidigraph(exgraph)
res, cost = egraph.extract(expr, cost_model=cost_model, include_cost=True)

print("extracted:", extracted)
# t1 = function-1-Term___init__(primitive-String-2684354568)
# 1 ------------^
# 1 ----------------------------------^
# GraphRoot(Mul(t1, t1)))
# 1 --^
# 1 --------^
# = 4
assert cost == 4
match extracted:
case (graphroot, (mul1, term1, term2)):
pass
case _:
assert False, f"failed to match: {extracted}"
assert term1[0].endswith("Term___init__")
assert term2[0].endswith("Term___init__")
assert term1 == term2
assert graphroot.endswith("GraphRoot")
assert mul1.endswith("Mul")
assert res == GraphRoot(Mul(A, A))
assert cost.total == 4


def test_simplify_pow_3():
A = Term("A")
expr = Pow(A, 3)
expr = GraphRoot(Pow(A, 3))
egraph = EGraph()
egraph.register(GraphRoot(expr))
egraph.register(expr)

cost, exgraph = _extraction(egraph, cost_model=MyCostModel())
assert cost == 14
res, cost = egraph.extract(expr, cost_model=cost_model, include_cost=True)
assert res == expr
assert cost.total == 14

egraph.run(simplify_pow.saturate())
cost, exgraph = _extraction(egraph, cost_model=MyCostModel())
[extracted] = _flatten_multidigraph(exgraph)
res, cost = egraph.extract(expr, cost_model=cost_model, include_cost=True)

print("extracted:", extracted)
# t1 = function-1-Term___init__(primitive-String-2684354568)
# 1 ------------^
# 1 ----------------------------------^
Expand All @@ -192,84 +127,38 @@ def test_simplify_pow_3():
# 1 --------^
# 1 ----------------^
# = 5
assert cost == 5
match extracted:
case (graphroot, (mul1, term1, (mul2, term2, term3))):
pass
case _:
assert False, f"failed to match: {extracted}"
assert term1[0].endswith("Term___init__")
assert term1 == term2
assert term1 == term3
assert graphroot.endswith("GraphRoot")
assert mul1.endswith("Mul")
assert mul2.endswith("Mul")


def test_simple_cost_func():
scf = CostModel().get_simple(self_cost=10)
cost = scf.compute(7, 8, 9)
# cost is just self_cost
assert cost == 10


def test_scaled_cost_func():
ccf = CostModel().get_scaled(10, [2, 3, 4])
cost = ccf.compute(7, 8, 9)
assert cost == 10 + (2 * 7) + (3 * 8) + (4 * 9)
assert res == GraphRoot(Mul(A, Mul(A, A)))
assert cost.total == 5


def test_loop_multiplier():
A = Term("A")
expr = Loop(Pow(A, 3))
expr = GraphRoot(Loop(Pow(A, 3)))
egraph = EGraph()
egraph.register(GraphRoot(expr))
egraph.register(expr)
egraph.run(simplify_pow.saturate())
cost, exgraph = _extraction(egraph, cost_model=MyCostModel())
[extracted] = _flatten_multidigraph(exgraph)
print("extracted", extracted)
res, cost = egraph.extract(expr, cost_model=cost_model, include_cost=True)
# Pow(A, 3) = 4 (see test_simplify_pow_3)
# Loop(A, 3) = (4 * 23 + 13) + 4
# ^^^^^^^^^^^ cost from multiplier and self_cost
# ^^^^ cost from DAG of children
# GraphRoot(...) = 1
assert cost == (4 * 23 + 13) + 4 + 1

match extracted:
case (graphroot, (loop, (mul1, term1, (mul2, term2, term3)))):
pass
case _:
assert False, f"failed to match: {extracted}"
assert term1[0].endswith("Term___init__")
assert term1 == term2
assert term1 == term3
assert graphroot.endswith("GraphRoot")
assert mul1.endswith("Mul")
assert mul2.endswith("Mul")
assert loop.endswith("Loop")
assert cost.total == (4 * 23 + 13) + 4 + 1
assert res == GraphRoot(Loop(Mul(A, Mul(A, A))))


def test_const_factor():
A = Term("A")
expr = Repeat(A, 13)
expr = GraphRoot(Repeat(A, 13))
egraph = EGraph()
egraph.register(GraphRoot(expr))
cost, exgraph = _extraction(egraph, cost_model=MyCostModel())
[extracted] = _flatten_multidigraph(exgraph)
extracted, cost = egraph.extract(expr, cost_model=cost_model, include_cost=True)
assert extracted == expr
dagcost = 2 + 1 + 1
# ^ Term("A")
# ^ literal 13
# ^ GraphRoot
repeatcost = 2 * 13
# ^ Term("A")
# ^ ntimes
assert cost == repeatcost + dagcost
match extracted:
case (graphroot, (repeat, term, literal)):
pass
case _:
assert False, f"failed to match: {extracted}"
assert term[0].endswith("Term___init__")
assert literal.startswith("primitive-i64-")
assert graphroot.endswith("GraphRoot")
assert repeat.endswith("Repeat")
assert cost.total == repeatcost + dagcost