Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 236 additions & 14 deletions torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
# LICENSE file in the root directory of this source tree.

import functools
import os
import re
from collections import OrderedDict
from typing import Any, Generic, Iterator, TypeVar

import torch
Expand All @@ -21,6 +24,9 @@
from torchtitan.components.ft import FTManager, has_torchft
from torchtitan.config import Optimizer as OptimizerConfig
from torchtitan.distributed import ParallelDims
from torchtitan.experiments.distributed_scion import DistributedScion, naive_param_norm
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import Color

__all__ = [
"OptimizersContainer",
Expand All @@ -36,6 +42,55 @@
T = TypeVar("T", bound=Optimizer)


def _extract_param_groups(
model: torch.nn.Module,
optimizer_config: dict[str, Any] | None = None,
):
param_groups_config: list[dict[str, Any]] | None = (
optimizer_config.pop("param_groups", None)
if optimizer_config is not None
else None
)
if param_groups_config is None:
param_groups_config = []

param_dict = OrderedDict(
(n, p) for n, p in model.named_parameters() if p.requires_grad
)
params = []

color = Color()
for param_group_config in param_groups_config:
str_match = param_group_config.pop("param_str_match")
filter_fn = functools.partial(re.search, str_match)
param_names = [n for n in param_dict.keys() if filter_fn(n)]
group_params = {
"params": [param_dict.pop(n) for n in param_names],
"param_names": param_names,
}
assert len(group_params["params"]) == len(group_params["param_names"])

if len(param_names) == 0:
logger.warning(
f'{color.red}Notice: No parameters found for `str_match` "{str_match}" on '
f"global rank {torch.distributed.get_rank()}{color.reset}"
)
continue
group_params.update(param_group_config)
params.append(group_params)

param_names = list(param_dict.keys())
params.insert(
0,
{
"params": [param_dict.pop(n) for n in param_names],
"param_names": param_names,
},
)
assert not param_dict
return params


class OptimizersContainer(Optimizer, Stateful, Generic[T]):
"""A container for multiple optimizers.

Expand Down Expand Up @@ -74,11 +129,34 @@ def __init__(
all_params = []
self.optimizers = []
self.model_parts = model_parts
param_groups_config = optimizer_kwargs.get("param_groups", None)
# Whether to keep old LR values when loading.
self.preserve_lrs_when_loading = False
self.norms_to_log: list[str] | None = None

for model in self.model_parts:
params = [p for p in model.parameters() if p.requires_grad]
self.optimizers.append(optimizer_cls(params, **optimizer_kwargs))
# copy parts we will pop from to preserve settings across model parts
kwargs = optimizer_kwargs.copy()
if "param_groups" in optimizer_kwargs:
kwargs["param_groups"] = (
param_groups_config.copy()
if param_groups_config is not None
else None
)

extra_kwargs = kwargs.pop("extra_kwargs")
params = _extract_param_groups(model, kwargs)

is_scion = issubclass(optimizer_cls, (DistributedScion))
if is_scion:
kwargs.update(extra_kwargs)
self.optimizers.append(optimizer_cls(params, **kwargs))
all_params.extend(params)
self._validate_length(len(self.model_parts))
# Do not separately save the external settings in
# optimizer defaults.
optimizer_kwargs.pop("param_groups", None)
optimizer_kwargs.update(optimizer_kwargs.pop("extra_kwargs", {}))
self._post_init(all_params, optimizer_kwargs)

def __iter__(self) -> Iterator[T]:
Expand All @@ -93,7 +171,12 @@ def step(self, *args, **kwargs) -> None:

def zero_grad(self, *args, **kwargs) -> None:
for optimizer in self.optimizers:
optimizer.zero_grad(*args, **kwargs)
if not (
isinstance(optimizer, (DistributedScion))
and optimizer.is_light
and optimizer.use_momentum
):
optimizer.zero_grad(*args, **kwargs)

def state_dict(self) -> dict[str, Any]:
func = functools.partial(
Expand All @@ -107,13 +190,68 @@ def state_dict(self) -> dict[str, Any]:
}

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
if self.preserve_lrs_when_loading:
# Store current learning rates
prev_lrs = []
for optimizer in self.optimizers:
prev_lrs.append([group["lr"] for group in optimizer.param_groups])

func = functools.partial(
set_optimizer_state_dict,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
list(map(func, self.model_parts, self.optimizers))

if self.preserve_lrs_when_loading:
# Restore the original learning rates
for optimizer, optim_prev_lrs in zip(self.optimizers, prev_lrs):
for param_group, prev_lr in zip(optimizer.param_groups, optim_prev_lrs):
if param_group["lr"] != prev_lr:
logger.warning(
f"Restoring lr from {param_group['lr']} to {prev_lr} | "
f"for {param_group['param_names']}"
)
param_group["lr"] = prev_lr

def calculate_norm_at_next_step(self):
# for Dist-scion, we tell the optimizer to calculate the norm at next step
# in the step() function
for i, _ in enumerate(self.model_parts):
optimizer = self.optimizers[i]
if isinstance(optimizer, DistributedScion):
optimizer.calculate_norm_at_next_step(self.norms_to_log)

def get_parameter_norms(self):
all_norms = {}
for i, model_part in enumerate(self.model_parts):
# NB: assumes correspondences between model parts and optimizers
optimizer = self.optimizers[i]
for group in optimizer.param_groups:
if isinstance(optimizer, DistributedScion):
all_norms.update(optimizer.get_norms_at_current_step())
else:
all_norms.update(
naive_param_norm.get_parameter_norms(
[model_part],
[optimizer],
self.norms_to_log,
)
)
# # To Debug, we can force using naive_param_norm
# all_norms.update(
# naive_param_norm.get_parameter_norms([model_part], [optimizer])
# )

return all_norms

def get_lrs(self):
lrs = {}
for i, optimizer in enumerate(self.optimizers):
for k, group in enumerate(optimizer.param_groups):
lrs[f"lr/opt_{i}/group_{k}"] = group["lr"]
return lrs

def _validate_length(self, expected_length: int) -> None:
assert expected_length == len(self.optimizers), (
"Must pass one optimizer per model part or per param if "
Expand Down Expand Up @@ -246,6 +384,7 @@ def build_optimizers(
optimizer_config: OptimizerConfig,
parallel_dims: ParallelDims,
ft_manager: FTManager | None = None,
extra_kwargs: dict[str, Any] | None = None,
) -> OptimizersContainer:
"""Create a OptimizersContainer for the given model parts and job config.

Expand Down Expand Up @@ -280,31 +419,114 @@ def build_optimizers(
"TorchFT is not supported with optimizers in backward."
)

extra_kwargs = extra_kwargs if extra_kwargs is not None else {}

name = optimizer_config.name
lr = optimizer_config.lr
beta1 = optimizer_config.beta1
beta2 = optimizer_config.beta2
eps = optimizer_config.eps
weight_decay = optimizer_config.weight_decay

optim_implementation = optimizer_config.implementation
assert optim_implementation in ["fused", "foreach", "for-loop"]
is_scion = name == "DistributedScion"

fused = optim_implementation == "fused"
foreach = optim_implementation == "foreach"
if name in ["Adam", "AdamW"]:
optim_implementation = optimizer_config.implementation
assert optim_implementation in ["fused", "foreach", "for-loop"]

optimizer_kwargs = {
"lr": lr,
"betas": (beta1, beta2),
"eps": eps,
"weight_decay": weight_decay,
"fused": fused,
"foreach": foreach,
fused = optim_implementation == "fused"
foreach = optim_implementation == "foreach"

if parallel_dims.ep_enabled:
# Because for Expert Parallel, we have two different device meshes.
fused, foreach = False, False

optimizer_kwargs = {
"lr": lr,
"betas": (beta1, beta2),
"eps": eps,
"weight_decay": weight_decay,
"fused": fused,
"foreach": foreach,
}
elif is_scion:
backend_steps = optimizer_config.backend_steps
zeropower_backend_algorithm = optimizer_config.zeropower_backend
momentum = optimizer_config.momentum
nesterov = optimizer_config.nesterov
is_light = optimizer_config.is_light
weight_decay = optimizer_config.weight_decay
if os.environ.get("SCION_DEBUG_GRAD") == "1":
# only if we want to debug the gradient, we dont run SVD
norm_factor = "none"
zeropower_backend_algorithm = "identity"
logger.warning(
'`SCION_DEBUG_GRAD` is set to 1, we will not run SVD and use the "identity" backend'
)
else:
norm_factor = "spectral"

optimizer_kwargs = {
"is_light": is_light,
"weight_decay": weight_decay,
"lr": lr,
"momentum": momentum,
"nesterov": nesterov,
"eps": eps,
"norm_factor": norm_factor,
"backend": zeropower_backend_algorithm,
"backend_steps": backend_steps,
}
else:
raise NotImplementedError(f"Optimizer {name} not added.")

# Configure parameter group settings
embed_lr = optimizer_config.embed_lr
embed_str_match = optimizer_config.embed_str_match
if embed_lr is not None and embed_str_match:
param_groups_config = optimizer_kwargs.setdefault("param_groups", [])
param_group_config = {
"param_str_match": embed_str_match,
"lr": embed_lr,
}
if is_scion:
param_group_config["norm_factor"] = "embed_sqrt"
param_group_config["backend"] = "identity"
param_groups_config.append(param_group_config)
unembed_lr = optimizer_config.unembed_lr
unembed_str_match = optimizer_config.unembed_str_match
if unembed_lr is not None and unembed_str_match:
param_groups_config = optimizer_kwargs.setdefault("param_groups", [])
param_group_config = {
"param_str_match": unembed_str_match,
"lr": unembed_lr,
}
if is_scion:
param_group_config["norm_factor"] = "unembed_sqrt"
param_group_config["backend"] = "identity"
param_groups_config.append(param_group_config)

router_str_match = optimizer_config.router_str_match
if router_str_match:
param_groups_config = optimizer_kwargs.setdefault("param_groups", [])
param_group_config = {
"param_str_match": router_str_match,
"lr": lr,
}
if is_scion:
param_group_config["norm_factor"] = "spectral"
param_group_config["backend"] = zeropower_backend_algorithm
param_groups_config.append(param_group_config)

optimizer_kwargs["extra_kwargs"] = {
"parallel_dims": parallel_dims,
**extra_kwargs,
}

optimizer_classes = {
"Adam": torch.optim.Adam,
"AdamW": torch.optim.AdamW,
"DistributedScion": DistributedScion,
}
if name not in optimizer_classes:
raise NotImplementedError(f"Optimizer {name} not added.")
Expand Down
34 changes: 34 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ class Metrics:
enable_wandb: bool = False
"""Whether to log metrics to Weights & Biases"""

log_norm_freq: int = -1
"""How often to log norms in iterations"""


@dataclass
class Model:
Expand Down Expand Up @@ -138,6 +141,37 @@ class Optimizer:
register_post_accumulate_grad_hook after the optimizer is built.
"""

# Below is Scion-specific configs
is_light: bool = False
"""Whether to use Scion's light (memory-saving) version"""

zeropower_backend: str = "newtonschulz5"
"Which `zeropower_backend` to use."

backend_steps: int = 5
"""Number of steps for the Scion backend"""

momentum: float = 0.95
"""Scion momentum to use"""

nesterov: bool = False
"""Whether to use Nesterov momentum in Scion"""

embed_lr: float | None = None
"""Embedding layer learning rate"""

unembed_lr: float | None = None
"""Unembedding layer learning rate"""

embed_str_match: str | None = None
"""String to match for embedding layer parameter group"""

unembed_str_match: str | None = None
"""String to match for unembedding layer parameter group"""

router_str_match: str | None = None
"""String to match for MoE router layer parameter group"""


@dataclass
class LRScheduler:
Expand Down
8 changes: 8 additions & 0 deletions torchtitan/experiments/distributed_scion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .distributed_scion import DistributedScion # noqa: F401
from .utils import remove_orig_mod_and_weight_for_p_name # noqa: F401
Loading