From a54842e66d1f39f4bec96977533e1d62b218eb23 Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Fri, 18 Jul 2025 09:27:39 -0500 Subject: [PATCH 01/21] Add rule for termdict lookup --- sealir/eqsat/rvsdg_eqsat.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/sealir/eqsat/rvsdg_eqsat.py b/sealir/eqsat/rvsdg_eqsat.py index adbb606..dbf96d4 100644 --- a/sealir/eqsat/rvsdg_eqsat.py +++ b/sealir/eqsat/rvsdg_eqsat.py @@ -132,6 +132,15 @@ def dyn_index(self, target: Term) -> DynInt: ... class TermDict(Expr): def __init__(self, term_map: Map[String, Term]): ... + def lookup(self, key: StringLike) -> Unit: + """Trigger rule to lookup a value in the dict. + + This is needed for .get() to match + """ + ... + + def get(self, key: StringLike) -> Term: ... + @function(cost=MAXCOST) # max cost to make it unextractable def _dyn_index_partial(terms: Vec[Term], target: Term) -> DynInt: ... @@ -426,6 +435,14 @@ def ruleset_func_outputs( ).then(delete(x)) +@ruleset +def ruleset_termdict(mapping: Map[String, Term], key: String): + yield rule( + TermDict(mapping).lookup(key), + mapping.contains(key), + ).then(set_(TermDict(mapping).get(key)).to(mapping[key])) + + ruleset_rvsdg_basic = ( ruleset_simplify_dbgvalue | ruleset_portlist_basic @@ -437,6 +454,7 @@ def ruleset_func_outputs( | ruleset_region_dyn_get | ruleset_region_propgate_output | ruleset_func_outputs + | ruleset_termdict ) From b117fc4d5963e94583c30aeeb54f9dce24cbc38f Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Fri, 25 Jul 2025 08:43:45 -0500 Subject: [PATCH 02/21] Add unary `-` --- sealir/eqsat/py_eqsat.py | 15 ++++++++++++++- sealir/eqsat/rvsdg_convert.py | 2 ++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sealir/eqsat/py_eqsat.py b/sealir/eqsat/py_eqsat.py index 488d8ce..2e93622 100644 --- a/sealir/eqsat/py_eqsat.py +++ b/sealir/eqsat/py_eqsat.py @@ -6,6 +6,7 @@ function, i64, i64Like, + rewrite, rule, ruleset, union, @@ -31,6 +32,10 @@ def Py_NotIO(io: Term, term: Term) -> Term: ... +@function +def Py_NegIO(io: Term, term: Term) -> Term: ... + + @function def Py_Lt(a: Term, b: Term) -> Term: ... @@ -225,5 +230,13 @@ def loop_rules( ) +@ruleset +def ruleset_literal_i64_folding(io: Term, ival: i64): + # Constant fold negation of integer literals + yield rewrite(Py_NegIO(io, Term.LiteralI64(ival)).getPort(1)).to( + Term.LiteralI64(0 - ival) + ) + + def make_rules(): - return loop_rules + return loop_rules | ruleset_literal_i64_folding diff --git a/sealir/eqsat/rvsdg_convert.py b/sealir/eqsat/rvsdg_convert.py index e39a085..3dd792f 100644 --- a/sealir/eqsat/rvsdg_convert.py +++ b/sealir/eqsat/rvsdg_convert.py @@ -133,6 +133,8 @@ def coro(expr: SExpr, state: ase.TraverseState): match op: case "not": res = py_eqsat.Py_NotIO(ioterm, operandterm) + case "-": + res = py_eqsat.Py_NegIO(ioterm, operandterm) case _: raise NotImplementedError(f"unsupported op: {op!r}") From ca5d04f9afe90e68be01dc56cf169302741b0f62 Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Fri, 25 Jul 2025 08:50:24 -0500 Subject: [PATCH 03/21] Add a way to get e-graph extraction statistic --- sealir/eqsat/rvsdg_extract.py | 35 ++++++++++++++------- sealir/tests/test_rvsdg_egraph_roundtrip.py | 13 ++++++-- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/sealir/eqsat/rvsdg_extract.py b/sealir/eqsat/rvsdg_extract.py index 7be2e71..7a00691 100644 --- a/sealir/eqsat/rvsdg_extract.py +++ b/sealir/eqsat/rvsdg_extract.py @@ -86,6 +86,7 @@ def egraph_extraction( *, cost_model=None, converter_class=EGraphToRVSDG, + stats: dict[str, Any] | None = None, ): gdct: EGraphJsonDict = json.loads( egraph._serialize( @@ -95,9 +96,13 @@ def egraph_extraction( [root] = get_graph_root(gdct) root_eclass = gdct["nodes"][root]["eclass"] + if stats is not None: + stats["num_enodes"] = len(gdct["nodes"]) + stats["num_eclasses"] = len(gdct["class_data"]) + cost_model = CostModel() if cost_model is None else cost_model extraction = Extraction(gdct, root_eclass, cost_model) - cost, exgraph = extraction.choose() + cost, exgraph = extraction.choose(stats=stats) expr = convert_to_rvsdg( exgraph, @@ -208,19 +213,23 @@ def _compute_cost( max_iter=1000, max_no_progress=100, epsilon=1e-6, - ) -> dict[str, Bucket]: + ) -> tuple[dict[str, Bucket], int]: """ Uses dynamic programming with iterative cost propagation Args: - max_iter (int, optional): Maximum number of iterations to compute - costs. Defaults to 10000. max_no_progress (int, optional): Maximum - iterations without cost improvement. Defaults to 500. + max_iter (int, optional): + Maximum number of iterations to compute costs. + Defaults to 10000. + max_no_progress (int, optional): + Maximum iterations without cost improvement. + Defaults to 500. Returns: - dict[str, Bucket]: A mapping of equivalence classes to their lowest - cost representations. - + tuple[dict[str, Bucket], int]: A tuple containing: + - A mapping of equivalence classes to their lowest cost + representations + - The number of rounds (iterations) that were performed Performance notes @@ -338,10 +347,14 @@ def propagate_cost(state_tracker): else: last_changed_i = round_i - return selections + return selections, round_i - def choose(self) -> tuple[float, nx.MultiDiGraph]: - selections = self._compute_cost() + def choose( + self, stats: dict[str, Any] | None = None + ) -> tuple[float, nx.MultiDiGraph]: + selections, round_i = self._compute_cost() + if stats is not None: + stats["extraction_iteration_count"] = round_i nodes = self.nodes chosen_root, rootcost = selections[self.root_eclass].best() diff --git a/sealir/tests/test_rvsdg_egraph_roundtrip.py b/sealir/tests/test_rvsdg_egraph_roundtrip.py index 50adf16..65f4ef1 100644 --- a/sealir/tests/test_rvsdg_egraph_roundtrip.py +++ b/sealir/tests/test_rvsdg_egraph_roundtrip.py @@ -16,7 +16,9 @@ def frontend(fn): return rvsdg_expr, dbginfo -def middle_end(rvsdg_expr, apply_to_egraph, cost_model=None): +def middle_end( + rvsdg_expr, apply_to_egraph, cost_model=None, extract_kwargs=None +): """The middle end encode the RVSDG into a EGraph to apply rewrite rules. After that, it is extracted back into RVSDG. """ @@ -29,8 +31,9 @@ def middle_end(rvsdg_expr, apply_to_egraph, cost_model=None): apply_to_egraph(egraph, func) # Extraction + extract_kwargs = extract_kwargs or {} cost, extracted = egraph_extraction( - egraph, rvsdg_expr, cost_model=cost_model + egraph, rvsdg_expr, cost_model=cost_model, **extract_kwargs ) return cost, extracted @@ -50,10 +53,14 @@ def display_egraph(egraph: EGraph, func): return root - cost, extracted = middle_end(rvsdg_expr, display_egraph) + extract_kwargs = dict(stats={}) + cost, extracted = middle_end( + rvsdg_expr, display_egraph, extract_kwargs=extract_kwargs + ) print("Extracted from EGraph".center(80, "=")) print("cost =", cost) print(rvsdg.format_rvsdg(extracted)) + print("stats:", extract_kwargs) jt = llvm_codegen(rvsdg_expr) res = jt(*args) From f35b09fd7ed732c9f24ad11b56c7ca733ade4c28 Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Fri, 29 Aug 2025 13:19:53 -0500 Subject: [PATCH 04/21] Move extract extension to the end to allow generic extension at the youngest subclass --- sealir/eqsat/rvsdg_extract_details.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sealir/eqsat/rvsdg_extract_details.py b/sealir/eqsat/rvsdg_extract_details.py index 3282b8a..1133479 100644 --- a/sealir/eqsat/rvsdg_extract_details.py +++ b/sealir/eqsat/rvsdg_extract_details.py @@ -114,9 +114,6 @@ def get_children() -> dict | list: attrs = self.handle_region_attributes(key, grm) return grm.write(rg.RegionBegin(inports=ins, attrs=attrs)) case "Term", children: - extended_handle = self.handle_Term(op, children, grm) - if extended_handle is not NotImplemented: - return extended_handle match op, children: case "GraphRoot", {"t": term}: return term @@ -222,6 +219,9 @@ def get_children() -> dict | list: rg.Unpack(val=regionbegin, idx=idx) ) case _: + extended_handle = self.handle_Term(op, children, grm) + if extended_handle is not NotImplemented: + return extended_handle raise NotImplementedError( f"invalid Term: {node_type}, {children}" ) From 3ce570d0a890e6b968c85d96562c00ae0c7c98c7 Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Tue, 2 Sep 2025 15:27:29 -0500 Subject: [PATCH 05/21] wip --- sealir/eqsat/rvsdg_extract.py | 68 ++++++++++++++++++--------- sealir/eqsat/rvsdg_extract_details.py | 13 +++++ sealir/rvsdg/grammar.py | 9 ++++ sealir/rvsdg/restructuring.py | 1 + 4 files changed, 69 insertions(+), 22 deletions(-) diff --git a/sealir/eqsat/rvsdg_extract.py b/sealir/eqsat/rvsdg_extract.py index 7a00691..62d0c97 100644 --- a/sealir/eqsat/rvsdg_extract.py +++ b/sealir/eqsat/rvsdg_extract.py @@ -112,6 +112,7 @@ def egraph_extraction( egraph, converter_class=converter_class, ) + return cost, expr @@ -129,7 +130,9 @@ def convert_to_rvsdg( decls = state.__egg_decls__ # Do the conversion back into RVSDG - node_iterator = list(nx.dfs_postorder_nodes(exgraph, source=root)) + common_root = "common_root" + + node_iterator = list(nx.dfs_postorder_nodes(exgraph, source=common_root)) def egg_fn_to_arg_names(egg_fn: str) -> tuple[str, ...]: for ref in state.egg_fn_to_callable_refs[egg_fn]: @@ -149,17 +152,20 @@ def iterator(node_iter): children.sort() _, children = zip(*children) # extract argument names - kind, _, egg_fn = node.split("-") - match kind: - case "primitive": - pass - case "function": - # TODO: put this into the converter_class - arg_names = egg_fn_to_arg_names(egg_fn) - children = dict(zip(arg_names, children, strict=True)) - case _: - raise NotImplementedError(f"kind is {kind!r}") - yield node, children + if node == common_root: + yield node, children + else: + kind, _, egg_fn = node.split("-") + match kind: + case "primitive": + pass + case "function": + # TODO: put this into the converter_class + arg_names = egg_fn_to_arg_names(egg_fn) + children = dict(zip(arg_names, children, strict=True)) + case _: + raise NotImplementedError(f"kind is {kind!r}") + yield node, children conversion = converter_class(gdct, rvsdg_sexpr, egg_fn_to_arg_names) return conversion.run(iterator(node_iterator)) @@ -197,7 +203,8 @@ class Extraction: _DEBUG = False def __init__(self, graph_json: EGraphJsonDict, root_eclass, cost_model): - self.root_eclass = root_eclass + self.graph_json = graph_json + self.root_eclass = "common_root" self.class_data = defaultdict(set) self.nodes = {k: Node(**v) for k, v in graph_json["nodes"].items()} self.node_types = { @@ -267,6 +274,18 @@ def _compute_cost( for child in enode.children: G.add_edge(k, nodes[child].eclass) + # Get all nodes with in-degree of 0 + common_root = "common_root" + + # Do the conversion back into RVSDG + root_eclasses =[node for node, in_degree in G.in_degree() if in_degree == 0 and self.graph_json['class_data'][node]['type'] != "Unit"] + G.add_node(common_root, shape="rect") + for n in root_eclasses: + G.add_edge(common_root, n) + + self.nodes[common_root] = Node(children=[next(iter(eclassmap[ec])) for ec in root_eclasses], cost=0.0, eclass="common_root", op="common_root", subsumed=False) + + # Get per-node cost function cm = self.cost_model nodecostmap: dict[str, CostFunc] = {} @@ -274,13 +293,16 @@ def _compute_cost( if k not in eclassmap: node = nodes[k] children_eclasses = [nodes[c].eclass for c in node.children] - nodecostmap[k] = cm.get_cost_function( - nodename=k, - op=node.op, - ty=self.node_types[k], - cost=node.cost, - children=children_eclasses, - ) + if k == "common_root": + nodecostmap[k] = cm.get_simple(0) + else: + nodecostmap[k] = cm.get_cost_function( + nodename=k, + op=node.op, + ty=self.node_types[k], + cost=node.cost, + children=children_eclasses, + ) if self._DEBUG: render_extraction_graph(G, "eclass") @@ -298,7 +320,7 @@ def _compute_cost( # Use BFS layers to estimate topological sort topo_ordered = [] - for layer in nx.bfs_layers(G, self.root_eclass): + for layer in nx.bfs_layers(G, [common_root]): topo_ordered += layer def propagate_cost(state_tracker): @@ -334,7 +356,7 @@ def propagate_cost(state_tracker): costchanged.update(state_tracker) if costchanged.converged(): # root score is computed? - if math.isfinite(state_tracker.current[self.root_eclass]): + if all(math.isfinite(state_tracker.current[root]) for root in [common_root]): break # root score is missing? @@ -357,11 +379,13 @@ def choose( stats["extraction_iteration_count"] = round_i nodes = self.nodes + assert self.root_eclass == "common_root" chosen_root, rootcost = selections[self.root_eclass].best() # make selected graph G = nx.MultiDiGraph() todolist = [chosen_root] + visited = set() while todolist: cur = todolist.pop() diff --git a/sealir/eqsat/rvsdg_extract_details.py b/sealir/eqsat/rvsdg_extract_details.py index 1133479..550a0b3 100644 --- a/sealir/eqsat/rvsdg_extract_details.py +++ b/sealir/eqsat/rvsdg_extract_details.py @@ -70,6 +70,7 @@ def filter_by_type( def dispatch(self, key: str, grm: Grammar): if key in self.memo: return self.memo[key] + assert False, "nothing should use this anymroe" node = self.gdct["nodes"][key] child_keys = node["children"] for k in child_keys: @@ -88,6 +89,9 @@ def get_children(self, key): def handle( self, key: str, child_keys: list[str] | dict[str, str], grm: Grammar ): + if key == "common_root": + return grm.write(rg.Rootset(tuple(self.memo[k] for k in child_keys))) + allow_dynamic_op = self.allow_dynamic_op nodes = self.gdct["nodes"] @@ -458,3 +462,12 @@ def handle_Py_Term(self, op: str, children: dict | list, grm: Grammar): def handle_region_attributes(self, key: str, grm: Grammar): return grm.write(rg.Attrs(())) + + def handle_generic( + self, key: str, op: str, children: dict | list, grm: Grammar + ): + assert isinstance(children, dict) + print("---? generic") + return grm.write( + rg.Generic(name=str(op), children=tuple(children.values())) + ) \ No newline at end of file diff --git a/sealir/rvsdg/grammar.py b/sealir/rvsdg/grammar.py index c2d8e50..2baeeb3 100644 --- a/sealir/rvsdg/grammar.py +++ b/sealir/rvsdg/grammar.py @@ -11,6 +11,15 @@ class _Root(grammar.Rule): pass +class Rootset(_Root): + roots: tuple[SExpr, ...] + + +class Generic(_Root): + name: str + children: tuple[SExpr, ...] + + class Loc(_Root): filename: str line_first: int diff --git a/sealir/rvsdg/restructuring.py b/sealir/rvsdg/restructuring.py index 8c7facd..542024f 100644 --- a/sealir/rvsdg/restructuring.py +++ b/sealir/rvsdg/restructuring.py @@ -1041,6 +1041,7 @@ def formatter(expr: SExpr, state: ase.TraverseState): for arg in expr._args: if isinstance(arg, SExpr): text = yield arg + assert text is not None, arg else: text = repr(arg) argrefs.append(text) From 30302968ff0b2118b369fc8d0b55bc2e1ffb62e2 Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Wed, 3 Sep 2025 17:54:21 -0500 Subject: [PATCH 06/21] GenericList and handle_primitives --- sealir/eqsat/rvsdg_extract_details.py | 10 ++++++---- sealir/rvsdg/grammar.py | 5 +++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/sealir/eqsat/rvsdg_extract_details.py b/sealir/eqsat/rvsdg_extract_details.py index 550a0b3..fc9ec8e 100644 --- a/sealir/eqsat/rvsdg_extract_details.py +++ b/sealir/eqsat/rvsdg_extract_details.py @@ -108,7 +108,7 @@ def get_children() -> dict | list: return [memo[v] for v in child_keys] if key.startswith("primitive-"): - return self.handle_primitive(node_type, node, get_children()) + return self.handle_primitive(node_type, node, get_children(), grm) elif key.startswith("function-"): op = node["op"] children = get_children() @@ -266,7 +266,7 @@ def get_children() -> dict | list: else: raise NotImplementedError(key) - def handle_primitive(self, node_type: str, node, children: tuple): + def handle_primitive(self, node_type: str, node, children: tuple, grm: Grammar): match node_type: case "String": unquoted = node["op"][1:-1] @@ -290,7 +290,10 @@ def handle_primitive(self, node_type: str, node, children: tuple): case "Vec_String": return children case _: - raise NotImplementedError(f"primitive of: {node_type}") + if node_type.startswith("Vec_"): + return grm.write(rg.GenericList(name=node_type, children=tuple(children))) + else: + raise NotImplementedError(node_type) def handle_Value(self, op: str, children: dict | list, grm: Grammar): match op, children: @@ -467,7 +470,6 @@ def handle_generic( self, key: str, op: str, children: dict | list, grm: Grammar ): assert isinstance(children, dict) - print("---? generic") return grm.write( rg.Generic(name=str(op), children=tuple(children.values())) ) \ No newline at end of file diff --git a/sealir/rvsdg/grammar.py b/sealir/rvsdg/grammar.py index 2baeeb3..094994f 100644 --- a/sealir/rvsdg/grammar.py +++ b/sealir/rvsdg/grammar.py @@ -20,6 +20,11 @@ class Generic(_Root): children: tuple[SExpr, ...] +class GenericList(_Root): + name: str + children: tuple[SExpr, ...] + + class Loc(_Root): filename: str line_first: int From eaa056576153f2c176d08cb244ad70ac17a3738d Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Fri, 5 Sep 2025 14:09:03 -0500 Subject: [PATCH 07/21] More --- sealir/eqsat/py_eqsat.py | 8 +++++ sealir/eqsat/rvsdg_convert.py | 15 +++++++++ sealir/eqsat/rvsdg_eqsat.py | 9 ++++++ sealir/eqsat/rvsdg_extract.py | 21 ++++++++++--- sealir/eqsat/rvsdg_extract_details.py | 45 +++++++++++++++++++++++---- sealir/rvsdg/grammar.py | 7 +++++ sealir/rvsdg/restructuring.py | 13 ++++++++ sealir/rvsdg/scfg_to_sexpr.py | 10 ++++++ 8 files changed, 118 insertions(+), 10 deletions(-) diff --git a/sealir/eqsat/py_eqsat.py b/sealir/eqsat/py_eqsat.py index 2e93622..2b4d3b4 100644 --- a/sealir/eqsat/py_eqsat.py +++ b/sealir/eqsat/py_eqsat.py @@ -124,6 +124,14 @@ def Py_AttrIO(io: Term, obj: Term, attrname: StringLike) -> Term: ... def Py_SubscriptIO(io: Term, obj: Term, index: Term) -> Term: ... +@function +def Py_SliceIO(io: Term, lower: Term, upper: Term, step: Term) -> Term: ... + + +@function +def Py_Tuple(elems: TermList) -> Term: ... + + @function def Py_LoadGlobal(io: Term, name: StringLike) -> Term: ... diff --git a/sealir/eqsat/rvsdg_convert.py b/sealir/eqsat/rvsdg_convert.py index 3dd792f..8be2391 100644 --- a/sealir/eqsat/rvsdg_convert.py +++ b/sealir/eqsat/rvsdg_convert.py @@ -240,6 +240,21 @@ def coro(expr: SExpr, state: ase.TraverseState): py_eqsat.Py_SubscriptIO(ioterm, valterm, idxterm) ) + case rg.PySlice(io=io, lower=lower, upper=upper, step=step): + ioterm = yield io + lowerterm = yield lower + upperterm = yield upper + stepterm = yield step + return WrapTerm( + py_eqsat.Py_SliceIO(ioterm, lowerterm, upperterm, stepterm) + ) + + case rg.PyTuple(elems): + elemvals = [] + for el in elems: + elemvals.append((yield el)) + return py_eqsat.Py_Tuple(eg.termlist(*elemvals)) + case rg.PyInt(int(intval)): assert intval.bit_length() < 64 return eg.Term.LiteralI64(intval) diff --git a/sealir/eqsat/rvsdg_eqsat.py b/sealir/eqsat/rvsdg_eqsat.py index dbf96d4..332ac02 100644 --- a/sealir/eqsat/rvsdg_eqsat.py +++ b/sealir/eqsat/rvsdg_eqsat.py @@ -35,6 +35,8 @@ class DynInt(Expr): def __init__(self, num: i64Like): ... + def get(self) -> i64: ... + def __mul__(self, other: DynInt) -> DynInt: ... converter(i64, DynInt, DynInt) @@ -443,6 +445,12 @@ def ruleset_termdict(mapping: Map[String, Term], key: String): ).then(set_(TermDict(mapping).get(key)).to(mapping[key])) +@ruleset +def ruleset_dynint(n: i64, m: i64): + yield rule(DynInt(n)).then(set_(DynInt(n).get()).to(n)) + yield rewrite(DynInt(n) * DynInt(m)).to(DynInt(n * m)) + + ruleset_rvsdg_basic = ( ruleset_simplify_dbgvalue | ruleset_portlist_basic @@ -455,6 +463,7 @@ def ruleset_termdict(mapping: Map[String, Term], key: String): | ruleset_region_propgate_output | ruleset_func_outputs | ruleset_termdict + | ruleset_dynint ) diff --git a/sealir/eqsat/rvsdg_extract.py b/sealir/eqsat/rvsdg_extract.py index 62d0c97..9f99325 100644 --- a/sealir/eqsat/rvsdg_extract.py +++ b/sealir/eqsat/rvsdg_extract.py @@ -278,13 +278,23 @@ def _compute_cost( common_root = "common_root" # Do the conversion back into RVSDG - root_eclasses =[node for node, in_degree in G.in_degree() if in_degree == 0 and self.graph_json['class_data'][node]['type'] != "Unit"] + root_eclasses = [ + node + for node, in_degree in G.in_degree() + if in_degree == 0 + and self.graph_json["class_data"][node]["type"] != "Unit" + ] G.add_node(common_root, shape="rect") for n in root_eclasses: G.add_edge(common_root, n) - self.nodes[common_root] = Node(children=[next(iter(eclassmap[ec])) for ec in root_eclasses], cost=0.0, eclass="common_root", op="common_root", subsumed=False) - + self.nodes[common_root] = Node( + children=[next(iter(eclassmap[ec])) for ec in root_eclasses], + cost=0.0, + eclass="common_root", + op="common_root", + subsumed=False, + ) # Get per-node cost function cm = self.cost_model @@ -356,7 +366,10 @@ def propagate_cost(state_tracker): costchanged.update(state_tracker) if costchanged.converged(): # root score is computed? - if all(math.isfinite(state_tracker.current[root]) for root in [common_root]): + if all( + math.isfinite(state_tracker.current[root]) + for root in [common_root] + ): break # root score is missing? diff --git a/sealir/eqsat/rvsdg_extract_details.py b/sealir/eqsat/rvsdg_extract_details.py index fc9ec8e..92b3cb4 100644 --- a/sealir/eqsat/rvsdg_extract_details.py +++ b/sealir/eqsat/rvsdg_extract_details.py @@ -90,7 +90,9 @@ def handle( self, key: str, child_keys: list[str] | dict[str, str], grm: Grammar ): if key == "common_root": - return grm.write(rg.Rootset(tuple(self.memo[k] for k in child_keys))) + return grm.write( + rg.Rootset(tuple(self.memo[k] for k in child_keys)) + ) allow_dynamic_op = self.allow_dynamic_op @@ -223,7 +225,9 @@ def get_children() -> dict | list: rg.Unpack(val=regionbegin, idx=idx) ) case _: - extended_handle = self.handle_Term(op, children, grm) + extended_handle = self.handle_Term( + op, children, grm + ) if extended_handle is not NotImplemented: return extended_handle raise NotImplementedError( @@ -260,13 +264,21 @@ def get_children() -> dict | list: res = handler(key, op, children, grm) if res is not NotImplemented: return res + + def fmt(kv): + k, v = kv + return f"{k}={ase.pretty_str(v)}" + + fmt_children = "\n".join(map(fmt, children.items())) raise NotImplementedError( - f"function of: {op!r} :: {node_type}, {children}" + f"function of: {op!r} :: {node_type}, {children}\n{fmt_children}" ) else: raise NotImplementedError(key) - def handle_primitive(self, node_type: str, node, children: tuple, grm: Grammar): + def handle_primitive( + self, node_type: str, node, children: tuple, grm: Grammar + ): match node_type: case "String": unquoted = node["op"][1:-1] @@ -291,7 +303,11 @@ def handle_primitive(self, node_type: str, node, children: tuple, grm: Grammar): return children case _: if node_type.startswith("Vec_"): - return grm.write(rg.GenericList(name=node_type, children=tuple(children))) + return grm.write( + rg.GenericList( + name=node_type, children=tuple(children) + ) + ) else: raise NotImplementedError(node_type) @@ -460,6 +476,23 @@ def handle_Py_Term(self, op: str, children: dict | list, grm: Grammar): operands=operands, ) ) + case "Py_Tuple", {"elems": tuple(elems)}: + return grm.write( + rg.PyTuple( + elems=elems, + ) + ) + case "Py_SliceIO", { + "io": io, + "lower": lower, + "upper": upper, + "step": step, + }: + return grm.write( + rg.PySlice(io=io, lower=lower, upper=upper, step=step) + ) + case "Py_SubscriptIO", {"io": io, "obj": obj, "index": index}: + return grm.write(rg.PySubscript(io=io, value=obj, index=index)) case _: return NotImplemented @@ -472,4 +505,4 @@ def handle_generic( assert isinstance(children, dict) return grm.write( rg.Generic(name=str(op), children=tuple(children.values())) - ) \ No newline at end of file + ) diff --git a/sealir/rvsdg/grammar.py b/sealir/rvsdg/grammar.py index 094994f..34cde67 100644 --- a/sealir/rvsdg/grammar.py +++ b/sealir/rvsdg/grammar.py @@ -173,6 +173,13 @@ class PySubscript(_Root): index: SExpr +class PySlice(_Root): + io: SExpr + lower: SExpr + upper: SExpr + step: SExpr + + class PyUnaryOp(_Root): op: str io: SExpr diff --git a/sealir/rvsdg/restructuring.py b/sealir/rvsdg/restructuring.py index 542024f..08f7239 100644 --- a/sealir/rvsdg/restructuring.py +++ b/sealir/rvsdg/restructuring.py @@ -775,6 +775,19 @@ def get_loopvar(ports): rg.PySubscript(io=ctx.load_io(), value=value, index=index) ) + case ("PyAst_Slice", (lower, upper, step, interloc)): + lower_val = yield lower + upper_val = yield upper + step_val = yield step + return ctx.insert_io_node( + rg.PySlice( + io=ctx.load_io(), + lower=lower_val, + upper=upper_val, + step=step_val, + ) + ) + case ("PyAst_Pass", (interloc,)): return diff --git a/sealir/rvsdg/scfg_to_sexpr.py b/sealir/rvsdg/scfg_to_sexpr.py index 9f1262a..bb2a182 100644 --- a/sealir/rvsdg/scfg_to_sexpr.py +++ b/sealir/rvsdg/scfg_to_sexpr.py @@ -176,6 +176,16 @@ def visit_Subscript(self, node: ast.Subscript) -> SExpr: self.get_loc(node), ) + def visit_Slice(self, node: ast.Slice) -> SExpr: + none = self._tape.expr("PyAst_None", self.get_loc(node)) + return self._tape.expr( + "PyAst_Slice", + self.visit(node.lower) if node.lower is not None else none, + self.visit(node.upper) if node.upper is not None else none, + self.visit(node.step) if node.step is not None else none, + self.get_loc(node), + ) + def visit_Call(self, node: ast.Call) -> SExpr: posargs = self._tape.expr( "PyAst_callargs_pos", *map(self.visit, node.args) From 3685ca6d3d1cbf57d956ae47b3d27505ae22d82e Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:47:10 -0500 Subject: [PATCH 08/21] more resilience formatting --- sealir/eqsat/rvsdg_extract_details.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sealir/eqsat/rvsdg_extract_details.py b/sealir/eqsat/rvsdg_extract_details.py index 92b3cb4..9798b12 100644 --- a/sealir/eqsat/rvsdg_extract_details.py +++ b/sealir/eqsat/rvsdg_extract_details.py @@ -267,7 +267,11 @@ def get_children() -> dict | list: def fmt(kv): k, v = kv - return f"{k}={ase.pretty_str(v)}" + if isinstance(v, ase.SExpr): + return f"{k}={ase.pretty_str(v)}" + else: + return f"{k}={v}" + fmt_children = "\n".join(map(fmt, children.items())) raise NotImplementedError( From 24eb0ca64027b03cad0556d46d74464aa40e08c8 Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Fri, 12 Sep 2025 09:25:01 -0500 Subject: [PATCH 09/21] Stronger typing in extraction and flattening of Generic --- sealir/eqsat/rvsdg_extract_details.py | 47 +++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/sealir/eqsat/rvsdg_extract_details.py b/sealir/eqsat/rvsdg_extract_details.py index 9798b12..5a906af 100644 --- a/sealir/eqsat/rvsdg_extract_details.py +++ b/sealir/eqsat/rvsdg_extract_details.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Iterable, Iterator +from typing import Iterable, Iterator, MutableMapping from sealir import ase from sealir.rvsdg import Grammar @@ -14,6 +14,34 @@ class Data: ... +_memo_elem_type = ase.value_type | tuple + + +class _TypeCheckedDict(MutableMapping): + def __init__(self): + self._data = {} + + def __getitem__(self, key): + return self._data[key] + + def __setitem__(self, key, value): + self._check_setitem(key, value) + self._data[key] = value + + def __delitem__(self, key): + del self._data[key] + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + def _check_setitem(self, key, value): + if not isinstance(value, _memo_elem_type): + raise TypeError(f"{type(value)} :: {value}") + + class EGraphToRVSDG: allow_dynamic_op = False grammar = Grammar @@ -23,7 +51,7 @@ def __init__( ): self.rvsdg_sexpr = rvsdg_sexpr self.gdct = gdct - self.memo = {} + self.memo = _TypeCheckedDict() self.egg_fn_to_arg_names = egg_fn_to_arg_names def run(self, node_and_children): @@ -300,11 +328,11 @@ def handle_primitive( case "f64": return float(node["op"]) case "Vec_Term": - return children + return tuple(children) case "Vec_Port": - return children + return tuple(children) case "Vec_String": - return children + return tuple(children) case _: if node_type.startswith("Vec_"): return grm.write( @@ -507,6 +535,13 @@ def handle_generic( self, key: str, op: str, children: dict | list, grm: Grammar ): assert isinstance(children, dict) + # flatten children.values() into SExpr + values = [] + for k, v in children.items(): + if isinstance(v, tuple): + values.append(grm.write(rg.GenericList(name=k, children=v))) + else: + values.append(v) return grm.write( - rg.Generic(name=str(op), children=tuple(children.values())) + rg.Generic(name=str(op), children=tuple(values)) ) From 8a997cbcd1e274c6f129fac35f1a761e7e32b504 Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Fri, 12 Sep 2025 15:06:49 -0500 Subject: [PATCH 10/21] Support setitem --- sealir/eqsat/py_eqsat.py | 4 ++++ sealir/eqsat/rvsdg_convert.py | 9 +++++++++ sealir/rvsdg/grammar.py | 7 +++++++ sealir/rvsdg/restructuring.py | 15 ++++++++++++++- 4 files changed, 34 insertions(+), 1 deletion(-) diff --git a/sealir/eqsat/py_eqsat.py b/sealir/eqsat/py_eqsat.py index 2b4d3b4..bf571fd 100644 --- a/sealir/eqsat/py_eqsat.py +++ b/sealir/eqsat/py_eqsat.py @@ -124,6 +124,10 @@ def Py_AttrIO(io: Term, obj: Term, attrname: StringLike) -> Term: ... def Py_SubscriptIO(io: Term, obj: Term, index: Term) -> Term: ... +@function +def Py_SetitemIO(io: Term, obj: Term, index: Term, val: Term) -> Term: ... + + @function def Py_SliceIO(io: Term, lower: Term, upper: Term, step: Term) -> Term: ... diff --git a/sealir/eqsat/rvsdg_convert.py b/sealir/eqsat/rvsdg_convert.py index 8be2391..8003192 100644 --- a/sealir/eqsat/rvsdg_convert.py +++ b/sealir/eqsat/rvsdg_convert.py @@ -240,6 +240,15 @@ def coro(expr: SExpr, state: ase.TraverseState): py_eqsat.Py_SubscriptIO(ioterm, valterm, idxterm) ) + case rg.PySetItem(io=io, obj=obj, value=value, index=index): + ioterm = yield io + objterm = yield obj + valterm = yield value + idxterm = yield index + return WrapTerm( + py_eqsat.Py_SetitemIO(ioterm, objterm, idxterm, valterm) + ) + case rg.PySlice(io=io, lower=lower, upper=upper, step=step): ioterm = yield io lowerterm = yield lower diff --git a/sealir/rvsdg/grammar.py b/sealir/rvsdg/grammar.py index 34cde67..ce84886 100644 --- a/sealir/rvsdg/grammar.py +++ b/sealir/rvsdg/grammar.py @@ -173,6 +173,13 @@ class PySubscript(_Root): index: SExpr +class PySetItem(_Root): + io: SExpr + obj: SExpr + index: SExpr + value: SExpr + + class PySlice(_Root): io: SExpr lower: SExpr diff --git a/sealir/rvsdg/restructuring.py b/sealir/rvsdg/restructuring.py index 08f7239..f6687cb 100644 --- a/sealir/rvsdg/restructuring.py +++ b/sealir/rvsdg/restructuring.py @@ -434,7 +434,7 @@ def unpack_pystr(sexpr: SExpr) -> str | None: def unpack_pyast_name(sexpr: SExpr) -> str: - assert sexpr._head == "PyAst_Name" + assert sexpr._head == "PyAst_Name", sexpr return cast(str, sexpr._args[0]) @@ -647,6 +647,19 @@ def get_loopvar(ports): res = yield rval tar: SExpr + if ( + len(targets) == 1 + and targets[0]._head == "PyAst_Subscript" + ): + [lhs, indices, loc] = targets[0]._args + lhs = (yield lhs) + indices = (yield indices) + rval = (yield rval) + setitem = rg.PySetItem(io=ctx.load_io(), + obj=lhs, + index=indices, + value=rval) + return ctx.insert_io_node(setitem) if ( len(targets) == 1 and unpack_pyast_name(targets[0]) == internal_prefix("_") From 258607a64b188e636a0238801ff45bd64ad0e510 Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Fri, 12 Sep 2025 16:37:32 -0500 Subject: [PATCH 11/21] Temp fix for handling Vec in Rootset This needs refactoring later. --- sealir/eqsat/rvsdg_extract_details.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sealir/eqsat/rvsdg_extract_details.py b/sealir/eqsat/rvsdg_extract_details.py index 5a906af..011adf6 100644 --- a/sealir/eqsat/rvsdg_extract_details.py +++ b/sealir/eqsat/rvsdg_extract_details.py @@ -118,8 +118,14 @@ def handle( self, key: str, child_keys: list[str] | dict[str, str], grm: Grammar ): if key == "common_root": + # legalize child + values = [] + for k in child_keys: + val = self.memo[k] + if isinstance(val, ase.SExpr): + values.append(val) return grm.write( - rg.Rootset(tuple(self.memo[k] for k in child_keys)) + rg.Rootset(tuple(values)) ) allow_dynamic_op = self.allow_dynamic_op From 9a8fca6757194d76a1090c5c27737f21bbc5de6b Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Tue, 16 Sep 2025 14:37:25 -0500 Subject: [PATCH 12/21] FloorDiv --- sealir/eqsat/py_eqsat.py | 2 ++ sealir/eqsat/rvsdg_convert.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/sealir/eqsat/py_eqsat.py b/sealir/eqsat/py_eqsat.py index bf571fd..3f8066d 100644 --- a/sealir/eqsat/py_eqsat.py +++ b/sealir/eqsat/py_eqsat.py @@ -91,6 +91,8 @@ def Py_Div(a: Term, b: Term) -> Term: ... @function def Py_DivIO(io: Term, a: Term, b: Term) -> Term: ... +@function +def Py_FloorDivIO(io: Term, a: Term, b: Term) -> Term: ... @function def Py_MatMult(a: Term, b: Term) -> Term: ... diff --git a/sealir/eqsat/rvsdg_convert.py b/sealir/eqsat/rvsdg_convert.py index 8003192..bbe4a1a 100644 --- a/sealir/eqsat/rvsdg_convert.py +++ b/sealir/eqsat/rvsdg_convert.py @@ -163,6 +163,9 @@ def coro(expr: SExpr, state: ase.TraverseState): case "/": res = py_eqsat.Py_DivIO(ioterm, lhsterm, rhsterm) + case "//": + res = py_eqsat.Py_FloorDivIO(ioterm, lhsterm, rhsterm) + case "@": res = py_eqsat.Py_MatMultIO(ioterm, lhsterm, rhsterm) From ae04966b01521cdd9b0c110484f1774aaf04d105 Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Fri, 19 Sep 2025 17:22:46 -0500 Subject: [PATCH 13/21] copy_into_tree is missing downcast --- sealir/ase.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sealir/ase.py b/sealir/ase.py index fe7ade9..4fc0c5a 100644 --- a/sealir/ase.py +++ b/sealir/ase.py @@ -814,7 +814,8 @@ def copy_tree_into(self: SExpr, tape: Tape) -> SExpr: Returns a fresh Expr in the new tape. """ oldtree = self._tape - crawler = TapeCrawler(oldtree, self._get_downcast()) + downcast = self._get_downcast() + crawler = TapeCrawler(oldtree, downcast) crawler.seek(self._handle) liveset = set(_select(crawler.walk_descendants(), 1)) surviving = sorted(liveset) @@ -836,7 +837,8 @@ def copy_tree_into(self: SExpr, tape: Tape) -> SExpr: out = tape.read_value(mapping[self._handle]) assert isinstance(out, SExpr) - return out + + return downcast(out) class BasicSExpr(SExpr): From 2189b399555d300350bea059f039a26fb3cb8e0a Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Mon, 27 Oct 2025 15:07:15 -0500 Subject: [PATCH 14/21] Fix tests --- sealir/eqsat/py_eqsat.py | 2 ++ sealir/eqsat/rvsdg_extract.py | 1 - sealir/eqsat/rvsdg_extract_details.py | 9 ++------- sealir/rvsdg/restructuring.py | 18 +++++++----------- sealir/tests/test_cost_extraction.py | 5 ++--- sealir/tests/test_rvsdg_egglog_from_source.py | 2 ++ sealir/tests/test_rvsdg_egraph_roundtrip.py | 5 ++++- 7 files changed, 19 insertions(+), 23 deletions(-) diff --git a/sealir/eqsat/py_eqsat.py b/sealir/eqsat/py_eqsat.py index 3f8066d..8d4fc67 100644 --- a/sealir/eqsat/py_eqsat.py +++ b/sealir/eqsat/py_eqsat.py @@ -91,9 +91,11 @@ def Py_Div(a: Term, b: Term) -> Term: ... @function def Py_DivIO(io: Term, a: Term, b: Term) -> Term: ... + @function def Py_FloorDivIO(io: Term, a: Term, b: Term) -> Term: ... + @function def Py_MatMult(a: Term, b: Term) -> Term: ... diff --git a/sealir/eqsat/rvsdg_extract.py b/sealir/eqsat/rvsdg_extract.py index 9f99325..fcc7709 100644 --- a/sealir/eqsat/rvsdg_extract.py +++ b/sealir/eqsat/rvsdg_extract.py @@ -112,7 +112,6 @@ def egraph_extraction( egraph, converter_class=converter_class, ) - return cost, expr diff --git a/sealir/eqsat/rvsdg_extract_details.py b/sealir/eqsat/rvsdg_extract_details.py index 011adf6..b08f07f 100644 --- a/sealir/eqsat/rvsdg_extract_details.py +++ b/sealir/eqsat/rvsdg_extract_details.py @@ -124,9 +124,7 @@ def handle( val = self.memo[k] if isinstance(val, ase.SExpr): values.append(val) - return grm.write( - rg.Rootset(tuple(values)) - ) + return grm.write(rg.Rootset(tuple(values))) allow_dynamic_op = self.allow_dynamic_op @@ -306,7 +304,6 @@ def fmt(kv): else: return f"{k}={v}" - fmt_children = "\n".join(map(fmt, children.items())) raise NotImplementedError( f"function of: {op!r} :: {node_type}, {children}\n{fmt_children}" @@ -548,6 +545,4 @@ def handle_generic( values.append(grm.write(rg.GenericList(name=k, children=v))) else: values.append(v) - return grm.write( - rg.Generic(name=str(op), children=tuple(values)) - ) + return grm.write(rg.Generic(name=str(op), children=tuple(values))) diff --git a/sealir/rvsdg/restructuring.py b/sealir/rvsdg/restructuring.py index f6687cb..6c6b2b3 100644 --- a/sealir/rvsdg/restructuring.py +++ b/sealir/rvsdg/restructuring.py @@ -647,18 +647,14 @@ def get_loopvar(ports): res = yield rval tar: SExpr - if ( - len(targets) == 1 - and targets[0]._head == "PyAst_Subscript" - ): + if len(targets) == 1 and targets[0]._head == "PyAst_Subscript": [lhs, indices, loc] = targets[0]._args - lhs = (yield lhs) - indices = (yield indices) - rval = (yield rval) - setitem = rg.PySetItem(io=ctx.load_io(), - obj=lhs, - index=indices, - value=rval) + lhs = yield lhs + indices = yield indices + rval = yield rval + setitem = rg.PySetItem( + io=ctx.load_io(), obj=lhs, index=indices, value=rval + ) return ctx.insert_io_node(setitem) if ( len(targets) == 1 diff --git a/sealir/tests/test_cost_extraction.py b/sealir/tests/test_cost_extraction.py index 4fd3855..0dbf7a6 100644 --- a/sealir/tests/test_cost_extraction.py +++ b/sealir/tests/test_cost_extraction.py @@ -90,9 +90,8 @@ def format_node(node): return node return tuple([node, *(format_node(child) for child in children)]) - # Get all nodes with no incoming edges (roots) - roots = [n for n in graph.nodes if graph.in_degree(n) == 0] - return [format_node(root) for root in roots] + roots = [k for k in graph.nodes if k.endswith("GraphRoot")] + return [format_node(x) for x in roots] def test_cost_duplicated_term(): diff --git a/sealir/tests/test_rvsdg_egglog_from_source.py b/sealir/tests/test_rvsdg_egglog_from_source.py index 9ce0f5e..9d8893d 100644 --- a/sealir/tests/test_rvsdg_egglog_from_source.py +++ b/sealir/tests/test_rvsdg_egglog_from_source.py @@ -25,8 +25,10 @@ def define_egraph(egraph, func): if verbose: print(egraph.extract(root)) + return root cost, extracted = middle_end(rvsdg_expr, define_egraph) + if verbose: print("Extracted from EGraph".center(80, "=")) print("cost =", cost) diff --git a/sealir/tests/test_rvsdg_egraph_roundtrip.py b/sealir/tests/test_rvsdg_egraph_roundtrip.py index 65f4ef1..6eac2e1 100644 --- a/sealir/tests/test_rvsdg_egraph_roundtrip.py +++ b/sealir/tests/test_rvsdg_egraph_roundtrip.py @@ -32,9 +32,11 @@ def middle_end( # Extraction extract_kwargs = extract_kwargs or {} - cost, extracted = egraph_extraction( + cost, rootset = egraph_extraction( egraph, rvsdg_expr, cost_model=cost_model, **extract_kwargs ) + [extracted] = [node for node in rootset._args if node._head == "Func"] + return cost, extracted @@ -57,6 +59,7 @@ def display_egraph(egraph: EGraph, func): cost, extracted = middle_end( rvsdg_expr, display_egraph, extract_kwargs=extract_kwargs ) + print("Extracted from EGraph".center(80, "=")) print("cost =", cost) print(rvsdg.format_rvsdg(extracted)) From b78e9d37628fce105505a1899557bf107097252c Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Mon, 27 Oct 2025 16:54:52 -0500 Subject: [PATCH 15/21] Refactor --- sealir/eqsat/rvsdg_extract.py | 57 ++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/sealir/eqsat/rvsdg_extract.py b/sealir/eqsat/rvsdg_extract.py index fcc7709..d905f98 100644 --- a/sealir/eqsat/rvsdg_extract.py +++ b/sealir/eqsat/rvsdg_extract.py @@ -214,6 +214,40 @@ def __init__(self, graph_json: EGraphJsonDict, root_eclass, cost_model): self.class_data[node.eclass].add(k) self.cost_model = cost_model + def _create_common_root(self, G: nx.DiGraph, eclassmap: dict[str, set[str]]) -> str: + """Create a common root node and add it to the graph. + + Args: + G: The eclass dependency graph + eclassmap: Mapping from eclass names to sets of node names + + Returns: + The name of the common root node + """ + # Get all nodes with in-degree of 0 + common_root = "common_root" + + # Do the conversion back into RVSDG + root_eclasses = [ + node + for node, in_degree in G.in_degree() + if in_degree == 0 + and self.graph_json["class_data"][node]["type"] != "Unit" + ] + G.add_node(common_root, shape="rect") + for n in root_eclasses: + G.add_edge(common_root, n) + + self.nodes[common_root] = Node( + children=[next(iter(eclassmap[ec])) for ec in root_eclasses], + cost=0.0, + eclass="common_root", + op="common_root", + subsumed=False, + ) + + return common_root + def _compute_cost( self, max_iter=1000, @@ -273,27 +307,8 @@ def _compute_cost( for child in enode.children: G.add_edge(k, nodes[child].eclass) - # Get all nodes with in-degree of 0 - common_root = "common_root" - - # Do the conversion back into RVSDG - root_eclasses = [ - node - for node, in_degree in G.in_degree() - if in_degree == 0 - and self.graph_json["class_data"][node]["type"] != "Unit" - ] - G.add_node(common_root, shape="rect") - for n in root_eclasses: - G.add_edge(common_root, n) - - self.nodes[common_root] = Node( - children=[next(iter(eclassmap[ec])) for ec in root_eclasses], - cost=0.0, - eclass="common_root", - op="common_root", - subsumed=False, - ) + # Create common root node + common_root = self._create_common_root(G, eclassmap) # Get per-node cost function cm = self.cost_model From 3ec2b7fc3a2953ce466da3124bd40db4aa821e45 Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Tue, 28 Oct 2025 15:23:23 -0500 Subject: [PATCH 16/21] Handle children dag eclass cycle --- sealir/eqsat/rvsdg_extract.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/sealir/eqsat/rvsdg_extract.py b/sealir/eqsat/rvsdg_extract.py index d905f98..97249f8 100644 --- a/sealir/eqsat/rvsdg_extract.py +++ b/sealir/eqsat/rvsdg_extract.py @@ -359,7 +359,6 @@ def propagate_cost(state_tracker): for k in reversed(topo_ordered): if k not in eclassmap: node = nodes[k] - cost_dag = dagcost.compute_cost(k) cost = sum(cost_dag.values()) selections[node.eclass].put(cost, k) @@ -477,6 +476,8 @@ class SubgraphCost: "Stores computed node costs for reuse." _stats: _SubgraphCostStats = field(default_factory=_SubgraphCostStats) "Tracks cache hits and misses." + _visited_dag_eclass: list[str] = field(default_factory=list) + "Tracks Children DAG path to avoid recursion" def compute_cost(self, nodename: str) -> dict[str, float]: if (cc := self._cache.get(nodename)) is None: @@ -527,15 +528,21 @@ def _compute_cost(self, nodename: str) -> dict[str, float]: return costs def _compute_choice(self, eclass: str) -> dict[str, float]: - selections = self.selections - - choices = selections[eclass] - if not choices: + if eclass in self._visited_dag_eclass: + # Avoid recursion return {eclass: MAX_COST} - best = choices.best() - - return self.compute_cost(best.name) - + self._visited_dag_eclass.append(eclass) + try: + selections = self.selections + + choices = selections[eclass] + if not choices: + return {eclass: MAX_COST} + best = choices.best() + + return self.compute_cost(best.name) + finally: + self._visited_dag_eclass.pop() class ExtractionError(Exception): pass From 21e7cb14827bb2633333938cf4c20b5e01481c09 Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Wed, 5 Nov 2025 10:51:23 -0600 Subject: [PATCH 17/21] Fix typing --- sealir/egg_utils.py | 2 +- sealir/eqsat/rvsdg_extract.py | 20 ++++++++++++++++---- sealir/llvm_pyapi_backend.py | 6 +++--- sealir/rvsdg/evaluating.py | 8 ++++---- sealir/rvsdg/restructuring.py | 2 +- sealir/tests/test_rvsdg_egraph_roundtrip.py | 3 +-- 6 files changed, 26 insertions(+), 15 deletions(-) diff --git a/sealir/egg_utils.py b/sealir/egg_utils.py index bc19556..fa1c7b1 100644 --- a/sealir/egg_utils.py +++ b/sealir/egg_utils.py @@ -107,7 +107,7 @@ def extract_eclasses(egraph: EGraph) -> EClassData: def reconstruct( - nodes: dict[str, dict], class_data: dict[str, str] + nodes: dict[str, dict], class_data: dict[str, dict[str, str]] ) -> dict[str, Term]: done: dict[str, Term] = {} diff --git a/sealir/eqsat/rvsdg_extract.py b/sealir/eqsat/rvsdg_extract.py index 97249f8..5d833b5 100644 --- a/sealir/eqsat/rvsdg_extract.py +++ b/sealir/eqsat/rvsdg_extract.py @@ -12,6 +12,7 @@ import networkx as nx from egglog import EGraph +from sealir.ase import SExpr from .egraph_utils import EGraphJsonDict from .rvsdg_extract_details import EGraphToRVSDG @@ -87,7 +88,7 @@ def egraph_extraction( cost_model=None, converter_class=EGraphToRVSDG, stats: dict[str, Any] | None = None, -): +) -> tuple[float, CommonRoot]: gdct: EGraphJsonDict = json.loads( egraph._serialize( n_inline_leaves=0, split_primitive_outputs=False @@ -123,7 +124,7 @@ def convert_to_rvsdg( egraph: EGraph, *, converter_class, -): +) -> CommonRoot: # Get declarations so we have named fields state = egraph._state decls = state.__egg_decls__ @@ -167,7 +168,15 @@ def iterator(node_iter): yield node, children conversion = converter_class(gdct, rvsdg_sexpr, egg_fn_to_arg_names) - return conversion.run(iterator(node_iterator)) + return CommonRoot(conversion.run(iterator(node_iterator))) + + +@dataclass(frozen=True) +class CommonRoot: + root: SExpr + + def filter_children(self, pred) -> list[SExpr]: + return list(filter(pred, self.root._args)) def get_graph_root(graph_json: EGraphJsonDict) -> set[str]: @@ -214,7 +223,9 @@ def __init__(self, graph_json: EGraphJsonDict, root_eclass, cost_model): self.class_data[node.eclass].add(k) self.cost_model = cost_model - def _create_common_root(self, G: nx.DiGraph, eclassmap: dict[str, set[str]]) -> str: + def _create_common_root( + self, G: nx.DiGraph, eclassmap: dict[str, set[str]] + ) -> str: """Create a common root node and add it to the graph. Args: @@ -544,6 +555,7 @@ def _compute_choice(self, eclass: str) -> dict[str, float]: finally: self._visited_dag_eclass.pop() + class ExtractionError(Exception): pass diff --git a/sealir/llvm_pyapi_backend.py b/sealir/llvm_pyapi_backend.py index afca1fe..2f90e0a 100644 --- a/sealir/llvm_pyapi_backend.py +++ b/sealir/llvm_pyapi_backend.py @@ -180,8 +180,8 @@ def get_region_args(): return SSAValue(builder.function.args[idx]) case rg.Unpack(val=source, idx=int(idx)): - ports: PackedValues = yield source - return ports[idx] + packedvalues: PackedValues = yield source + return packedvalues[idx] case rg.PyBinOpPure(op=op, lhs=lhs, rhs=rhs): lhsval = (yield lhs).value @@ -376,7 +376,7 @@ def get_region_args(): # end loop builder.position_at_end(bb_endloop) # Returns the value from the loop body because this is a tail loop - return loopout_values + return PackedValues.make(*loopout_values) case rg.PyForLoop( iter_arg_idx=int(iter_arg_idx), diff --git a/sealir/rvsdg/evaluating.py b/sealir/rvsdg/evaluating.py index 95339d0..0160dce 100644 --- a/sealir/rvsdg/evaluating.py +++ b/sealir/rvsdg/evaluating.py @@ -7,7 +7,7 @@ from contextlib import contextmanager from dataclasses import dataclass from pprint import pprint -from typing import Any, Sequence, Type, TypeAlias +from typing import Any, Sequence, Type, TypeAlias, Mapping from sealir import ase, rvsdg from sealir.rvsdg import grammar as rg @@ -57,10 +57,10 @@ def evaluate( callargs: tuple, callkwargs: dict, *, - init_scope: dict | None = None, + init_scope: Mapping | None = None, init_state: ase.TraverseState | None = None, init_memo: dict | None = None, - global_ns: dict | None = None, + global_ns: Mapping | None = None, dbginfo: rvsdg.SourceDebugInfo, ): stack: list[dict[str, Any]] = [{}] @@ -81,7 +81,7 @@ def push(callargs): finally: stack.pop() - def scope() -> dict[str, Any]: + def scope() -> Mapping[str, Any]: return stack[-1] def get_region_args() -> list[Any]: diff --git a/sealir/rvsdg/restructuring.py b/sealir/rvsdg/restructuring.py index 6c6b2b3..0b3db9f 100644 --- a/sealir/rvsdg/restructuring.py +++ b/sealir/rvsdg/restructuring.py @@ -438,7 +438,7 @@ def unpack_pyast_name(sexpr: SExpr) -> str: return cast(str, sexpr._args[0]) -def is_directive(text: str) -> str: +def is_directive(text: str) -> bool: return text.startswith("#file:") or text.startswith("#loc:") diff --git a/sealir/tests/test_rvsdg_egraph_roundtrip.py b/sealir/tests/test_rvsdg_egraph_roundtrip.py index 6eac2e1..1433606 100644 --- a/sealir/tests/test_rvsdg_egraph_roundtrip.py +++ b/sealir/tests/test_rvsdg_egraph_roundtrip.py @@ -35,8 +35,7 @@ def middle_end( cost, rootset = egraph_extraction( egraph, rvsdg_expr, cost_model=cost_model, **extract_kwargs ) - [extracted] = [node for node in rootset._args if node._head == "Func"] - + [extracted] = rootset.filter_children(lambda node: node._head == "Func") return cost, extracted From 65f4ae9eae9111d3d104688f89cef2ec6429774e Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Wed, 5 Nov 2025 16:26:16 -0600 Subject: [PATCH 18/21] Refactor extraction --- sealir/eqsat/rvsdg_extract.py | 189 +++++++++++++------- sealir/rvsdg/evaluating.py | 2 +- sealir/tests/test_cost_extraction.py | 68 ++++--- sealir/tests/test_rvsdg_egraph_roundtrip.py | 28 ++- 4 files changed, 168 insertions(+), 119 deletions(-) diff --git a/sealir/eqsat/rvsdg_extract.py b/sealir/eqsat/rvsdg_extract.py index 5d833b5..1831b81 100644 --- a/sealir/eqsat/rvsdg_extract.py +++ b/sealir/eqsat/rvsdg_extract.py @@ -7,12 +7,13 @@ from dataclasses import dataclass, field from itertools import starmap from pprint import pformat -from typing import Any, Callable, NamedTuple, Sequence +from typing import Any, Callable, NamedTuple, Self, Sequence import networkx as nx from egglog import EGraph from sealir.ase import SExpr + from .egraph_utils import EGraphJsonDict from .rvsdg_extract_details import EGraphToRVSDG @@ -81,58 +82,54 @@ def eval_constant_node( return None +@dataclass(frozen=True) +class ExtractionResult: + extraction: Extraction + root: str + cost: float + graph: nx.MultiDiGraph + + def extract_sexpr(self, original_sexpr: SExpr, converter_class) -> SExpr: + """Extract back to SExpr format using any converter class.""" + expr = _convert_graph_to_sexpr( + self, + original_sexpr, + converter_class=converter_class, + ) + return expr + + def egraph_extraction( egraph: EGraph, - rvsdg_sexpr, + cost_model: CostModel | None = None, *, - cost_model=None, - converter_class=EGraphToRVSDG, - stats: dict[str, Any] | None = None, -) -> tuple[float, CommonRoot]: - gdct: EGraphJsonDict = json.loads( - egraph._serialize( - n_inline_leaves=0, split_primitive_outputs=False - ).to_json() - ) - [root] = get_graph_root(gdct) - root_eclass = gdct["nodes"][root]["eclass"] - - if stats is not None: - stats["num_enodes"] = len(gdct["nodes"]) - stats["num_eclasses"] = len(gdct["class_data"]) - - cost_model = CostModel() if cost_model is None else cost_model - extraction = Extraction(gdct, root_eclass, cost_model) - cost, exgraph = extraction.choose(stats=stats) - - expr = convert_to_rvsdg( - exgraph, - gdct, - rvsdg_sexpr, - root, - egraph, - converter_class=converter_class, - ) - return cost, expr + stats: dict[str, object] | None = None, +) -> Extraction: + """Extract from an egraph and return an Extraction object. + + The returned Extraction can be used to: + - .compute() to compute costs and return ExtractionResult + - .extract_graph_root() to extract using auto-detected GraphRoot + - .extract_enode(root) to extract from a specific enode + - .extract_eclass(eclass) to extract from an equivalence class + """ + return Extraction.from_egraph(egraph, cost_model, stats=stats) -def convert_to_rvsdg( - exgraph: nx.MultiDiGraph, - gdct: EGraphJsonDict, - rvsdg_sexpr, - root: str, - egraph: EGraph, +def _convert_graph_to_sexpr( + result: ExtractionResult, + original_sexpr, *, converter_class, -) -> CommonRoot: +) -> SExpr: # Get declarations so we have named fields - state = egraph._state + state = result.extraction.egraph._state decls = state.__egg_decls__ # Do the conversion back into RVSDG - common_root = "common_root" - - node_iterator = list(nx.dfs_postorder_nodes(exgraph, source=common_root)) + node_iterator = list( + nx.dfs_postorder_nodes(result.graph, source=result.root) + ) def egg_fn_to_arg_names(egg_fn: str) -> tuple[str, ...]: for ref in state.egg_fn_to_callable_refs[egg_fn]: @@ -146,13 +143,13 @@ def iterator(node_iter): # Get children nodes in order children = [ (data["label"], child) - for _, child, data in exgraph.out_edges(node, data=True) + for _, child, data in result.graph.out_edges(node, data=True) ] if children: children.sort() _, children = zip(*children) # extract argument names - if node == common_root: + if node == "common_root": yield node, children else: kind, _, egg_fn = node.split("-") @@ -167,16 +164,10 @@ def iterator(node_iter): raise NotImplementedError(f"kind is {kind!r}") yield node, children - conversion = converter_class(gdct, rvsdg_sexpr, egg_fn_to_arg_names) - return CommonRoot(conversion.run(iterator(node_iterator))) - - -@dataclass(frozen=True) -class CommonRoot: - root: SExpr - - def filter_children(self, pred) -> list[SExpr]: - return list(filter(pred, self.root._args)) + conversion = converter_class( + result.extraction.graph_json, original_sexpr, egg_fn_to_arg_names + ) + return conversion.run(iterator(node_iterator)) def get_graph_root(graph_json: EGraphJsonDict) -> set[str]: @@ -203,14 +194,24 @@ def render_extraction_graph(G: nx.MultiDiGraph, filename: str): class Extraction: + graph_json: EGraphJsonDict nodes: dict[str, Node] node_types: dict[str, str] root_eclass: str cost_model: CostModel + egraph: EGraph + _selections: dict[str, Bucket] | None + _computed: bool _DEBUG = False - def __init__(self, graph_json: EGraphJsonDict, root_eclass, cost_model): + def __init__( + self, + graph_json: EGraphJsonDict, + cost_model, + egraph: EGraph, + stats: dict[str, Any] | None = None, + ): self.graph_json = graph_json self.root_eclass = "common_root" self.class_data = defaultdict(set) @@ -222,6 +223,34 @@ def __init__(self, graph_json: EGraphJsonDict, root_eclass, cost_model): for k, node in self.nodes.items(): self.class_data[node.eclass].add(k) self.cost_model = cost_model + self.egraph = egraph + self._selections = None + self._computed = False + self.stats = stats or {} + + @classmethod + def from_egraph( + cls, + egraph: EGraph, + cost_model: CostModel | None = None, + *, + stats: dict[str, object] | None = None, + ) -> "Extraction": + """Create an Extraction object from an egraph. + + This replaces the standalone egraph_extraction function. + """ + gdct: EGraphJsonDict = json.loads( + egraph._serialize( + n_inline_leaves=0, split_primitive_outputs=False + ).to_json() + ) + if stats is not None: + stats["num_enodes"] = len(gdct["nodes"]) + stats["num_eclasses"] = len(gdct["class_data"]) + + cost_model = CostModel() if cost_model is None else cost_model + return cls(gdct, cost_model, egraph, stats) def _create_common_root( self, G: nx.DiGraph, eclassmap: dict[str, set[str]] @@ -329,7 +358,7 @@ def _compute_cost( node = nodes[k] children_eclasses = [nodes[c].eclass for c in node.children] if k == "common_root": - nodecostmap[k] = cm.get_simple(0) + nodecostmap[k] = cm.get_simple(1) else: nodecostmap[k] = cm.get_cost_function( nodename=k, @@ -406,18 +435,44 @@ def propagate_cost(state_tracker): else: last_changed_i = round_i - return selections, round_i + return dict(**selections), round_i - def choose( - self, stats: dict[str, Any] | None = None - ) -> tuple[float, nx.MultiDiGraph]: - selections, round_i = self._compute_cost() - if stats is not None: - stats["extraction_iteration_count"] = round_i + def compute(self) -> Self: + """Compute extraction costs (idempotent).""" + if not self._computed: + self._selections, round_i = self._compute_cost() + self.stats["extraction_iteration_count"] = round_i + self._computed = True + + return self + def extract_enode(self, root: str) -> ExtractionResult: + """Extract starting from an enode (specific node).""" + + root_eclass = self.nodes[root].eclass + return self._do_extract(root, root_eclass) + + def extract_eclass(self, root_eclass: str) -> ExtractionResult: + """Extract starting from an equivalence class.""" + + return self._do_extract(root_eclass, root_eclass) + + def extract_graph_root(self) -> ExtractionResult: + """Extract using the automatically detected GraphRoot.""" + + [root] = get_graph_root(self.graph_json) + root_eclass = self.nodes[root].eclass + return self._do_extract(root, root_eclass) + + def _do_extract(self, root: str, root_eclass: str) -> ExtractionResult: + """Internal method to perform the actual extraction.""" + if not self._computed: + self.compute() + + selections = self._selections + assert selections is not None nodes = self.nodes - assert self.root_eclass == "common_root" - chosen_root, rootcost = selections[self.root_eclass].best() + chosen_root, rootcost = selections[root_eclass].best() # make selected graph G = nx.MultiDiGraph() @@ -438,7 +493,9 @@ def choose( if self._DEBUG: render_extraction_graph(G, "chosen") - return rootcost, G + + # Create a new ExtractionResult that holds the extracted data + return ExtractionResult(self, root, rootcost, G) @dataclass diff --git a/sealir/rvsdg/evaluating.py b/sealir/rvsdg/evaluating.py index 0160dce..ef88c35 100644 --- a/sealir/rvsdg/evaluating.py +++ b/sealir/rvsdg/evaluating.py @@ -7,7 +7,7 @@ from contextlib import contextmanager from dataclasses import dataclass from pprint import pprint -from typing import Any, Sequence, Type, TypeAlias, Mapping +from typing import Any, Mapping, Sequence, Type, TypeAlias from sealir import ase, rvsdg from sealir.rvsdg import grammar as rg diff --git a/sealir/tests/test_cost_extraction.py b/sealir/tests/test_cost_extraction.py index 0dbf7a6..f436205 100644 --- a/sealir/tests/test_cost_extraction.py +++ b/sealir/tests/test_cost_extraction.py @@ -1,5 +1,3 @@ -import json - from egglog import ( EGraph, Expr, @@ -12,12 +10,7 @@ ) from sealir.eqsat.rvsdg_eqsat import GraphRoot -from sealir.eqsat.rvsdg_extract import ( - CostModel, - EGraphJsonDict, - Extraction, - get_graph_root, -) +from sealir.eqsat.rvsdg_extract import CostModel, egraph_extraction class Term(Expr): @@ -66,22 +59,6 @@ def simplify_pow(x: Term, i: i64): yield rewrite(Pow(x, 1)).to(x) -def _extraction(egraph, cost_model=None): - # TODO move this into rvsdg_extract - gdct: EGraphJsonDict = json.loads( - egraph._serialize( - n_inline_leaves=0, split_primitive_outputs=False - ).to_json() - ) - [root] = get_graph_root(gdct) - root_eclass = gdct["nodes"][root]["eclass"] - - cost_model = cost_model or CostModel() - _extraction = Extraction(gdct, root_eclass, cost_model) - cost, exgraph = _extraction.choose() - return cost, exgraph - - def _flatten_multidigraph(graph): def format_node(node): # Get all successors (children) of the node @@ -94,6 +71,15 @@ def format_node(node): return [format_node(x) for x in roots] +def _extraction(egraph): + extraction = egraph_extraction(egraph, cost_model=MyCostModel()) + # Use the new explicit method for extracting with auto-detected root + result = extraction.extract_graph_root() + exgraph = result.graph + [extracted] = _flatten_multidigraph(exgraph) + return result.cost, extracted + + def test_cost_duplicated_term(): A = Term("A") B = Term("B") @@ -101,8 +87,7 @@ def test_cost_duplicated_term(): egraph = EGraph() egraph.register(GraphRoot(expr)) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - [extracted] = _flatten_multidigraph(exgraph) + cost, extracted = _extraction(egraph) print("extracted:", extracted) # t1 = function-0-Term___init__(primitive-String-2684354568) # 1 ------------^ @@ -141,12 +126,18 @@ def test_simplify_pow_2(): egraph = EGraph() egraph.register(GraphRoot(expr)) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - assert cost == 14 + extraction = egraph_extraction(egraph, cost_model=MyCostModel()) + # Use the new explicit method for extracting from default root + result = extraction.extract_graph_root() + assert result.cost == 14 egraph.run(simplify_pow.saturate()) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - [extracted] = _flatten_multidigraph(exgraph) + + extraction = egraph_extraction(egraph, cost_model=MyCostModel()) + result = extraction.extract_graph_root() + cost = result.cost + + [extracted] = _flatten_multidigraph(result.graph) print("extracted:", extracted) # t1 = function-1-Term___init__(primitive-String-2684354568) @@ -175,11 +166,16 @@ def test_simplify_pow_3(): egraph = EGraph() egraph.register(GraphRoot(expr)) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - assert cost == 14 + extraction = egraph_extraction(egraph, cost_model=MyCostModel()) + result = extraction.extract_graph_root() + exgraph = result.graph + assert result.cost == 14 egraph.run(simplify_pow.saturate()) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) + extraction = egraph_extraction(egraph, cost_model=MyCostModel()) + result = extraction.extract_graph_root() + exgraph = result.graph + cost = result.cost [extracted] = _flatten_multidigraph(exgraph) print("extracted:", extracted) @@ -224,8 +220,7 @@ def test_loop_multiplier(): egraph = EGraph() egraph.register(GraphRoot(expr)) egraph.run(simplify_pow.saturate()) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - [extracted] = _flatten_multidigraph(exgraph) + cost, extracted = _extraction(egraph) print("extracted", extracted) # Pow(A, 3) = 4 (see test_simplify_pow_3) # Loop(A, 3) = (4 * 23 + 13) + 4 @@ -253,8 +248,7 @@ def test_const_factor(): expr = Repeat(A, 13) egraph = EGraph() egraph.register(GraphRoot(expr)) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - [extracted] = _flatten_multidigraph(exgraph) + cost, extracted = _extraction(egraph) dagcost = 2 + 1 + 1 # ^ Term("A") # ^ literal 13 diff --git a/sealir/tests/test_rvsdg_egraph_roundtrip.py b/sealir/tests/test_rvsdg_egraph_roundtrip.py index 1433606..d9bae30 100644 --- a/sealir/tests/test_rvsdg_egraph_roundtrip.py +++ b/sealir/tests/test_rvsdg_egraph_roundtrip.py @@ -3,7 +3,11 @@ from sealir import rvsdg from sealir.eqsat.rvsdg_convert import egraph_conversion from sealir.eqsat.rvsdg_eqsat import GraphRoot -from sealir.eqsat.rvsdg_extract import egraph_extraction +from sealir.eqsat.rvsdg_extract import ( + EGraphToRVSDG, + egraph_extraction, + get_graph_root, +) from sealir.llvm_pyapi_backend import llvm_codegen @@ -16,9 +20,7 @@ def frontend(fn): return rvsdg_expr, dbginfo -def middle_end( - rvsdg_expr, apply_to_egraph, cost_model=None, extract_kwargs=None -): +def middle_end(rvsdg_expr, apply_to_egraph, cost_model=None, stats=None): """The middle end encode the RVSDG into a EGraph to apply rewrite rules. After that, it is extracted back into RVSDG. """ @@ -31,12 +33,10 @@ def middle_end( apply_to_egraph(egraph, func) # Extraction - extract_kwargs = extract_kwargs or {} - cost, rootset = egraph_extraction( - egraph, rvsdg_expr, cost_model=cost_model, **extract_kwargs - ) - [extracted] = rootset.filter_children(lambda node: node._head == "Func") - return cost, extracted + extraction = egraph_extraction(egraph, cost_model=cost_model, stats=stats) + result = extraction.extract_graph_root() + expr = result.extract_sexpr(rvsdg_expr, EGraphToRVSDG) + return result.cost, expr def compiler_pipeline(fn, args, *, verbose=False): @@ -54,15 +54,13 @@ def display_egraph(egraph: EGraph, func): return root - extract_kwargs = dict(stats={}) - cost, extracted = middle_end( - rvsdg_expr, display_egraph, extract_kwargs=extract_kwargs - ) + stats = {} + cost, extracted = middle_end(rvsdg_expr, display_egraph, stats=stats) print("Extracted from EGraph".center(80, "=")) print("cost =", cost) print(rvsdg.format_rvsdg(extracted)) - print("stats:", extract_kwargs) + print("stats:", stats) jt = llvm_codegen(rvsdg_expr) res = jt(*args) From 8e11963527c645cc2aead167f56beb3e8b84a4b3 Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Thu, 6 Nov 2025 15:51:02 -0600 Subject: [PATCH 19/21] Continue refactor --- sealir/ase.py | 15 +++++--- sealir/eqsat/rvsdg_extract.py | 42 +++++++++++++++++----- sealir/eqsat/rvsdg_extract_details.py | 52 +++++++++++++++------------ 3 files changed, 74 insertions(+), 35 deletions(-) diff --git a/sealir/ase.py b/sealir/ase.py index 4fc0c5a..e6660ca 100644 --- a/sealir/ase.py +++ b/sealir/ase.py @@ -159,7 +159,12 @@ def fixup(x): @graphviz_function def render_dot( - self, *, gv, show_metadata: bool = False, only_reachable: bool = False + self, + *, + gv, + show_metadata: bool = False, + only_reachable: bool = False, + source_node: SExpr | None = None, ): def make_label(i, x): if isinstance(x, SExpr): @@ -172,7 +177,10 @@ def make_label(i, x): crawler = TapeCrawler(self, self._downcast) # Records that are children of the last node - crawler.seek(self.last()) + if source_node is None: + crawler.seek(self.last()) + else: + crawler.seek(source_node._handle) # Seek to the first non-metadata in the back while crawler.pos > 0: @@ -232,8 +240,7 @@ def make_label(i, x): if not head.startswith(metadata_prefix): lastname = nodename # emit start - if lastname: - edges.append((("start", lastname), {})) + edges.append((("start", f"node{source_node._handle}"), {})) # emit edges for args, kwargs in edges: g.edge(*args, **kwargs) diff --git a/sealir/eqsat/rvsdg_extract.py b/sealir/eqsat/rvsdg_extract.py index 1831b81..bc46edb 100644 --- a/sealir/eqsat/rvsdg_extract.py +++ b/sealir/eqsat/rvsdg_extract.py @@ -7,7 +7,15 @@ from dataclasses import dataclass, field from itertools import starmap from pprint import pformat -from typing import Any, Callable, NamedTuple, Self, Sequence +from typing import ( + Any, + Callable, + Iterator, + MutableMapping, + NamedTuple, + Self, + Sequence, +) import networkx as nx from egglog import EGraph @@ -89,12 +97,18 @@ class ExtractionResult: cost: float graph: nx.MultiDiGraph - def extract_sexpr(self, original_sexpr: SExpr, converter_class) -> SExpr: + def extract_sexpr( + self, + original_sexpr: SExpr, + converter_class, + memo: MutableMapping | None = None, + ) -> SExpr: """Extract back to SExpr format using any converter class.""" expr = _convert_graph_to_sexpr( self, original_sexpr, converter_class=converter_class, + memo=memo, ) return expr @@ -121,6 +135,7 @@ def _convert_graph_to_sexpr( original_sexpr, *, converter_class, + memo, ) -> SExpr: # Get declarations so we have named fields state = result.extraction.egraph._state @@ -165,7 +180,10 @@ def iterator(node_iter): yield node, children conversion = converter_class( - result.extraction.graph_json, original_sexpr, egg_fn_to_arg_names + result.extraction.graph_json, + original_sexpr, + egg_fn_to_arg_names, + memo=memo, ) return conversion.run(iterator(node_iterator)) @@ -446,25 +464,31 @@ def compute(self) -> Self: return self + def iter_graph_root(self) -> Iterator[str]: + return iter(self.nodes["common_root"].children) + def extract_enode(self, root: str) -> ExtractionResult: """Extract starting from an enode (specific node).""" root_eclass = self.nodes[root].eclass - return self._do_extract(root, root_eclass) + return self._do_extract(root_eclass) def extract_eclass(self, root_eclass: str) -> ExtractionResult: """Extract starting from an equivalence class.""" - return self._do_extract(root_eclass, root_eclass) + return self._do_extract(root_eclass) def extract_graph_root(self) -> ExtractionResult: """Extract using the automatically detected GraphRoot.""" [root] = get_graph_root(self.graph_json) root_eclass = self.nodes[root].eclass - return self._do_extract(root, root_eclass) + return self._do_extract(root_eclass) - def _do_extract(self, root: str, root_eclass: str) -> ExtractionResult: + def extract_common_root(self) -> ExtractionResult: + return self._do_extract("common_root") + + def _do_extract(self, root_eclass: str) -> ExtractionResult: """Internal method to perform the actual extraction.""" if not self._computed: self.compute() @@ -484,7 +508,7 @@ def _do_extract(self, root: str, root_eclass: str) -> ExtractionResult: if cur in visited: continue visited.add(cur) - + G.add_node(cur) for i, u in enumerate(nodes[cur].children): child_eclass = nodes[u].eclass child_key, cost = selections[child_eclass].best() @@ -495,7 +519,7 @@ def _do_extract(self, root: str, root_eclass: str) -> ExtractionResult: render_extraction_graph(G, "chosen") # Create a new ExtractionResult that holds the extracted data - return ExtractionResult(self, root, rootcost, G) + return ExtractionResult(self, chosen_root, rootcost, G) @dataclass diff --git a/sealir/eqsat/rvsdg_extract_details.py b/sealir/eqsat/rvsdg_extract_details.py index b08f07f..bccbf2e 100644 --- a/sealir/eqsat/rvsdg_extract_details.py +++ b/sealir/eqsat/rvsdg_extract_details.py @@ -44,14 +44,19 @@ def _check_setitem(self, key, value): class EGraphToRVSDG: allow_dynamic_op = False + unknown_use_generic = False grammar = Grammar def __init__( - self, gdct: EGraphJsonDict, rvsdg_sexpr: ase.SExpr, egg_fn_to_arg_names + self, + gdct: EGraphJsonDict, + rvsdg_sexpr: ase.SExpr, + egg_fn_to_arg_names, + memo: MutableMapping | None, ): self.rvsdg_sexpr = rvsdg_sexpr self.gdct = gdct - self.memo = _TypeCheckedDict() + self.memo = memo if memo is not None else _TypeCheckedDict() self.egg_fn_to_arg_names = egg_fn_to_arg_names def run(self, node_and_children): @@ -59,12 +64,14 @@ def run(self, node_and_children): with self.rvsdg_sexpr._tape as tape: grm = self.grammar(tape) for key, child_keys in node_and_children: - try: - last = memo[key] = self.handle(key, child_keys, grm) - except Exception as e: - e.add_note(f"Extracting: {key}, {child_keys}") - raise - + if key in memo: + last = memo[key] + else: + try: + last = memo[key] = self.handle(key, child_keys, grm) + except Exception as e: + e.add_note(f"Extracting: {key}, {child_keys}") + raise return last def lookup_sexpr(self, uid: int) -> ase.SExpr: @@ -262,9 +269,8 @@ def get_children() -> dict | list: ) if extended_handle is not NotImplemented: return extended_handle - raise NotImplementedError( - f"invalid Term: {node_type}, {children}" - ) + return self.handle_unknown(key, op, children, grm) + case "TermList", {"terms": terms}: return tuple(terms) case "PortList", {"ports": ports}: @@ -297,17 +303,7 @@ def get_children() -> dict | list: if res is not NotImplemented: return res - def fmt(kv): - k, v = kv - if isinstance(v, ase.SExpr): - return f"{k}={ase.pretty_str(v)}" - else: - return f"{k}={v}" - - fmt_children = "\n".join(map(fmt, children.items())) - raise NotImplementedError( - f"function of: {op!r} :: {node_type}, {children}\n{fmt_children}" - ) + return self.handle_unknown(key, op, children, grm) else: raise NotImplementedError(key) @@ -534,6 +530,18 @@ def handle_Py_Term(self, op: str, children: dict | list, grm: Grammar): def handle_region_attributes(self, key: str, grm: Grammar): return grm.write(rg.Attrs(())) + def handle_unknown( + self, key: str, op: str, children: dict | list, grm: Grammar + ): + if self.unknown_use_generic: + return self.handle_generic(key, op, children, grm) + else: + nodes = self.gdct["nodes"] + node = nodes[key] + eclass = node["eclass"] + node_type = self.gdct["class_data"][eclass]["type"] + raise NotImplementedError(f"{node_type}: {key} - {op}, {children}") + def handle_generic( self, key: str, op: str, children: dict | list, grm: Grammar ): From 354cc67cb5c3fa36d1eacc1b377a6811878c9c7c Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Fri, 7 Nov 2025 13:37:08 -0600 Subject: [PATCH 20/21] Fix exception --- sealir/grammar.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sealir/grammar.py b/sealir/grammar.py index 9ae41b2..2fa0c5f 100644 --- a/sealir/grammar.py +++ b/sealir/grammar.py @@ -228,8 +228,10 @@ def __getattr__( return super().__getattribute__(name) try: idx = self._slots[name] - except IndexError: - raise AttributeError(name) + except KeyError: + raise AttributeError( + f"{self._head} doesn't have attribute {name!r}" + ) if idx + 1 == len(self._rulety._fields): last_fd = self._rulety._fields[-1] From 248f470ba7b3e8e2e1d93eaaa54574f0c7b5a7f1 Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Fri, 7 Nov 2025 14:54:37 -0600 Subject: [PATCH 21/21] Rename extract_sexpr to convert --- sealir/eqsat/rvsdg_extract.py | 2 +- sealir/tests/test_rvsdg_egraph_roundtrip.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sealir/eqsat/rvsdg_extract.py b/sealir/eqsat/rvsdg_extract.py index bc46edb..9a1ccec 100644 --- a/sealir/eqsat/rvsdg_extract.py +++ b/sealir/eqsat/rvsdg_extract.py @@ -97,7 +97,7 @@ class ExtractionResult: cost: float graph: nx.MultiDiGraph - def extract_sexpr( + def convert( self, original_sexpr: SExpr, converter_class, diff --git a/sealir/tests/test_rvsdg_egraph_roundtrip.py b/sealir/tests/test_rvsdg_egraph_roundtrip.py index d9bae30..753bff0 100644 --- a/sealir/tests/test_rvsdg_egraph_roundtrip.py +++ b/sealir/tests/test_rvsdg_egraph_roundtrip.py @@ -35,7 +35,7 @@ def middle_end(rvsdg_expr, apply_to_egraph, cost_model=None, stats=None): # Extraction extraction = egraph_extraction(egraph, cost_model=cost_model, stats=stats) result = extraction.extract_graph_root() - expr = result.extract_sexpr(rvsdg_expr, EGraphToRVSDG) + expr = result.convert(rvsdg_expr, EGraphToRVSDG) return result.cost, expr