Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions sealir-tutorials/ch02_egraph_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@
from sealir import rvsdg
from sealir.eqsat.rvsdg_convert import egraph_conversion
from sealir.eqsat.rvsdg_eqsat import GraphRoot
from sealir.eqsat.rvsdg_extract import egraph_extraction
from sealir.eqsat.rvsdg_extract import (
CostModel,
EGraphToRVSDG,
egraph_extraction,
)

# We'll be extending from chapter 1.
from ch01_basic_compiler import (
Expand Down Expand Up @@ -113,14 +117,22 @@ def max_if_else(x, y):
# efficiency. While all variants are functionally identical,
# we are primarily interested in identifying the "best" one,
# where "best" depends on context--—such as execution speed, code size, or
# energy efficiency. To address this, the `egraph_extraction()` function allows
# users to define custom cost models, tailoring the selection process to
# prioritize the variant that aligns with their specific optimization goals.
# energy efficiency.
#
# The extraction process involves three steps:
# 1. Create an extraction instance with `egraph_extraction()` using a cost model
# 2. Extract the graph root to get extraction results
# 3. Convert to s-expression using a converter class
#
# This 3-step approach allows users to define custom cost models, tailoring
# the selection process to prioritize the variant that aligns with their
# specific optimization goals.

if __name__ == "__main__":
help(egraph_extraction)

# Here, we will use the default cost model, which is based on the node count.
# The extraction follows the 3-step process described above.


class EGraphExtractionOutput(TypedDict):
Expand All @@ -135,7 +147,10 @@ def pipeline_egraph_extraction(
with pipeline_report.nest(
"EGraph Extraction", default_expanded=True
) as report:
cost, extracted = egraph_extraction(egraph, rvsdg_expr)
extraction = egraph_extraction(egraph)
extresult = extraction.extract_graph_root()
extracted = extresult.convert(rvsdg_expr, EGraphToRVSDG)
cost = extresult.cost
report.append("Cost", cost)
report.append("Extracted", rvsdg.format_rvsdg(extracted))
return {"cost": cost, "extracted": extracted}
Expand Down
33 changes: 24 additions & 9 deletions sealir-tutorials/ch03_egraph_program_rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@

from __future__ import annotations

from egglog import EGraph, Ruleset, Unit, function, i64, rewrite, rule, ruleset
from egglog import (
EGraph,
Schedule,
Unit,
function,
i64,
rewrite,
rule,
ruleset,
)
from sealir import rvsdg
from sealir.eqsat import rvsdg_eqsat
from sealir.eqsat.rvsdg_eqsat import GraphRoot, Term, TermList
Expand All @@ -46,18 +55,19 @@
from utils import IN_NOTEBOOK, Report, display

# Next, we'll explore a new compiler pipeline designed with customizable
# rulesets. To enable this flexibility, we've introduced a `ruleset` argument,
# allowing you to tailor the pipeline's behavior to your specific needs.
# rule schedules. To enable this flexibility, we've introduced a `rule_schedule`
# argument, allowing you to tailor the pipeline's behavior to your specific needs.
# A rule schedule defines how and when rules are applied during saturation.


def egraph_saturation(
egraph: EGraph,
egraph_root: GraphRoot,
ruleset: Ruleset,
rule_schedule: Schedule,
pipeline_report=Report.Sink(),
) -> EGraphOutput:
# Apply the ruleset to the egraph
egraph.run(ruleset.saturate())
# Apply the rule schedule to the egraph for saturation
egraph.run(rule_schedule)
pipeline_report.append("EGraph Saturated", egraph)
return {"egraph": egraph, "egraph_root": egraph_root}

Expand Down Expand Up @@ -140,12 +150,15 @@ def ifelse_fold(a, b):
else:
return b

# Add our const-propagation rule to the basic rvsdg ruleset
# Add our const-propagation rule to the basic rvsdg ruleset.
# We use .saturate() to create a schedule that runs the rules to saturation.
my_ruleset = rvsdg_eqsat.ruleset_rvsdg_basic | ruleset_const_propagate

report = Report("Test", default_expanded=True)
jt = compiler_pipeline(
fn=ifelse_fold, pipeline_report=report, ruleset=my_ruleset
fn=ifelse_fold,
pipeline_report=report,
rule_schedule=my_ruleset.saturate(),
).jit_func
report.display()
run_test(ifelse_fold, jt, (12, 34))
Expand Down Expand Up @@ -197,7 +210,9 @@ def ruleset_const_fold_if_else(a: Term, b: Term, c: Term, operands: TermList):

report = Report("Test", default_expanded=True)
jt = compiler_pipeline(
fn=ifelse_fold, pipeline_report=report, ruleset=my_ruleset
fn=ifelse_fold,
pipeline_report=report,
rule_schedule=my_ruleset.saturate(),
).jit_func
report.display()
run_test(ifelse_fold, jt, (12, 34))
Expand Down
40 changes: 34 additions & 6 deletions sealir-tutorials/ch04_0_typeinfer_prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@
from sealir.eqsat.rvsdg_extract import (
CostModel,
EGraphToRVSDG,
Extraction,
egraph_extraction,
)
from sealir.llvm_pyapi_backend import SSAValue
from sealir.rvsdg import grammar as rg

from ch02_egraph_basic import (
BackendOutput,
Expand All @@ -70,6 +72,11 @@
# - `converter_class` is for customizing EGraph-to-RVSDG conversion as we will be
# introducing new RVSDG operations for typed operations.
# - `cost_model` is for customizing the cost of the new operations.
#
# The extraction API has been updated to use the new 3-step process:
# 1. Create extraction with egraph_extraction()
# 2. Extract graph root to get extraction results
# 3. Convert s-expression with converter class


def pipeline_egraph_extraction(
Expand All @@ -82,11 +89,18 @@ def pipeline_egraph_extraction(
with pipeline_report.nest(
"EGraph Extraction", default_expanded=True
) as report:
cost, extracted = egraph_extraction(
# Step 1: Create extraction instance with custom cost model
extraction = egraph_extraction(
egraph,
cost_model=cost_model, # <-------------- new
)
# Step 2: Extract the graph root
extresult = extraction.extract_graph_root()
cost = extresult.cost
# Step 3: Extract s-expression with custom converter
extracted = extresult.convert(
rvsdg_expr,
converter_class=converter_class, # <---- new
cost_model=cost_model, # <-------------- new
)
report.append("Cost", cost)
report.append("Extracted", rvsdg.format_rvsdg(extracted))
Expand Down Expand Up @@ -131,6 +145,7 @@ def add_x_y(x, y):


# We will start with the same ruleset as in chapter 3.
# Note that we now use .saturate() to create a rule schedule.

basic_ruleset = rvsdg_eqsat.ruleset_rvsdg_basic | ruleset_const_propagate

Expand All @@ -143,7 +158,7 @@ def add_x_y(x, y):
report = Report("Compiler Pipeline", default_expanded=True)
jt = compiler_pipeline(
fn=add_x_y,
ruleset=basic_ruleset,
rule_schedule=basic_ruleset.saturate(),
converter_class=EGraphToRVSDG,
codegen_extension=None,
cost_model=None,
Expand Down Expand Up @@ -301,6 +316,19 @@ def handle_Term(self, op: str, children: dict | list, grm: Grammar):
# Use parent's implementation for other terms.
return super().handle_Term(op, children, grm)

def handle_Type(
self, key: str, op: str, children: dict | list, grm: Grammar
):

match op, children:
case "Type", {"name": str(typename)}:
if typename == "Int64":
return grm.write(
rg.Generic(name="Type", children=tuple([typename]))
)

raise NotImplementedError("handle_Type", op, children)


# The LLVM code-generation also needs an extension:

Expand Down Expand Up @@ -352,7 +380,7 @@ def get_cost_function(self, nodename, op, ty, cost, children):
report = Report("Compiler Pipeline", default_expanded=True)
jt = compiler_pipeline(
fn=add_x_y,
ruleset=typeinfer_ruleset,
rule_schedule=typeinfer_ruleset.saturate(),
converter_class=ExtendEGraphToRVSDG,
codegen_extension=codegen_extension,
cost_model=MyCostModel(),
Expand Down Expand Up @@ -387,7 +415,7 @@ def chained_additions(x, y):
report = Report("Compiler Pipeline", default_expanded=True)
jt = compiler_pipeline(
fn=chained_additions,
ruleset=typeinfer_ruleset,
rule_schedule=typeinfer_ruleset.saturate(),
converter_class=ExtendEGraphToRVSDG,
codegen_extension=codegen_extension,
cost_model=MyCostModel(),
Expand Down Expand Up @@ -427,7 +455,7 @@ def ruleset_optimize_boxing(x: Term):
report = Report("Compiler Pipeline", default_expanded=True)
jt = compiler_pipeline(
fn=chained_additions,
ruleset=optimized_ruleset,
rule_schedule=optimized_ruleset.saturate(),
converter_class=ExtendEGraphToRVSDG,
codegen_extension=codegen_extension,
cost_model=MyCostModel(),
Expand Down
Loading
Loading