diff --git a/perf/common.jl b/perf/common.jl new file mode 100644 index 0000000000..331da89a3e --- /dev/null +++ b/perf/common.jl @@ -0,0 +1,257 @@ +using BenchmarkTools: @benchmark +using Reactant, Enzyme, PrettyTables, Statistics + +function simple_mse_loss(model, x, z, ps, st) + y, _ = Lux.apply(model, x, ps, st) + return MSELoss()(y, z) +end + +function simple_mse_loss_gradient(model, x, z, ps, st) + return Enzyme.gradient( + Reverse, simple_mse_loss, Const(model), Const(x), Const(z), ps, Const(st) + ) +end + +function benchmark_nn_primal( + model, x, z, ps, st; disable_scatter_gather_bench=true, disable_pad_bench=true +) + results = Vector{Tuple{String,String,Float64,Float64,Float64}}() + + # Only XLA + compiled_fwd_xla = @compile compile_options = Reactant.DefaultXLACompileOptions(; + sync=true + ) simple_mse_loss(model, x, z, ps, st) + bench = @benchmark $compiled_fwd_xla($model, $x, $z, $ps, $st) setup = (GC.gc(true)) + push!(results, ("Primal", "Only XLA", median(bench).time, std(bench).time, 1.0)) + baseline = median(bench).time + + # Default + compiled_fwd = @compile sync = true simple_mse_loss(model, x, z, ps, st) + bench = @benchmark $compiled_fwd($model, $x, $z, $ps, $st) setup = (GC.gc(true)) + push!( + results, + ( + "Primal", + "All", + median(bench).time, + std(bench).time, + median(bench).time / baseline, + ), + ) + + # Disable Scatter + if disable_scatter_gather_bench + compiled_fwd_no_scatter = @compile compile_options = CompileOptions(; + disable_scatter_gather_optimization_passes=true, sync=true + ) simple_mse_loss(model, x, z, ps, st) + bench = @benchmark $compiled_fwd_no_scatter($model, $x, $z, $ps, $st) setup = (GC.gc( + true + )) + + push!( + results, + ( + "Primal", + "No Scatter/Gather Optimizations", + median(bench).time, + std(bench).time, + median(bench).time / baseline, + ), + ) + end + + # Disable Pad + if disable_pad_bench + compiled_fwd_no_pad = @compile compile_options = CompileOptions(; + disable_pad_optimization_passes=true, sync=true + ) simple_mse_loss(model, x, z, ps, st) + bench = @benchmark $compiled_fwd_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc( + true + )) + + push!( + results, + ( + "Primal", + "No Pad Optimizations", + median(bench).time, + std(bench).time, + median(bench).time / baseline, + ), + ) + end + + # Disable Scatter and Pad + if disable_scatter_gather_bench && disable_pad_bench + compiled_fwd_no_scatter_pad = @compile compile_options = CompileOptions(; + disable_scatter_gather_optimization_passes=true, + disable_pad_optimization_passes=true, + sync=true, + ) simple_mse_loss(model, x, z, ps, st) + bench = @benchmark $compiled_fwd_no_scatter_pad($model, $x, $z, $ps, $st) setup = (GC.gc( + true + )) + + push!( + results, + ( + "Primal", + "No Scatter/Gather and Pad Optimizations", + median(bench).time, + std(bench).time, + median(bench).time / baseline, + ), + ) + end + + sort!(results; by=x -> x[3]) + return results +end + +function benchmark_nn_gradient(model, x, z, ps, st; kwargs...) + return vcat( + [ + benchmark_nn_gradient_internal(model, x, z, ps, st, mode; kwargs...) for + mode in [:all, :before_enzyme, :after_enzyme] + ]..., + ) +end + +function benchmark_nn_gradient_internal( + model, x, z, ps, st, mode; disable_scatter_gather_bench=true, disable_pad_bench=true +) + @info "Benchmarking gradient with mode: $(Meta.quot(mode))" + + results = Vector{Tuple{String,String,Float64,Float64,Float64}}() + + # Only XLA + compiled_grad_xla = @compile compile_options = Reactant.DefaultXLACompileOptions(; + sync=true + ) simple_mse_loss_gradient(model, x, z, ps, st) + bench = @benchmark $compiled_grad_xla($model, $x, $z, $ps, $st) setup = (GC.gc(true)) + push!( + results, ("Gradient ($mode)", "Only XLA", median(bench).time, std(bench).time, 1.0) + ) + baseline = median(bench).time + + display(results[end]) + + # Default + compiled_grad = @compile sync = true optimize = mode simple_mse_loss_gradient( + model, x, z, ps, st + ) + bench = @benchmark $compiled_grad($model, $x, $z, $ps, $st) setup = (GC.gc(true)) + push!( + results, + ( + "Gradient ($mode)", + "All", + median(bench).time, + std(bench).time, + median(bench).time / baseline, + ), + ) + + display(results[end]) + + # Disable Scatter + if disable_scatter_gather_bench + compiled_grad_no_scatter = @compile compile_options = CompileOptions(; + disable_scatter_gather_optimization_passes=true, + optimization_passes=mode, + sync=true, + ) simple_mse_loss_gradient(model, x, z, ps, st) + bench = @benchmark $compiled_grad_no_scatter($model, $x, $z, $ps, $st) setup = (GC.gc( + true + )) + + push!( + results, + ( + "Gradient ($mode)", + "No Scatter/Gather Optimizations", + median(bench).time, + std(bench).time, + median(bench).time / baseline, + ), + ) + + display(results[end]) + end + + # Disable Pad + if disable_pad_bench + compiled_grad_no_pad = @compile compile_options = CompileOptions(; + disable_pad_optimization_passes=true, optimization_passes=mode, sync=true + ) simple_mse_loss_gradient(model, x, z, ps, st) + bench = @benchmark $compiled_grad_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc( + true + )) + + push!( + results, + ( + "Gradient ($mode)", + "No Pad Optimizations", + median(bench).time, + std(bench).time, + median(bench).time / baseline, + ), + ) + + display(results[end]) + end + + # Disable Pad and Scatter + if disable_scatter_gather_bench && disable_pad_bench + compiled_grad_no_scatter_no_pad = @compile compile_options = CompileOptions(; + disable_scatter_gather_optimization_passes=true, + disable_pad_optimization_passes=true, + optimization_passes=mode, + sync=true, + ) simple_mse_loss_gradient(model, x, z, ps, st) + bench = @benchmark $compiled_grad_no_scatter_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc( + true + )) + + push!( + results, + ( + "Gradient ($mode)", + "No Scatter/Gather/Pad Optimizations", + median(bench).time, + std(bench).time, + median(bench).time / baseline, + ), + ) + + display(results[end]) + end + + sort!(results; by=x -> x[3]) + return results +end + +function pretty_print_table(results) + header = ( + ["Mode", "Optimization Passes", "Median Time", "Std. Dev. Time", "Relative Timing"], + ["", "", "s", "s", "Time / XLA Time"], + ) + + results = copy(results) + results[:, 3] ./= 1e9 + results[:, 4] ./= 1e9 + + hl_r = Highlighter((data, i, j) -> j == 5 && data[i, j] > 1.0, crayon"bold red") + hl_g = Highlighter((data, i, j) -> j == 5 && data[i, j] < 1.0, crayon"bold green") + display( + pretty_table( + results; + header, + header_crayon=crayon"yellow bold", + highlighters=(hl_r, hl_g), + tf=tf_unicode_rounded, + ), + ) + return nothing +end diff --git a/perf/neuraloperators/Project.toml b/perf/neuraloperators/Project.toml new file mode 100644 index 0000000000..838e520c97 --- /dev/null +++ b/perf/neuraloperators/Project.toml @@ -0,0 +1,21 @@ +[deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" + +[sources] +Reactant = {path = "../../"} + +[compat] +BenchmarkTools = "1.6" +CSV = "0.10.15" +Lux = "1.13.4" +NeuralOperators = "0.6" +PrettyTables = "2.4.0" +Random = "1.11" +julia = "1.11" diff --git a/perf/neuraloperators/main.jl b/perf/neuraloperators/main.jl new file mode 100644 index 0000000000..cb063671c6 --- /dev/null +++ b/perf/neuraloperators/main.jl @@ -0,0 +1,102 @@ +using NeuralOperators, Lux, Random + +include("../common.jl") + +const xdev = reactant_device(; force=true) + +function run_deeponet_benchmarks() + @info "Running DeepONet benchmarks" + + model = DeepONet(; + branch=(64, ntuple(Returns(256), 4)..., 16), + trunk=(1, ntuple(Returns(256), 4)..., 16), + branch_activation=gelu, + trunk_activation=gelu, + ) + ps, st = xdev(Lux.setup(Random.default_rng(), model)) + u = xdev(rand(Float32, 64, 1024)) + y = xdev(rand(Float32, 1, 128)) + z = xdev(rand(Float32, 128, 1024)) + + primal_timings = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + benchmark_nn_primal( + model, + (u, y), + z, + ps, + st; + disable_scatter_gather_bench=true, + disable_pad_bench=true, + ) + end + + gradient_timings = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + benchmark_nn_gradient( + model, + (u, y), + z, + ps, + st; + disable_scatter_gather_bench=true, + disable_pad_bench=true, + ) + end + + timings = vcat(primal_timings, gradient_timings) + pretty_print_table(permutedims(hcat([[t...] for t in timings]...), (2, 1))) + + return nothing +end + +function run_fno_benchmarks() + @info "Running FNO benchmarks" + + model = FourierNeuralOperator((16, 16), 3, 8, 64) + ps, st = xdev(Lux.setup(Random.default_rng(), model)) + x = xdev(rand(Float32, 64, 64, 1, 256)) + z = xdev(rand(Float32, 64, 64, 8, 256)) + + primal_timings = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + benchmark_nn_primal( + model, + x, + z, + ps, + st; + disable_scatter_gather_bench=true, + disable_pad_bench=true, + ) + end + + gradient_timings = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + benchmark_nn_gradient( + model, + x, + z, + ps, + st; + disable_scatter_gather_bench=true, + disable_pad_bench=true, + ) + end + + timings = vcat(primal_timings, gradient_timings) + pretty_print_table(permutedims(hcat([[t...] for t in timings]...), (2, 1))) + + return nothing +end + +run_deeponet_benchmarks() +run_fno_benchmarks()