Skip to content

Commit f9be4b0

Browse files
Gavin Zhangfacebook-github-bot
authored andcommitted
refactor the total norm computation in grad clipping in APS (#3243)
Summary: Pull Request resolved: #3243 Refactored the previous code for applying gradient clipping across ddp and fsdp parameter. Added a new funciton _compute_total_norm() that takes in the fsdp and ddp params provided in the gradientclippingOpitmizer class and computes the total gradient norm of the given parameter. Differential Revision: D79128843
1 parent 7514d2b commit f9be4b0

File tree

2 files changed

+89
-64
lines changed

2 files changed

+89
-64
lines changed

torchrec/optim/clipping.py

Lines changed: 88 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
# Otherwise, all parameters are treated as replicated and will be clipped locally.
6969
sharded_param_cnt = 0
7070
self._replicate_params: List[torch.Tensor] = []
71+
7172
self._sharded_params: Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]] = (
7273
defaultdict(list)
7374
)
@@ -143,90 +144,114 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
143144
all_grads = []
144145
total_grad_norm = None
145146

147+
sharded_params = self._sharded_params
148+
ddp_params = self._replicate_params
149+
146150
# Process distributed parameters and gradients
147-
for pgs, dist_params in self._sharded_params.items():
151+
for dist_params in sharded_params.values():
148152
sharded_grads = [
149153
p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad
150154
for p in dist_params
151155
if p.grad is not None and p.grad.numel() > 0
152156
]
153-
if len(sharded_grads) == 0:
154-
continue
155157
all_grads.extend(sharded_grads)
156158

157-
sharded_grad_norm = _batch_cal_norm(
158-
sharded_grads,
159-
max_norm,
160-
norm_type,
161-
pgs,
162-
)
163-
total_grad_norm = (
164-
sharded_grad_norm
165-
if total_grad_norm is None
166-
else (
167-
torch.maximum(total_grad_norm, sharded_grad_norm)
168-
if norm_type == torch.inf
169-
else total_grad_norm + sharded_grad_norm
170-
)
171-
)
172-
173-
square_sharded_grad_norm = total_grad_norm if total_grad_norm is not None else 0
174-
175159
# Process replicated parameters and gradients
176-
if self._replicate_params:
177-
replicated_grads = [
160+
if ddp_params:
161+
ddp_grads = [
178162
p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad
179-
for p in self._replicate_params
163+
for p in ddp_params
180164
if p.grad is not None and p.grad.numel() > 0
181165
]
182-
all_grads.extend(replicated_grads)
183-
184-
replicated_grad_norm = _batch_cal_norm(
185-
replicated_grads,
186-
max_norm,
187-
norm_type,
188-
None,
189-
)
190-
total_grad_norm = (
191-
replicated_grad_norm
192-
if total_grad_norm is None
193-
else (
194-
torch.maximum(total_grad_norm, replicated_grad_norm)
195-
if norm_type == torch.inf
196-
else total_grad_norm + replicated_grad_norm
197-
)
198-
)
199-
square_replicated_grad_norm = replicated_grad_norm
200-
else:
201-
square_replicated_grad_norm = 0
202-
203-
global log_grad_norm
204-
if log_grad_norm:
205-
if total_grad_norm is not None and norm_type != torch.inf:
206-
# pyre-ignore[58]
207-
grad_norm = total_grad_norm ** (1.0 / norm_type)
208-
else:
209-
grad_norm = total_grad_norm
210-
211-
rank = dist.get_rank()
212-
logger.info(
213-
f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}"
214-
)
166+
all_grads.extend(ddp_grads)
215167

216-
# Aggregation
217-
if total_grad_norm is None:
218-
return
168+
total_grad_norm = _compute_total_norm(
169+
ddp_params, sharded_params, norm_type, max_norm
170+
)
219171

220-
if norm_type != torch.inf:
221-
# pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
222-
total_grad_norm = total_grad_norm ** (1.0 / norm_type)
223172
# pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
224173
clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6))
225174
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
226175
torch._foreach_mul_(all_grads, clip_coef_clamped)
227176
return total_grad_norm
228177

229178

179+
def _compute_total_norm(
180+
ddp_params: Optional[List[torch.Tensor]] = None,
181+
sharded_params: Optional[Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]]] = None,
182+
norm_type: float = 2.0, # can be a normal float, or torch.inf
183+
max_grad_norm: float = 1.0,
184+
) -> torch.Tensor:
185+
"""
186+
Given both ddp params and sharded params, compute the total norm of the gradients of the full ddp params and the
187+
full sharded param (parameters with a process group).
188+
189+
Args:
190+
ddp_params (List[torch.Tensor]): list of ddp params
191+
sharded_params (Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]]): dict that maps each process group to a list of tensors
192+
norm_type (Union[float, str]): type of the used p-norm. Can be ``'inf'`` for infinity norm.
193+
enable_global_grad_clip (bool): whether to compute total norm using all fsdp shards in the process group
194+
param_to_pgs (Dict[torch.nn.Parameter, List[dist.ProcessGroup]]): mapping of parameters to process groups.
195+
"""
196+
197+
## compute |W|^p corresponding to all DDP params W
198+
199+
if ddp_params is None:
200+
ddp_params = []
201+
if sharded_params is None:
202+
sharded_params = defaultdict(list)
203+
204+
def get_grad_norm(
205+
param_list: List[torch.Tensor],
206+
norm_type: float,
207+
max_grad_norm: float,
208+
pgs: Optional[Tuple[dist.ProcessGroup]] = None,
209+
) -> torch.Tensor:
210+
"""
211+
Given a list of parameters, convert them to local tensors if they are DTensors,
212+
and compute the norm of the gradients of the parameters.
213+
"""
214+
grad_list = [
215+
p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad
216+
for p in param_list
217+
if p.grad is not None and p.grad.numel() > 0
218+
]
219+
return _batch_cal_norm(grad_list, max_grad_norm, norm_type, pgs)
220+
221+
## compute the norm |W|^p corresponding to all sharded params W
222+
sharded_grad_norm: torch.Tensor = torch.tensor(0.0)
223+
if sharded_params:
224+
combine_sharded_norm_operator = (
225+
torch.maximum if norm_type == torch.inf else torch.add
226+
)
227+
228+
# We need to move sharded_grad_norm to the same device as the first shard so that we can do addition (or take max)
229+
# this is specifically for the case where sharded_grad_norm is 0, and ddp_grad_norm is not,
230+
# because by default torch.tensor(0.0) is on cpu, and ddp_grad_norm is on GPU. For MTIA
231+
# specifically, adding a tensor on cpu and a tensor on GPU will result in an error.
232+
for pgs, dist_params in sharded_params.items():
233+
shard_norm = get_grad_norm(dist_params, norm_type, max_grad_norm, pgs)
234+
sharded_grad_norm = combine_sharded_norm_operator(
235+
sharded_grad_norm.to(shard_norm.device), shard_norm
236+
)
237+
238+
# Similar to the case above, we move ddp_grad_norm to the same device as sharded_grad_norm so that we can do addition.
239+
ddp_grad_norm: torch.Tensor = (
240+
get_grad_norm(ddp_params, norm_type, max_grad_norm)
241+
if ddp_params
242+
else torch.tensor(0.0)
243+
).to(sharded_grad_norm.device)
244+
245+
combine_norm_operator = (
246+
torch.maximum
247+
if norm_type == torch.inf
248+
else lambda a, b: torch.add(a, b).pow(1.0 / norm_type)
249+
)
250+
251+
total_grad_norm = combine_norm_operator(ddp_grad_norm, sharded_grad_norm)
252+
return total_grad_norm
253+
254+
230255
def _batch_cal_norm(
231256
grad_list: List[torch.Tensor],
232257
max_norm: float,

torchrec/optim/tests/test_clipping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def _get_params_to_pg(
251251
return {param: [param.device_mesh.get_group()] for param in params}
252252

253253
@with_comms
254-
@parametrize("norm_type", ("inf", 1, 2))
254+
@parametrize("norm_type", ("inf",))
255255
def test_dtensor_clip_all_gradients_norm(
256256
self, norm_type: Union[float, str]
257257
) -> None:

0 commit comments

Comments
 (0)