File tree Expand file tree Collapse file tree 1 file changed +8
-1
lines changed
megatron/core/distributed/fsdp/src/megatron_fsdp Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Original file line number Diff line number Diff line change @@ -2473,6 +2473,10 @@ def update_main_grads(self):
24732473 optimizer_grad = group .main_grad_buffer .get_item (
24742474 item_id , only_shard = sharded_optimizer_state
24752475 )
2476+ if group .main_weight_buffer is not None :
2477+ if getattr (self , "use_precision_aware_optimizer" , False ):
2478+ # Convert the gradient to the main weight buffer dtype.
2479+ optimizer_grad = optimizer_grad .to (param .dtype )
24762480
24772481 if name not in self .dist_main_grad :
24782482 # Register the gradient as a distributed tensor.
@@ -2497,8 +2501,11 @@ def update_main_grads(self):
24972501
24982502 # The presence of main_grad_buffer but no main_weight_buffer may imply
24992503 # that a precision-aware optimizer is used.
2500- if getattr (self , "use_precision_aware_optimizer" , False ):
2504+ if getattr (self , "use_precision_aware_optimizer" , True ):
25012505 setattr (param , "decoupled_grad" , grad )
2506+ else :
2507+ # Attach the gradient to the optimizer parameter.
2508+ setattr (param , "grad" , grad .to (param .dtype ) if grad is not None else None )
25022509
25032510 @property
25042511 def num_buckets (self ):
You can’t perform that action at this time.
0 commit comments