Skip to content

Commit ee99c47

Browse files
authored
fix: ensure our HLO passes are not run for only XLA case (#1410)
1 parent 4eaf23f commit ee99c47

File tree

2 files changed

+34
-22
lines changed

2 files changed

+34
-22
lines changed

src/CompileOptions.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -295,16 +295,28 @@ function __compile_options_with_reversed_propagation(compile_options::CompileOpt
295295
end
296296

297297
"""
298-
DefaultXLACompileOptions()
298+
DefaultXLACompileOptions(;
299+
donated_args=:auto, sync=false, optimize_then_pad=true, assert_nonallocating=false
300+
)
299301
300302
Runs specific Enzyme-JAX passes to ensure that the generated code is compatible with
301-
XLA compilation.
303+
XLA compilation. For the documentation of the allowed kwargs see [`CompileOptions`](@ref).
302304
303305
!!! warning
304306
305307
This is mostly a benchmarking option, and the default [`CompileOptions`](@ref) is almost
306308
certainly a better option.
307309
"""
308-
function DefaultXLACompileOptions()
309-
return CompileOptions(; optimization_passes=:only_enzyme, inline=false)
310+
function DefaultXLACompileOptions(;
311+
donated_args=:auto, sync=false, optimize_then_pad=true, assert_nonallocating=false
312+
)
313+
return CompileOptions(;
314+
optimization_passes=:only_enzyme,
315+
inline=false,
316+
donated_args,
317+
sync,
318+
optimize_then_pad,
319+
assert_nonallocating,
320+
optimize_communications=false,
321+
)
310322
end

src/Compiler.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1785,7 +1785,7 @@ function compile_mlir!(
17851785
],
17861786
',',
17871787
),
1788-
"only_enzyme",
1788+
"no_enzyme",
17891789
)
17901790
elseif compile_options.optimization_passes === :only_enzyme
17911791
run_pass_pipeline!(
@@ -1801,7 +1801,7 @@ function compile_mlir!(
18011801
],
18021802
',',
18031803
),
1804-
"after_enzyme",
1804+
"only_enzyme",
18051805
)
18061806
elseif compile_options.optimization_passes === :after_enzyme
18071807
run_pass_pipeline!(
@@ -1852,7 +1852,7 @@ function compile_mlir!(
18521852
end,
18531853
',',
18541854
),
1855-
"before_enzyme",
1855+
"after_enzyme",
18561856
)
18571857
elseif compile_options.optimization_passes === :before_enzyme
18581858
run_pass_pipeline!(
@@ -1887,7 +1887,7 @@ function compile_mlir!(
18871887
end,
18881888
',',
18891889
),
1890-
"after_enzyme",
1890+
"before_enzyme",
18911891
)
18921892
elseif compile_options.optimization_passes === :canonicalize
18931893
run_pass_pipeline!(mod, "mark-func-memory-effects,canonicalize", "canonicalize")
@@ -1897,23 +1897,23 @@ function compile_mlir!(
18971897
run_pass_pipeline!(mod, compile_options.optimization_passes, "custom_pass")
18981898
end
18991899

1900-
if !(compile_options.optimization_passes isa String)
1901-
if compile_options.optimization_passes (:none, :just_batch, :canonicalize) && (
1900+
if compile_options.optimization_passes isa Symbol &&
1901+
compile_options.optimization_passes === :all &&
1902+
(
19021903
compile_options.transpose_propagate === :up ||
19031904
compile_options.reshape_propagate === :up
19041905
)
1905-
# We tried propagating reshapes and transposes up. If at this point we are left
1906-
# with them, we propagate them down to minimize the number of Ops in the IR.
1907-
run_pass_pipeline!(
1908-
mod,
1909-
optimization_passes(
1910-
Reactant.__compile_options_with_reversed_propagation(compile_options);
1911-
recognize_comms,
1912-
lower_comms,
1913-
),
1914-
"post_op_transpose_reshape",
1915-
)
1916-
end
1906+
# We tried propagating reshapes and transposes up. If at this point we are left
1907+
# with them, we propagate them down to minimize the number of Ops in the IR.
1908+
run_pass_pipeline!(
1909+
mod,
1910+
optimization_passes(
1911+
Reactant.__compile_options_with_reversed_propagation(compile_options);
1912+
recognize_comms,
1913+
lower_comms,
1914+
),
1915+
"post_op_transpose_reshape",
1916+
)
19171917
end
19181918

19191919
if backend == "cuda" && compile_options.cudnn_hlo_optimize

0 commit comments

Comments
 (0)