Skip to content

Commit 93a186e

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Add JIT and Variable Batch Support to Benchmark
Summary: This update introduces an option to apply Just-In-Time (JIT) compilation in the training pipeline configuration for performance comparison. It also adds support for variable batch sizes, including the generation of Variable Batch KeyedJaggedTensor (VB-KJT). Differential Revision: D76928208
1 parent 7eee82f commit 93a186e

File tree

2 files changed

+77
-37
lines changed

2 files changed

+77
-37
lines changed

torchrec/distributed/benchmark/benchmark_pipeline_utils.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class BaseModelConfig(ABC):
5353

5454
# Common parameters for all model types
5555
batch_size: int
56+
batch_sizes: Optional[List[int]]
5657
num_float_features: int
5758
feature_pooling_avg: int
5859
use_offsets: bool
@@ -274,6 +275,7 @@ def generate_pipeline(
274275
model: nn.Module,
275276
opt: torch.optim.Optimizer,
276277
device: torch.device,
278+
apply_jit: bool = False,
277279
) -> Union[TrainPipelineBase, TrainPipelineSparseDist]:
278280
"""
279281
Generate a training pipeline instance based on the configuration.
@@ -294,6 +296,8 @@ def generate_pipeline(
294296
model (nn.Module): The model to be trained.
295297
opt (torch.optim.Optimizer): The optimizer to use for training.
296298
device (torch.device): The device to run the training on.
299+
apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
300+
Default is False.
297301
298302
Returns:
299303
Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the
@@ -315,20 +319,27 @@ def generate_pipeline(
315319

316320
if pipeline_type == "semi":
317321
return TrainPipelineSemiSync(
318-
model=model, optimizer=opt, device=device, start_batch=0
322+
model=model,
323+
optimizer=opt,
324+
device=device,
325+
start_batch=0,
326+
apply_jit=apply_jit,
319327
)
320328
elif pipeline_type == "fused":
321329
return TrainPipelineFusedSparseDist(
322330
model=model,
323331
optimizer=opt,
324332
device=device,
325333
emb_lookup_stream=emb_lookup_stream,
334+
apply_jit=apply_jit,
326335
)
327-
elif pipeline_type in _pipeline_cls:
328-
Pipeline = _pipeline_cls[pipeline_type]
329-
return Pipeline(model=model, optimizer=opt, device=device)
336+
elif pipeline_type == "base":
337+
assert apply_jit is False, "JIT is not supported for base pipeline"
338+
339+
return TrainPipelineBase(model=model, optimizer=opt, device=device)
330340
else:
331-
raise RuntimeError(f"unknown pipeline option {pipeline_type}")
341+
Pipeline = _pipeline_cls[pipeline_type]
342+
return Pipeline(model=model, optimizer=opt, device=device, apply_jit=apply_jit) # pyre-ignore[28]
332343

333344

334345
def generate_planner(
@@ -338,8 +349,7 @@ def generate_planner(
338349
weighted_tables: Optional[List[EmbeddingBagConfig]],
339350
sharding_type: ShardingType,
340351
compute_kernel: EmbeddingComputeKernel,
341-
num_batches: int,
342-
batch_size: int,
352+
batch_sizes: List[int],
343353
pooling_factors: Optional[List[float]],
344354
num_poolings: Optional[List[float]],
345355
) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]:
@@ -353,8 +363,7 @@ def generate_planner(
353363
weighted_tables: List of weighted embedding tables
354364
sharding_type: Strategy for sharding embedding tables
355365
compute_kernel: Compute kernel to use for embedding tables
356-
num_batches: Number of batches to process
357-
batch_size: Size of each batch
366+
batch_sizes: Sizes of each batch
358367
pooling_factors: Pooling factors for each feature of the table
359368
num_poolings: Number of poolings for each feature of the table
360369
@@ -366,15 +375,14 @@ def generate_planner(
366375
"""
367376
# Create parameter constraints for tables
368377
constraints = {}
378+
num_batches = len(batch_sizes)
369379

370380
if pooling_factors is None:
371381
pooling_factors = [POOLING_FACTOR] * num_batches
372382

373383
if num_poolings is None:
374384
num_poolings = [NUM_POOLINGS] * num_batches
375385

376-
batch_sizes = [batch_size] * num_batches
377-
378386
assert (
379387
len(pooling_factors) == num_batches and len(num_poolings) == num_batches
380388
), "The length of pooling_factors and num_poolings must match the number of batches."
@@ -472,7 +480,7 @@ def generate_data(
472480
tables: List[EmbeddingBagConfig],
473481
weighted_tables: List[EmbeddingBagConfig],
474482
model_config: BaseModelConfig,
475-
num_batches: int,
483+
batch_sizes: List[int],
476484
) -> List[ModelInput]:
477485
"""
478486
Generate model input data for benchmarking.
@@ -490,7 +498,7 @@ def generate_data(
490498

491499
return [
492500
ModelInput.generate(
493-
batch_size=model_config.batch_size,
501+
batch_size=batch_size,
494502
tables=tables,
495503
weighted_tables=weighted_tables,
496504
num_float_features=model_config.num_float_features,
@@ -508,5 +516,5 @@ def generate_data(
508516
),
509517
pin_memory=model_config.pin_memory,
510518
)
511-
for _ in range(num_batches)
519+
for batch_size in batch_sizes
512520
]

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,13 @@ class PipelineConfig:
125125
emb_lookup_stream (str): The stream to use for embedding lookups.
126126
Only used by certain pipeline types (e.g., "fused").
127127
Default is "data_dist".
128+
apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
129+
Default is False.
128130
"""
129131

130132
pipeline: str = "base"
131133
emb_lookup_stream: str = "data_dist"
134+
apply_jit: bool = False
132135

133136

134137
@dataclass
@@ -137,6 +140,7 @@ class ModelSelectionConfig:
137140

138141
# Common config for all model types
139142
batch_size: int = 8192
143+
batch_sizes: Optional[List[int]] = None
140144
num_float_features: int = 10
141145
feature_pooling_avg: int = 10
142146
use_offsets: bool = False
@@ -189,6 +193,7 @@ def main(
189193
model_config = create_model_config(
190194
model_name=model_selection.model_name,
191195
batch_size=model_selection.batch_size,
196+
batch_sizes=model_selection.batch_sizes,
192197
num_float_features=model_selection.num_float_features,
193198
feature_pooling_avg=model_selection.feature_pooling_avg,
194199
use_offsets=model_selection.use_offsets,
@@ -255,6 +260,15 @@ def runner(
255260
compute_device=ctx.device.type,
256261
)
257262

263+
batch_sizes = model_config.batch_sizes
264+
265+
if batch_sizes is None:
266+
batch_sizes = [model_config.batch_size] * run_option.num_batches
267+
else:
268+
assert (
269+
len(batch_sizes) == run_option.num_batches
270+
), "The length of batch_sizes must match the number of batches."
271+
258272
# Create a planner for sharding based on the specified type
259273
planner = generate_planner(
260274
planner_type=run_option.planner_type,
@@ -263,16 +277,15 @@ def runner(
263277
weighted_tables=weighted_tables,
264278
sharding_type=run_option.sharding_type,
265279
compute_kernel=run_option.compute_kernel,
266-
num_batches=run_option.num_batches,
267-
batch_size=model_config.batch_size,
280+
batch_sizes=batch_sizes,
268281
pooling_factors=run_option.pooling_factors,
269282
num_poolings=run_option.num_poolings,
270283
)
271284
bench_inputs = generate_data(
272285
tables=tables,
273286
weighted_tables=weighted_tables,
274287
model_config=model_config,
275-
num_batches=run_option.num_batches,
288+
batch_sizes=batch_sizes,
276289
)
277290

278291
sharded_model, optimizer = generate_sharded_model_and_optimizer(
@@ -288,14 +301,6 @@ def runner(
288301
},
289302
planner=planner,
290303
)
291-
pipeline = generate_pipeline(
292-
pipeline_type=pipeline_config.pipeline,
293-
emb_lookup_stream=pipeline_config.emb_lookup_stream,
294-
model=sharded_model,
295-
opt=optimizer,
296-
device=ctx.device,
297-
)
298-
pipeline.progress(iter(bench_inputs))
299304

300305
def _func_to_benchmark(
301306
bench_inputs: List[ModelInput],
@@ -309,20 +314,47 @@ def _func_to_benchmark(
309314
except StopIteration:
310315
break
311316

312-
result = benchmark_func(
313-
name=type(pipeline).__name__,
314-
bench_inputs=bench_inputs, # pyre-ignore
315-
prof_inputs=bench_inputs, # pyre-ignore
316-
num_benchmarks=5,
317-
num_profiles=2,
318-
profile_dir=run_option.profile,
319-
world_size=run_option.world_size,
320-
func_to_benchmark=_func_to_benchmark,
321-
benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline},
322-
rank=rank,
317+
# Run comparison if apply_jit is True, otherwise run single benchmark
318+
jit_configs = (
319+
[(True, "WithJIT"), (False, "WithoutJIT")]
320+
if pipeline_config.apply_jit
321+
else [(False, "")]
323322
)
323+
results = []
324+
325+
for apply_jit, jit_suffix in jit_configs:
326+
pipeline = generate_pipeline(
327+
pipeline_type=pipeline_config.pipeline,
328+
emb_lookup_stream=pipeline_config.emb_lookup_stream,
329+
model=sharded_model,
330+
opt=optimizer,
331+
device=ctx.device,
332+
apply_jit=apply_jit,
333+
)
334+
pipeline.progress(iter(bench_inputs))
335+
336+
name = (
337+
f"{type(pipeline).__name__}{jit_suffix}"
338+
if jit_suffix
339+
else type(pipeline).__name__
340+
)
341+
result = benchmark_func(
342+
name=name,
343+
bench_inputs=bench_inputs, # pyre-ignore
344+
prof_inputs=bench_inputs, # pyre-ignore
345+
num_benchmarks=5,
346+
num_profiles=2,
347+
profile_dir=run_option.profile,
348+
world_size=run_option.world_size,
349+
func_to_benchmark=_func_to_benchmark,
350+
benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline},
351+
rank=rank,
352+
)
353+
results.append(result)
354+
324355
if rank == 0:
325-
print(result)
356+
for result in results:
357+
print(result)
326358

327359

328360
if __name__ == "__main__":

0 commit comments

Comments
 (0)