Skip to content

Commit 70f90b7

Browse files
committed
fix: correct usage of sync
1 parent 81ba6fc commit 70f90b7

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

perf/common.jl

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ function benchmark_nn_primal(
1818
results = Vector{Tuple{String,String,Float64,Float64,Float64}}()
1919

2020
# Only XLA
21-
compiled_fwd_xla = @compile sync = true compile_options = Reactant.DefaultXLACompileOptions() simple_mse_loss(
22-
model, x, z, ps, st
23-
)
21+
compiled_fwd_xla = @compile compile_options = Reactant.DefaultXLACompileOptions(;
22+
sync=true
23+
) simple_mse_loss(model, x, z, ps, st)
2424
bench = @benchmark $compiled_fwd_xla($model, $x, $z, $ps, $st) setup = (GC.gc(true))
2525
push!(results, ("Primal", "Only XLA", median(bench).time, std(bench).time, 1.0))
2626
baseline = median(bench).time
@@ -41,8 +41,8 @@ function benchmark_nn_primal(
4141

4242
# Disable Scatter
4343
if disable_scatter_gather_bench
44-
compiled_fwd_no_scatter = @compile sync = true compile_options = CompileOptions(;
45-
disable_scatter_gather_optimization_passes=true
44+
compiled_fwd_no_scatter = @compile compile_options = CompileOptions(;
45+
disable_scatter_gather_optimization_passes=true, sync=true
4646
) simple_mse_loss(model, x, z, ps, st)
4747
bench = @benchmark $compiled_fwd_no_scatter($model, $x, $z, $ps, $st) setup = (GC.gc(
4848
true
@@ -62,8 +62,8 @@ function benchmark_nn_primal(
6262

6363
# Disable Pad
6464
if disable_pad_bench
65-
compiled_fwd_no_pad = @compile sync = true compile_options = CompileOptions(;
66-
disable_pad_optimization_passes=true
65+
compiled_fwd_no_pad = @compile compile_options = CompileOptions(;
66+
disable_pad_optimization_passes=true, sync=true
6767
) simple_mse_loss(model, x, z, ps, st)
6868
bench = @benchmark $compiled_fwd_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
6969
true
@@ -83,9 +83,10 @@ function benchmark_nn_primal(
8383

8484
# Disable Scatter and Pad
8585
if disable_scatter_gather_bench && disable_pad_bench
86-
compiled_fwd_no_scatter_pad = @compile sync = true compile_options = CompileOptions(;
86+
compiled_fwd_no_scatter_pad = @compile compile_options = CompileOptions(;
8787
disable_scatter_gather_optimization_passes=true,
8888
disable_pad_optimization_passes=true,
89+
sync=true,
8990
) simple_mse_loss(model, x, z, ps, st)
9091
bench = @benchmark $compiled_fwd_no_scatter_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
9192
true
@@ -124,9 +125,9 @@ function benchmark_nn_gradient_internal(
124125
results = Vector{Tuple{String,String,Float64,Float64,Float64}}()
125126

126127
# Only XLA
127-
compiled_grad_xla = @compile sync = true compile_options = Reactant.DefaultXLACompileOptions() simple_mse_loss_gradient(
128-
model, x, z, ps, st
129-
)
128+
compiled_grad_xla = @compile compile_options = Reactant.DefaultXLACompileOptions(;
129+
sync=true
130+
) simple_mse_loss_gradient(model, x, z, ps, st)
130131
bench = @benchmark $compiled_grad_xla($model, $x, $z, $ps, $st) setup = (GC.gc(true))
131132
push!(
132133
results, ("Gradient ($mode)", "Only XLA", median(bench).time, std(bench).time, 1.0)
@@ -155,8 +156,10 @@ function benchmark_nn_gradient_internal(
155156

156157
# Disable Scatter
157158
if disable_scatter_gather_bench
158-
compiled_grad_no_scatter = @compile sync = true compile_options = CompileOptions(;
159-
disable_scatter_gather_optimization_passes=true, optimization_passes=mode
159+
compiled_grad_no_scatter = @compile compile_options = CompileOptions(;
160+
disable_scatter_gather_optimization_passes=true,
161+
optimization_passes=mode,
162+
sync=true,
160163
) simple_mse_loss_gradient(model, x, z, ps, st)
161164
bench = @benchmark $compiled_grad_no_scatter($model, $x, $z, $ps, $st) setup = (GC.gc(
162165
true
@@ -178,8 +181,8 @@ function benchmark_nn_gradient_internal(
178181

179182
# Disable Pad
180183
if disable_pad_bench
181-
compiled_grad_no_pad = @compile sync = true compile_options = CompileOptions(;
182-
disable_pad_optimization_passes=true, optimization_passes=mode
184+
compiled_grad_no_pad = @compile compile_options = CompileOptions(;
185+
disable_pad_optimization_passes=true, optimization_passes=mode, sync=true
183186
) simple_mse_loss_gradient(model, x, z, ps, st)
184187
bench = @benchmark $compiled_grad_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
185188
true
@@ -201,10 +204,11 @@ function benchmark_nn_gradient_internal(
201204

202205
# Disable Pad and Scatter
203206
if disable_scatter_gather_bench && disable_pad_bench
204-
compiled_grad_no_scatter_no_pad = @compile sync = true compile_options = CompileOptions(;
207+
compiled_grad_no_scatter_no_pad = @compile compile_options = CompileOptions(;
205208
disable_scatter_gather_optimization_passes=true,
206209
disable_pad_optimization_passes=true,
207210
optimization_passes=mode,
211+
sync=true,
208212
) simple_mse_loss_gradient(model, x, z, ps, st)
209213
bench = @benchmark $compiled_grad_no_scatter_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
210214
true

0 commit comments

Comments
 (0)