Skip to content

Added Optimizer configuration that supports optimizer type, learning rate, momentum, and weight decay configurations. #3107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions torchrec/distributed/benchmark/benchmark_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,21 +432,23 @@ def generate_sharded_model_and_optimizer(
kernel_type: str,
pg: dist.ProcessGroup,
device: torch.device,
fused_params: Optional[Dict[str, Any]] = None,
fused_params: Dict[str, Any],
dense_optimizer: str,
dense_lr: float,
dense_momentum: Optional[float],
dense_weight_decay: Optional[float],
planner: Optional[
Union[
EmbeddingShardingPlanner,
HeteroEmbeddingShardingPlanner,
]
] = None,
) -> Tuple[nn.Module, Optimizer]:
# Ensure fused_params is always a dictionary
fused_params_dict = {} if fused_params is None else fused_params

sharder = TestEBCSharder(
sharding_type=sharding_type,
kernel_type=kernel_type,
fused_params=fused_params_dict,
fused_params=fused_params,
)
sharders = [cast(ModuleSharder[nn.Module], sharder)]

Expand All @@ -466,14 +468,28 @@ def generate_sharded_model_and_optimizer(
sharders=sharders,
plan=plan,
).to(device)
optimizer = optim.SGD(
[
param
for name, param in sharded_model.named_parameters()
if "sparse" not in name
],
lr=0.1,
)

# Get dense parameters
dense_params = [
param
for name, param in sharded_model.named_parameters()
if "sparse" not in name
]

# Create optimizer based on the specified type
optimizer_class = getattr(optim, dense_optimizer)

# Create optimizer with momentum and/or weight_decay if provided
optimizer_kwargs = {"lr": dense_lr}

if dense_momentum is not None:
optimizer_kwargs["momentum"] = dense_momentum

if dense_weight_decay is not None:
optimizer_kwargs["weight_decay"] = dense_weight_decay

optimizer = optimizer_class(dense_params, **optimizer_kwargs)

return sharded_model, optimizer


Expand Down
38 changes: 34 additions & 4 deletions torchrec/distributed/benchmark/benchmark_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ class RunOptions:
pooling_factors (Optional[List[float]]): Pooling factors for each feature of the table.
This is the average number of values each sample has for the feature.
num_poolings (Optional[List[float]]): Number of poolings for each feature of the table.
dense_optimizer (str): Optimizer to use for dense parameters.
Default is "SGD".
dense_lr (float): Learning rate for dense parameters.
Default is 0.1.
sparse_optimizer (str): Optimizer to use for sparse parameters.
Default is "EXACT_ADAGRAD".
sparse_lr (float): Learning rate for sparse parameters.
Default is 0.1.
"""

world_size: int = 2
Expand All @@ -94,6 +102,14 @@ class RunOptions:
planner_type: str = "embedding"
pooling_factors: Optional[List[float]] = None
num_poolings: Optional[List[float]] = None
dense_optimizer: str = "SGD"
dense_lr: float = 0.1
dense_momentum: Optional[float] = None
dense_weight_decay: Optional[float] = None
sparse_optimizer: str = "EXACT_ADAGRAD"
sparse_lr: float = 0.1
sparse_momentum: Optional[float] = None
sparse_weight_decay: Optional[float] = None


@dataclass
Expand Down Expand Up @@ -286,17 +302,31 @@ def runner(
num_batches=run_option.num_batches,
)

# Prepare fused_params for sparse optimizer
fused_params = {
"optimizer": getattr(EmbOptimType, run_option.sparse_optimizer.upper()),
"learning_rate": run_option.sparse_lr,
}

# Add momentum and weight_decay to fused_params if provided
if run_option.sparse_momentum is not None:
fused_params["momentum"] = run_option.sparse_momentum

if run_option.sparse_weight_decay is not None:
fused_params["weight_decay"] = run_option.sparse_weight_decay

sharded_model, optimizer = generate_sharded_model_and_optimizer(
model=unsharded_model,
sharding_type=run_option.sharding_type.value,
kernel_type=run_option.compute_kernel.value,
# pyre-ignore
pg=ctx.pg,
device=ctx.device,
fused_params={
"optimizer": EmbOptimType.EXACT_ADAGRAD,
"learning_rate": 0.1,
},
fused_params=fused_params,
dense_optimizer=run_option.dense_optimizer,
dense_lr=run_option.dense_lr,
dense_momentum=run_option.dense_momentum,
dense_weight_decay=run_option.dense_weight_decay,
planner=planner,
)
pipeline = generate_pipeline(
Expand Down
Loading