@@ -125,10 +125,13 @@ class PipelineConfig:
125
125
emb_lookup_stream (str): The stream to use for embedding lookups.
126
126
Only used by certain pipeline types (e.g., "fused").
127
127
Default is "data_dist".
128
+ apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
129
+ Default is False.
128
130
"""
129
131
130
132
pipeline : str = "base"
131
133
emb_lookup_stream : str = "data_dist"
134
+ apply_jit : bool = False
132
135
133
136
134
137
@dataclass
@@ -137,6 +140,7 @@ class ModelSelectionConfig:
137
140
138
141
# Common config for all model types
139
142
batch_size : int = 8192
143
+ batch_sizes : Optional [List [int ]] = None
140
144
num_float_features : int = 10
141
145
feature_pooling_avg : int = 10
142
146
use_offsets : bool = False
@@ -189,6 +193,7 @@ def main(
189
193
model_config = create_model_config (
190
194
model_name = model_selection .model_name ,
191
195
batch_size = model_selection .batch_size ,
196
+ batch_sizes = model_selection .batch_sizes ,
192
197
num_float_features = model_selection .num_float_features ,
193
198
feature_pooling_avg = model_selection .feature_pooling_avg ,
194
199
use_offsets = model_selection .use_offsets ,
@@ -255,6 +260,15 @@ def runner(
255
260
compute_device = ctx .device .type ,
256
261
)
257
262
263
+ batch_sizes = model_config .batch_sizes
264
+
265
+ if batch_sizes is None :
266
+ batch_sizes = [model_config .batch_size ] * run_option .num_batches
267
+ else :
268
+ assert (
269
+ len (batch_sizes ) == run_option .num_batches
270
+ ), "The length of batch_sizes must match the number of batches."
271
+
258
272
# Create a planner for sharding based on the specified type
259
273
planner = generate_planner (
260
274
planner_type = run_option .planner_type ,
@@ -263,16 +277,15 @@ def runner(
263
277
weighted_tables = weighted_tables ,
264
278
sharding_type = run_option .sharding_type ,
265
279
compute_kernel = run_option .compute_kernel ,
266
- num_batches = run_option .num_batches ,
267
- batch_size = model_config .batch_size ,
280
+ batch_sizes = batch_sizes ,
268
281
pooling_factors = run_option .pooling_factors ,
269
282
num_poolings = run_option .num_poolings ,
270
283
)
271
284
bench_inputs = generate_data (
272
285
tables = tables ,
273
286
weighted_tables = weighted_tables ,
274
287
model_config = model_config ,
275
- num_batches = run_option . num_batches ,
288
+ batch_sizes = batch_sizes ,
276
289
)
277
290
278
291
sharded_model , optimizer = generate_sharded_model_and_optimizer (
@@ -288,14 +301,6 @@ def runner(
288
301
},
289
302
planner = planner ,
290
303
)
291
- pipeline = generate_pipeline (
292
- pipeline_type = pipeline_config .pipeline ,
293
- emb_lookup_stream = pipeline_config .emb_lookup_stream ,
294
- model = sharded_model ,
295
- opt = optimizer ,
296
- device = ctx .device ,
297
- )
298
- pipeline .progress (iter (bench_inputs ))
299
304
300
305
def _func_to_benchmark (
301
306
bench_inputs : List [ModelInput ],
@@ -309,20 +314,47 @@ def _func_to_benchmark(
309
314
except StopIteration :
310
315
break
311
316
312
- result = benchmark_func (
313
- name = type (pipeline ).__name__ ,
314
- bench_inputs = bench_inputs , # pyre-ignore
315
- prof_inputs = bench_inputs , # pyre-ignore
316
- num_benchmarks = 5 ,
317
- num_profiles = 2 ,
318
- profile_dir = run_option .profile ,
319
- world_size = run_option .world_size ,
320
- func_to_benchmark = _func_to_benchmark ,
321
- benchmark_func_kwargs = {"model" : sharded_model , "pipeline" : pipeline },
322
- rank = rank ,
317
+ # Run comparison if apply_jit is True, otherwise run single benchmark
318
+ jit_configs = (
319
+ [(True , "WithJIT" ), (False , "WithoutJIT" )]
320
+ if pipeline_config .apply_jit
321
+ else [(False , "" )]
323
322
)
323
+ results = []
324
+
325
+ for apply_jit , jit_suffix in jit_configs :
326
+ pipeline = generate_pipeline (
327
+ pipeline_type = pipeline_config .pipeline ,
328
+ emb_lookup_stream = pipeline_config .emb_lookup_stream ,
329
+ model = sharded_model ,
330
+ opt = optimizer ,
331
+ device = ctx .device ,
332
+ apply_jit = apply_jit ,
333
+ )
334
+ pipeline .progress (iter (bench_inputs ))
335
+
336
+ name = (
337
+ f"{ type (pipeline ).__name__ } { jit_suffix } "
338
+ if jit_suffix
339
+ else type (pipeline ).__name__
340
+ )
341
+ result = benchmark_func (
342
+ name = name ,
343
+ bench_inputs = bench_inputs , # pyre-ignore
344
+ prof_inputs = bench_inputs , # pyre-ignore
345
+ num_benchmarks = 5 ,
346
+ num_profiles = 2 ,
347
+ profile_dir = run_option .profile ,
348
+ world_size = run_option .world_size ,
349
+ func_to_benchmark = _func_to_benchmark ,
350
+ benchmark_func_kwargs = {"model" : sharded_model , "pipeline" : pipeline },
351
+ rank = rank ,
352
+ )
353
+ results .append (result )
354
+
324
355
if rank == 0 :
325
- print (result )
356
+ for result in results :
357
+ print (result )
326
358
327
359
328
360
if __name__ == "__main__" :
0 commit comments