From ef2f7de922ab972b02808be03244b6add76415f4 Mon Sep 17 00:00:00 2001 From: Yernar Sadybekov Date: Mon, 23 Jun 2025 22:54:31 -0700 Subject: [PATCH] Added Optimizer configuration that supports optimizer type, learning rate, momentum, and weight decay configurations. (#3107) Summary: This commit introduces enhancements to the optimizer configuration in TorchRec. It now supports specifying the optimizer type, learning rate, momentum, and weight decay. Differential Revision: D76837924 --- .../benchmark/benchmark_pipeline_utils.py | 40 +++++++++++++------ .../benchmark/benchmark_train_pipeline.py | 38 ++++++++++++++++-- 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_pipeline_utils.py b/torchrec/distributed/benchmark/benchmark_pipeline_utils.py index bd6fc9bea..e52454c63 100644 --- a/torchrec/distributed/benchmark/benchmark_pipeline_utils.py +++ b/torchrec/distributed/benchmark/benchmark_pipeline_utils.py @@ -432,7 +432,11 @@ def generate_sharded_model_and_optimizer( kernel_type: str, pg: dist.ProcessGroup, device: torch.device, - fused_params: Optional[Dict[str, Any]] = None, + fused_params: Dict[str, Any], + dense_optimizer: str, + dense_lr: float, + dense_momentum: Optional[float], + dense_weight_decay: Optional[float], planner: Optional[ Union[ EmbeddingShardingPlanner, @@ -440,13 +444,11 @@ def generate_sharded_model_and_optimizer( ] ] = None, ) -> Tuple[nn.Module, Optimizer]: - # Ensure fused_params is always a dictionary - fused_params_dict = {} if fused_params is None else fused_params sharder = TestEBCSharder( sharding_type=sharding_type, kernel_type=kernel_type, - fused_params=fused_params_dict, + fused_params=fused_params, ) sharders = [cast(ModuleSharder[nn.Module], sharder)] @@ -466,14 +468,28 @@ def generate_sharded_model_and_optimizer( sharders=sharders, plan=plan, ).to(device) - optimizer = optim.SGD( - [ - param - for name, param in sharded_model.named_parameters() - if "sparse" not in name - ], - lr=0.1, - ) + + # Get dense parameters + dense_params = [ + param + for name, param in sharded_model.named_parameters() + if "sparse" not in name + ] + + # Create optimizer based on the specified type + optimizer_class = getattr(optim, dense_optimizer) + + # Create optimizer with momentum and/or weight_decay if provided + optimizer_kwargs = {"lr": dense_lr} + + if dense_momentum is not None: + optimizer_kwargs["momentum"] = dense_momentum + + if dense_weight_decay is not None: + optimizer_kwargs["weight_decay"] = dense_weight_decay + + optimizer = optimizer_class(dense_params, **optimizer_kwargs) + return sharded_model, optimizer diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py index f9aacb30b..27d8203ce 100644 --- a/torchrec/distributed/benchmark/benchmark_train_pipeline.py +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -83,6 +83,14 @@ class RunOptions: pooling_factors (Optional[List[float]]): Pooling factors for each feature of the table. This is the average number of values each sample has for the feature. num_poolings (Optional[List[float]]): Number of poolings for each feature of the table. + dense_optimizer (str): Optimizer to use for dense parameters. + Default is "SGD". + dense_lr (float): Learning rate for dense parameters. + Default is 0.1. + sparse_optimizer (str): Optimizer to use for sparse parameters. + Default is "EXACT_ADAGRAD". + sparse_lr (float): Learning rate for sparse parameters. + Default is 0.1. """ world_size: int = 2 @@ -94,6 +102,14 @@ class RunOptions: planner_type: str = "embedding" pooling_factors: Optional[List[float]] = None num_poolings: Optional[List[float]] = None + dense_optimizer: str = "SGD" + dense_lr: float = 0.1 + dense_momentum: Optional[float] = None + dense_weight_decay: Optional[float] = None + sparse_optimizer: str = "EXACT_ADAGRAD" + sparse_lr: float = 0.1 + sparse_momentum: Optional[float] = None + sparse_weight_decay: Optional[float] = None @dataclass @@ -286,6 +302,19 @@ def runner( num_batches=run_option.num_batches, ) + # Prepare fused_params for sparse optimizer + fused_params = { + "optimizer": getattr(EmbOptimType, run_option.sparse_optimizer.upper()), + "learning_rate": run_option.sparse_lr, + } + + # Add momentum and weight_decay to fused_params if provided + if run_option.sparse_momentum is not None: + fused_params["momentum"] = run_option.sparse_momentum + + if run_option.sparse_weight_decay is not None: + fused_params["weight_decay"] = run_option.sparse_weight_decay + sharded_model, optimizer = generate_sharded_model_and_optimizer( model=unsharded_model, sharding_type=run_option.sharding_type.value, @@ -293,10 +322,11 @@ def runner( # pyre-ignore pg=ctx.pg, device=ctx.device, - fused_params={ - "optimizer": EmbOptimType.EXACT_ADAGRAD, - "learning_rate": 0.1, - }, + fused_params=fused_params, + dense_optimizer=run_option.dense_optimizer, + dense_lr=run_option.dense_lr, + dense_momentum=run_option.dense_momentum, + dense_weight_decay=run_option.dense_weight_decay, planner=planner, ) pipeline = generate_pipeline(