@@ -59,7 +59,7 @@ def __init__(
59
59
super ().__init__ (optimizer )
60
60
self ._clipping = clipping
61
61
self ._max_gradient = max_gradient
62
- self ._norm_type = norm_type
62
+ self ._norm_type = float ( norm_type )
63
63
self ._check_meta : bool = True
64
64
self ._enable_global_grad_clip = enable_global_grad_clip
65
65
self ._step_num = 0
@@ -122,121 +122,130 @@ def step(self, closure: Any = None) -> None:
122
122
for p in self ._replicate_params
123
123
]
124
124
torch .nn .utils .clip_grad_norm_ (
125
- replicate_params ,
126
- self ._max_gradient ,
127
- norm_type = float ( self ._norm_type ) ,
125
+ parameters = replicate_params ,
126
+ max_norm = self ._max_gradient ,
127
+ norm_type = self ._norm_type ,
128
128
)
129
129
else :
130
130
self .clip_grad_norm_ ()
131
131
132
132
elif self ._clipping == GradientClipping .VALUE :
133
- torch .nn .utils .clip_grad_value_ (self ._replicate_params , self ._max_gradient )
133
+ torch .nn .utils .clip_grad_value_ (
134
+ parameters = self ._replicate_params , clip_value = self ._max_gradient
135
+ )
134
136
135
137
super ().step (closure )
136
138
self ._step_num += 1
137
139
138
- @torch .no_grad ()
139
140
def clip_grad_norm_ (self ) -> Optional [Union [float , torch .Tensor ]]:
140
141
"""Clip the gradient norm of all parameters."""
141
- max_norm = self . _max_gradient
142
- norm_type = float ( self ._norm_type )
142
+
143
+ # converts self._norm_type to a float if it's a string. Used in the case where self._norm_type is 'inf'.
143
144
all_grads = []
144
- total_grad_norm = None
145
+ sharded_params = self ._sharded_params
146
+ replicate_params = self ._replicate_params
145
147
146
148
# Process distributed parameters and gradients
147
- for pgs , dist_params in self ._sharded_params .items ():
148
- sharded_grads = [
149
- p .grad ._local_tensor if isinstance (p .grad , DTensor ) else p .grad
150
- for p in dist_params
151
- if p .grad is not None and p .grad .numel () > 0
152
- ]
153
- if len (sharded_grads ) == 0 :
154
- continue
155
- all_grads .extend (sharded_grads )
156
-
157
- sharded_grad_norm = _batch_cal_norm (
158
- sharded_grads ,
159
- max_norm ,
160
- norm_type ,
161
- pgs ,
162
- )
163
- total_grad_norm = (
164
- sharded_grad_norm
165
- if total_grad_norm is None
166
- else (
167
- torch .maximum (total_grad_norm , sharded_grad_norm )
168
- if norm_type == torch .inf
169
- else total_grad_norm + sharded_grad_norm
170
- )
171
- )
172
-
173
- square_sharded_grad_norm = total_grad_norm if total_grad_norm is not None else 0
149
+ sharded_grads = {
150
+ pgs : _get_grads (dist_params ) for pgs , dist_params in sharded_params .items ()
151
+ }
152
+ all_grads .extend (* sharded_grads .values ())
174
153
175
154
# Process replicated parameters and gradients
176
- if self ._replicate_params :
177
- replicated_grads = [
178
- p .grad ._local_tensor if isinstance (p .grad , DTensor ) else p .grad
179
- for p in self ._replicate_params
180
- if p .grad is not None and p .grad .numel () > 0
181
- ]
182
- all_grads .extend (replicated_grads )
183
-
184
- replicated_grad_norm = _batch_cal_norm (
185
- replicated_grads ,
186
- max_norm ,
187
- norm_type ,
188
- None ,
189
- )
190
- total_grad_norm = (
191
- replicated_grad_norm
192
- if total_grad_norm is None
193
- else (
194
- torch .maximum (total_grad_norm , replicated_grad_norm )
195
- if norm_type == torch .inf
196
- else total_grad_norm + replicated_grad_norm
197
- )
198
- )
199
- square_replicated_grad_norm = replicated_grad_norm
200
- else :
201
- square_replicated_grad_norm = 0
202
-
203
- global log_grad_norm
204
- if log_grad_norm :
205
- if total_grad_norm is not None and norm_type != torch .inf :
206
- # pyre-ignore[58]
207
- grad_norm = total_grad_norm ** (1.0 / norm_type )
208
- else :
209
- grad_norm = total_grad_norm
210
-
211
- rank = dist .get_rank ()
212
- logger .info (
213
- f"Clipping [rank={ rank } , step={ self ._step_num } ]: square_sharded_grad_norm = { square_sharded_grad_norm } , square_replicated_grad_norm = { square_replicated_grad_norm } , total_grad_norm = { grad_norm } "
214
- )
215
-
216
- # Aggregation
217
- if total_grad_norm is None :
218
- return
155
+ replicate_grads = _get_grads (replicate_params )
156
+ all_grads .extend (replicate_grads )
157
+
158
+ total_grad_norm = _compute_total_norm (
159
+ replicate_grads = replicate_grads ,
160
+ sharded_grads = sharded_grads ,
161
+ norm_type = self ._norm_type ,
162
+ max_grad_norm = self ._max_gradient ,
163
+ )
219
164
220
- if norm_type != torch .inf :
221
- # pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
222
- total_grad_norm = total_grad_norm ** (1.0 / norm_type )
223
165
# pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
224
- clip_coef = cast (torch .Tensor , max_norm / (total_grad_norm + 1e-6 ))
166
+ clip_coef = cast (torch .Tensor , self . _max_gradient / (total_grad_norm + 1e-6 ))
225
167
clip_coef_clamped = torch .clamp (clip_coef , max = 1.0 )
226
168
torch ._foreach_mul_ (all_grads , clip_coef_clamped )
227
169
return total_grad_norm
228
170
229
171
172
+ def _get_grads (
173
+ param_list : List [torch .Tensor ],
174
+ ) -> List [torch .Tensor ]:
175
+ """Get the gradients of a list of parameters. Converts DTensors to local tensors if needed."""
176
+ grads = [
177
+ p .grad ._local_tensor if isinstance (p .grad , DTensor ) else p .grad
178
+ for p in param_list
179
+ if p .grad is not None and p .grad .numel () > 0
180
+ ]
181
+ return grads
182
+
183
+
184
+ def _compute_total_norm (
185
+ replicate_grads : List [torch .Tensor ],
186
+ sharded_grads : Dict [Tuple [dist .ProcessGroup ], List [torch .Tensor ]],
187
+ norm_type : float = 2.0 , # can be a normal float, or torch.inf
188
+ max_grad_norm : float = 1.0 ,
189
+ ) -> torch .Tensor :
190
+ """
191
+ Given both replicate grads and sharded grads, compute the total norm of the gradients of the full replicate params and the
192
+ full sharded param (parameters with a process group).
193
+
194
+ Args:
195
+ replicate_grads (List[torch.Tensor]): list of gradients for replicate params
196
+ sharded_grads (Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]]): dict that maps each process group to a list of gradients for sharded params
197
+ norm_type (float): type of the used p-norm. Can be torch.inf for infinity norm.
198
+ max_grad_norm (float): max gradient norm.
199
+ """
200
+
201
+ ## compute the norm |W|^p corresponding to all sharded params W
202
+ sharded_grad_norm : torch .Tensor = torch .tensor (0.0 )
203
+ combine_norm_operator = torch .maximum if norm_type == torch .inf else torch .add
204
+
205
+ # We need to move sharded_grad_norm to the same device as the first shard so that we can do addition (or take max)
206
+ # this is specifically for the case where sharded_grad_norm is 0, and replicate_grad_norm is not,
207
+ # because by default torch.tensor(0.0) is on cpu, and replicate_grad_norm is on GPU. For MTIA
208
+ # specifically, adding a tensor on cpu and a tensor on GPU will result in an error.
209
+ for pgs , dist_params in sharded_grads .items ():
210
+ current_shard_norm = _batch_cal_norm (
211
+ grad_list = dist_params ,
212
+ max_norm = max_grad_norm ,
213
+ norm_type = norm_type ,
214
+ process_groups = pgs ,
215
+ )
216
+ sharded_grad_norm = combine_norm_operator (
217
+ sharded_grad_norm .to (current_shard_norm .device ), current_shard_norm
218
+ )
219
+ # compute |W|^p corresponding to all replicate params W
220
+ # Similar to the case above, we move replicate_grad_norm to the same device as sharded_grad_norm so that we can do addition.
221
+ replicate_grad_norm : torch .Tensor = (
222
+ _batch_cal_norm (
223
+ grad_list = replicate_grads , max_norm = max_grad_norm , norm_type = norm_type
224
+ )
225
+ if replicate_grads
226
+ else torch .tensor (0.0 )
227
+ ).to (sharded_grad_norm .device )
228
+
229
+ # In the p-norm case, we are given norms |W_sharded|^p and |W_replicate|^p. To compute the total norm, we need to
230
+ # sum them and take the p-th root. In the inf-norm case, we are given max(|W_sharded|) and max(|W_replicate|).
231
+ # To compute the total norm, we need to take max(max(|W_sharded|), max(|W_replicate|).
232
+ combined_norm = combine_norm_operator (replicate_grad_norm , sharded_grad_norm )
233
+ total_grad_norm = (
234
+ combined_norm .pow (1.0 / norm_type ) if norm_type != torch .inf else combined_norm
235
+ )
236
+
237
+ return total_grad_norm
238
+
239
+
230
240
def _batch_cal_norm (
231
241
grad_list : List [torch .Tensor ],
232
242
max_norm : float ,
233
243
norm_type : float = 2.0 ,
234
244
process_groups : Optional [Tuple [dist .ProcessGroup ]] = None ,
235
245
) -> torch .Tensor :
236
- """Helper function that calculates the norm of a list of gradients in batches. If process_groups
237
- are passed in, the norm will be aggregated across all ranks in the process group.
246
+ """Helper function that calculates the p-th power of the norm of a list of gradients in batches.
247
+ If process_groups are passed in, the norm will be aggregated across all ranks in the process group.
238
248
"""
239
-
240
249
global use_64bit_grad_norm
241
250
if use_64bit_grad_norm :
242
251
grad_norms = torch .linalg .vector_norm (
0 commit comments