Skip to content

perf: run ablations for the paper #1408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 257 additions & 0 deletions perf/common.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
using BenchmarkTools: @benchmark
using Reactant, Enzyme, PrettyTables, Statistics

function simple_mse_loss(model, x, z, ps, st)
y, _ = Lux.apply(model, x, ps, st)
return MSELoss()(y, z)
end

function simple_mse_loss_gradient(model, x, z, ps, st)
return Enzyme.gradient(
Reverse, simple_mse_loss, Const(model), Const(x), Const(z), ps, Const(st)
)
end

function benchmark_nn_primal(
model, x, z, ps, st; disable_scatter_gather_bench=true, disable_pad_bench=true
)
results = Vector{Tuple{String,String,Float64,Float64,Float64}}()

# Only XLA
compiled_fwd_xla = @compile compile_options = Reactant.DefaultXLACompileOptions(;
sync=true
) simple_mse_loss(model, x, z, ps, st)
bench = @benchmark $compiled_fwd_xla($model, $x, $z, $ps, $st) setup = (GC.gc(true))
push!(results, ("Primal", "Only XLA", median(bench).time, std(bench).time, 1.0))
baseline = median(bench).time

# Default
compiled_fwd = @compile sync = true simple_mse_loss(model, x, z, ps, st)
bench = @benchmark $compiled_fwd($model, $x, $z, $ps, $st) setup = (GC.gc(true))
push!(
results,
(
"Primal",
"All",
median(bench).time,
std(bench).time,
median(bench).time / baseline,
),
)

# Disable Scatter
if disable_scatter_gather_bench
compiled_fwd_no_scatter = @compile compile_options = CompileOptions(;
disable_scatter_gather_optimization_passes=true, sync=true
) simple_mse_loss(model, x, z, ps, st)
bench = @benchmark $compiled_fwd_no_scatter($model, $x, $z, $ps, $st) setup = (GC.gc(
true
))

push!(
results,
(
"Primal",
"No Scatter/Gather Optimizations",
median(bench).time,
std(bench).time,
median(bench).time / baseline,
),
)
end

# Disable Pad
if disable_pad_bench
compiled_fwd_no_pad = @compile compile_options = CompileOptions(;
disable_pad_optimization_passes=true, sync=true
) simple_mse_loss(model, x, z, ps, st)
bench = @benchmark $compiled_fwd_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
true
))

push!(
results,
(
"Primal",
"No Pad Optimizations",
median(bench).time,
std(bench).time,
median(bench).time / baseline,
),
)
end

# Disable Scatter and Pad
if disable_scatter_gather_bench && disable_pad_bench
compiled_fwd_no_scatter_pad = @compile compile_options = CompileOptions(;
disable_scatter_gather_optimization_passes=true,
disable_pad_optimization_passes=true,
sync=true,
) simple_mse_loss(model, x, z, ps, st)
bench = @benchmark $compiled_fwd_no_scatter_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
true
))

push!(
results,
(
"Primal",
"No Scatter/Gather and Pad Optimizations",
median(bench).time,
std(bench).time,
median(bench).time / baseline,
),
)
end

sort!(results; by=x -> x[3])
return results
end

function benchmark_nn_gradient(model, x, z, ps, st; kwargs...)
return vcat(
[
benchmark_nn_gradient_internal(model, x, z, ps, st, mode; kwargs...) for
mode in [:all, :before_enzyme, :after_enzyme]
]...,
)
end

function benchmark_nn_gradient_internal(
model, x, z, ps, st, mode; disable_scatter_gather_bench=true, disable_pad_bench=true
)
@info "Benchmarking gradient with mode: $(Meta.quot(mode))"

results = Vector{Tuple{String,String,Float64,Float64,Float64}}()

# Only XLA
compiled_grad_xla = @compile compile_options = Reactant.DefaultXLACompileOptions(;
sync=true
) simple_mse_loss_gradient(model, x, z, ps, st)
bench = @benchmark $compiled_grad_xla($model, $x, $z, $ps, $st) setup = (GC.gc(true))
push!(
results, ("Gradient ($mode)", "Only XLA", median(bench).time, std(bench).time, 1.0)
)
baseline = median(bench).time

display(results[end])

# Default
compiled_grad = @compile sync = true optimize = mode simple_mse_loss_gradient(
model, x, z, ps, st
)
bench = @benchmark $compiled_grad($model, $x, $z, $ps, $st) setup = (GC.gc(true))
push!(
results,
(
"Gradient ($mode)",
"All",
median(bench).time,
std(bench).time,
median(bench).time / baseline,
),
)

display(results[end])

# Disable Scatter
if disable_scatter_gather_bench
compiled_grad_no_scatter = @compile compile_options = CompileOptions(;
disable_scatter_gather_optimization_passes=true,
optimization_passes=mode,
sync=true,
) simple_mse_loss_gradient(model, x, z, ps, st)
bench = @benchmark $compiled_grad_no_scatter($model, $x, $z, $ps, $st) setup = (GC.gc(
true
))

push!(
results,
(
"Gradient ($mode)",
"No Scatter/Gather Optimizations",
median(bench).time,
std(bench).time,
median(bench).time / baseline,
),
)

display(results[end])
end

# Disable Pad
if disable_pad_bench
compiled_grad_no_pad = @compile compile_options = CompileOptions(;
disable_pad_optimization_passes=true, optimization_passes=mode, sync=true
) simple_mse_loss_gradient(model, x, z, ps, st)
bench = @benchmark $compiled_grad_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
true
))

push!(
results,
(
"Gradient ($mode)",
"No Pad Optimizations",
median(bench).time,
std(bench).time,
median(bench).time / baseline,
),
)

display(results[end])
end

# Disable Pad and Scatter
if disable_scatter_gather_bench && disable_pad_bench
compiled_grad_no_scatter_no_pad = @compile compile_options = CompileOptions(;
disable_scatter_gather_optimization_passes=true,
disable_pad_optimization_passes=true,
optimization_passes=mode,
sync=true,
) simple_mse_loss_gradient(model, x, z, ps, st)
bench = @benchmark $compiled_grad_no_scatter_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
true
))

push!(
results,
(
"Gradient ($mode)",
"No Scatter/Gather/Pad Optimizations",
median(bench).time,
std(bench).time,
median(bench).time / baseline,
),
)

display(results[end])
end

sort!(results; by=x -> x[3])
return results
end

function pretty_print_table(results)
header = (
["Mode", "Optimization Passes", "Median Time", "Std. Dev. Time", "Relative Timing"],
["", "", "s", "s", "Time / XLA Time"],
)

results = copy(results)
results[:, 3] ./= 1e9
results[:, 4] ./= 1e9

hl_r = Highlighter((data, i, j) -> j == 5 && data[i, j] > 1.0, crayon"bold red")
hl_g = Highlighter((data, i, j) -> j == 5 && data[i, j] < 1.0, crayon"bold green")
display(
pretty_table(
results;
header,
header_crayon=crayon"yellow bold",
highlighters=(hl_r, hl_g),
tf=tf_unicode_rounded,
),
)
return nothing
end
21 changes: 21 additions & 0 deletions perf/neuraloperators/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"

[sources]
Reactant = {path = "../../"}

[compat]
BenchmarkTools = "1.6"
CSV = "0.10.15"
Lux = "1.13.4"
NeuralOperators = "0.6"
PrettyTables = "2.4.0"
Random = "1.11"
julia = "1.11"
102 changes: 102 additions & 0 deletions perf/neuraloperators/main.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using NeuralOperators, Lux, Random

include("../common.jl")

const xdev = reactant_device(; force=true)

function run_deeponet_benchmarks()
@info "Running DeepONet benchmarks"

model = DeepONet(;
branch=(64, ntuple(Returns(256), 4)..., 16),
trunk=(1, ntuple(Returns(256), 4)..., 16),
branch_activation=gelu,
trunk_activation=gelu,
)
ps, st = xdev(Lux.setup(Random.default_rng(), model))
u = xdev(rand(Float32, 64, 1024))
y = xdev(rand(Float32, 1, 128))
z = xdev(rand(Float32, 128, 1024))

primal_timings = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
benchmark_nn_primal(
model,
(u, y),
z,
ps,
st;
disable_scatter_gather_bench=true,
disable_pad_bench=true,
)
end

gradient_timings = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
benchmark_nn_gradient(
model,
(u, y),
z,
ps,
st;
disable_scatter_gather_bench=true,
disable_pad_bench=true,
)
end

timings = vcat(primal_timings, gradient_timings)
pretty_print_table(permutedims(hcat([[t...] for t in timings]...), (2, 1)))

return nothing
end

function run_fno_benchmarks()
@info "Running FNO benchmarks"

model = FourierNeuralOperator((16, 16), 3, 8, 64)
ps, st = xdev(Lux.setup(Random.default_rng(), model))
x = xdev(rand(Float32, 64, 64, 1, 256))
z = xdev(rand(Float32, 64, 64, 8, 256))

primal_timings = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
benchmark_nn_primal(
model,
x,
z,
ps,
st;
disable_scatter_gather_bench=true,
disable_pad_bench=true,
)
end

gradient_timings = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
benchmark_nn_gradient(
model,
x,
z,
ps,
st;
disable_scatter_gather_bench=true,
disable_pad_bench=true,
)
end

timings = vcat(primal_timings, gradient_timings)
pretty_print_table(permutedims(hcat([[t...] for t in timings]...), (2, 1)))

return nothing
end

run_deeponet_benchmarks()
run_fno_benchmarks()