@@ -18,9 +18,9 @@ function benchmark_nn_primal(
18
18
results = Vector {Tuple{String,String,Float64,Float64,Float64}} ()
19
19
20
20
# Only XLA
21
- compiled_fwd_xla = @compile sync = true compile_options = Reactant. DefaultXLACompileOptions () simple_mse_loss (
22
- model, x, z, ps, st
23
- )
21
+ compiled_fwd_xla = @compile compile_options = Reactant. DefaultXLACompileOptions (;
22
+ sync = true
23
+ ) simple_mse_loss (model, x, z, ps, st)
24
24
bench = @benchmark $ compiled_fwd_xla ($ model, $ x, $ z, $ ps, $ st) setup = (GC. gc (true ))
25
25
push! (results, (" Primal" , " Only XLA" , median (bench). time, std (bench). time, 1.0 ))
26
26
baseline = median (bench). time
@@ -41,8 +41,8 @@ function benchmark_nn_primal(
41
41
42
42
# Disable Scatter
43
43
if disable_scatter_gather_bench
44
- compiled_fwd_no_scatter = @compile sync = true compile_options = CompileOptions (;
45
- disable_scatter_gather_optimization_passes= true
44
+ compiled_fwd_no_scatter = @compile compile_options = CompileOptions (;
45
+ disable_scatter_gather_optimization_passes= true , sync = true
46
46
) simple_mse_loss (model, x, z, ps, st)
47
47
bench = @benchmark $ compiled_fwd_no_scatter ($ model, $ x, $ z, $ ps, $ st) setup = (GC. gc (
48
48
true
@@ -62,8 +62,8 @@ function benchmark_nn_primal(
62
62
63
63
# Disable Pad
64
64
if disable_pad_bench
65
- compiled_fwd_no_pad = @compile sync = true compile_options = CompileOptions (;
66
- disable_pad_optimization_passes= true
65
+ compiled_fwd_no_pad = @compile compile_options = CompileOptions (;
66
+ disable_pad_optimization_passes= true , sync = true
67
67
) simple_mse_loss (model, x, z, ps, st)
68
68
bench = @benchmark $ compiled_fwd_no_pad ($ model, $ x, $ z, $ ps, $ st) setup = (GC. gc (
69
69
true
@@ -83,9 +83,10 @@ function benchmark_nn_primal(
83
83
84
84
# Disable Scatter and Pad
85
85
if disable_scatter_gather_bench && disable_pad_bench
86
- compiled_fwd_no_scatter_pad = @compile sync = true compile_options = CompileOptions (;
86
+ compiled_fwd_no_scatter_pad = @compile compile_options = CompileOptions (;
87
87
disable_scatter_gather_optimization_passes= true ,
88
88
disable_pad_optimization_passes= true ,
89
+ sync= true ,
89
90
) simple_mse_loss (model, x, z, ps, st)
90
91
bench = @benchmark $ compiled_fwd_no_scatter_pad ($ model, $ x, $ z, $ ps, $ st) setup = (GC. gc (
91
92
true
@@ -124,9 +125,9 @@ function benchmark_nn_gradient_internal(
124
125
results = Vector {Tuple{String,String,Float64,Float64,Float64}} ()
125
126
126
127
# Only XLA
127
- compiled_grad_xla = @compile sync = true compile_options = Reactant. DefaultXLACompileOptions () simple_mse_loss_gradient (
128
- model, x, z, ps, st
129
- )
128
+ compiled_grad_xla = @compile compile_options = Reactant. DefaultXLACompileOptions (;
129
+ sync = true
130
+ ) simple_mse_loss_gradient (model, x, z, ps, st)
130
131
bench = @benchmark $ compiled_grad_xla ($ model, $ x, $ z, $ ps, $ st) setup = (GC. gc (true ))
131
132
push! (
132
133
results, (" Gradient ($mode )" , " Only XLA" , median (bench). time, std (bench). time, 1.0 )
@@ -155,8 +156,10 @@ function benchmark_nn_gradient_internal(
155
156
156
157
# Disable Scatter
157
158
if disable_scatter_gather_bench
158
- compiled_grad_no_scatter = @compile sync = true compile_options = CompileOptions (;
159
- disable_scatter_gather_optimization_passes= true , optimization_passes= mode
159
+ compiled_grad_no_scatter = @compile compile_options = CompileOptions (;
160
+ disable_scatter_gather_optimization_passes= true ,
161
+ optimization_passes= mode,
162
+ sync= true ,
160
163
) simple_mse_loss_gradient (model, x, z, ps, st)
161
164
bench = @benchmark $ compiled_grad_no_scatter ($ model, $ x, $ z, $ ps, $ st) setup = (GC. gc (
162
165
true
@@ -178,8 +181,8 @@ function benchmark_nn_gradient_internal(
178
181
179
182
# Disable Pad
180
183
if disable_pad_bench
181
- compiled_grad_no_pad = @compile sync = true compile_options = CompileOptions (;
182
- disable_pad_optimization_passes= true , optimization_passes= mode
184
+ compiled_grad_no_pad = @compile compile_options = CompileOptions (;
185
+ disable_pad_optimization_passes= true , optimization_passes= mode, sync = true
183
186
) simple_mse_loss_gradient (model, x, z, ps, st)
184
187
bench = @benchmark $ compiled_grad_no_pad ($ model, $ x, $ z, $ ps, $ st) setup = (GC. gc (
185
188
true
@@ -201,10 +204,11 @@ function benchmark_nn_gradient_internal(
201
204
202
205
# Disable Pad and Scatter
203
206
if disable_scatter_gather_bench && disable_pad_bench
204
- compiled_grad_no_scatter_no_pad = @compile sync = true compile_options = CompileOptions (;
207
+ compiled_grad_no_scatter_no_pad = @compile compile_options = CompileOptions (;
205
208
disable_scatter_gather_optimization_passes= true ,
206
209
disable_pad_optimization_passes= true ,
207
210
optimization_passes= mode,
211
+ sync= true ,
208
212
) simple_mse_loss_gradient (model, x, z, ps, st)
209
213
bench = @benchmark $ compiled_grad_no_scatter_no_pad ($ model, $ x, $ z, $ ps, $ st) setup = (GC. gc (
210
214
true
0 commit comments