Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions sealir/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,12 @@ def fixup(x):

@graphviz_function
def render_dot(
self, *, gv, show_metadata: bool = False, only_reachable: bool = False
self,
*,
gv,
show_metadata: bool = False,
only_reachable: bool = False,
source_node: SExpr | None = None,
):
def make_label(i, x):
if isinstance(x, SExpr):
Expand All @@ -172,7 +177,10 @@ def make_label(i, x):
crawler = TapeCrawler(self, self._downcast)

# Records that are children of the last node
crawler.seek(self.last())
if source_node is None:
crawler.seek(self.last())
else:
crawler.seek(source_node._handle)

# Seek to the first non-metadata in the back
while crawler.pos > 0:
Expand Down Expand Up @@ -232,8 +240,7 @@ def make_label(i, x):
if not head.startswith(metadata_prefix):
lastname = nodename
# emit start
if lastname:
edges.append((("start", lastname), {}))
edges.append((("start", f"node{source_node._handle}"), {}))
# emit edges
for args, kwargs in edges:
g.edge(*args, **kwargs)
Expand Down Expand Up @@ -814,7 +821,8 @@ def copy_tree_into(self: SExpr, tape: Tape) -> SExpr:
Returns a fresh Expr in the new tape.
"""
oldtree = self._tape
crawler = TapeCrawler(oldtree, self._get_downcast())
downcast = self._get_downcast()
crawler = TapeCrawler(oldtree, downcast)
crawler.seek(self._handle)
liveset = set(_select(crawler.walk_descendants(), 1))
surviving = sorted(liveset)
Expand All @@ -836,7 +844,8 @@ def copy_tree_into(self: SExpr, tape: Tape) -> SExpr:

out = tape.read_value(mapping[self._handle])
assert isinstance(out, SExpr)
return out

return downcast(out)


class BasicSExpr(SExpr):
Expand Down
2 changes: 1 addition & 1 deletion sealir/egg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def extract_eclasses(egraph: EGraph) -> EClassData:


def reconstruct(
nodes: dict[str, dict], class_data: dict[str, str]
nodes: dict[str, dict], class_data: dict[str, dict[str, str]]
) -> dict[str, Term]:
done: dict[str, Term] = {}

Expand Down
31 changes: 30 additions & 1 deletion sealir/eqsat/py_eqsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
function,
i64,
i64Like,
rewrite,
rule,
ruleset,
union,
Expand All @@ -31,6 +32,10 @@
def Py_NotIO(io: Term, term: Term) -> Term: ...


@function
def Py_NegIO(io: Term, term: Term) -> Term: ...


@function
def Py_Lt(a: Term, b: Term) -> Term: ...

Expand Down Expand Up @@ -87,6 +92,10 @@ def Py_Div(a: Term, b: Term) -> Term: ...
def Py_DivIO(io: Term, a: Term, b: Term) -> Term: ...


@function
def Py_FloorDivIO(io: Term, a: Term, b: Term) -> Term: ...


@function
def Py_MatMult(a: Term, b: Term) -> Term: ...

Expand Down Expand Up @@ -119,6 +128,18 @@ def Py_AttrIO(io: Term, obj: Term, attrname: StringLike) -> Term: ...
def Py_SubscriptIO(io: Term, obj: Term, index: Term) -> Term: ...


@function
def Py_SetitemIO(io: Term, obj: Term, index: Term, val: Term) -> Term: ...


@function
def Py_SliceIO(io: Term, lower: Term, upper: Term, step: Term) -> Term: ...


@function
def Py_Tuple(elems: TermList) -> Term: ...


@function
def Py_LoadGlobal(io: Term, name: StringLike) -> Term: ...

Expand Down Expand Up @@ -225,5 +246,13 @@ def loop_rules(
)


@ruleset
def ruleset_literal_i64_folding(io: Term, ival: i64):
# Constant fold negation of integer literals
yield rewrite(Py_NegIO(io, Term.LiteralI64(ival)).getPort(1)).to(
Term.LiteralI64(0 - ival)
)


def make_rules():
return loop_rules
return loop_rules | ruleset_literal_i64_folding
29 changes: 29 additions & 0 deletions sealir/eqsat/rvsdg_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def coro(expr: SExpr, state: ase.TraverseState):
match op:
case "not":
res = py_eqsat.Py_NotIO(ioterm, operandterm)
case "-":
res = py_eqsat.Py_NegIO(ioterm, operandterm)
case _:
raise NotImplementedError(f"unsupported op: {op!r}")

Expand Down Expand Up @@ -161,6 +163,9 @@ def coro(expr: SExpr, state: ase.TraverseState):
case "/":
res = py_eqsat.Py_DivIO(ioterm, lhsterm, rhsterm)

case "//":
res = py_eqsat.Py_FloorDivIO(ioterm, lhsterm, rhsterm)

case "@":
res = py_eqsat.Py_MatMultIO(ioterm, lhsterm, rhsterm)

Expand Down Expand Up @@ -238,6 +243,30 @@ def coro(expr: SExpr, state: ase.TraverseState):
py_eqsat.Py_SubscriptIO(ioterm, valterm, idxterm)
)

case rg.PySetItem(io=io, obj=obj, value=value, index=index):
ioterm = yield io
objterm = yield obj
valterm = yield value
idxterm = yield index
return WrapTerm(
py_eqsat.Py_SetitemIO(ioterm, objterm, idxterm, valterm)
)

case rg.PySlice(io=io, lower=lower, upper=upper, step=step):
ioterm = yield io
lowerterm = yield lower
upperterm = yield upper
stepterm = yield step
return WrapTerm(
py_eqsat.Py_SliceIO(ioterm, lowerterm, upperterm, stepterm)
)

case rg.PyTuple(elems):
elemvals = []
for el in elems:
elemvals.append((yield el))
return py_eqsat.Py_Tuple(eg.termlist(*elemvals))

case rg.PyInt(int(intval)):
assert intval.bit_length() < 64
return eg.Term.LiteralI64(intval)
Expand Down
27 changes: 27 additions & 0 deletions sealir/eqsat/rvsdg_eqsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

class DynInt(Expr):
def __init__(self, num: i64Like): ...
def get(self) -> i64: ...
def __mul__(self, other: DynInt) -> DynInt: ...


converter(i64, DynInt, DynInt)
Expand Down Expand Up @@ -132,6 +134,15 @@ def dyn_index(self, target: Term) -> DynInt: ...
class TermDict(Expr):
def __init__(self, term_map: Map[String, Term]): ...

def lookup(self, key: StringLike) -> Unit:
"""Trigger rule to lookup a value in the dict.

This is needed for .get() to match
"""
...

def get(self, key: StringLike) -> Term: ...


@function(cost=MAXCOST) # max cost to make it unextractable
def _dyn_index_partial(terms: Vec[Term], target: Term) -> DynInt: ...
Expand Down Expand Up @@ -426,6 +437,20 @@ def ruleset_func_outputs(
).then(delete(x))


@ruleset
def ruleset_termdict(mapping: Map[String, Term], key: String):
yield rule(
TermDict(mapping).lookup(key),
mapping.contains(key),
).then(set_(TermDict(mapping).get(key)).to(mapping[key]))


@ruleset
def ruleset_dynint(n: i64, m: i64):
yield rule(DynInt(n)).then(set_(DynInt(n).get()).to(n))
yield rewrite(DynInt(n) * DynInt(m)).to(DynInt(n * m))


ruleset_rvsdg_basic = (
ruleset_simplify_dbgvalue
| ruleset_portlist_basic
Expand All @@ -437,6 +462,8 @@ def ruleset_func_outputs(
| ruleset_region_dyn_get
| ruleset_region_propgate_output
| ruleset_func_outputs
| ruleset_termdict
| ruleset_dynint
)


Expand Down
Loading
Loading