diff --git a/torchrec/distributed/benchmark/benchmark_pipeline_utils.py b/torchrec/distributed/benchmark/benchmark_pipeline_utils.py index bd6fc9bea..162712586 100644 --- a/torchrec/distributed/benchmark/benchmark_pipeline_utils.py +++ b/torchrec/distributed/benchmark/benchmark_pipeline_utils.py @@ -62,6 +62,7 @@ class BaseModelConfig(ABC): # Common parameters for all model types batch_size: int + batch_sizes: Optional[List[int]] num_float_features: int feature_pooling_avg: int use_offsets: bool @@ -283,6 +284,7 @@ def generate_pipeline( model: nn.Module, opt: torch.optim.Optimizer, device: torch.device, + apply_jit: bool = False, ) -> Union[TrainPipelineBase, TrainPipelineSparseDist]: """ Generate a training pipeline instance based on the configuration. @@ -303,6 +305,8 @@ def generate_pipeline( model (nn.Module): The model to be trained. opt (torch.optim.Optimizer): The optimizer to use for training. device (torch.device): The device to run the training on. + apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model. + Default is False. Returns: Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the @@ -324,7 +328,11 @@ def generate_pipeline( if pipeline_type == "semi": return TrainPipelineSemiSync( - model=model, optimizer=opt, device=device, start_batch=0 + model=model, + optimizer=opt, + device=device, + start_batch=0, + apply_jit=apply_jit, ) elif pipeline_type == "fused": return TrainPipelineFusedSparseDist( @@ -332,12 +340,16 @@ def generate_pipeline( optimizer=opt, device=device, emb_lookup_stream=emb_lookup_stream, + apply_jit=apply_jit, ) - elif pipeline_type in _pipeline_cls: - Pipeline = _pipeline_cls[pipeline_type] - return Pipeline(model=model, optimizer=opt, device=device) + elif pipeline_type == "base": + assert apply_jit is False, "JIT is not supported for base pipeline" + + return TrainPipelineBase(model=model, optimizer=opt, device=device) else: - raise RuntimeError(f"unknown pipeline option {pipeline_type}") + Pipeline = _pipeline_cls[pipeline_type] + # pyre-ignore[28] + return Pipeline(model=model, optimizer=opt, device=device, apply_jit=apply_jit) def generate_planner( @@ -347,8 +359,7 @@ def generate_planner( weighted_tables: Optional[List[EmbeddingBagConfig]], sharding_type: ShardingType, compute_kernel: EmbeddingComputeKernel, - num_batches: int, - batch_size: int, + batch_sizes: List[int], pooling_factors: Optional[List[float]], num_poolings: Optional[List[float]], ) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]: @@ -362,8 +373,7 @@ def generate_planner( weighted_tables: List of weighted embedding tables sharding_type: Strategy for sharding embedding tables compute_kernel: Compute kernel to use for embedding tables - num_batches: Number of batches to process - batch_size: Size of each batch + batch_sizes: Sizes of each batch pooling_factors: Pooling factors for each feature of the table num_poolings: Number of poolings for each feature of the table @@ -375,6 +385,7 @@ def generate_planner( """ # Create parameter constraints for tables constraints = {} + num_batches = len(batch_sizes) if pooling_factors is None: pooling_factors = [POOLING_FACTOR] * num_batches @@ -382,8 +393,6 @@ def generate_planner( if num_poolings is None: num_poolings = [NUM_POOLINGS] * num_batches - batch_sizes = [batch_size] * num_batches - assert ( len(pooling_factors) == num_batches and len(num_poolings) == num_batches ), "The length of pooling_factors and num_poolings must match the number of batches." @@ -481,7 +490,7 @@ def generate_data( tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], model_config: BaseModelConfig, - num_batches: int, + batch_sizes: List[int], ) -> List[ModelInput]: """ Generate model input data for benchmarking. @@ -499,7 +508,7 @@ def generate_data( return [ ModelInput.generate( - batch_size=model_config.batch_size, + batch_size=batch_size, tables=tables, weighted_tables=weighted_tables, num_float_features=model_config.num_float_features, @@ -517,5 +526,5 @@ def generate_data( ), pin_memory=model_config.pin_memory, ) - for _ in range(num_batches) + for batch_size in batch_sizes ] diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py index f9aacb30b..7bfd5bc76 100644 --- a/torchrec/distributed/benchmark/benchmark_train_pipeline.py +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -136,10 +136,13 @@ class PipelineConfig: emb_lookup_stream (str): The stream to use for embedding lookups. Only used by certain pipeline types (e.g., "fused"). Default is "data_dist". + apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model. + Default is False. """ pipeline: str = "base" emb_lookup_stream: str = "data_dist" + apply_jit: bool = False @dataclass @@ -148,6 +151,7 @@ class ModelSelectionConfig: # Common config for all model types batch_size: int = 8192 + batch_sizes: Optional[List[int]] = None num_float_features: int = 10 feature_pooling_avg: int = 10 use_offsets: bool = False @@ -200,6 +204,7 @@ def main( model_config = create_model_config( model_name=model_selection.model_name, batch_size=model_selection.batch_size, + batch_sizes=model_selection.batch_sizes, num_float_features=model_selection.num_float_features, feature_pooling_avg=model_selection.feature_pooling_avg, use_offsets=model_selection.use_offsets, @@ -266,6 +271,15 @@ def runner( compute_device=ctx.device.type, ) + batch_sizes = model_config.batch_sizes + + if batch_sizes is None: + batch_sizes = [model_config.batch_size] * run_option.num_batches + else: + assert ( + len(batch_sizes) == run_option.num_batches + ), "The length of batch_sizes must match the number of batches." + # Create a planner for sharding based on the specified type planner = generate_planner( planner_type=run_option.planner_type, @@ -274,8 +288,7 @@ def runner( weighted_tables=weighted_tables, sharding_type=run_option.sharding_type, compute_kernel=run_option.compute_kernel, - num_batches=run_option.num_batches, - batch_size=model_config.batch_size, + batch_sizes=batch_sizes, pooling_factors=run_option.pooling_factors, num_poolings=run_option.num_poolings, ) @@ -283,7 +296,7 @@ def runner( tables=tables, weighted_tables=weighted_tables, model_config=model_config, - num_batches=run_option.num_batches, + batch_sizes=batch_sizes, ) sharded_model, optimizer = generate_sharded_model_and_optimizer( @@ -299,14 +312,6 @@ def runner( }, planner=planner, ) - pipeline = generate_pipeline( - pipeline_type=pipeline_config.pipeline, - emb_lookup_stream=pipeline_config.emb_lookup_stream, - model=sharded_model, - opt=optimizer, - device=ctx.device, - ) - pipeline.progress(iter(bench_inputs)) def _func_to_benchmark( bench_inputs: List[ModelInput], @@ -320,20 +325,47 @@ def _func_to_benchmark( except StopIteration: break - result = benchmark_func( - name=type(pipeline).__name__, - bench_inputs=bench_inputs, # pyre-ignore - prof_inputs=bench_inputs, # pyre-ignore - num_benchmarks=5, - num_profiles=2, - profile_dir=run_option.profile, - world_size=run_option.world_size, - func_to_benchmark=_func_to_benchmark, - benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline}, - rank=rank, + # Run comparison if apply_jit is True, otherwise run single benchmark + jit_configs = ( + [(True, "WithJIT"), (False, "WithoutJIT")] + if pipeline_config.apply_jit + else [(False, "")] ) + results = [] + + for apply_jit, jit_suffix in jit_configs: + pipeline = generate_pipeline( + pipeline_type=pipeline_config.pipeline, + emb_lookup_stream=pipeline_config.emb_lookup_stream, + model=sharded_model, + opt=optimizer, + device=ctx.device, + apply_jit=apply_jit, + ) + pipeline.progress(iter(bench_inputs)) + + name = ( + f"{type(pipeline).__name__}{jit_suffix}" + if jit_suffix + else type(pipeline).__name__ + ) + result = benchmark_func( + name=name, + bench_inputs=bench_inputs, # pyre-ignore + prof_inputs=bench_inputs, # pyre-ignore + num_benchmarks=5, + num_profiles=2, + profile_dir=run_option.profile, + world_size=run_option.world_size, + func_to_benchmark=_func_to_benchmark, + benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline}, + rank=rank, + ) + results.append(result) + if rank == 0: - print(result) + for result in results: + print(result) if __name__ == "__main__":