From 41a0c085fe2339cdc1e7bc8a5278e037de3e4ed7 Mon Sep 17 00:00:00 2001 From: Gavin Zhang Date: Wed, 6 Aug 2025 00:25:53 -0700 Subject: [PATCH] refactor the total norm computation in grad clipping in APS (#3243) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3243 Refactored the previous code for applying gradient clipping across ddp and fsdp parameter. Added a new funciton _compute_total_norm() that takes in the replicated and sharded params provided in the gradientclippingOpitmizer class and computes the total gradient norm of the given parameter. Differential Revision: D79128843 --- torchrec/optim/clipping.py | 181 +++++++++++++++++++------------------ 1 file changed, 95 insertions(+), 86 deletions(-) diff --git a/torchrec/optim/clipping.py b/torchrec/optim/clipping.py index 2ba9a6290..366a51c8e 100644 --- a/torchrec/optim/clipping.py +++ b/torchrec/optim/clipping.py @@ -59,7 +59,7 @@ def __init__( super().__init__(optimizer) self._clipping = clipping self._max_gradient = max_gradient - self._norm_type = norm_type + self._norm_type = float(norm_type) self._check_meta: bool = True self._enable_global_grad_clip = enable_global_grad_clip self._step_num = 0 @@ -122,121 +122,130 @@ def step(self, closure: Any = None) -> None: for p in self._replicate_params ] torch.nn.utils.clip_grad_norm_( - replicate_params, - self._max_gradient, - norm_type=float(self._norm_type), + parameters=replicate_params, + max_norm=self._max_gradient, + norm_type=self._norm_type, ) else: self.clip_grad_norm_() elif self._clipping == GradientClipping.VALUE: - torch.nn.utils.clip_grad_value_(self._replicate_params, self._max_gradient) + torch.nn.utils.clip_grad_value_( + parameters=self._replicate_params, clip_value=self._max_gradient + ) super().step(closure) self._step_num += 1 - @torch.no_grad() def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: """Clip the gradient norm of all parameters.""" - max_norm = self._max_gradient - norm_type = float(self._norm_type) + + # converts self._norm_type to a float if it's a string. Used in the case where self._norm_type is 'inf'. all_grads = [] - total_grad_norm = None + sharded_params = self._sharded_params + replicate_params = self._replicate_params # Process distributed parameters and gradients - for pgs, dist_params in self._sharded_params.items(): - sharded_grads = [ - p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad - for p in dist_params - if p.grad is not None and p.grad.numel() > 0 - ] - if len(sharded_grads) == 0: - continue - all_grads.extend(sharded_grads) - - sharded_grad_norm = _batch_cal_norm( - sharded_grads, - max_norm, - norm_type, - pgs, - ) - total_grad_norm = ( - sharded_grad_norm - if total_grad_norm is None - else ( - torch.maximum(total_grad_norm, sharded_grad_norm) - if norm_type == torch.inf - else total_grad_norm + sharded_grad_norm - ) - ) - - square_sharded_grad_norm = total_grad_norm if total_grad_norm is not None else 0 + sharded_grads = { + pgs: _get_grads(dist_params) for pgs, dist_params in sharded_params.items() + } + all_grads.extend(*sharded_grads.values()) # Process replicated parameters and gradients - if self._replicate_params: - replicated_grads = [ - p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad - for p in self._replicate_params - if p.grad is not None and p.grad.numel() > 0 - ] - all_grads.extend(replicated_grads) - - replicated_grad_norm = _batch_cal_norm( - replicated_grads, - max_norm, - norm_type, - None, - ) - total_grad_norm = ( - replicated_grad_norm - if total_grad_norm is None - else ( - torch.maximum(total_grad_norm, replicated_grad_norm) - if norm_type == torch.inf - else total_grad_norm + replicated_grad_norm - ) - ) - square_replicated_grad_norm = replicated_grad_norm - else: - square_replicated_grad_norm = 0 - - global log_grad_norm - if log_grad_norm: - if total_grad_norm is not None and norm_type != torch.inf: - # pyre-ignore[58] - grad_norm = total_grad_norm ** (1.0 / norm_type) - else: - grad_norm = total_grad_norm - - rank = dist.get_rank() - logger.info( - f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}" - ) - - # Aggregation - if total_grad_norm is None: - return + replicate_grads = _get_grads(replicate_params) + all_grads.extend(replicate_grads) + + total_grad_norm = _compute_total_norm( + replicate_grads=replicate_grads, + sharded_grads=sharded_grads, + norm_type=self._norm_type, + max_grad_norm=self._max_gradient, + ) - if norm_type != torch.inf: - # pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float. - total_grad_norm = total_grad_norm ** (1.0 / norm_type) # pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor]. - clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6)) + clip_coef = cast(torch.Tensor, self._max_gradient / (total_grad_norm + 1e-6)) clip_coef_clamped = torch.clamp(clip_coef, max=1.0) torch._foreach_mul_(all_grads, clip_coef_clamped) return total_grad_norm +def _get_grads( + param_list: List[torch.Tensor], +) -> List[torch.Tensor]: + """Get the gradients of a list of parameters. Converts DTensors to local tensors if needed.""" + grads = [ + p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad + for p in param_list + if p.grad is not None and p.grad.numel() > 0 + ] + return grads + + +def _compute_total_norm( + replicate_grads: List[torch.Tensor], + sharded_grads: Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]], + norm_type: float = 2.0, # can be a normal float, or torch.inf + max_grad_norm: float = 1.0, +) -> torch.Tensor: + """ + Given both replicate grads and sharded grads, compute the total norm of the gradients of the full replicate params and the + full sharded param (parameters with a process group). + + Args: + replicate_grads (List[torch.Tensor]): list of gradients for replicate params + sharded_grads (Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]]): dict that maps each process group to a list of gradients for sharded params + norm_type (float): type of the used p-norm. Can be torch.inf for infinity norm. + max_grad_norm (float): max gradient norm. + """ + + ## compute the norm |W|^p corresponding to all sharded params W + sharded_grad_norm: torch.Tensor = torch.tensor(0.0) + combine_norm_operator = torch.maximum if norm_type == torch.inf else torch.add + + # We need to move sharded_grad_norm to the same device as the first shard so that we can do addition (or take max) + # this is specifically for the case where sharded_grad_norm is 0, and replicate_grad_norm is not, + # because by default torch.tensor(0.0) is on cpu, and replicate_grad_norm is on GPU. For MTIA + # specifically, adding a tensor on cpu and a tensor on GPU will result in an error. + for pgs, dist_params in sharded_grads.items(): + current_shard_norm = _batch_cal_norm( + grad_list=dist_params, + max_norm=max_grad_norm, + norm_type=norm_type, + process_groups=pgs, + ) + sharded_grad_norm = combine_norm_operator( + sharded_grad_norm.to(current_shard_norm.device), current_shard_norm + ) + # compute |W|^p corresponding to all replicate params W + # Similar to the case above, we move replicate_grad_norm to the same device as sharded_grad_norm so that we can do addition. + replicate_grad_norm: torch.Tensor = ( + _batch_cal_norm( + grad_list=replicate_grads, max_norm=max_grad_norm, norm_type=norm_type + ) + if replicate_grads + else torch.tensor(0.0) + ).to(sharded_grad_norm.device) + + # In the p-norm case, we are given norms |W_sharded|^p and |W_replicate|^p. To compute the total norm, we need to + # sum them and take the p-th root. In the inf-norm case, we are given max(|W_sharded|) and max(|W_replicate|). + # To compute the total norm, we need to take max(max(|W_sharded|), max(|W_replicate|). + combined_norm = combine_norm_operator(replicate_grad_norm, sharded_grad_norm) + total_grad_norm = ( + combined_norm.pow(1.0 / norm_type) if norm_type != torch.inf else combined_norm + ) + + return total_grad_norm + + def _batch_cal_norm( grad_list: List[torch.Tensor], max_norm: float, norm_type: float = 2.0, process_groups: Optional[Tuple[dist.ProcessGroup]] = None, ) -> torch.Tensor: - """Helper function that calculates the norm of a list of gradients in batches. If process_groups - are passed in, the norm will be aggregated across all ranks in the process group. + """Helper function that calculates the p-th power of the norm of a list of gradients in batches. + If process_groups are passed in, the norm will be aggregated across all ranks in the process group. """ - global use_64bit_grad_norm if use_64bit_grad_norm: grad_norms = torch.linalg.vector_norm(