diff --git a/sealir/ase.py b/sealir/ase.py index fe7ade9..e6660ca 100644 --- a/sealir/ase.py +++ b/sealir/ase.py @@ -159,7 +159,12 @@ def fixup(x): @graphviz_function def render_dot( - self, *, gv, show_metadata: bool = False, only_reachable: bool = False + self, + *, + gv, + show_metadata: bool = False, + only_reachable: bool = False, + source_node: SExpr | None = None, ): def make_label(i, x): if isinstance(x, SExpr): @@ -172,7 +177,10 @@ def make_label(i, x): crawler = TapeCrawler(self, self._downcast) # Records that are children of the last node - crawler.seek(self.last()) + if source_node is None: + crawler.seek(self.last()) + else: + crawler.seek(source_node._handle) # Seek to the first non-metadata in the back while crawler.pos > 0: @@ -232,8 +240,7 @@ def make_label(i, x): if not head.startswith(metadata_prefix): lastname = nodename # emit start - if lastname: - edges.append((("start", lastname), {})) + edges.append((("start", f"node{source_node._handle}"), {})) # emit edges for args, kwargs in edges: g.edge(*args, **kwargs) @@ -814,7 +821,8 @@ def copy_tree_into(self: SExpr, tape: Tape) -> SExpr: Returns a fresh Expr in the new tape. """ oldtree = self._tape - crawler = TapeCrawler(oldtree, self._get_downcast()) + downcast = self._get_downcast() + crawler = TapeCrawler(oldtree, downcast) crawler.seek(self._handle) liveset = set(_select(crawler.walk_descendants(), 1)) surviving = sorted(liveset) @@ -836,7 +844,8 @@ def copy_tree_into(self: SExpr, tape: Tape) -> SExpr: out = tape.read_value(mapping[self._handle]) assert isinstance(out, SExpr) - return out + + return downcast(out) class BasicSExpr(SExpr): diff --git a/sealir/egg_utils.py b/sealir/egg_utils.py index bc19556..fa1c7b1 100644 --- a/sealir/egg_utils.py +++ b/sealir/egg_utils.py @@ -107,7 +107,7 @@ def extract_eclasses(egraph: EGraph) -> EClassData: def reconstruct( - nodes: dict[str, dict], class_data: dict[str, str] + nodes: dict[str, dict], class_data: dict[str, dict[str, str]] ) -> dict[str, Term]: done: dict[str, Term] = {} diff --git a/sealir/eqsat/py_eqsat.py b/sealir/eqsat/py_eqsat.py index 488d8ce..8d4fc67 100644 --- a/sealir/eqsat/py_eqsat.py +++ b/sealir/eqsat/py_eqsat.py @@ -6,6 +6,7 @@ function, i64, i64Like, + rewrite, rule, ruleset, union, @@ -31,6 +32,10 @@ def Py_NotIO(io: Term, term: Term) -> Term: ... +@function +def Py_NegIO(io: Term, term: Term) -> Term: ... + + @function def Py_Lt(a: Term, b: Term) -> Term: ... @@ -87,6 +92,10 @@ def Py_Div(a: Term, b: Term) -> Term: ... def Py_DivIO(io: Term, a: Term, b: Term) -> Term: ... +@function +def Py_FloorDivIO(io: Term, a: Term, b: Term) -> Term: ... + + @function def Py_MatMult(a: Term, b: Term) -> Term: ... @@ -119,6 +128,18 @@ def Py_AttrIO(io: Term, obj: Term, attrname: StringLike) -> Term: ... def Py_SubscriptIO(io: Term, obj: Term, index: Term) -> Term: ... +@function +def Py_SetitemIO(io: Term, obj: Term, index: Term, val: Term) -> Term: ... + + +@function +def Py_SliceIO(io: Term, lower: Term, upper: Term, step: Term) -> Term: ... + + +@function +def Py_Tuple(elems: TermList) -> Term: ... + + @function def Py_LoadGlobal(io: Term, name: StringLike) -> Term: ... @@ -225,5 +246,13 @@ def loop_rules( ) +@ruleset +def ruleset_literal_i64_folding(io: Term, ival: i64): + # Constant fold negation of integer literals + yield rewrite(Py_NegIO(io, Term.LiteralI64(ival)).getPort(1)).to( + Term.LiteralI64(0 - ival) + ) + + def make_rules(): - return loop_rules + return loop_rules | ruleset_literal_i64_folding diff --git a/sealir/eqsat/rvsdg_convert.py b/sealir/eqsat/rvsdg_convert.py index e39a085..bbe4a1a 100644 --- a/sealir/eqsat/rvsdg_convert.py +++ b/sealir/eqsat/rvsdg_convert.py @@ -133,6 +133,8 @@ def coro(expr: SExpr, state: ase.TraverseState): match op: case "not": res = py_eqsat.Py_NotIO(ioterm, operandterm) + case "-": + res = py_eqsat.Py_NegIO(ioterm, operandterm) case _: raise NotImplementedError(f"unsupported op: {op!r}") @@ -161,6 +163,9 @@ def coro(expr: SExpr, state: ase.TraverseState): case "/": res = py_eqsat.Py_DivIO(ioterm, lhsterm, rhsterm) + case "//": + res = py_eqsat.Py_FloorDivIO(ioterm, lhsterm, rhsterm) + case "@": res = py_eqsat.Py_MatMultIO(ioterm, lhsterm, rhsterm) @@ -238,6 +243,30 @@ def coro(expr: SExpr, state: ase.TraverseState): py_eqsat.Py_SubscriptIO(ioterm, valterm, idxterm) ) + case rg.PySetItem(io=io, obj=obj, value=value, index=index): + ioterm = yield io + objterm = yield obj + valterm = yield value + idxterm = yield index + return WrapTerm( + py_eqsat.Py_SetitemIO(ioterm, objterm, idxterm, valterm) + ) + + case rg.PySlice(io=io, lower=lower, upper=upper, step=step): + ioterm = yield io + lowerterm = yield lower + upperterm = yield upper + stepterm = yield step + return WrapTerm( + py_eqsat.Py_SliceIO(ioterm, lowerterm, upperterm, stepterm) + ) + + case rg.PyTuple(elems): + elemvals = [] + for el in elems: + elemvals.append((yield el)) + return py_eqsat.Py_Tuple(eg.termlist(*elemvals)) + case rg.PyInt(int(intval)): assert intval.bit_length() < 64 return eg.Term.LiteralI64(intval) diff --git a/sealir/eqsat/rvsdg_eqsat.py b/sealir/eqsat/rvsdg_eqsat.py index adbb606..332ac02 100644 --- a/sealir/eqsat/rvsdg_eqsat.py +++ b/sealir/eqsat/rvsdg_eqsat.py @@ -35,6 +35,8 @@ class DynInt(Expr): def __init__(self, num: i64Like): ... + def get(self) -> i64: ... + def __mul__(self, other: DynInt) -> DynInt: ... converter(i64, DynInt, DynInt) @@ -132,6 +134,15 @@ def dyn_index(self, target: Term) -> DynInt: ... class TermDict(Expr): def __init__(self, term_map: Map[String, Term]): ... + def lookup(self, key: StringLike) -> Unit: + """Trigger rule to lookup a value in the dict. + + This is needed for .get() to match + """ + ... + + def get(self, key: StringLike) -> Term: ... + @function(cost=MAXCOST) # max cost to make it unextractable def _dyn_index_partial(terms: Vec[Term], target: Term) -> DynInt: ... @@ -426,6 +437,20 @@ def ruleset_func_outputs( ).then(delete(x)) +@ruleset +def ruleset_termdict(mapping: Map[String, Term], key: String): + yield rule( + TermDict(mapping).lookup(key), + mapping.contains(key), + ).then(set_(TermDict(mapping).get(key)).to(mapping[key])) + + +@ruleset +def ruleset_dynint(n: i64, m: i64): + yield rule(DynInt(n)).then(set_(DynInt(n).get()).to(n)) + yield rewrite(DynInt(n) * DynInt(m)).to(DynInt(n * m)) + + ruleset_rvsdg_basic = ( ruleset_simplify_dbgvalue | ruleset_portlist_basic @@ -437,6 +462,8 @@ def ruleset_func_outputs( | ruleset_region_dyn_get | ruleset_region_propgate_output | ruleset_func_outputs + | ruleset_termdict + | ruleset_dynint ) diff --git a/sealir/eqsat/rvsdg_extract.py b/sealir/eqsat/rvsdg_extract.py index 7be2e71..9a1ccec 100644 --- a/sealir/eqsat/rvsdg_extract.py +++ b/sealir/eqsat/rvsdg_extract.py @@ -7,11 +7,21 @@ from dataclasses import dataclass, field from itertools import starmap from pprint import pformat -from typing import Any, Callable, NamedTuple, Sequence +from typing import ( + Any, + Callable, + Iterator, + MutableMapping, + NamedTuple, + Self, + Sequence, +) import networkx as nx from egglog import EGraph +from sealir.ase import SExpr + from .egraph_utils import EGraphJsonDict from .rvsdg_extract_details import EGraphToRVSDG @@ -80,51 +90,61 @@ def eval_constant_node( return None +@dataclass(frozen=True) +class ExtractionResult: + extraction: Extraction + root: str + cost: float + graph: nx.MultiDiGraph + + def convert( + self, + original_sexpr: SExpr, + converter_class, + memo: MutableMapping | None = None, + ) -> SExpr: + """Extract back to SExpr format using any converter class.""" + expr = _convert_graph_to_sexpr( + self, + original_sexpr, + converter_class=converter_class, + memo=memo, + ) + return expr + + def egraph_extraction( egraph: EGraph, - rvsdg_sexpr, + cost_model: CostModel | None = None, *, - cost_model=None, - converter_class=EGraphToRVSDG, -): - gdct: EGraphJsonDict = json.loads( - egraph._serialize( - n_inline_leaves=0, split_primitive_outputs=False - ).to_json() - ) - [root] = get_graph_root(gdct) - root_eclass = gdct["nodes"][root]["eclass"] - - cost_model = CostModel() if cost_model is None else cost_model - extraction = Extraction(gdct, root_eclass, cost_model) - cost, exgraph = extraction.choose() - - expr = convert_to_rvsdg( - exgraph, - gdct, - rvsdg_sexpr, - root, - egraph, - converter_class=converter_class, - ) - return cost, expr + stats: dict[str, object] | None = None, +) -> Extraction: + """Extract from an egraph and return an Extraction object. + + The returned Extraction can be used to: + - .compute() to compute costs and return ExtractionResult + - .extract_graph_root() to extract using auto-detected GraphRoot + - .extract_enode(root) to extract from a specific enode + - .extract_eclass(eclass) to extract from an equivalence class + """ + return Extraction.from_egraph(egraph, cost_model, stats=stats) -def convert_to_rvsdg( - exgraph: nx.MultiDiGraph, - gdct: EGraphJsonDict, - rvsdg_sexpr, - root: str, - egraph: EGraph, +def _convert_graph_to_sexpr( + result: ExtractionResult, + original_sexpr, *, converter_class, -): + memo, +) -> SExpr: # Get declarations so we have named fields - state = egraph._state + state = result.extraction.egraph._state decls = state.__egg_decls__ # Do the conversion back into RVSDG - node_iterator = list(nx.dfs_postorder_nodes(exgraph, source=root)) + node_iterator = list( + nx.dfs_postorder_nodes(result.graph, source=result.root) + ) def egg_fn_to_arg_names(egg_fn: str) -> tuple[str, ...]: for ref in state.egg_fn_to_callable_refs[egg_fn]: @@ -138,25 +158,33 @@ def iterator(node_iter): # Get children nodes in order children = [ (data["label"], child) - for _, child, data in exgraph.out_edges(node, data=True) + for _, child, data in result.graph.out_edges(node, data=True) ] if children: children.sort() _, children = zip(*children) # extract argument names - kind, _, egg_fn = node.split("-") - match kind: - case "primitive": - pass - case "function": - # TODO: put this into the converter_class - arg_names = egg_fn_to_arg_names(egg_fn) - children = dict(zip(arg_names, children, strict=True)) - case _: - raise NotImplementedError(f"kind is {kind!r}") - yield node, children - - conversion = converter_class(gdct, rvsdg_sexpr, egg_fn_to_arg_names) + if node == "common_root": + yield node, children + else: + kind, _, egg_fn = node.split("-") + match kind: + case "primitive": + pass + case "function": + # TODO: put this into the converter_class + arg_names = egg_fn_to_arg_names(egg_fn) + children = dict(zip(arg_names, children, strict=True)) + case _: + raise NotImplementedError(f"kind is {kind!r}") + yield node, children + + conversion = converter_class( + result.extraction.graph_json, + original_sexpr, + egg_fn_to_arg_names, + memo=memo, + ) return conversion.run(iterator(node_iterator)) @@ -184,15 +212,26 @@ def render_extraction_graph(G: nx.MultiDiGraph, filename: str): class Extraction: + graph_json: EGraphJsonDict nodes: dict[str, Node] node_types: dict[str, str] root_eclass: str cost_model: CostModel + egraph: EGraph + _selections: dict[str, Bucket] | None + _computed: bool _DEBUG = False - def __init__(self, graph_json: EGraphJsonDict, root_eclass, cost_model): - self.root_eclass = root_eclass + def __init__( + self, + graph_json: EGraphJsonDict, + cost_model, + egraph: EGraph, + stats: dict[str, Any] | None = None, + ): + self.graph_json = graph_json + self.root_eclass = "common_root" self.class_data = defaultdict(set) self.nodes = {k: Node(**v) for k, v in graph_json["nodes"].items()} self.node_types = { @@ -202,25 +241,93 @@ def __init__(self, graph_json: EGraphJsonDict, root_eclass, cost_model): for k, node in self.nodes.items(): self.class_data[node.eclass].add(k) self.cost_model = cost_model + self.egraph = egraph + self._selections = None + self._computed = False + self.stats = stats or {} + + @classmethod + def from_egraph( + cls, + egraph: EGraph, + cost_model: CostModel | None = None, + *, + stats: dict[str, object] | None = None, + ) -> "Extraction": + """Create an Extraction object from an egraph. + + This replaces the standalone egraph_extraction function. + """ + gdct: EGraphJsonDict = json.loads( + egraph._serialize( + n_inline_leaves=0, split_primitive_outputs=False + ).to_json() + ) + if stats is not None: + stats["num_enodes"] = len(gdct["nodes"]) + stats["num_eclasses"] = len(gdct["class_data"]) + + cost_model = CostModel() if cost_model is None else cost_model + return cls(gdct, cost_model, egraph, stats) + + def _create_common_root( + self, G: nx.DiGraph, eclassmap: dict[str, set[str]] + ) -> str: + """Create a common root node and add it to the graph. + + Args: + G: The eclass dependency graph + eclassmap: Mapping from eclass names to sets of node names + + Returns: + The name of the common root node + """ + # Get all nodes with in-degree of 0 + common_root = "common_root" + + # Do the conversion back into RVSDG + root_eclasses = [ + node + for node, in_degree in G.in_degree() + if in_degree == 0 + and self.graph_json["class_data"][node]["type"] != "Unit" + ] + G.add_node(common_root, shape="rect") + for n in root_eclasses: + G.add_edge(common_root, n) + + self.nodes[common_root] = Node( + children=[next(iter(eclassmap[ec])) for ec in root_eclasses], + cost=0.0, + eclass="common_root", + op="common_root", + subsumed=False, + ) + + return common_root def _compute_cost( self, max_iter=1000, max_no_progress=100, epsilon=1e-6, - ) -> dict[str, Bucket]: + ) -> tuple[dict[str, Bucket], int]: """ Uses dynamic programming with iterative cost propagation Args: - max_iter (int, optional): Maximum number of iterations to compute - costs. Defaults to 10000. max_no_progress (int, optional): Maximum - iterations without cost improvement. Defaults to 500. + max_iter (int, optional): + Maximum number of iterations to compute costs. + Defaults to 10000. + max_no_progress (int, optional): + Maximum iterations without cost improvement. + Defaults to 500. Returns: - dict[str, Bucket]: A mapping of equivalence classes to their lowest - cost representations. - + tuple[dict[str, Bucket], int]: A tuple containing: + - A mapping of equivalence classes to their lowest cost + representations + - The number of rounds (iterations) that were performed Performance notes @@ -258,6 +365,9 @@ def _compute_cost( for child in enode.children: G.add_edge(k, nodes[child].eclass) + # Create common root node + common_root = self._create_common_root(G, eclassmap) + # Get per-node cost function cm = self.cost_model nodecostmap: dict[str, CostFunc] = {} @@ -265,13 +375,16 @@ def _compute_cost( if k not in eclassmap: node = nodes[k] children_eclasses = [nodes[c].eclass for c in node.children] - nodecostmap[k] = cm.get_cost_function( - nodename=k, - op=node.op, - ty=self.node_types[k], - cost=node.cost, - children=children_eclasses, - ) + if k == "common_root": + nodecostmap[k] = cm.get_simple(1) + else: + nodecostmap[k] = cm.get_cost_function( + nodename=k, + op=node.op, + ty=self.node_types[k], + cost=node.cost, + children=children_eclasses, + ) if self._DEBUG: render_extraction_graph(G, "eclass") @@ -289,7 +402,7 @@ def _compute_cost( # Use BFS layers to estimate topological sort topo_ordered = [] - for layer in nx.bfs_layers(G, self.root_eclass): + for layer in nx.bfs_layers(G, [common_root]): topo_ordered += layer def propagate_cost(state_tracker): @@ -304,7 +417,6 @@ def propagate_cost(state_tracker): for k in reversed(topo_ordered): if k not in eclassmap: node = nodes[k] - cost_dag = dagcost.compute_cost(k) cost = sum(cost_dag.values()) selections[node.eclass].put(cost, k) @@ -325,7 +437,10 @@ def propagate_cost(state_tracker): costchanged.update(state_tracker) if costchanged.converged(): # root score is computed? - if math.isfinite(state_tracker.current[self.root_eclass]): + if all( + math.isfinite(state_tracker.current[root]) + for root in [common_root] + ): break # root score is missing? @@ -338,24 +453,62 @@ def propagate_cost(state_tracker): else: last_changed_i = round_i - return selections + return dict(**selections), round_i + + def compute(self) -> Self: + """Compute extraction costs (idempotent).""" + if not self._computed: + self._selections, round_i = self._compute_cost() + self.stats["extraction_iteration_count"] = round_i + self._computed = True + + return self + + def iter_graph_root(self) -> Iterator[str]: + return iter(self.nodes["common_root"].children) + + def extract_enode(self, root: str) -> ExtractionResult: + """Extract starting from an enode (specific node).""" - def choose(self) -> tuple[float, nx.MultiDiGraph]: - selections = self._compute_cost() + root_eclass = self.nodes[root].eclass + return self._do_extract(root_eclass) + def extract_eclass(self, root_eclass: str) -> ExtractionResult: + """Extract starting from an equivalence class.""" + + return self._do_extract(root_eclass) + + def extract_graph_root(self) -> ExtractionResult: + """Extract using the automatically detected GraphRoot.""" + + [root] = get_graph_root(self.graph_json) + root_eclass = self.nodes[root].eclass + return self._do_extract(root_eclass) + + def extract_common_root(self) -> ExtractionResult: + return self._do_extract("common_root") + + def _do_extract(self, root_eclass: str) -> ExtractionResult: + """Internal method to perform the actual extraction.""" + if not self._computed: + self.compute() + + selections = self._selections + assert selections is not None nodes = self.nodes - chosen_root, rootcost = selections[self.root_eclass].best() + chosen_root, rootcost = selections[root_eclass].best() # make selected graph G = nx.MultiDiGraph() todolist = [chosen_root] + visited = set() while todolist: cur = todolist.pop() if cur in visited: continue visited.add(cur) - + G.add_node(cur) for i, u in enumerate(nodes[cur].children): child_eclass = nodes[u].eclass child_key, cost = selections[child_eclass].best() @@ -364,7 +517,9 @@ def choose(self) -> tuple[float, nx.MultiDiGraph]: if self._DEBUG: render_extraction_graph(G, "chosen") - return rootcost, G + + # Create a new ExtractionResult that holds the extracted data + return ExtractionResult(self, chosen_root, rootcost, G) @dataclass @@ -413,6 +568,8 @@ class SubgraphCost: "Stores computed node costs for reuse." _stats: _SubgraphCostStats = field(default_factory=_SubgraphCostStats) "Tracks cache hits and misses." + _visited_dag_eclass: list[str] = field(default_factory=list) + "Tracks Children DAG path to avoid recursion" def compute_cost(self, nodename: str) -> dict[str, float]: if (cc := self._cache.get(nodename)) is None: @@ -463,14 +620,21 @@ def _compute_cost(self, nodename: str) -> dict[str, float]: return costs def _compute_choice(self, eclass: str) -> dict[str, float]: - selections = self.selections - - choices = selections[eclass] - if not choices: + if eclass in self._visited_dag_eclass: + # Avoid recursion return {eclass: MAX_COST} - best = choices.best() - - return self.compute_cost(best.name) + self._visited_dag_eclass.append(eclass) + try: + selections = self.selections + + choices = selections[eclass] + if not choices: + return {eclass: MAX_COST} + best = choices.best() + + return self.compute_cost(best.name) + finally: + self._visited_dag_eclass.pop() class ExtractionError(Exception): diff --git a/sealir/eqsat/rvsdg_extract_details.py b/sealir/eqsat/rvsdg_extract_details.py index 3282b8a..bccbf2e 100644 --- a/sealir/eqsat/rvsdg_extract_details.py +++ b/sealir/eqsat/rvsdg_extract_details.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Iterable, Iterator +from typing import Iterable, Iterator, MutableMapping from sealir import ase from sealir.rvsdg import Grammar @@ -14,16 +14,49 @@ class Data: ... +_memo_elem_type = ase.value_type | tuple + + +class _TypeCheckedDict(MutableMapping): + def __init__(self): + self._data = {} + + def __getitem__(self, key): + return self._data[key] + + def __setitem__(self, key, value): + self._check_setitem(key, value) + self._data[key] = value + + def __delitem__(self, key): + del self._data[key] + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + def _check_setitem(self, key, value): + if not isinstance(value, _memo_elem_type): + raise TypeError(f"{type(value)} :: {value}") + + class EGraphToRVSDG: allow_dynamic_op = False + unknown_use_generic = False grammar = Grammar def __init__( - self, gdct: EGraphJsonDict, rvsdg_sexpr: ase.SExpr, egg_fn_to_arg_names + self, + gdct: EGraphJsonDict, + rvsdg_sexpr: ase.SExpr, + egg_fn_to_arg_names, + memo: MutableMapping | None, ): self.rvsdg_sexpr = rvsdg_sexpr self.gdct = gdct - self.memo = {} + self.memo = memo if memo is not None else _TypeCheckedDict() self.egg_fn_to_arg_names = egg_fn_to_arg_names def run(self, node_and_children): @@ -31,12 +64,14 @@ def run(self, node_and_children): with self.rvsdg_sexpr._tape as tape: grm = self.grammar(tape) for key, child_keys in node_and_children: - try: - last = memo[key] = self.handle(key, child_keys, grm) - except Exception as e: - e.add_note(f"Extracting: {key}, {child_keys}") - raise - + if key in memo: + last = memo[key] + else: + try: + last = memo[key] = self.handle(key, child_keys, grm) + except Exception as e: + e.add_note(f"Extracting: {key}, {child_keys}") + raise return last def lookup_sexpr(self, uid: int) -> ase.SExpr: @@ -70,6 +105,7 @@ def filter_by_type( def dispatch(self, key: str, grm: Grammar): if key in self.memo: return self.memo[key] + assert False, "nothing should use this anymroe" node = self.gdct["nodes"][key] child_keys = node["children"] for k in child_keys: @@ -88,6 +124,15 @@ def get_children(self, key): def handle( self, key: str, child_keys: list[str] | dict[str, str], grm: Grammar ): + if key == "common_root": + # legalize child + values = [] + for k in child_keys: + val = self.memo[k] + if isinstance(val, ase.SExpr): + values.append(val) + return grm.write(rg.Rootset(tuple(values))) + allow_dynamic_op = self.allow_dynamic_op nodes = self.gdct["nodes"] @@ -104,7 +149,7 @@ def get_children() -> dict | list: return [memo[v] for v in child_keys] if key.startswith("primitive-"): - return self.handle_primitive(node_type, node, get_children()) + return self.handle_primitive(node_type, node, get_children(), grm) elif key.startswith("function-"): op = node["op"] children = get_children() @@ -114,9 +159,6 @@ def get_children() -> dict | list: attrs = self.handle_region_attributes(key, grm) return grm.write(rg.RegionBegin(inports=ins, attrs=attrs)) case "Term", children: - extended_handle = self.handle_Term(op, children, grm) - if extended_handle is not NotImplemented: - return extended_handle match op, children: case "GraphRoot", {"t": term}: return term @@ -222,9 +264,13 @@ def get_children() -> dict | list: rg.Unpack(val=regionbegin, idx=idx) ) case _: - raise NotImplementedError( - f"invalid Term: {node_type}, {children}" + extended_handle = self.handle_Term( + op, children, grm ) + if extended_handle is not NotImplemented: + return extended_handle + return self.handle_unknown(key, op, children, grm) + case "TermList", {"terms": terms}: return tuple(terms) case "PortList", {"ports": ports}: @@ -256,13 +302,14 @@ def get_children() -> dict | list: res = handler(key, op, children, grm) if res is not NotImplemented: return res - raise NotImplementedError( - f"function of: {op!r} :: {node_type}, {children}" - ) + + return self.handle_unknown(key, op, children, grm) else: raise NotImplementedError(key) - def handle_primitive(self, node_type: str, node, children: tuple): + def handle_primitive( + self, node_type: str, node, children: tuple, grm: Grammar + ): match node_type: case "String": unquoted = node["op"][1:-1] @@ -280,13 +327,20 @@ def handle_primitive(self, node_type: str, node, children: tuple): case "f64": return float(node["op"]) case "Vec_Term": - return children + return tuple(children) case "Vec_Port": - return children + return tuple(children) case "Vec_String": - return children + return tuple(children) case _: - raise NotImplementedError(f"primitive of: {node_type}") + if node_type.startswith("Vec_"): + return grm.write( + rg.GenericList( + name=node_type, children=tuple(children) + ) + ) + else: + raise NotImplementedError(node_type) def handle_Value(self, op: str, children: dict | list, grm: Grammar): match op, children: @@ -453,8 +507,50 @@ def handle_Py_Term(self, op: str, children: dict | list, grm: Grammar): operands=operands, ) ) + case "Py_Tuple", {"elems": tuple(elems)}: + return grm.write( + rg.PyTuple( + elems=elems, + ) + ) + case "Py_SliceIO", { + "io": io, + "lower": lower, + "upper": upper, + "step": step, + }: + return grm.write( + rg.PySlice(io=io, lower=lower, upper=upper, step=step) + ) + case "Py_SubscriptIO", {"io": io, "obj": obj, "index": index}: + return grm.write(rg.PySubscript(io=io, value=obj, index=index)) case _: return NotImplemented def handle_region_attributes(self, key: str, grm: Grammar): return grm.write(rg.Attrs(())) + + def handle_unknown( + self, key: str, op: str, children: dict | list, grm: Grammar + ): + if self.unknown_use_generic: + return self.handle_generic(key, op, children, grm) + else: + nodes = self.gdct["nodes"] + node = nodes[key] + eclass = node["eclass"] + node_type = self.gdct["class_data"][eclass]["type"] + raise NotImplementedError(f"{node_type}: {key} - {op}, {children}") + + def handle_generic( + self, key: str, op: str, children: dict | list, grm: Grammar + ): + assert isinstance(children, dict) + # flatten children.values() into SExpr + values = [] + for k, v in children.items(): + if isinstance(v, tuple): + values.append(grm.write(rg.GenericList(name=k, children=v))) + else: + values.append(v) + return grm.write(rg.Generic(name=str(op), children=tuple(values))) diff --git a/sealir/grammar.py b/sealir/grammar.py index 9ae41b2..2fa0c5f 100644 --- a/sealir/grammar.py +++ b/sealir/grammar.py @@ -228,8 +228,10 @@ def __getattr__( return super().__getattribute__(name) try: idx = self._slots[name] - except IndexError: - raise AttributeError(name) + except KeyError: + raise AttributeError( + f"{self._head} doesn't have attribute {name!r}" + ) if idx + 1 == len(self._rulety._fields): last_fd = self._rulety._fields[-1] diff --git a/sealir/llvm_pyapi_backend.py b/sealir/llvm_pyapi_backend.py index afca1fe..2f90e0a 100644 --- a/sealir/llvm_pyapi_backend.py +++ b/sealir/llvm_pyapi_backend.py @@ -180,8 +180,8 @@ def get_region_args(): return SSAValue(builder.function.args[idx]) case rg.Unpack(val=source, idx=int(idx)): - ports: PackedValues = yield source - return ports[idx] + packedvalues: PackedValues = yield source + return packedvalues[idx] case rg.PyBinOpPure(op=op, lhs=lhs, rhs=rhs): lhsval = (yield lhs).value @@ -376,7 +376,7 @@ def get_region_args(): # end loop builder.position_at_end(bb_endloop) # Returns the value from the loop body because this is a tail loop - return loopout_values + return PackedValues.make(*loopout_values) case rg.PyForLoop( iter_arg_idx=int(iter_arg_idx), diff --git a/sealir/rvsdg/evaluating.py b/sealir/rvsdg/evaluating.py index 95339d0..ef88c35 100644 --- a/sealir/rvsdg/evaluating.py +++ b/sealir/rvsdg/evaluating.py @@ -7,7 +7,7 @@ from contextlib import contextmanager from dataclasses import dataclass from pprint import pprint -from typing import Any, Sequence, Type, TypeAlias +from typing import Any, Mapping, Sequence, Type, TypeAlias from sealir import ase, rvsdg from sealir.rvsdg import grammar as rg @@ -57,10 +57,10 @@ def evaluate( callargs: tuple, callkwargs: dict, *, - init_scope: dict | None = None, + init_scope: Mapping | None = None, init_state: ase.TraverseState | None = None, init_memo: dict | None = None, - global_ns: dict | None = None, + global_ns: Mapping | None = None, dbginfo: rvsdg.SourceDebugInfo, ): stack: list[dict[str, Any]] = [{}] @@ -81,7 +81,7 @@ def push(callargs): finally: stack.pop() - def scope() -> dict[str, Any]: + def scope() -> Mapping[str, Any]: return stack[-1] def get_region_args() -> list[Any]: diff --git a/sealir/rvsdg/grammar.py b/sealir/rvsdg/grammar.py index c2d8e50..ce84886 100644 --- a/sealir/rvsdg/grammar.py +++ b/sealir/rvsdg/grammar.py @@ -11,6 +11,20 @@ class _Root(grammar.Rule): pass +class Rootset(_Root): + roots: tuple[SExpr, ...] + + +class Generic(_Root): + name: str + children: tuple[SExpr, ...] + + +class GenericList(_Root): + name: str + children: tuple[SExpr, ...] + + class Loc(_Root): filename: str line_first: int @@ -159,6 +173,20 @@ class PySubscript(_Root): index: SExpr +class PySetItem(_Root): + io: SExpr + obj: SExpr + index: SExpr + value: SExpr + + +class PySlice(_Root): + io: SExpr + lower: SExpr + upper: SExpr + step: SExpr + + class PyUnaryOp(_Root): op: str io: SExpr diff --git a/sealir/rvsdg/restructuring.py b/sealir/rvsdg/restructuring.py index 8c7facd..0b3db9f 100644 --- a/sealir/rvsdg/restructuring.py +++ b/sealir/rvsdg/restructuring.py @@ -434,11 +434,11 @@ def unpack_pystr(sexpr: SExpr) -> str | None: def unpack_pyast_name(sexpr: SExpr) -> str: - assert sexpr._head == "PyAst_Name" + assert sexpr._head == "PyAst_Name", sexpr return cast(str, sexpr._args[0]) -def is_directive(text: str) -> str: +def is_directive(text: str) -> bool: return text.startswith("#file:") or text.startswith("#loc:") @@ -647,6 +647,15 @@ def get_loopvar(ports): res = yield rval tar: SExpr + if len(targets) == 1 and targets[0]._head == "PyAst_Subscript": + [lhs, indices, loc] = targets[0]._args + lhs = yield lhs + indices = yield indices + rval = yield rval + setitem = rg.PySetItem( + io=ctx.load_io(), obj=lhs, index=indices, value=rval + ) + return ctx.insert_io_node(setitem) if ( len(targets) == 1 and unpack_pyast_name(targets[0]) == internal_prefix("_") @@ -775,6 +784,19 @@ def get_loopvar(ports): rg.PySubscript(io=ctx.load_io(), value=value, index=index) ) + case ("PyAst_Slice", (lower, upper, step, interloc)): + lower_val = yield lower + upper_val = yield upper + step_val = yield step + return ctx.insert_io_node( + rg.PySlice( + io=ctx.load_io(), + lower=lower_val, + upper=upper_val, + step=step_val, + ) + ) + case ("PyAst_Pass", (interloc,)): return @@ -1041,6 +1063,7 @@ def formatter(expr: SExpr, state: ase.TraverseState): for arg in expr._args: if isinstance(arg, SExpr): text = yield arg + assert text is not None, arg else: text = repr(arg) argrefs.append(text) diff --git a/sealir/rvsdg/scfg_to_sexpr.py b/sealir/rvsdg/scfg_to_sexpr.py index 9f1262a..bb2a182 100644 --- a/sealir/rvsdg/scfg_to_sexpr.py +++ b/sealir/rvsdg/scfg_to_sexpr.py @@ -176,6 +176,16 @@ def visit_Subscript(self, node: ast.Subscript) -> SExpr: self.get_loc(node), ) + def visit_Slice(self, node: ast.Slice) -> SExpr: + none = self._tape.expr("PyAst_None", self.get_loc(node)) + return self._tape.expr( + "PyAst_Slice", + self.visit(node.lower) if node.lower is not None else none, + self.visit(node.upper) if node.upper is not None else none, + self.visit(node.step) if node.step is not None else none, + self.get_loc(node), + ) + def visit_Call(self, node: ast.Call) -> SExpr: posargs = self._tape.expr( "PyAst_callargs_pos", *map(self.visit, node.args) diff --git a/sealir/tests/test_cost_extraction.py b/sealir/tests/test_cost_extraction.py index 4fd3855..f436205 100644 --- a/sealir/tests/test_cost_extraction.py +++ b/sealir/tests/test_cost_extraction.py @@ -1,5 +1,3 @@ -import json - from egglog import ( EGraph, Expr, @@ -12,12 +10,7 @@ ) from sealir.eqsat.rvsdg_eqsat import GraphRoot -from sealir.eqsat.rvsdg_extract import ( - CostModel, - EGraphJsonDict, - Extraction, - get_graph_root, -) +from sealir.eqsat.rvsdg_extract import CostModel, egraph_extraction class Term(Expr): @@ -66,22 +59,6 @@ def simplify_pow(x: Term, i: i64): yield rewrite(Pow(x, 1)).to(x) -def _extraction(egraph, cost_model=None): - # TODO move this into rvsdg_extract - gdct: EGraphJsonDict = json.loads( - egraph._serialize( - n_inline_leaves=0, split_primitive_outputs=False - ).to_json() - ) - [root] = get_graph_root(gdct) - root_eclass = gdct["nodes"][root]["eclass"] - - cost_model = cost_model or CostModel() - _extraction = Extraction(gdct, root_eclass, cost_model) - cost, exgraph = _extraction.choose() - return cost, exgraph - - def _flatten_multidigraph(graph): def format_node(node): # Get all successors (children) of the node @@ -90,9 +67,17 @@ def format_node(node): return node return tuple([node, *(format_node(child) for child in children)]) - # Get all nodes with no incoming edges (roots) - roots = [n for n in graph.nodes if graph.in_degree(n) == 0] - return [format_node(root) for root in roots] + roots = [k for k in graph.nodes if k.endswith("GraphRoot")] + return [format_node(x) for x in roots] + + +def _extraction(egraph): + extraction = egraph_extraction(egraph, cost_model=MyCostModel()) + # Use the new explicit method for extracting with auto-detected root + result = extraction.extract_graph_root() + exgraph = result.graph + [extracted] = _flatten_multidigraph(exgraph) + return result.cost, extracted def test_cost_duplicated_term(): @@ -102,8 +87,7 @@ def test_cost_duplicated_term(): egraph = EGraph() egraph.register(GraphRoot(expr)) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - [extracted] = _flatten_multidigraph(exgraph) + cost, extracted = _extraction(egraph) print("extracted:", extracted) # t1 = function-0-Term___init__(primitive-String-2684354568) # 1 ------------^ @@ -142,12 +126,18 @@ def test_simplify_pow_2(): egraph = EGraph() egraph.register(GraphRoot(expr)) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - assert cost == 14 + extraction = egraph_extraction(egraph, cost_model=MyCostModel()) + # Use the new explicit method for extracting from default root + result = extraction.extract_graph_root() + assert result.cost == 14 egraph.run(simplify_pow.saturate()) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - [extracted] = _flatten_multidigraph(exgraph) + + extraction = egraph_extraction(egraph, cost_model=MyCostModel()) + result = extraction.extract_graph_root() + cost = result.cost + + [extracted] = _flatten_multidigraph(result.graph) print("extracted:", extracted) # t1 = function-1-Term___init__(primitive-String-2684354568) @@ -176,11 +166,16 @@ def test_simplify_pow_3(): egraph = EGraph() egraph.register(GraphRoot(expr)) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - assert cost == 14 + extraction = egraph_extraction(egraph, cost_model=MyCostModel()) + result = extraction.extract_graph_root() + exgraph = result.graph + assert result.cost == 14 egraph.run(simplify_pow.saturate()) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) + extraction = egraph_extraction(egraph, cost_model=MyCostModel()) + result = extraction.extract_graph_root() + exgraph = result.graph + cost = result.cost [extracted] = _flatten_multidigraph(exgraph) print("extracted:", extracted) @@ -225,8 +220,7 @@ def test_loop_multiplier(): egraph = EGraph() egraph.register(GraphRoot(expr)) egraph.run(simplify_pow.saturate()) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - [extracted] = _flatten_multidigraph(exgraph) + cost, extracted = _extraction(egraph) print("extracted", extracted) # Pow(A, 3) = 4 (see test_simplify_pow_3) # Loop(A, 3) = (4 * 23 + 13) + 4 @@ -254,8 +248,7 @@ def test_const_factor(): expr = Repeat(A, 13) egraph = EGraph() egraph.register(GraphRoot(expr)) - cost, exgraph = _extraction(egraph, cost_model=MyCostModel()) - [extracted] = _flatten_multidigraph(exgraph) + cost, extracted = _extraction(egraph) dagcost = 2 + 1 + 1 # ^ Term("A") # ^ literal 13 diff --git a/sealir/tests/test_rvsdg_egglog_from_source.py b/sealir/tests/test_rvsdg_egglog_from_source.py index 9ce0f5e..9d8893d 100644 --- a/sealir/tests/test_rvsdg_egglog_from_source.py +++ b/sealir/tests/test_rvsdg_egglog_from_source.py @@ -25,8 +25,10 @@ def define_egraph(egraph, func): if verbose: print(egraph.extract(root)) + return root cost, extracted = middle_end(rvsdg_expr, define_egraph) + if verbose: print("Extracted from EGraph".center(80, "=")) print("cost =", cost) diff --git a/sealir/tests/test_rvsdg_egraph_roundtrip.py b/sealir/tests/test_rvsdg_egraph_roundtrip.py index 50adf16..753bff0 100644 --- a/sealir/tests/test_rvsdg_egraph_roundtrip.py +++ b/sealir/tests/test_rvsdg_egraph_roundtrip.py @@ -3,7 +3,11 @@ from sealir import rvsdg from sealir.eqsat.rvsdg_convert import egraph_conversion from sealir.eqsat.rvsdg_eqsat import GraphRoot -from sealir.eqsat.rvsdg_extract import egraph_extraction +from sealir.eqsat.rvsdg_extract import ( + EGraphToRVSDG, + egraph_extraction, + get_graph_root, +) from sealir.llvm_pyapi_backend import llvm_codegen @@ -16,7 +20,7 @@ def frontend(fn): return rvsdg_expr, dbginfo -def middle_end(rvsdg_expr, apply_to_egraph, cost_model=None): +def middle_end(rvsdg_expr, apply_to_egraph, cost_model=None, stats=None): """The middle end encode the RVSDG into a EGraph to apply rewrite rules. After that, it is extracted back into RVSDG. """ @@ -29,10 +33,10 @@ def middle_end(rvsdg_expr, apply_to_egraph, cost_model=None): apply_to_egraph(egraph, func) # Extraction - cost, extracted = egraph_extraction( - egraph, rvsdg_expr, cost_model=cost_model - ) - return cost, extracted + extraction = egraph_extraction(egraph, cost_model=cost_model, stats=stats) + result = extraction.extract_graph_root() + expr = result.convert(rvsdg_expr, EGraphToRVSDG) + return result.cost, expr def compiler_pipeline(fn, args, *, verbose=False): @@ -50,10 +54,13 @@ def display_egraph(egraph: EGraph, func): return root - cost, extracted = middle_end(rvsdg_expr, display_egraph) + stats = {} + cost, extracted = middle_end(rvsdg_expr, display_egraph, stats=stats) + print("Extracted from EGraph".center(80, "=")) print("cost =", cost) print(rvsdg.format_rvsdg(extracted)) + print("stats:", stats) jt = llvm_codegen(rvsdg_expr) res = jt(*args)