diff --git a/sealir-tutorials/ch02_egraph_basic.py b/sealir-tutorials/ch02_egraph_basic.py index d1de0a2..2ded9b5 100644 --- a/sealir-tutorials/ch02_egraph_basic.py +++ b/sealir-tutorials/ch02_egraph_basic.py @@ -42,7 +42,11 @@ from sealir import rvsdg from sealir.eqsat.rvsdg_convert import egraph_conversion from sealir.eqsat.rvsdg_eqsat import GraphRoot -from sealir.eqsat.rvsdg_extract import egraph_extraction +from sealir.eqsat.rvsdg_extract import ( + CostModel, + EGraphToRVSDG, + egraph_extraction, +) # We'll be extending from chapter 1. from ch01_basic_compiler import ( @@ -113,14 +117,22 @@ def max_if_else(x, y): # efficiency. While all variants are functionally identical, # we are primarily interested in identifying the "best" one, # where "best" depends on context--—such as execution speed, code size, or -# energy efficiency. To address this, the `egraph_extraction()` function allows -# users to define custom cost models, tailoring the selection process to -# prioritize the variant that aligns with their specific optimization goals. +# energy efficiency. +# +# The extraction process involves three steps: +# 1. Create an extraction instance with `egraph_extraction()` using a cost model +# 2. Extract the graph root to get extraction results +# 3. Convert to s-expression using a converter class +# +# This 3-step approach allows users to define custom cost models, tailoring +# the selection process to prioritize the variant that aligns with their +# specific optimization goals. if __name__ == "__main__": help(egraph_extraction) # Here, we will use the default cost model, which is based on the node count. +# The extraction follows the 3-step process described above. class EGraphExtractionOutput(TypedDict): @@ -135,7 +147,10 @@ def pipeline_egraph_extraction( with pipeline_report.nest( "EGraph Extraction", default_expanded=True ) as report: - cost, extracted = egraph_extraction(egraph, rvsdg_expr) + extraction = egraph_extraction(egraph) + extresult = extraction.extract_graph_root() + extracted = extresult.convert(rvsdg_expr, EGraphToRVSDG) + cost = extresult.cost report.append("Cost", cost) report.append("Extracted", rvsdg.format_rvsdg(extracted)) return {"cost": cost, "extracted": extracted} diff --git a/sealir-tutorials/ch03_egraph_program_rewrites.py b/sealir-tutorials/ch03_egraph_program_rewrites.py index 2cee8d5..1d7972b 100644 --- a/sealir-tutorials/ch03_egraph_program_rewrites.py +++ b/sealir-tutorials/ch03_egraph_program_rewrites.py @@ -30,7 +30,16 @@ from __future__ import annotations -from egglog import EGraph, Ruleset, Unit, function, i64, rewrite, rule, ruleset +from egglog import ( + EGraph, + Schedule, + Unit, + function, + i64, + rewrite, + rule, + ruleset, +) from sealir import rvsdg from sealir.eqsat import rvsdg_eqsat from sealir.eqsat.rvsdg_eqsat import GraphRoot, Term, TermList @@ -46,18 +55,19 @@ from utils import IN_NOTEBOOK, Report, display # Next, we'll explore a new compiler pipeline designed with customizable -# rulesets. To enable this flexibility, we've introduced a `ruleset` argument, -# allowing you to tailor the pipeline's behavior to your specific needs. +# rule schedules. To enable this flexibility, we've introduced a `rule_schedule` +# argument, allowing you to tailor the pipeline's behavior to your specific needs. +# A rule schedule defines how and when rules are applied during saturation. def egraph_saturation( egraph: EGraph, egraph_root: GraphRoot, - ruleset: Ruleset, + rule_schedule: Schedule, pipeline_report=Report.Sink(), ) -> EGraphOutput: - # Apply the ruleset to the egraph - egraph.run(ruleset.saturate()) + # Apply the rule schedule to the egraph for saturation + egraph.run(rule_schedule) pipeline_report.append("EGraph Saturated", egraph) return {"egraph": egraph, "egraph_root": egraph_root} @@ -140,12 +150,15 @@ def ifelse_fold(a, b): else: return b - # Add our const-propagation rule to the basic rvsdg ruleset + # Add our const-propagation rule to the basic rvsdg ruleset. + # We use .saturate() to create a schedule that runs the rules to saturation. my_ruleset = rvsdg_eqsat.ruleset_rvsdg_basic | ruleset_const_propagate report = Report("Test", default_expanded=True) jt = compiler_pipeline( - fn=ifelse_fold, pipeline_report=report, ruleset=my_ruleset + fn=ifelse_fold, + pipeline_report=report, + rule_schedule=my_ruleset.saturate(), ).jit_func report.display() run_test(ifelse_fold, jt, (12, 34)) @@ -197,7 +210,9 @@ def ruleset_const_fold_if_else(a: Term, b: Term, c: Term, operands: TermList): report = Report("Test", default_expanded=True) jt = compiler_pipeline( - fn=ifelse_fold, pipeline_report=report, ruleset=my_ruleset + fn=ifelse_fold, + pipeline_report=report, + rule_schedule=my_ruleset.saturate(), ).jit_func report.display() run_test(ifelse_fold, jt, (12, 34)) diff --git a/sealir-tutorials/ch04_0_typeinfer_prelude.py b/sealir-tutorials/ch04_0_typeinfer_prelude.py index 05ad6ed..f730202 100644 --- a/sealir-tutorials/ch04_0_typeinfer_prelude.py +++ b/sealir-tutorials/ch04_0_typeinfer_prelude.py @@ -44,9 +44,11 @@ from sealir.eqsat.rvsdg_extract import ( CostModel, EGraphToRVSDG, + Extraction, egraph_extraction, ) from sealir.llvm_pyapi_backend import SSAValue +from sealir.rvsdg import grammar as rg from ch02_egraph_basic import ( BackendOutput, @@ -70,6 +72,11 @@ # - `converter_class` is for customizing EGraph-to-RVSDG conversion as we will be # introducing new RVSDG operations for typed operations. # - `cost_model` is for customizing the cost of the new operations. +# +# The extraction API has been updated to use the new 3-step process: +# 1. Create extraction with egraph_extraction() +# 2. Extract graph root to get extraction results +# 3. Convert s-expression with converter class def pipeline_egraph_extraction( @@ -82,11 +89,18 @@ def pipeline_egraph_extraction( with pipeline_report.nest( "EGraph Extraction", default_expanded=True ) as report: - cost, extracted = egraph_extraction( + # Step 1: Create extraction instance with custom cost model + extraction = egraph_extraction( egraph, + cost_model=cost_model, # <-------------- new + ) + # Step 2: Extract the graph root + extresult = extraction.extract_graph_root() + cost = extresult.cost + # Step 3: Extract s-expression with custom converter + extracted = extresult.convert( rvsdg_expr, converter_class=converter_class, # <---- new - cost_model=cost_model, # <-------------- new ) report.append("Cost", cost) report.append("Extracted", rvsdg.format_rvsdg(extracted)) @@ -131,6 +145,7 @@ def add_x_y(x, y): # We will start with the same ruleset as in chapter 3. +# Note that we now use .saturate() to create a rule schedule. basic_ruleset = rvsdg_eqsat.ruleset_rvsdg_basic | ruleset_const_propagate @@ -143,7 +158,7 @@ def add_x_y(x, y): report = Report("Compiler Pipeline", default_expanded=True) jt = compiler_pipeline( fn=add_x_y, - ruleset=basic_ruleset, + rule_schedule=basic_ruleset.saturate(), converter_class=EGraphToRVSDG, codegen_extension=None, cost_model=None, @@ -301,6 +316,19 @@ def handle_Term(self, op: str, children: dict | list, grm: Grammar): # Use parent's implementation for other terms. return super().handle_Term(op, children, grm) + def handle_Type( + self, key: str, op: str, children: dict | list, grm: Grammar + ): + + match op, children: + case "Type", {"name": str(typename)}: + if typename == "Int64": + return grm.write( + rg.Generic(name="Type", children=tuple([typename])) + ) + + raise NotImplementedError("handle_Type", op, children) + # The LLVM code-generation also needs an extension: @@ -352,7 +380,7 @@ def get_cost_function(self, nodename, op, ty, cost, children): report = Report("Compiler Pipeline", default_expanded=True) jt = compiler_pipeline( fn=add_x_y, - ruleset=typeinfer_ruleset, + rule_schedule=typeinfer_ruleset.saturate(), converter_class=ExtendEGraphToRVSDG, codegen_extension=codegen_extension, cost_model=MyCostModel(), @@ -387,7 +415,7 @@ def chained_additions(x, y): report = Report("Compiler Pipeline", default_expanded=True) jt = compiler_pipeline( fn=chained_additions, - ruleset=typeinfer_ruleset, + rule_schedule=typeinfer_ruleset.saturate(), converter_class=ExtendEGraphToRVSDG, codegen_extension=codegen_extension, cost_model=MyCostModel(), @@ -427,7 +455,7 @@ def ruleset_optimize_boxing(x: Term): report = Report("Compiler Pipeline", default_expanded=True) jt = compiler_pipeline( fn=chained_additions, - ruleset=optimized_ruleset, + rule_schedule=optimized_ruleset.saturate(), converter_class=ExtendEGraphToRVSDG, codegen_extension=codegen_extension, cost_model=MyCostModel(), diff --git a/sealir-tutorials/ch04_1_typeinfer_ifelse.py b/sealir-tutorials/ch04_1_typeinfer_ifelse.py index 6620fd1..e058556 100644 --- a/sealir-tutorials/ch04_1_typeinfer_ifelse.py +++ b/sealir-tutorials/ch04_1_typeinfer_ifelse.py @@ -29,6 +29,7 @@ from __future__ import annotations import ctypes +from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass from functools import partial @@ -56,6 +57,7 @@ rule, ruleset, set_, + subsume, union, vars_, ) @@ -151,10 +153,6 @@ def __or__(self, other: Type) -> Type: # Let's define some rules that will establish what is disallowed: -@function -def failed_to_unify(ty: Type) -> Unit: ... - - @ruleset def ruleset_type_basic( ta: Type, @@ -169,13 +167,6 @@ def ruleset_type_basic( yield rewrite(ta | tb).to(tb | ta) yield birewrite((ta | tb) | tc).to(ta | (tb | tc)) - # Identify errors - yield rule( - # If both sides are valid types and not equal, then fail - ty == ta | tb, - ne(ta).to(tb), # ta != tb - ).then(failed_to_unify(ty)) - if __name__ == "__main__": eg = EGraph() @@ -247,70 +238,18 @@ def getType(self) -> Type: # Type inference can fail, so we must provide a mechanism for reporting errors. # ### `ErrorMsg` -# We'll define a `ErrorMsg` class in the egraph to capture all the error -# message. The compilation will always start with `ErrorMsg.root()` in the -# EGraph. When type inference encounters an error, That root node will be -# merged with `ErroMsg.fail()` nodes. - - -class ErrorMsg(Expr): - @classmethod - def root(cls) -> ErrorMsg: - "The empty root" - ... - - @classmethod - def fail(cls, msg: String) -> ErrorMsg: - "A node for failure message" - ... +# We'll define a `ErrorMsg` class in the egraph to capture error messages. +# Error messages are created using the `ErrorMessage()` function, which takes +# a string message and creates an ErrorMsg node in the egraph. - @method(preserve=True) - def eval(self) -> tuple[str, tuple]: - """ - This is for converting the information in the EGraph back to - Python. This will parse the EGraph node to extract the message string. - """ - from egglog.builtins import ClassMethodRef, _extract_call - - call = _extract_call(self) - if isinstance(call.callable, ClassMethodRef): - assert call.callable.class_name == "ErrorMsg" - args = [self.__with_expr__(x).eval() for x in call.args] - return call.callable.method_name, tuple(args) - raise TypeError - - -# Helpers to process the error message - - -def get_error_message(err_info: tuple[str, tuple]) -> str: - "Helper to process the result of ErrorMsg.eval()" - match err_info: - case "fail", (msg,): - return msg - case _: - raise NotImplementedError +class ErrorMsg(Expr): ... -# For example -if __name__ == "__main__": - root = ErrorMsg.root() - eg = EGraph() - eg.register( - union(root).with_(ErrorMsg.fail("I failed")), - union(root).with_(ErrorMsg.fail("Failed again")), - ) - if IN_NOTEBOOK: - eg.display(graphviz=True) - msgs = eg.extract_multiple(root, n=3) - print(msgs) - for msg in msgs: - print(msg.eval()) - try: - print(get_error_message(msg.eval())) - except NotImplementedError: - print("no msg") +@function +def ErrorMessage(msg: StringLike) -> ErrorMsg: + "A node for failure message" + ... # ## Typing addition @@ -380,7 +319,7 @@ def ruleset_type_infer_add(): def setup_argtypes(*argtypes): def rule_gen(region): return [ - set_(TypedIns(region).arg(i).getType()).to(ty) + set_(_AttrRegionInputType(region, i).getType()).to(ty) for i, ty in enumerate(argtypes, start=1) ] @@ -412,16 +351,23 @@ def arg_rules( # Associate type variables to region inputs/outputs. -class TypedIns(Expr): - def __init__(self, region: Region): ... +class Attribute(Expr): ... - def arg(self, idx: i64Like) -> TypeVar: ... +@function +def _AttrRegionInputType(region: Region, argidx: i64) -> TypeVar: ... -class TypedOuts(Expr): - def __init__(self, region: Region): ... - def at(self, idx: i64Like) -> TypeVar: ... +@function +def _AttrRegionOutputType(region: Region, argidx: i64) -> TypeVar: ... + + +@function +def RegionInputType(region: Region, argidx: i64, typ: Type) -> Attribute: ... + + +@function +def RegionOutputType(region: Region, argidx: i64, typ: Type) -> Attribute: ... @ruleset @@ -435,7 +381,7 @@ def ruleset_region_types( # Propagate region types yield rule( # Inputs - typ == TypedIns(region).arg(idx), + typ == _AttrRegionInputType(region, idx), term == region.get(idx), ).then( union(TypeVar(term)).with_(typ), @@ -444,9 +390,27 @@ def ruleset_region_types( yield rule( # Outputs term == Term.RegionEnd(region=region, ports=portlist), - pv := portlist.getValue(idx), + portlist.getValue(idx), + ).then( + union(_AttrRegionOutputType(region, idx)).with_( + TypeVar(portlist.getValue(idx)) + ) + ) + + +@ruleset +def ruleset_annotate_types(ty: Type, region: Region, idx: i64): + yield rule( + ty == _AttrRegionInputType(region, idx).getType(), + ).then( + subsume(_AttrRegionInputType(region, idx)), + RegionInputType(region, idx, ty), + ) + yield rule( + ty == _AttrRegionOutputType(region, idx).getType(), ).then( - union(TypedOuts(region).at(idx)).with_(TypeVar(pv)), + subsume(_AttrRegionOutputType(region, idx)), + RegionOutputType(region, idx, ty), ) @@ -458,13 +422,13 @@ def ruleset_region_types( | ruleset_type_basic | ruleset_type_infer_add | setup_argtypes(TypeInt64, TypeInt64) - ) + ).saturate() + ruleset_annotate_types report = Report("Compiler Pipeline", default_expanded=True) try: # this should raise a NotImplementedError because we haven't implemented # the conversions of `Nb_Add_Int64` back into RVSDG. cres = _ch03_compiler_pipeline( - fn=example_0, ruleset=rules, pipeline_report=report + fn=example_0, rule_schedule=rules, pipeline_report=report ) except NotImplementedError as e: # Expect the error to be raised because we haven't implemented the @@ -489,7 +453,7 @@ def ruleset_region_types( # conversions of `Nb_Add_Int64` back into RVSDG. # # Observe in the egraph: -# - `Typedouts`, `TypedIns`, `Type.simple("Int64")` +# - `_AttrRegionOutputType`, `_AttrRegionInputType`, `Type.simple("Int64")` # ## Extend the rest of the compiler @@ -501,27 +465,38 @@ def ruleset_region_types( class CompilationError(Exception): - pass + def __init__(self, *messages): + self.messages = messages + + def __str__(self): + buf = [""] + for msg in self.messages: + buf.append(" - " + str(msg)) + return "\n".join(buf) -def egraph_saturation_with_error_checking( +def egraph_saturation_with_debug( egraph: EGraph, egraph_root: GraphRoot, - ruleset: Ruleset, + rule_schedule: Ruleset, pipeline_debug: bool = False, pipeline_report=Report.Sink(), + display_egraph=False, ) -> EGraphOutput: with pipeline_report.nest("Egraph Saturation") as report: # Define graph root that points to the function - - # Define the empty root node for the error messages - errors = ErrorMsg.root() - egraph.let("errors", errors) if pipeline_debug: report.append("[debug] initial egraph", egraph) # Run all the rules until saturation - egraph.run(ruleset.saturate()) + runreport = egraph.run(rule_schedule) + report.append("saturation report", runreport.updated) + + if display_egraph: + egraph.display() + # from sealir.model_explorer.core import prepare_egraph, visualize_egraph + # visualize_egraph(egraph, filepath="debug") + # prepare_egraph(egraph, filepath="debug") if pipeline_debug: report.append("[debug] saturated egraph", egraph) @@ -529,52 +504,83 @@ def egraph_saturation_with_error_checking( "[debug] egglog.extract", egraph.extract(egraph_root) ) - # Use egglog's default extractor to get the error messages - errmsgs = map( - lambda x: x.eval(), egraph.extract_multiple(errors, n=10) - ) - errmsgs_filtered = [ - get_error_message((meth, args)) - for meth, args in errmsgs - if meth != "root" - ] - if errmsgs_filtered: - # Raise CompilationError if there are compiler errors - raise CompilationError("\n".join(errmsgs_filtered)) - return dict(egraph=egraph, egraph_root=egraph_root) +class EGraphExtractionOutputMore(EGraphExtractionOutput): + extracted_roots: list[SExpr] + + def pipeline_egraph_extraction( egraph, rvsdg_expr, converter_class, cost_model, pipeline_report=Report.Sink(), -) -> EGraphExtractionOutput: +) -> EGraphExtractionOutputMore: with pipeline_report.nest( "EGraph Extraction", default_expanded=True ) as report: try: - # This is the same as ch4.1 - cost, extracted = egraph_extraction( + # Step 1 of 3-step extraction: Create extraction instance with cost model + # This follows the same pattern as ch4.0 but with enhanced error handling + extraction = egraph_extraction( egraph, - rvsdg_expr, - converter_class=converter_class, cost_model=cost_model, ) except ExtractionError as e: raise CompilationError("extraction failed") from e - report.append("Extracted RVSDG", format_rvsdg(extracted)) + # Look up by types + grouped_by_type = defaultdict(set) + for k, v in extraction.node_types.items(): + grouped_by_type[v].add(k) + + memo = {} + # Process error message + errors = [] + for k in grouped_by_type["ErrorMsg"]: + msg = extraction.extract_enode(k).convert( + rvsdg_expr, + converter_class=converter_class, + memo=memo, + ) + match msg.name, msg.children: + case "ErrorMessage", [str(msg)]: + errors.append(msg) + case _: + raise NotImplementedError(msg) + + if errors: + raise CompilationError(*errors) + + # Steps 2 & 3 of extraction: Extract graph root and then s-expression + extresult = extraction.extract_graph_root() # Step 2 + cost = extresult.cost + extracted = extresult.convert( # Step 3 + rvsdg_expr, + converter_class=converter_class, + memo=memo, + ) + report.append("Extraction stats", extraction.stats) report.append("Extracted cost", cost) + report.append(f"Extracted Func", format_rvsdg(extracted)) + + extracted_roots = defaultdict(list) + for k in extraction.iter_graph_root(): + node = extraction.extract_enode(k).convert( + rvsdg_expr, converter_class, memo=memo + ) + extracted_roots[extraction.node_types[k]].append(node) - return dict(cost=cost, extracted=extracted) + return dict( + cost=cost, extracted=extracted, extracted_roots=extracted_roots + ) pipeline_middle_end = ( _ch03_compiler_pipeline.trunc("egraph_saturation") - .extend(egraph_saturation_with_error_checking) + .extend(egraph_saturation_with_debug) .extend(pipeline_egraph_extraction) ) @@ -588,10 +594,16 @@ class BackendOutput(TypedDict): @pipeline_middle_end.extend def pipeline_backend( - extracted, argtypes, backend, pipeline_report=Report.Sink() + extracted, + argtypes, + extracted_roots, + backend, + pipeline_report=Report.Sink(), ) -> BackendOutput: with pipeline_report.nest("Backend") as report: - module = backend.lower(extracted, argtypes) + module = backend.lower( + extracted, argtypes, extracted_roots["Attribute"] + ) report.append("Lowered module", module) return dict(module=module) @@ -737,8 +749,12 @@ def ruleset_propagate_typeof_ifelse( ), then_region.get(idx), ).then( - union(TypeVar(operands[idx])).with_(TypedIns(then_region).arg(idx)), - union(TypeVar(operands[idx])).with_(TypedIns(else_region).arg(idx)), + union(TypeVar(operands[idx])).with_( + _AttrRegionInputType(then_region, idx) + ), + union(TypeVar(operands[idx])).with_( + _AttrRegionInputType(else_region, idx) + ), ) @function @@ -804,11 +820,13 @@ class NbOp_Type(NbOp_Base): class NbOp_InTypeAttr(NbOp_Base): + region: SExpr idx: int type: NbOp_Type class NbOp_OutTypeAttr(NbOp_Base): + region: SExpr idx: int type: NbOp_Type @@ -861,38 +879,7 @@ class Grammar(grammar.Grammar): # Define attribute formating -def my_attr_format(attrs: rg.Attrs) -> str: - ins = {} - outs = {} - others = [] - for attr in attrs.attrs: - match attr: - case NbOp_InTypeAttr(idx=int(idx), type=NbOp_Type(name=str(name))): - ins[idx] = name - case NbOp_OutTypeAttr( - idx=int(idx), type=NbOp_Type(name=str(name)) - ): - outs[idx] = name - case _: - others.append(attr) - - def format(dct): - if len(dct): - hi = max(dct.keys()) - out = ", ".join(dct.get(i, "_") for i in range(hi + 1)) - return f"({out})" - else: - return "()" - - outbuf = [] - if ins or outs: - outbuf.append(format(ins) + "->" + format(outs)) - for other in others: - outbuf.append(ase.pretty_str(other)) - return ", ".join(outbuf) - - -format_rvsdg = partial(rvsdg.format_rvsdg, format_attrs=my_attr_format) +format_rvsdg = rvsdg.format_rvsdg # ### Extend EGraph to RVSDG @@ -900,70 +887,24 @@ def format(dct): class ExtendEGraphToRVSDG(EGraphToRVSDG): grammar = Grammar + unknown_use_generic = True - def handle_region_attributes(self, key: str, grm: Grammar): - - def search_equiv_calls(self_key: str): - nodes = self.gdct["nodes"] - ecl = nodes[self_key]["eclass"] - for k, v in nodes.items(): - children = v["children"] - if children and nodes[children[0]]["eclass"] == ecl: - yield k, v - - def get_types(key_arg): - typs = [] - for k, v in search_equiv_calls(key_arg): - for j in self.search_eclass_siblings(k): - op = self.gdct["nodes"][j]["op"] - if op.startswith("Type."): - typ = self.dispatch(j, grm) - typs.append(typ) - return typs - - attrs = [] - typedargs = list(self.search_calls(key, "TypedIns")) - if typedargs: - [typedarg] = typedargs - for key_arg in self.search_method_calls(typedarg, "arg"): - _k_self, k_idx = self.get_children(key_arg) - # get the idx in `.arg(idx)` - idx = self.dispatch(k_idx, grm) - typs = get_types(key_arg) - - if len(typs) == 1: - typ = typs[0] - attrs.append(grm.write(NbOp_InTypeAttr(idx=idx, type=typ))) - else: - resolved = list(map(ase.pretty_str, typs)) - assert len(typs) == 0, f"multiple types: {resolved}" - - typedouts = list(self.search_calls(key, "TypedOuts")) - if typedouts: - [typedout] = typedouts - for key_at in self.search_method_calls(typedout, "at"): - _k_self, k_idx = self.get_children(key_at) - idx = self.dispatch(k_idx, grm) - - typs = get_types(key_at) - if len(typs) == 1: - typ = typs[0] - attrs.append( - grm.write(NbOp_OutTypeAttr(idx=idx, type=typ)) - ) - else: - assert len(typs) == 0, "multiple types" - - return grm.write(rg.Attrs(tuple(attrs))) + def is_type_from_egraph(self, node) -> bool: + op = node["op"] + if op.startswith("Type."): + return True + return False def handle_Type( self, key: str, op: str, children: dict | list, grm: Grammar ): - assert op == "Type.simple" - match children: - case {"name": name}: + match op, children: + case "Type.simple", {"name": name}: return grm.write(NbOp_Type(name)) - raise NotImplementedError + case "· | ·", {"self": lhs, "other": rhs}: + raise NotImplementedError(f"cannot unify: {lhs} and {rhs}") + case _: + return NotImplemented def handle_Term(self, op: str, children: dict | list, grm: Grammar): match op, children: @@ -987,6 +928,29 @@ def handle_Term(self, op: str, children: dict | list, grm: Grammar): # Use parent's implementation for other terms. return super().handle_Term(op, children, grm) + def handle_Attribute( + self, key: str, op: str, children: dict | list, grm: Grammar + ): + match op, children: + case "RegionInputType", { + "region": region, + "argidx": int(idx), + "typ": typ, + }: + return grm.write( + NbOp_InTypeAttr(region=region, idx=idx, type=typ) + ) + case "RegionOutputType", { + "region": region, + "argidx": int(idx), + "typ": typ, + }: + return grm.write( + NbOp_OutTypeAttr(region=region, idx=idx, type=typ) + ) + case _: + return NotImplemented + # ### Define cost model # penalize Python operations (`Py_` prefix) @@ -999,9 +963,11 @@ def get_cost_function(self, nodename, op, ty, cost, children): return self.get_simple(1) elif op.startswith("Py_"): # Penalize Python operations - return self.get_simple(float("inf")) + return self.get_simple(float(1e20)) elif op.startswith("Nb_"): return self.get_simple(cost) + elif op.endswith("getType") and ty == "Type": + return self.get_simple(1e20) # Fallthrough to parent's cost function return super().get_cost_function(nodename, op, ty, cost, children) @@ -1017,40 +983,41 @@ def get_port_by_name(ports: Sequence[rg.Port], name: str): raise ValueError(f"{name!r} not found") -class Attributes: - _typedins: dict[int, NbOp_InTypeAttr] - _typedouts: dict[int, NbOp_OutTypeAttr] - - def __init__(self, attrs: rg.Attrs): +class AttributeParser: - ins = {} - outs = {} - for attr in attrs.attrs: - match attr: - case NbOp_InTypeAttr(idx=idx): - ins[idx] = attr - case NbOp_OutTypeAttr(idx=idx): - outs[idx] = attr + def __init__(self, roots: Sequence[SExpr]): + region_inputs = {} + region_outputs = {} + for node in roots: + match node: + case NbOp_InTypeAttr(region=rb, idx=int(idx), type=typ): + record = region_inputs.setdefault(rb, {}) + record[idx] = typ + case NbOp_OutTypeAttr(region=rb, idx=int(idx), type=typ): + record = region_outputs.setdefault(rb, {}) + record[idx] = typ case _: - raise ValueError(attr) + raise ValueError(node) - self._typedins = ins - self._typedouts = outs + self._typedins = region_inputs + self._typedouts = region_outputs - def get_output_attribute(self, idx: int) -> NbOp_OutTypeAttr | None: - return self._typedouts.get(idx) + def get_region_attr(self, rb: rg.RegionBegin) -> RegionAttribute: + return RegionAttribute(self._typedins[rb], self._typedouts[rb]) - def get_output_type(self, idx: int) -> NbOp_Type | None: - at = self._typedouts.get(idx) - if at is not None: - return at.type - return None - def get_return_type(self, regionend: rg.RegionEnd): - i, p = get_port_by_name(regionend.ports, rvsdg.internal_prefix("ret")) - if attr := self.get_output_attribute(i): - return attr.type - raise CompilationError("Missing return type") +class RegionAttribute: + def __init__(self, typedins, typedouts): + self._typedins = typedins + self._typedouts = typedouts + + def get_return_type(self, func: rg.Func): + i, _ = get_port_by_name(func.body.ports, rvsdg.internal_prefix("ret")) + try: + out = self._typedouts[i] + except KeyError: + raise CompilationError("Missing return type") from None + return out def num_input_types(self): return len(self._typedins) @@ -1060,7 +1027,10 @@ def num_output_types(self): def input_types(self): for idx in range(1, self.num_input_types() + 1): - yield self._typedins[idx].type + yield self._typedins[idx] + + def get_output_type(self, i): + return self._typedouts.get(i) # - @@ -1082,7 +1052,7 @@ def __init__(self): self.initialize_llvm() def initialize_llvm(self): - llvm.initialize() + # llvm.initialize() llvm.initialize_native_target() llvm.initialize_native_asmprinter() @@ -1119,12 +1089,15 @@ def lower_cast(self, builder, value, fromty, toty): f"unsupported lower_cast: {fromty} -> {toty}" ) - def lower(self, root: rg.Func, argtypes): + def lower(self, root: rg.Func, argtypes, other_roots): mod = ir.Module() llargtypes = [*map(self.lower_type, argtypes)] fname = root.fname - retty = Attributes(root.body.begin.attrs).get_return_type(root.body) + self.attributes = AttributeParser(other_roots) + retty = self.attributes.get_region_attr( + root.body.begin + ).get_return_type(root) llrettype = self.lower_type(retty) fnty = ir.FunctionType(llrettype, llargtypes) @@ -1373,7 +1346,10 @@ def example_1(a, b): jit_func = jit_compiler( fn=example_1, argtypes=(Int64, Int64), - ruleset=(base_ruleset | setup_argtypes(TypeInt64, TypeInt64)), + rule_schedule=( + base_ruleset | setup_argtypes(TypeInt64, TypeInt64) + ).saturate() + + ruleset_annotate_types.saturate(), converter_class=ExtendEGraphToRVSDG, backend=Backend(), cost_model=MyCostModel(), @@ -1431,11 +1407,12 @@ def ruleset_type_infer_float( cres = jit_compiler( fn=example_2, argtypes=(Int64, Int64), - ruleset=( + rule_schedule=( base_ruleset | setup_argtypes(TypeInt64, TypeInt64) | ruleset_type_infer_float # < --- added for float() - ), + ).saturate() + + ruleset_annotate_types.saturate(), converter_class=ExtendEGraphToRVSDG, backend=Backend(), cost_model=MyCostModel(), @@ -1464,12 +1441,22 @@ def example_3(a, b): # Add rules to signal error +@function +def failed_to_unify(ty: Type) -> Unit: ... + + +@ruleset +def ruleset_prune_type_unify(ty: Type): + yield rewrite(ty | ty, subsume=True).to(ty) + + @ruleset -def ruleset_failed_to_unify(ty: Type): +def ruleset_failed_to_unify(ty: Type, tz: Type): yield rule( - failed_to_unify(ty), + ty | tz, ).then( - union(ErrorMsg.root()).with_(ErrorMsg.fail("fail to unify")), + failed_to_unify(ty | tz), + ErrorMessage("fail to unify"), ) @@ -1479,12 +1466,13 @@ def ruleset_failed_to_unify(ty: Type): jit_compiler( fn=example_3, argtypes=(Int64, Int64), - ruleset=( + rule_schedule=( base_ruleset | setup_argtypes(TypeInt64, TypeInt64) | ruleset_type_infer_float - | ruleset_failed_to_unify - ), + ).saturate() + + (ruleset_prune_type_unify.saturate()) + + (ruleset_failed_to_unify | ruleset_annotate_types).saturate(), converter_class=ExtendEGraphToRVSDG, backend=Backend(), cost_model=MyCostModel(), @@ -1528,10 +1516,8 @@ def ruleset_type_infer_failure_report( failed_to_unify(ty), name == then_ports[idx].name, ).then( - union(ErrorMsg.root()).with_( - ErrorMsg.fail( - join("Failed to unify if-else outgoing variables: ", name) - ) + ErrorMessage( + join("Failed to unify if-else outgoing variables: ", name) ), ) @@ -1542,13 +1528,17 @@ def ruleset_type_infer_failure_report( jit_compiler( fn=example_3, argtypes=(Int64, Int64), - ruleset=( + rule_schedule=( base_ruleset | setup_argtypes(TypeInt64, TypeInt64) | ruleset_type_infer_float - | ruleset_failed_to_unify + ).saturate() + + (ruleset_prune_type_unify.saturate()) + + ( + ruleset_failed_to_unify | ruleset_type_infer_failure_report - ), + | ruleset_annotate_types + ).saturate(), converter_class=ExtendEGraphToRVSDG, backend=Backend(), cost_model=MyCostModel(), diff --git a/sealir-tutorials/ch04_2_typeinfer_loops.py b/sealir-tutorials/ch04_2_typeinfer_loops.py index 3ea1447..49ee1e3 100644 --- a/sealir-tutorials/ch04_2_typeinfer_loops.py +++ b/sealir-tutorials/ch04_2_typeinfer_loops.py @@ -68,14 +68,15 @@ SExpr, Type, TypeBool, - TypedIns, TypeInt64, TypeVar, + _AttrRegionInputType, _wc, ) from ch04_1_typeinfer_ifelse import base_ruleset as _ch4_1_base_ruleset from ch04_1_typeinfer_ifelse import ( jit_compiler, + ruleset_annotate_types, ruleset_failed_to_unify, ruleset_type_infer_failure_report, ruleset_type_infer_float, @@ -116,7 +117,7 @@ def assign_output_loop_typevar( region.get(idx), ).then( # propagate loop inputs - union(TypeVar(operands[idx])).with_(TypedIns(region).arg(idx)), + union(TypeVar(operands[idx])).with_(_AttrRegionInputType(region, idx)), ) yield rule( @@ -289,13 +290,17 @@ def lower_expr(self, expr, state): base_ruleset = ( _ch4_1_base_ruleset | ruleset_type_infer_float - | ruleset_failed_to_unify - | ruleset_type_infer_failure_report | ruleset_type_infer_undef | ruleset_type_infer_not | ruleset_propagate_typeof_loops ) +finalize_ruleset = ( + ruleset_annotate_types + | ruleset_failed_to_unify + | ruleset_type_infer_failure_report +) + # ## Example 1: Simple While Loop # # Demonstrate loop compilation with a simple while loop example @@ -321,7 +326,12 @@ def example_1(init, n): cres = jit_compiler( fn=example_1, argtypes=(Int64, Int64), - ruleset=base_ruleset | setup_argtypes(TypeInt64, TypeInt64), + # Rule schedule: combine rulesets and create saturation schedules + # First saturate the base rules and argument types, then finalize + rule_schedule=( + base_ruleset | setup_argtypes(TypeInt64, TypeInt64) + ).saturate() + + (finalize_ruleset).saturate(), **compiler_config, ) jit_func = cres.jit_func @@ -348,7 +358,10 @@ def example_2(init, n): cres = jit_compiler( fn=example_2, argtypes=(Int64, Int64), - ruleset=base_ruleset | setup_argtypes(TypeInt64, TypeInt64), + rule_schedule=( + base_ruleset | setup_argtypes(TypeInt64, TypeInt64) + ).saturate() + + (finalize_ruleset).saturate(), **compiler_config, ) jit_func = cres.jit_func diff --git a/sealir-tutorials/ch05_typeinfer_array.py b/sealir-tutorials/ch05_typeinfer_array.py index f799835..82e39a8 100644 --- a/sealir-tutorials/ch05_typeinfer_array.py +++ b/sealir-tutorials/ch05_typeinfer_array.py @@ -44,6 +44,7 @@ function, i64, i64Like, + method, rewrite, rule, ruleset, @@ -62,8 +63,6 @@ from ch04_1_typeinfer_ifelse import ( Grammar, NbOp_Type, - TypedIns, - _wc, ) from ch04_2_typeinfer_loops import Backend as _ch04_2_Backend from ch04_2_typeinfer_loops import ( @@ -78,6 +77,9 @@ TypeInt64, TypeVar, base_ruleset, +) +from ch04_2_typeinfer_loops import finalize_ruleset as _ch04_2_finalize_ruleset +from ch04_2_typeinfer_loops import ( jit_compiler, setup_argtypes, ) @@ -124,6 +126,7 @@ def strided(cls) -> DataLayout: ... class ArrayDesc(Expr): + @method(cost=10000) def __init__(self, uid: StringLike): ... @property @@ -322,6 +325,25 @@ def ruleset_typeinfer_array_getitem( ) +@function +def FlattenArrayType(ndim: i64, dtype: Type, dataLayout: DataLayout) -> Type: ... + + +@ruleset +def ruleset_finalize_arraydesc( + ad: ArrayDesc, + ndim: i64, + dtype: Type, + datalayout: DataLayout, +): + yield rewrite(ad.toType(), subsume=True).to( + FlattenArrayType(ndim, dtype, datalayout), + ndim == ad.ndim, + dtype == ad.dtype, + datalayout == ad.dataLayout, + ) + + @function def Nb_Array_1D_Getitem_Scalar( io: Term, ary: Term, index: Term, dtype: Type @@ -359,6 +381,26 @@ def handle_Term(self, op: str, children: dict | list, grm: Grammar): ) return super().handle_Term(op, children, grm) + def handle_Type( + self, key: str, op: str, children: dict | list, grm: Grammar + ): + match op, children: + case "FlattenArrayType", { + "ndim": ndim, + "dtype": dtype, + "dataLayout": dataLayout, + }: + return grm.write( + NbOp_ArrayType( + ndim=ndim, + dtype=dtype, + datalayout=dataLayout.name, + shape=(-1,) * ndim, + ) + ) + case _: + return super().handle_Type(key, op, children, grm) + # ### Extend the LLVM Backend # @@ -425,6 +467,9 @@ class CtypeInt64Array1D(ctypes.Structure): "array_int64_1d", shape=("n",), dtype=TypeInt64, layout="c" ) + +finalize_ruleset = _ch04_2_finalize_ruleset | ruleset_finalize_arraydesc + compiler_config = dict( converter_class=ExtendEGraphToRVSDG, backend=Backend(), @@ -437,12 +482,18 @@ class CtypeInt64Array1D(ctypes.Structure): cres = jit_compiler( fn=example_1, argtypes=(array_1d_symbolic, Int64), - ruleset=( + # Rule schedule combines array type inference rules: + # - Base rules with argument types + # - Array metadata and shape information + # - Array getitem operations + # First saturate array inference, then finalize + rule_schedule=( base_ruleset | setup_argtypes(array_int64_1d.toType(), TypeInt64) | ruleset(*array_infos) | ruleset_typeinfer_array_getitem - ), + ).saturate() + + finalize_ruleset.saturate(), **compiler_config, ) jit_func = cres.jit_func @@ -477,12 +528,13 @@ def example_2(ary, size): cres = jit_compiler( fn=example_2, argtypes=(array_1d_symbolic, Int64), - ruleset=( + rule_schedule=( base_ruleset | setup_argtypes(array_int64_1d.toType(), TypeInt64) | ruleset(*array_infos) | ruleset_typeinfer_array_getitem - ), + ).saturate() + + finalize_ruleset.saturate(), **compiler_config, ) jit_func = cres.jit_func @@ -608,7 +660,7 @@ def ruleset_broadcasting( nd > y.ndim, nd_diff == nd - y.ndim, ).then( - subsume(bc), + # subsume(bc), union(z).with_(Broadcast(x, ArrayAddDim(y, nd_diff))), ) diff --git a/sealir-tutorials/ch06_mlir_backend.py b/sealir-tutorials/ch06_mlir_backend.py index a738955..e4f43da 100644 --- a/sealir-tutorials/ch06_mlir_backend.py +++ b/sealir-tutorials/ch06_mlir_backend.py @@ -51,7 +51,7 @@ run_test, ) from ch04_1_typeinfer_ifelse import ( - Attributes, + AttributeParser, ) from ch04_1_typeinfer_ifelse import ( ExtendEGraphToRVSDG as ConditionalExtendGraphtoRVSDG, @@ -84,6 +84,9 @@ NbOp_Not_Int64, ) from ch04_2_typeinfer_loops import base_ruleset as loop_ruleset +from ch04_2_typeinfer_loops import ( + finalize_ruleset, # Added to pipeline for final type resolution +) from utils import IN_NOTEBOOK, Report, display # ## MLIR Backend Implementation @@ -91,7 +94,7 @@ # Define the core MLIR backend class that handles type lowering and # expression compilation. -_DEBUG = False +_DEBUG = True @dataclass(frozen=True) @@ -128,26 +131,32 @@ def lower_type(self, ty: NbOp_Type): return self.f32 raise NotImplementedError(f"unknown type: {ty}") - def lower(self, root: rg.Func, argtypes): + def get_return_types(self, root: rg.Func): + return ( + self.lower_type( + self.attributes.get_region_attr( + root.body.begin + ).get_return_type(root) + ), + ) + + def lower(self, root: rg.Func, argtypes, extracted_roots): """Expression Lowering Lower RVSDG expressions to MLIR operations, handling control flow and data flow constructs. """ context = self.context - self.loc = loc = ir.Location.unknown(context=context) + self.loc = loc = ir.Location.name(f"{self}.lower()", context=context) self.module = module = ir.Module.create(loc=loc) + self.attributes = AttributeParser(extracted_roots) # Get the module body pointer so we can insert content into the # module. self.module_body = module_body = ir.InsertionPoint(module.body) # Convert SealIR types to MLIR types. input_types = tuple([self.lower_type(x) for x in argtypes]) - output_types = ( - self.lower_type( - Attributes(root.body.begin.attrs).get_return_type(root.body) - ), - ) + output_types = self.get_return_types(root) with context, loc, module_body: # Constuct a function that emits a callable C-interface. @@ -210,17 +219,20 @@ def run_passes(self, module): if _DEBUG: module.context.enable_multithreading(False) + + pass_man = passmanager.PassManager(context=module.context) if _DEBUG and not IN_NOTEBOOK: # notebook may hang if ir_printing is enabled and and MLIR failed. pass_man.enable_ir_printing() - pass_man = passmanager.PassManager(context=module.context) pass_man.add("convert-linalg-to-loops") pass_man.add("convert-scf-to-cf") pass_man.add("finalize-memref-to-llvm") pass_man.add("convert-math-to-libm") pass_man.add("convert-func-to-llvm") pass_man.add("convert-index-to-llvm") + pass_man.add("convert-arith-to-llvm") + pass_man.add("convert-cf-to-llvm") pass_man.add("reconcile-unrealized-casts") pass_man.enable_verifier(True) pass_man.run(module.operation) @@ -229,6 +241,9 @@ def run_passes(self, module): module.dump() return module + def _cast_return_value(self, val): + return val + def lower_expr(self, expr: SExpr, state: LowerStates): """Expression Lowering Implementation @@ -254,7 +269,7 @@ def lower_expr(self, expr: SExpr, state: LowerStates): portnames = [p.name for p in body.ports] retval = outs[portnames.index(internal_prefix("ret"))] - func.ReturnOp([retval]) + func.ReturnOp([self._cast_return_value(retval)]) case rg.RegionBegin(inports=ins): portvalues = [] for i, k in enumerate(ins): @@ -358,10 +373,10 @@ def lower_expr(self, expr: SExpr, state: LowerStates): condval = yield cond # process operands - rettys = Attributes(body.begin.attrs) + regionattrs = self.attributes.get_region_attr(body.begin) result_tys = [] - for i in range(0, rettys.num_output_types() + 1): - out_ty = rettys.get_output_type(i) + for i in range(0, regionattrs.num_output_types() + 1): + out_ty = regionattrs.get_output_type(i) if out_ty is not None: match out_ty.name: case "Int64": @@ -387,15 +402,15 @@ def lower_expr(self, expr: SExpr, state: LowerStates): return if_op.results case rg.Loop(body=rg.RegionEnd() as body, operands=operands): - rettys = Attributes(body.begin.attrs) + regionattrs = self.attributes.get_region_attr(body.begin) # process operands ops = [] for op in operands: ops.append((yield op)) result_tys = [] - for i in range(1, rettys.num_output_types() + 1): - out_ty = rettys.get_output_type(i) + for i in range(1, regionattrs.num_output_types() + 1): + out_ty = regionattrs.get_output_type(i) if out_ty is not None: match out_ty.name: case "Int64": @@ -429,7 +444,7 @@ def lower_expr(self, expr: SExpr, state: LowerStates): return while_op_res case _: - raise NotImplementedError(expr, type(expr)) + raise NotImplementedError(expr, type(expr), ase.as_tuple(expr)) # ## JIT Compilation # @@ -442,20 +457,14 @@ def jit_compile(self, llmod, func_node: rg.Func, func_name="func"): Convert the MLIR module into a JIT-callable function using the MLIR execution engine. """ - attributes = Attributes(func_node.body.begin.attrs) + funcattr = self.attributes.get_region_attr(func_node.body.begin) # Convert SealIR types into MLIR types with self.loc: input_types = tuple( - [self.lower_type(x) for x in attributes.input_types()] + [self.lower_type(x) for x in funcattr.input_types()] ) - output_types = ( - self.lower_type( - Attributes(func_node.body.begin.attrs).get_return_type( - func_node.body - ) - ), - ) + output_types = (self.lower_type(funcattr.get_return_type(func_node)),) return self.jit_compile_extra(llmod, input_types, output_types) def jit_compile_extra( @@ -602,7 +611,12 @@ def pipeline_run_be_passes( jit_func = jit_compiler( fn=example_1, argtypes=(Int64, Int64), - ruleset=(if_else_ruleset | setup_argtypes(TypeInt64, TypeInt64)), + # Two-phase rule schedule: first saturate type inference, then finalize + # The finalize_ruleset ensures all type information is properly resolved + rule_schedule=( + if_else_ruleset | setup_argtypes(TypeInt64, TypeInt64) + ).saturate() + + finalize_ruleset.saturate(), pipeline_report=report, **compiler_config, ).jit_func @@ -633,11 +647,12 @@ def example_2(a, b): jit_func = jit_compiler( fn=example_2, argtypes=(Int64, Int64), - ruleset=( + rule_schedule=( if_else_ruleset | setup_argtypes(TypeInt64, TypeInt64) | ruleset_type_infer_float # < --- added for float() - ), + ).saturate() + + finalize_ruleset.saturate(), pipeline_report=report, **compiler_config, ).jit_func @@ -668,7 +683,10 @@ def example_3(init, n): jit_func = jit_compiler( fn=example_3, argtypes=(Int64, Int64), - ruleset=(loop_ruleset | setup_argtypes(TypeInt64, TypeInt64)), + rule_schedule=( + (loop_ruleset | setup_argtypes(TypeInt64, TypeInt64)).saturate() + + finalize_ruleset.saturate() + ), pipeline_report=report, **compiler_config, ).jit_func @@ -697,7 +715,10 @@ def example_4(init, n): jit_func = jit_compiler( fn=example_4, argtypes=(Int64, Int64), - ruleset=(loop_ruleset | setup_argtypes(TypeInt64, TypeInt64)), + rule_schedule=( + loop_ruleset | setup_argtypes(TypeInt64, TypeInt64) + ).saturate() + + finalize_ruleset.saturate(), pipeline_report=report, **compiler_config, ).jit_func diff --git a/sealir-tutorials/ch07_mlir_ufunc.py b/sealir-tutorials/ch07_mlir_ufunc.py index 50454b1..da5e52d 100644 --- a/sealir-tutorials/ch07_mlir_ufunc.py +++ b/sealir-tutorials/ch07_mlir_ufunc.py @@ -49,7 +49,7 @@ base_ruleset, ) from ch06_mlir_backend import Backend as _Backend -from ch06_mlir_backend import ConditionalExtendGraphtoRVSDG, NbOp_Type +from ch06_mlir_backend import NbOp_Type, finalize_ruleset from utils.report import Report # ## Type Declarations @@ -224,10 +224,11 @@ def wrapper(inner_func): if extra_ruleset is not None: ruleset |= extra_ruleset # Compile the inner function and get the IR as a module. + # Use explicit .saturate() calls for the rule schedule cres = ufunc_compiler( fn=inner_func, argtypes=(input_type,) * num_inputs, - ruleset=ruleset, + rule_schedule=ruleset.saturate() + finalize_ruleset.saturate(), ndim=ndim, **compiler_config, ) diff --git a/sealir-tutorials/ch09_whole_program_compiler_driver.py b/sealir-tutorials/ch09_whole_program_compiler_driver.py index 469b8b4..23dc587 100644 --- a/sealir-tutorials/ch09_whole_program_compiler_driver.py +++ b/sealir-tutorials/ch09_whole_program_compiler_driver.py @@ -120,6 +120,10 @@ def __init__(self, source_code, file_name): self.functions = {} # List of all global ast.Call nodes self.global_calls = [] + # Mapping of imported Python modules in the global namespace + # Key is global name of the imported Module. + # Value is the fully-qualified import path. + self.imported = {} def get_call_graph(self) -> dict[str : tuple[str]]: """Obtain a call graph suitable for processing with networkx. @@ -192,6 +196,29 @@ def visit_all(self): """Visit all nodes in the AST.""" self.visit(self.tree) + def visit_Import(self, node): + # Add globals that are imported into self.imported + for alias in node.names: + # The name used in the global namespace + global_name = alias.asname if alias.asname else alias.name + # The fully-qualified import path + import_path = alias.name + self.imported[global_name] = import_path + + def visit_ImportFrom(self, node): + # Add globals that are imported into self.imported + module_name = node.module if node.module else "" + for alias in node.names: + # The name used in the global namespace + global_name = alias.asname if alias.asname else alias.name + # The fully-qualified import path + if module_name: + import_path = f"{module_name}.{alias.name}" + else: + # Handle relative imports (from . import name) + import_path = alias.name + self.imported[global_name] = import_path + def visit_FunctionDef(self, node): """Visit a function definition.""" # Create a new namespace for the function diff --git a/sealir-tutorials/demo01_gelu_tanh_approx.py b/sealir-tutorials/demo01_gelu_tanh_approx.py index 45cc504..38267bc 100644 --- a/sealir-tutorials/demo01_gelu_tanh_approx.py +++ b/sealir-tutorials/demo01_gelu_tanh_approx.py @@ -86,6 +86,7 @@ from ch05_typeinfer_array import MyCostModel as ch06_CostModel from ch05_typeinfer_array import ( base_ruleset, + finalize_ruleset, ) from ch06_mlir_backend import LowerStates, jit_compiler, run_test from ch07_mlir_ufunc import Backend as UfuncBackend @@ -476,9 +477,10 @@ def get_cost_function(self, nodename, op, ty, cost, children): jit_func = jit_compiler( fn=gelu_tanh_forward, argtypes=(Float32,), - ruleset=( + rule_schedule=( base_ruleset | setup_argtypes(TypeFloat32) | additional_rules - ), + ).saturate() + + finalize_ruleset.saturate(), pipeline_report=report, **compiler_config, ).jit_func @@ -562,12 +564,13 @@ def pow_expansion(x: Term, ival: i64): jit_func = jit_compiler( fn=gelu_tanh_forward, argtypes=(Float32,), - ruleset=( + rule_schedule=( base_ruleset | setup_argtypes(TypeFloat32) | additional_rules | optimize_rules - ), + ).saturate() + + finalize_ruleset.saturate(), pipeline_report=report, **compiler_config, ).jit_func diff --git a/sealir-tutorials/tests/test_ch03.py b/sealir-tutorials/tests/test_ch03.py index b362ea6..7e68566 100644 --- a/sealir-tutorials/tests/test_ch03.py +++ b/sealir-tutorials/tests/test_ch03.py @@ -22,8 +22,8 @@ def ifelse_fold(a, b): else: return b - def check(fn, ruleset): - cres = compiler_pipeline(fn=fn, ruleset=ruleset) + def check(fn, rule_schedule): + cres = compiler_pipeline(fn=fn, rule_schedule=rule_schedule) return [ cur for ps, cur in ase.walk_descendants_depth_first_no_repeat( @@ -41,18 +41,20 @@ def is_if_else(expr): ifelse_nodes = check( ifelse_fold, - ruleset=(rvsdg_eqsat.ruleset_rvsdg_basic | ruleset_const_propagate), + rule_schedule=( + rvsdg_eqsat.ruleset_rvsdg_basic | ruleset_const_propagate + ).saturate(), ) # folding shouldn't occur assert len(ifelse_nodes) == 1 ifelse_nodes = check( ifelse_fold, - ruleset=( + rule_schedule=( rvsdg_eqsat.ruleset_rvsdg_basic | ruleset_const_propagate | ruleset_const_fold_if_else - ), + ).saturate(), ) # folding should occur assert len(ifelse_nodes) == 0 diff --git a/sealir-tutorials/tests/test_ch04_0.py b/sealir-tutorials/tests/test_ch04_0.py index 740dd44..9fb410a 100644 --- a/sealir-tutorials/tests/test_ch04_0.py +++ b/sealir-tutorials/tests/test_ch04_0.py @@ -13,7 +13,7 @@ def test_ch04_0_autotest(): def check(fn, ruleset): cres = pipeline_backend( fn=fn, - ruleset=ruleset, + rule_schedule=ruleset.saturate(), converter_class=ExtendEGraphToRVSDG, cost_model=MyCostModel(), codegen_extension=codegen_extension, @@ -27,7 +27,7 @@ def test_ch04_0_code_functioning(): """ jt = compiler_pipeline( fn=chained_additions, - ruleset=optimized_ruleset, + rule_schedule=optimized_ruleset.saturate(), converter_class=ExtendEGraphToRVSDG, codegen_extension=codegen_extension, cost_model=MyCostModel(), diff --git a/sealir-tutorials/tests/test_demo01.py b/sealir-tutorials/tests/test_demo01.py index 3834f1a..b6f594c 100644 --- a/sealir-tutorials/tests/test_demo01.py +++ b/sealir-tutorials/tests/test_demo01.py @@ -12,9 +12,10 @@ def test_demo01_baseline(): cres = jit_compiler( fn=gelu_tanh_forward, argtypes=(Float32,), - ruleset=( + rule_schedule=( base_ruleset | setup_argtypes(TypeFloat32) | additional_rules - ), + ).saturate() + + finalize_ruleset.saturate(), **compiler_config ) llvm_module = cres.module @@ -27,12 +28,13 @@ def test_demo01_optimized(): cres = jit_compiler( fn=gelu_tanh_forward, argtypes=(Float32,), - ruleset=( + rule_schedule=( base_ruleset | setup_argtypes(TypeFloat32) | additional_rules | optimize_rules - ), + ).saturate() + + finalize_ruleset.saturate(), **compiler_config ) llvm_module = cres.module