From 79e302e2efe23c443cc016b7ae643935bcf47886 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Jun 2025 23:14:41 -0400 Subject: [PATCH 1/7] perf: run ablations for the paper --- perf/common.jl | 120 ++++++++++++++++++++++++++++++ perf/neuraloperators/Project.toml | 21 ++++++ perf/neuraloperators/main.jl | 39 ++++++++++ src/CompileOptions.jl | 7 -- src/Compiler.jl | 2 - 5 files changed, 180 insertions(+), 9 deletions(-) create mode 100644 perf/common.jl create mode 100644 perf/neuraloperators/Project.toml create mode 100644 perf/neuraloperators/main.jl diff --git a/perf/common.jl b/perf/common.jl new file mode 100644 index 0000000000..247544b847 --- /dev/null +++ b/perf/common.jl @@ -0,0 +1,120 @@ +using BenchmarkTools: @belapsed +using Reactant, Enzyme, PrettyTables, Statistics + +function simple_mse_loss(model, x, ps, st) + y, _ = Lux.apply(model, x, ps, st) + return sum(abs2, y) +end + +function benchmark_nn_primal( + model, x, ps, st; disable_scatter_gather_bench=true, disable_pad_bench=true +) + results = Vector{Tuple{String,String,Float64,Float64,Float64}}() + + # Only XLA + compiled_fwd_xla = @compile sync = true compile_options = Reactant.DefaultXLACompileOptions() simple_mse_loss( + model, x, ps, st + ) + bench = @benchmark $compiled_fwd_xla($model, $x, $ps, $st) + push!(results, ("Primal", "Only XLA", median(bench).time, std(bench).time, 1.0)) + baseline = median(bench).time + + # Default + compiled_fwd = @compile sync = true simple_mse_loss(model, x, ps, st) + bench = @benchmark $compiled_fwd($model, $x, $ps, $st) + push!( + results, + ( + "Primal", + "All", + median(bench).time, + std(bench).time, + median(bench).time / baseline, + ), + ) + + # Disable Scatter + if disable_scatter_gather_bench + compiled_fwd_no_scatter = @compile sync = true compile_options = CompileOptions(; + disable_scatter_gather_optimization_passes=true + ) simple_mse_loss(model, x, ps, st) + bench = @benchmark $compiled_fwd_no_scatter($model, $x, $ps, $st) + + push!( + results, + ( + "Primal", + "No Scatter/Gather Optimizations", + median(bench).time, + std(bench).time, + median(bench).time / baseline, + ), + ) + end + + # Disable Pad + if disable_pad_bench + compiled_fwd_no_pad = @compile sync = true compile_options = CompileOptions(; + disable_pad_optimization_passes=true + ) simple_mse_loss(model, x, ps, st) + bench = @benchmark $compiled_fwd_no_pad($model, $x, $ps, $st) + + push!( + results, + ( + "Primal", + "No Pad Optimizations", + median(bench).time, + std(bench).time, + median(bench).time / baseline, + ), + ) + end + + # Disable Scatter and Pad + if disable_scatter_gather_bench && disable_pad_bench + compiled_fwd_no_scatter_pad = @compile sync = true compile_options = CompileOptions(; + disable_scatter_gather_optimization_passes=true, + disable_pad_optimization_passes=true, + ) simple_mse_loss(model, x, ps, st) + bench = @benchmark $compiled_fwd_no_scatter_pad($model, $x, $ps, $st) + + push!( + results, + ( + "Primal", + "No Scatter/Gather and Pad Optimizations", + median(bench).time, + std(bench).time, + median(bench).time / baseline, + ), + ) + end + + sort!(results, by=x -> x[3]) + return results +end + +function pretty_print_table(results) + header = ( + ["Mode", "Optimization Passes", "Median Time", "Std. Dev. Time", "Relative Timing"], + ["", "", "s", "s", "Time / XLA Time"], + ) + + results = copy(results) + results[:, 3] ./= 1e9 + results[:, 4] ./= 1e9 + + hl_r = Highlighter((data, i, j) -> j == 5 && data[i, j] > 1.0, crayon"bold red") + hl_g = Highlighter((data, i, j) -> j == 5 && data[i, j] < 1.0, crayon"bold green") + display( + pretty_table( + results; + header, + header_crayon=crayon"yellow bold", + highlighters=(hl_r, hl_g), + tf=tf_unicode_rounded, + ), + ) + return nothing +end diff --git a/perf/neuraloperators/Project.toml b/perf/neuraloperators/Project.toml new file mode 100644 index 0000000000..838e520c97 --- /dev/null +++ b/perf/neuraloperators/Project.toml @@ -0,0 +1,21 @@ +[deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" + +[sources] +Reactant = {path = "../../"} + +[compat] +BenchmarkTools = "1.6" +CSV = "0.10.15" +Lux = "1.13.4" +NeuralOperators = "0.6" +PrettyTables = "2.4.0" +Random = "1.11" +julia = "1.11" diff --git a/perf/neuraloperators/main.jl b/perf/neuraloperators/main.jl new file mode 100644 index 0000000000..01177b1b8a --- /dev/null +++ b/perf/neuraloperators/main.jl @@ -0,0 +1,39 @@ +using NeuralOperators, Lux, Random + +include("../common.jl") + +const xdev = reactant_device() + +function run_deeponet_benchmarks() + @info "Running DeepONet benchmarks" + + model = DeepONet(; + branch=(64, ntuple(Returns(256), 5)..., 16), + trunk=(1, ntuple(Returns(256), 5)..., 16), + branch_activation=gelu, + trunk_activation=gelu, + ) + ps, st = xdev(Lux.setup(Random.default_rng(), model)) + u = xdev(rand(Float32, 64, 1024)) + y = xdev(rand(Float32, 1, 128)) + + primal_timings = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + benchmark_nn_primal( + model, + (u, y), + ps, + st; + disable_scatter_gather_bench=true, + disable_pad_bench=true, + ) + end + + pretty_print_table(permutedims(hcat([[t...] for t in primal_timings]...), (2, 1))) + + return nothing +end + +run_deeponet_benchmarks() diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index 30dfda915f..74b310fcad 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -138,9 +138,6 @@ Fine-grained control over the compilation options for the Reactant compiler. - `assert_nonallocating`: If `true`, we make sure that no new buffers are returned by the function. Any buffer returned must be donated from the inputs. Defaults to `false`. - - `sync`: Reactant computations are asynchronous by default. If `true`, the computation - will be executed synchronously, blocking till the computation is complete. This is - recommended when benchmarking. # Extended Help @@ -178,7 +175,6 @@ struct CompileOptions # julia codegen options assert_nonallocating::Bool donated_args::Symbol - sync::Bool ## private options for ablation studies disable_scatter_gather_optimization_passes::Bool disable_pad_optimization_passes::Bool @@ -201,7 +197,6 @@ function CompileOptions(; optimize_communications::Union{Bool,OptimizeCommunicationOptions}=true, assert_nonallocating::Bool=false, donated_args::Symbol=:auto, - sync::Bool=false, disable_scatter_gather_optimization_passes::Bool=false, disable_pad_optimization_passes::Bool=false, ) @@ -248,7 +243,6 @@ function CompileOptions(; optimize_communications, assert_nonallocating, donated_args, - sync, disable_scatter_gather_optimization_passes, disable_pad_optimization_passes, ) @@ -288,7 +282,6 @@ function __compile_options_with_reversed_propagation(compile_options::CompileOpt compile_options.optimize_communications, compile_options.assert_nonallocating, compile_options.donated_args, - compile_options.sync, compile_options.disable_scatter_gather_optimization_passes, compile_options.disable_pad_optimization_passes, ) diff --git a/src/Compiler.jl b/src/Compiler.jl index addca39369..9cb3d8a892 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1260,7 +1260,6 @@ function __get_compile_options_and_kwargs(; optimize_communications::Union{Bool,OptimizeCommunicationOptions}=true, assert_nonallocating::Bool=false, donated_args::Symbol=:auto, - sync::Bool=false, kwargs..., ) return ( @@ -1282,7 +1281,6 @@ function __get_compile_options_and_kwargs(; optimize_communications, assert_nonallocating, donated_args, - sync, ), kwargs, ) From 4e7acac8ef6909ba9d023a1fe3d66aee65269a2f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Jun 2025 23:16:13 -0400 Subject: [PATCH 2/7] Update perf/common.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- perf/common.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/perf/common.jl b/perf/common.jl index 247544b847..7ecaf2626c 100644 --- a/perf/common.jl +++ b/perf/common.jl @@ -91,7 +91,7 @@ function benchmark_nn_primal( ) end - sort!(results, by=x -> x[3]) + sort!(results; by=x -> x[3]) return results end From 89b526ef0ece2457d1f6d41f9f6021d5ba4494fc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Jun 2025 23:26:20 -0400 Subject: [PATCH 3/7] fix: import --- perf/common.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/perf/common.jl b/perf/common.jl index 7ecaf2626c..8259a8eba2 100644 --- a/perf/common.jl +++ b/perf/common.jl @@ -1,4 +1,4 @@ -using BenchmarkTools: @belapsed +using BenchmarkTools: @benchmark using Reactant, Enzyme, PrettyTables, Statistics function simple_mse_loss(model, x, ps, st) From c51a5176edfdc27ec23e495a6c52f0224fb779a3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Jun 2025 23:34:43 -0400 Subject: [PATCH 4/7] perf: FNOs --- perf/neuraloperators/main.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/perf/neuraloperators/main.jl b/perf/neuraloperators/main.jl index 01177b1b8a..8a7908042d 100644 --- a/perf/neuraloperators/main.jl +++ b/perf/neuraloperators/main.jl @@ -36,4 +36,26 @@ function run_deeponet_benchmarks() return nothing end +function run_fno_benchmarks() + @info "Running FNO benchmarks" + + model = FourierNeuralOperator((16, 16), 3, 8, 64) + ps, st = xdev(Lux.setup(Random.default_rng(), model)) + x = xdev(rand(Float32, 64, 64, 1, 256)) + + primal_timings = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + benchmark_nn_primal( + model, x, ps, st; disable_scatter_gather_bench=true, disable_pad_bench=true + ) + end + + pretty_print_table(permutedims(hcat([[t...] for t in primal_timings]...), (2, 1))) + + return nothing +end + +run_fno_benchmarks() run_deeponet_benchmarks() From a93b09f6118fde40824909ae7056af6afc455184 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 18 Jun 2025 00:05:03 -0400 Subject: [PATCH 5/7] perf: grad --- perf/common.jl | 159 ++++++++++++++++++++++++++++++++--- perf/neuraloperators/main.jl | 53 ++++++++++-- src/CompileOptions.jl | 7 ++ src/Compiler.jl | 2 + 4 files changed, 202 insertions(+), 19 deletions(-) diff --git a/perf/common.jl b/perf/common.jl index 8259a8eba2..1690587934 100644 --- a/perf/common.jl +++ b/perf/common.jl @@ -1,27 +1,33 @@ using BenchmarkTools: @benchmark using Reactant, Enzyme, PrettyTables, Statistics -function simple_mse_loss(model, x, ps, st) +function simple_mse_loss(model, x, z, ps, st) y, _ = Lux.apply(model, x, ps, st) - return sum(abs2, y) + return MSELoss()(y, z) +end + +function simple_mse_loss_gradient(model, x, z, ps, st) + return Enzyme.gradient( + Reverse, simple_mse_loss, Const(model), Const(x), Const(z), ps, Const(st) + ) end function benchmark_nn_primal( - model, x, ps, st; disable_scatter_gather_bench=true, disable_pad_bench=true + model, x, z, ps, st; disable_scatter_gather_bench=true, disable_pad_bench=true ) results = Vector{Tuple{String,String,Float64,Float64,Float64}}() # Only XLA compiled_fwd_xla = @compile sync = true compile_options = Reactant.DefaultXLACompileOptions() simple_mse_loss( - model, x, ps, st + model, x, z, ps, st ) - bench = @benchmark $compiled_fwd_xla($model, $x, $ps, $st) + bench = @benchmark $compiled_fwd_xla($model, $x, $z, $ps, $st) setup = (GC.gc(true)) push!(results, ("Primal", "Only XLA", median(bench).time, std(bench).time, 1.0)) baseline = median(bench).time # Default - compiled_fwd = @compile sync = true simple_mse_loss(model, x, ps, st) - bench = @benchmark $compiled_fwd($model, $x, $ps, $st) + compiled_fwd = @compile sync = true simple_mse_loss(model, x, z, ps, st) + bench = @benchmark $compiled_fwd($model, $x, $z, $ps, $st) setup = (GC.gc(true)) push!( results, ( @@ -37,8 +43,10 @@ function benchmark_nn_primal( if disable_scatter_gather_bench compiled_fwd_no_scatter = @compile sync = true compile_options = CompileOptions(; disable_scatter_gather_optimization_passes=true - ) simple_mse_loss(model, x, ps, st) - bench = @benchmark $compiled_fwd_no_scatter($model, $x, $ps, $st) + ) simple_mse_loss(model, x, z, ps, st) + bench = @benchmark $compiled_fwd_no_scatter($model, $x, $z, $ps, $st) setup = (GC.gc( + true + )) push!( results, @@ -56,8 +64,10 @@ function benchmark_nn_primal( if disable_pad_bench compiled_fwd_no_pad = @compile sync = true compile_options = CompileOptions(; disable_pad_optimization_passes=true - ) simple_mse_loss(model, x, ps, st) - bench = @benchmark $compiled_fwd_no_pad($model, $x, $ps, $st) + ) simple_mse_loss(model, x, z, ps, st) + bench = @benchmark $compiled_fwd_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc( + true + )) push!( results, @@ -76,8 +86,10 @@ function benchmark_nn_primal( compiled_fwd_no_scatter_pad = @compile sync = true compile_options = CompileOptions(; disable_scatter_gather_optimization_passes=true, disable_pad_optimization_passes=true, - ) simple_mse_loss(model, x, ps, st) - bench = @benchmark $compiled_fwd_no_scatter_pad($model, $x, $ps, $st) + ) simple_mse_loss(model, x, z, ps, st) + bench = @benchmark $compiled_fwd_no_scatter_pad($model, $x, $z, $ps, $st) setup = (GC.gc( + true + )) push!( results, @@ -95,6 +107,127 @@ function benchmark_nn_primal( return results end +function benchmark_nn_gradient(model, x, z, ps, st; kwargs...) + return vcat( + [ + benchmark_nn_gradient_internal(model, x, z, ps, st, mode; kwargs...) for + mode in [:all, :before_enzyme, :after_enzyme] + ]..., + ) +end + +function benchmark_nn_gradient_internal( + model, x, z, ps, st, mode; disable_scatter_gather_bench=true, disable_pad_bench=true +) + @info "Benchmarking gradient with mode: $(Meta.quot(mode))" + + results = Vector{Tuple{String,String,Float64,Float64,Float64}}() + + # Only XLA + compiled_grad_xla = @compile sync = true compile_options = Reactant.DefaultXLACompileOptions() simple_mse_loss_gradient( + model, x, z, ps, st + ) + bench = @benchmark $compiled_grad_xla($model, $x, $z, $ps, $st) setup = (GC.gc(true)) + push!( + results, ("Gradient ($mode)", "Only XLA", median(bench).time, std(bench).time, 1.0) + ) + baseline = median(bench).time + + display(results[end]) + + # Default + compiled_grad = @compile sync = true optimize = mode simple_mse_loss_gradient( + model, x, z, ps, st + ) + bench = @benchmark $compiled_grad($model, $x, $z, $ps, $st) setup = (GC.gc(true)) + push!( + results, + ( + "Gradient ($mode)", + "All", + median(bench).time, + std(bench).time, + median(bench).time / baseline, + ), + ) + + display(results[end]) + + # Disable Scatter + if disable_scatter_gather_bench + compiled_grad_no_scatter = @compile sync = true compile_options = CompileOptions(; + disable_scatter_gather_optimization_passes=true, optimization_passes=mode + ) simple_mse_loss_gradient(model, x, z, ps, st) + bench = @benchmark $compiled_grad_no_scatter($model, $x, $z, $ps, $st) setup = (GC.gc( + true + )) + + push!( + results, + ( + "Gradient ($mode)", + "No Scatter/Gather Optimizations", + median(bench).time, + std(bench).time, + median(bench).time / baseline, + ), + ) + + display(results[end]) + end + + # Disable Pad + if disable_pad_bench + compiled_grad_no_pad = @compile sync = true compile_options = CompileOptions(; + disable_pad_optimization_passes=true, optimization_passes=mode + ) simple_mse_loss_gradient(model, x, z, ps, st) + bench = @benchmark $compiled_grad_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc( + true + )) + + push!( + results, + ( + "Gradient ($mode)", + "No Pad Optimizations", + median(bench).time, + std(bench).time, + median(bench).time / baseline, + ), + ) + + display(results[end]) + end + + # Disable Pad and Scatter + if disable_scatter_gather_bench && disable_pad_bench + compiled_grad_no_scatter_no_pad = @compile sync = true compile_options = CompileOptions(; + disable_scatter_gather_optimization_passes=true, + disable_pad_optimization_passes=true, + optimization_passes=mode, + ) simple_mse_loss_gradient(model, x, z, ps, st) + bench = @benchmark $compiled_grad_no_scatter_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc( + true + )) + + push!( + results, + ( + "Gradient ($mode)", + "No Scatter/Gather/Pad Optimizations", + median(bench).time, + std(bench).time, + median(bench).time / baseline, + ), + ) + + display(results[end]) + end + + sort!(results; by=x -> x[3]) + return results +end + function pretty_print_table(results) header = ( ["Mode", "Optimization Passes", "Median Time", "Std. Dev. Time", "Relative Timing"], diff --git a/perf/neuraloperators/main.jl b/perf/neuraloperators/main.jl index 8a7908042d..6c75a9d1bb 100644 --- a/perf/neuraloperators/main.jl +++ b/perf/neuraloperators/main.jl @@ -8,14 +8,15 @@ function run_deeponet_benchmarks() @info "Running DeepONet benchmarks" model = DeepONet(; - branch=(64, ntuple(Returns(256), 5)..., 16), - trunk=(1, ntuple(Returns(256), 5)..., 16), + branch=(64, ntuple(Returns(256), 4)..., 16), + trunk=(1, ntuple(Returns(256), 4)..., 16), branch_activation=gelu, trunk_activation=gelu, ) ps, st = xdev(Lux.setup(Random.default_rng(), model)) u = xdev(rand(Float32, 64, 1024)) y = xdev(rand(Float32, 1, 128)) + z = xdev(rand(Float32, 128, 1024)) primal_timings = Reactant.with_config(; dot_general_precision=PrecisionConfig.HIGH, @@ -24,6 +25,7 @@ function run_deeponet_benchmarks() benchmark_nn_primal( model, (u, y), + z, ps, st; disable_scatter_gather_bench=true, @@ -31,7 +33,23 @@ function run_deeponet_benchmarks() ) end - pretty_print_table(permutedims(hcat([[t...] for t in primal_timings]...), (2, 1))) + gradient_timings = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + benchmark_nn_gradient( + model, + (u, y), + z, + ps, + st; + disable_scatter_gather_bench=true, + disable_pad_bench=true, + ) + end + + timings = vcat(primal_timings, gradient_timings) + pretty_print_table(permutedims(hcat([[t...] for t in timings]...), (2, 1))) return nothing end @@ -42,20 +60,43 @@ function run_fno_benchmarks() model = FourierNeuralOperator((16, 16), 3, 8, 64) ps, st = xdev(Lux.setup(Random.default_rng(), model)) x = xdev(rand(Float32, 64, 64, 1, 256)) + z = xdev(rand(Float32, 64, 64, 8, 256)) primal_timings = Reactant.with_config(; dot_general_precision=PrecisionConfig.HIGH, convolution_precision=PrecisionConfig.HIGH, ) do benchmark_nn_primal( - model, x, ps, st; disable_scatter_gather_bench=true, disable_pad_bench=true + model, + x, + z, + ps, + st; + disable_scatter_gather_bench=true, + disable_pad_bench=true, ) end - pretty_print_table(permutedims(hcat([[t...] for t in primal_timings]...), (2, 1))) + gradient_timings = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + benchmark_nn_gradient( + model, + x, + z, + ps, + st; + disable_scatter_gather_bench=true, + disable_pad_bench=true, + ) + end + + timings = vcat(primal_timings, gradient_timings) + pretty_print_table(permutedims(hcat([[t...] for t in timings]...), (2, 1))) return nothing end -run_fno_benchmarks() run_deeponet_benchmarks() +run_fno_benchmarks() diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index 74b310fcad..30dfda915f 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -138,6 +138,9 @@ Fine-grained control over the compilation options for the Reactant compiler. - `assert_nonallocating`: If `true`, we make sure that no new buffers are returned by the function. Any buffer returned must be donated from the inputs. Defaults to `false`. + - `sync`: Reactant computations are asynchronous by default. If `true`, the computation + will be executed synchronously, blocking till the computation is complete. This is + recommended when benchmarking. # Extended Help @@ -175,6 +178,7 @@ struct CompileOptions # julia codegen options assert_nonallocating::Bool donated_args::Symbol + sync::Bool ## private options for ablation studies disable_scatter_gather_optimization_passes::Bool disable_pad_optimization_passes::Bool @@ -197,6 +201,7 @@ function CompileOptions(; optimize_communications::Union{Bool,OptimizeCommunicationOptions}=true, assert_nonallocating::Bool=false, donated_args::Symbol=:auto, + sync::Bool=false, disable_scatter_gather_optimization_passes::Bool=false, disable_pad_optimization_passes::Bool=false, ) @@ -243,6 +248,7 @@ function CompileOptions(; optimize_communications, assert_nonallocating, donated_args, + sync, disable_scatter_gather_optimization_passes, disable_pad_optimization_passes, ) @@ -282,6 +288,7 @@ function __compile_options_with_reversed_propagation(compile_options::CompileOpt compile_options.optimize_communications, compile_options.assert_nonallocating, compile_options.donated_args, + compile_options.sync, compile_options.disable_scatter_gather_optimization_passes, compile_options.disable_pad_optimization_passes, ) diff --git a/src/Compiler.jl b/src/Compiler.jl index 9cb3d8a892..addca39369 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1260,6 +1260,7 @@ function __get_compile_options_and_kwargs(; optimize_communications::Union{Bool,OptimizeCommunicationOptions}=true, assert_nonallocating::Bool=false, donated_args::Symbol=:auto, + sync::Bool=false, kwargs..., ) return ( @@ -1281,6 +1282,7 @@ function __get_compile_options_and_kwargs(; optimize_communications, assert_nonallocating, donated_args, + sync, ), kwargs, ) From 5d23af002faab76a3dbfec4499eb5d11ebe2adf0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 18 Jun 2025 22:45:14 -0400 Subject: [PATCH 6/7] fix: remove forced passes at end --- perf/neuraloperators/main.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/perf/neuraloperators/main.jl b/perf/neuraloperators/main.jl index 6c75a9d1bb..cb063671c6 100644 --- a/perf/neuraloperators/main.jl +++ b/perf/neuraloperators/main.jl @@ -2,7 +2,7 @@ using NeuralOperators, Lux, Random include("../common.jl") -const xdev = reactant_device() +const xdev = reactant_device(; force=true) function run_deeponet_benchmarks() @info "Running DeepONet benchmarks" From 2536db14522a735b7ee18f1b9ac188243620e22b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 18 Jun 2025 23:01:43 -0400 Subject: [PATCH 7/7] fix: correct usage of sync --- perf/common.jl | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/perf/common.jl b/perf/common.jl index 1690587934..331da89a3e 100644 --- a/perf/common.jl +++ b/perf/common.jl @@ -18,9 +18,9 @@ function benchmark_nn_primal( results = Vector{Tuple{String,String,Float64,Float64,Float64}}() # Only XLA - compiled_fwd_xla = @compile sync = true compile_options = Reactant.DefaultXLACompileOptions() simple_mse_loss( - model, x, z, ps, st - ) + compiled_fwd_xla = @compile compile_options = Reactant.DefaultXLACompileOptions(; + sync=true + ) simple_mse_loss(model, x, z, ps, st) bench = @benchmark $compiled_fwd_xla($model, $x, $z, $ps, $st) setup = (GC.gc(true)) push!(results, ("Primal", "Only XLA", median(bench).time, std(bench).time, 1.0)) baseline = median(bench).time @@ -41,8 +41,8 @@ function benchmark_nn_primal( # Disable Scatter if disable_scatter_gather_bench - compiled_fwd_no_scatter = @compile sync = true compile_options = CompileOptions(; - disable_scatter_gather_optimization_passes=true + compiled_fwd_no_scatter = @compile compile_options = CompileOptions(; + disable_scatter_gather_optimization_passes=true, sync=true ) simple_mse_loss(model, x, z, ps, st) bench = @benchmark $compiled_fwd_no_scatter($model, $x, $z, $ps, $st) setup = (GC.gc( true @@ -62,8 +62,8 @@ function benchmark_nn_primal( # Disable Pad if disable_pad_bench - compiled_fwd_no_pad = @compile sync = true compile_options = CompileOptions(; - disable_pad_optimization_passes=true + compiled_fwd_no_pad = @compile compile_options = CompileOptions(; + disable_pad_optimization_passes=true, sync=true ) simple_mse_loss(model, x, z, ps, st) bench = @benchmark $compiled_fwd_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc( true @@ -83,9 +83,10 @@ function benchmark_nn_primal( # Disable Scatter and Pad if disable_scatter_gather_bench && disable_pad_bench - compiled_fwd_no_scatter_pad = @compile sync = true compile_options = CompileOptions(; + compiled_fwd_no_scatter_pad = @compile compile_options = CompileOptions(; disable_scatter_gather_optimization_passes=true, disable_pad_optimization_passes=true, + sync=true, ) simple_mse_loss(model, x, z, ps, st) bench = @benchmark $compiled_fwd_no_scatter_pad($model, $x, $z, $ps, $st) setup = (GC.gc( true @@ -124,9 +125,9 @@ function benchmark_nn_gradient_internal( results = Vector{Tuple{String,String,Float64,Float64,Float64}}() # Only XLA - compiled_grad_xla = @compile sync = true compile_options = Reactant.DefaultXLACompileOptions() simple_mse_loss_gradient( - model, x, z, ps, st - ) + compiled_grad_xla = @compile compile_options = Reactant.DefaultXLACompileOptions(; + sync=true + ) simple_mse_loss_gradient(model, x, z, ps, st) bench = @benchmark $compiled_grad_xla($model, $x, $z, $ps, $st) setup = (GC.gc(true)) push!( results, ("Gradient ($mode)", "Only XLA", median(bench).time, std(bench).time, 1.0) @@ -155,8 +156,10 @@ function benchmark_nn_gradient_internal( # Disable Scatter if disable_scatter_gather_bench - compiled_grad_no_scatter = @compile sync = true compile_options = CompileOptions(; - disable_scatter_gather_optimization_passes=true, optimization_passes=mode + compiled_grad_no_scatter = @compile compile_options = CompileOptions(; + disable_scatter_gather_optimization_passes=true, + optimization_passes=mode, + sync=true, ) simple_mse_loss_gradient(model, x, z, ps, st) bench = @benchmark $compiled_grad_no_scatter($model, $x, $z, $ps, $st) setup = (GC.gc( true @@ -178,8 +181,8 @@ function benchmark_nn_gradient_internal( # Disable Pad if disable_pad_bench - compiled_grad_no_pad = @compile sync = true compile_options = CompileOptions(; - disable_pad_optimization_passes=true, optimization_passes=mode + compiled_grad_no_pad = @compile compile_options = CompileOptions(; + disable_pad_optimization_passes=true, optimization_passes=mode, sync=true ) simple_mse_loss_gradient(model, x, z, ps, st) bench = @benchmark $compiled_grad_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc( true @@ -201,10 +204,11 @@ function benchmark_nn_gradient_internal( # Disable Pad and Scatter if disable_scatter_gather_bench && disable_pad_bench - compiled_grad_no_scatter_no_pad = @compile sync = true compile_options = CompileOptions(; + compiled_grad_no_scatter_no_pad = @compile compile_options = CompileOptions(; disable_scatter_gather_optimization_passes=true, disable_pad_optimization_passes=true, optimization_passes=mode, + sync=true, ) simple_mse_loss_gradient(model, x, z, ps, st) bench = @benchmark $compiled_grad_no_scatter_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc( true