diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index 74b2727f..b213c3af 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -296,6 +296,10 @@ class DistributedShampoo(torch.optim.Optimizer): 3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail. track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes. (Default: False) + experimental_param_to_lr (Optional[Callable[[Tensor], float]]): Optional mapping between Param and learning rate. + If set, this map needs to cover all parameters in param_groups. + This setting supersedes learning rate of each parameter group. + (Default: None) """ @@ -326,6 +330,7 @@ def __init__( precision_config: Optional[PrecisionConfig] = None, use_protected_eigh: bool = True, track_root_inv_residuals: bool = False, + experimental_param_to_lr: Optional[Callable[[torch.Tensor], float]] = None, ) -> None: # Hyperparameter checks. if not lr >= 0.0: @@ -474,6 +479,7 @@ def __init__( self._shampoo_pt2_compile_config: Optional[ShampooPT2CompileConfig] = ( shampoo_pt2_compile_config ) + self._experimental_param_to_lr = experimental_param_to_lr # Initialize dictionary containing lists of . self._per_group_state_lists: List[Dict[str, Any]] = [ @@ -1142,6 +1148,114 @@ def _per_group_step_impl( masked_blocked_search_directions=masked_blocked_search_directions ) + @torch.no_grad() + def _per_group_step_experimental_lrs( + self, + state_lists: Dict[str, Any], + step: torch.Tensor, + neg_lrs: List[torch.Tensor], + beta1: float, + beta3: float, + weight_decay: float, + momentum_param: float, + dampening: float, + grafting_config_not_none: bool, + compute_root_inverse: bool, + use_decoupled_weight_decay: bool, + use_bias_correction: bool, + use_grafting_method: bool, + use_nesterov: bool, + ) -> None: + # Incorporate L2-regularization or (coupled) weight decay if enabled. + # G <- G + lr * weight_decay * W + self._add_l2_regularization( + state_lists, + weight_decay, + use_decoupled_weight_decay, + ) + + with DequantizePreconditionersContext( + preconditioner_list=state_lists[SHAMPOO_PRECONDITIONER_LIST] + ), ( + DequantizePreconditionersContext( + preconditioner_list=state_lists[GRAFTING_PRECONDITIONER_LIST] + ) + if grafting_config_not_none + else contextlib.nullcontext() + ): + # Update Shampoo and grafting preconditioners / factor matrices. + # Example for AdaGrad accumulation: + # L <- L + G * G^T + # R <- R + G^T * G + # V <- V + G^2 (element-wise) + # (and similar) + self._update_preconditioners( + state_lists, + step, + grafting_config_not_none, + ) + + # Compute matrix root inverse. + # L_inv <- L ** (-1/4) + # R_inv <- R ** (-1/4) + # (and similar) + self._compute_root_inverse(state_lists, compute_root_inverse) + + # Compute filtered gradient or EMA of the gradients if beta1 > 0 and beta3 > 0. + # Note that we use two beta factors here akin to Lion. + # G_bar <- beta3 * G_tilde + (1 - beta3) * G + # G_tilde <- beta1 * G_tilde + (1 - beta1) * G + masked_filtered_grad_list = self._compute_filtered_grad_list( + state_lists, + step, + beta1, + beta3, + use_bias_correction, + ) + + # Precondition and graft filtered gradients. + # PT2 compile is currently disabled for preconditioning and grafting. + # TODO: Resolve preconditioning and grafting PT2 NEX issue and enable them. + # + # P_shampoo <- L_inv * G_bar * R_inv (and similar) + # P_grafting <- G_bar / (sqrt(V) + epsilon) + # P <- P_grafting if step < start_preconditioning_step + # P <- ||P_grafting|| / ||P_shampoo|| * P_shampoo otherwise + masked_blocked_search_directions = self._precondition_and_grafting( + state_lists, + masked_filtered_grad_list, + use_grafting_method, + grafting_config_not_none, + ) + + # Incorporate decoupled weight decay into search direction if enabled. + # P <- P + weight_decay * W + self._apply_decoupled_weight_decay( + state_lists, + masked_blocked_search_directions, + weight_decay, + use_decoupled_weight_decay, + ) + + # Update momentum optimizer state and use momentum / Nesterov if enabled. + # M <- momentum_param * M + (1 - dampening) * P + # P <- (1 - dampening) * P + momentum_param * M if use_nesterov + # P <- M otherwise. + self._update_momentum( + state_lists, + masked_blocked_search_directions, + momentum_param, + dampening, + use_nesterov, + ) + + # Updates parameters in distributed fashion. + # If DDP, executes AllGather communication to ensure all parameters are updated after local updates. + torch._foreach_mul_(masked_blocked_search_directions, neg_lrs) + state_lists[DISTRIBUTOR].update_params( + masked_blocked_search_directions=masked_blocked_search_directions + ) + @torch.no_grad() def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: """Performs a single optimization step. @@ -1173,12 +1287,6 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] # Iterate group step counter and define Python scalar step. step = state_lists[STEP].add_(1) - # NOTE: Wrap scalar of group[LR] into a 0D tensor to avoid PT2 recompilation; - # Send 0D tensor to GPU in `non_blocking` to avoid QPS regression. Remove the gpu - # tensor impl once PT2 supports cpu 0D tensor properly. - lr = torch.tensor(group[LR], dtype=torch.float).to( - self._device, non_blocking=True - ) beta1 = group[BETAS][0] beta3 = group[BETA3] weight_decay = group[WEIGHT_DECAY] @@ -1200,22 +1308,58 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] ) use_nesterov = group[USE_NESTEROV] - self._per_group_step( - state_lists, - step, - lr, - beta1, - beta3, - weight_decay, - momentum_param, - dampening, - grafting_config_not_none, - compute_root_inverse, - use_decoupled_weight_decay, - use_bias_correction, - use_grafting_method, - use_nesterov, - ) + if self._experimental_param_to_lr is None: + # NOTE: Wrap scalar of group[LR] into a 0D tensor to avoid PT2 recompilation; + # Send 0D tensor to GPU in `non_blocking` to avoid QPS regression. Remove the gpu + # tensor impl once PT2 supports cpu 0D tensor properly. + lr = torch.tensor(group[LR], dtype=torch.float).to( + self._device, non_blocking=True + ) + self._per_group_step( + state_lists, + step, + lr, + beta1, + beta3, + weight_decay, + momentum_param, + dampening, + grafting_config_not_none, + compute_root_inverse, + use_decoupled_weight_decay, + use_bias_correction, + use_grafting_method, + use_nesterov, + ) + else: + local_block_info_list = compress_list( + state_lists[DISTRIBUTOR].global_block_info_list, + state_lists[DISTRIBUTOR].distributor_selector, + ) + neg_lr_tersors = [] + for local_block_info in local_block_info_list: + lr_scalar = self._experimental_param_to_lr(local_block_info.param) + lr = torch.tensor(-lr_scalar, dtype=torch.float).to( + self._device, non_blocking=True + ) + neg_lr_tersors.append(lr) + + self._per_group_step_experimental_lrs( + state_lists, + step, + neg_lr_tersors, + beta1, + beta3, + weight_decay, + momentum_param, + dampening, + grafting_config_not_none, + compute_root_inverse, + use_decoupled_weight_decay, + use_bias_correction, + use_grafting_method, + use_nesterov, + ) return loss diff --git a/distributed_shampoo/examples/ddp_cifar10_example.py b/distributed_shampoo/examples/ddp_cifar10_example.py index 55ce7260..e81a7a73 100644 --- a/distributed_shampoo/examples/ddp_cifar10_example.py +++ b/distributed_shampoo/examples/ddp_cifar10_example.py @@ -131,6 +131,12 @@ ), use_protected_eigh=args.use_protected_eigh, track_root_inv_residuals=args.track_root_inv_residuals, + experimental_lrs=( + [float(f) for f in args.experimental_lrs.split(",")] + if args.experimental_lrs + else [] + ), + experimental_param_to_lr_mapping=args.experimental_param_to_lr_mapping, ) # checks for checkpointing @@ -140,7 +146,7 @@ raise ValueError( "Distributed checkpointing is only supported with DistributedShampoo!" ) - if args.se_distributed_checkpoint and args.checkpoint_dir is None: + if args.use_distributed_checkpoint and args.checkpoint_dir is None: raise ValueError( "Trying to use distributed checkpointing but checkpoint directory is not provided!" ) diff --git a/distributed_shampoo/examples/trainer_utils.py b/distributed_shampoo/examples/trainer_utils.py index 9efe9fb6..964803a8 100644 --- a/distributed_shampoo/examples/trainer_utils.py +++ b/distributed_shampoo/examples/trainer_utils.py @@ -12,7 +12,7 @@ import logging import random from abc import ABC, abstractmethod -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np @@ -92,6 +92,12 @@ def get_args(): # Arguments for optimizer. parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate.") + parser.add_argument( + "--experimental-lrs", + type=str, + default="", + help="Comma-separated list of learning rates. It overwrites --lr when it's set." + ) parser.add_argument( "--beta1", type=float, default=0.9, help="Beta1 for gradient filtering." ) @@ -195,6 +201,12 @@ def get_args(): action="store_true", help="Use debug mode for examining root inverse residuals.", ) + parser.add_argument( + "--experimental-param-to-lr-mapping", + type=bool, + default=False, + help="Use an experimental feature that maps parameters to learning rates." + ) # Arguments for grafting. parser.add_argument( @@ -410,6 +422,8 @@ def instantiate_optimizer( precision_config: Optional[PrecisionConfig], use_protected_eigh: bool, track_root_inv_residuals: bool, + experimental_lrs: List[float] = [], + experimental_param_to_lr_mapping: bool = False, ) -> torch.optim.Optimizer: if optimizer_type == OptimizerType.SGD: optimizer = torch.optim.SGD( @@ -438,33 +452,108 @@ def instantiate_optimizer( weight_decay=weight_decay, ) elif optimizer_type == OptimizerType.DISTRIBUTED_SHAMPOO: - optimizer = DistributedShampoo( - model.parameters(), - lr=lr, - betas=betas, - beta3=beta3, - epsilon=epsilon, - momentum=momentum, - dampening=dampening, - weight_decay=weight_decay, - max_preconditioner_dim=max_preconditioner_dim, - precondition_frequency=precondition_frequency, - start_preconditioning_step=start_preconditioning_step, - inv_root_override=inv_root_override, - exponent_multiplier=exponent_multiplier, - use_nesterov=use_nesterov, - use_bias_correction=use_bias_correction, - use_decoupled_weight_decay=use_decoupled_weight_decay, - grafting_config=instantiate_grafting_config( - grafting_type, grafting_beta2, grafting_epsilon - ), - use_merge_dims=use_merge_dims, - use_pytorch_compile=use_pytorch_compile, - distributed_config=distributed_config, - precision_config=precision_config, - use_protected_eigh=use_protected_eigh, - track_root_inv_residuals=track_root_inv_residuals, - ) + if len(experimental_lrs) == 0: + # The default and standard behavior of Shampoo. + optimizer = DistributedShampoo( + model.parameters(), + lr=lr, + betas=betas, + beta3=beta3, + epsilon=epsilon, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + max_preconditioner_dim=max_preconditioner_dim, + precondition_frequency=precondition_frequency, + start_preconditioning_step=start_preconditioning_step, + inv_root_override=inv_root_override, + exponent_multiplier=exponent_multiplier, + use_nesterov=use_nesterov, + use_bias_correction=use_bias_correction, + use_decoupled_weight_decay=use_decoupled_weight_decay, + grafting_config=instantiate_grafting_config( + grafting_type, grafting_beta2, grafting_epsilon + ), + use_merge_dims=use_merge_dims, + use_pytorch_compile=use_pytorch_compile, + distributed_config=distributed_config, + precision_config=precision_config, + use_protected_eigh=use_protected_eigh, + track_root_inv_residuals=track_root_inv_residuals, + ) + else: + param_to_lr_idx = { + param: param_idx % len(experimental_lrs) for param_idx, param in enumerate(model.parameters()) + } + if not experimental_param_to_lr_mapping: + # Here we assume that lrs are set to each parameters. + # Let's use round-robin for simplicity. + param_groups = [{"params": [], "lr": lr} for lr in experimental_lrs] + for param in model.parameters(): + param_groups[param_to_lr_idx[param]]["params"].append(param) + + optimizer = DistributedShampoo( + param_groups, + lr=lr, + betas=betas, + beta3=beta3, + epsilon=epsilon, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + max_preconditioner_dim=max_preconditioner_dim, + precondition_frequency=precondition_frequency, + start_preconditioning_step=start_preconditioning_step, + inv_root_override=inv_root_override, + exponent_multiplier=exponent_multiplier, + use_nesterov=use_nesterov, + use_bias_correction=use_bias_correction, + use_decoupled_weight_decay=use_decoupled_weight_decay, + grafting_config=instantiate_grafting_config( + grafting_type, grafting_beta2, grafting_epsilon + ), + use_merge_dims=use_merge_dims, + use_pytorch_compile=use_pytorch_compile, + distributed_config=distributed_config, + precision_config=precision_config, + use_protected_eigh=use_protected_eigh, + track_root_inv_residuals=track_root_inv_residuals, + ) + else: + # Pass learning rates via an experimental interface. + # Note that we pass only a single param_group. + def param_to_lr(param: torch.Tensor) -> float: + return experimental_lrs[param_to_lr_idx[param]] + + optimizer = DistributedShampoo( + model.parameters(), + lr=lr, + betas=betas, + beta3=beta3, + epsilon=epsilon, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + max_preconditioner_dim=max_preconditioner_dim, + precondition_frequency=precondition_frequency, + start_preconditioning_step=start_preconditioning_step, + inv_root_override=inv_root_override, + exponent_multiplier=exponent_multiplier, + use_nesterov=use_nesterov, + use_bias_correction=use_bias_correction, + use_decoupled_weight_decay=use_decoupled_weight_decay, + grafting_config=instantiate_grafting_config( + grafting_type, grafting_beta2, grafting_epsilon + ), + use_merge_dims=use_merge_dims, + use_pytorch_compile=use_pytorch_compile, + distributed_config=distributed_config, + precision_config=precision_config, + use_protected_eigh=use_protected_eigh, + track_root_inv_residuals=track_root_inv_residuals, + experimental_param_to_lr=param_to_lr, + ) + else: raise ValueError(f"Invalid OptimizerType {optimizer_type}!")