Skip to content

Commit a9b43eb

Browse files
author
Lifu Zhang
committed
fix for when use_precision_aware_optimizer=false
Signed-off-by: Lifu Zhang <[email protected]>
1 parent e7184a2 commit a9b43eb

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2473,6 +2473,10 @@ def update_main_grads(self):
24732473
optimizer_grad = group.main_grad_buffer.get_item(
24742474
item_id, only_shard=sharded_optimizer_state
24752475
)
2476+
if group.main_weight_buffer is not None:
2477+
if getattr(self, "use_precision_aware_optimizer", False):
2478+
# Convert the gradient to the main weight buffer dtype.
2479+
optimizer_grad = optimizer_grad.to(param.dtype)
24762480

24772481
if name not in self.dist_main_grad:
24782482
# Register the gradient as a distributed tensor.
@@ -2497,8 +2501,11 @@ def update_main_grads(self):
24972501

24982502
# The presence of main_grad_buffer but no main_weight_buffer may imply
24992503
# that a precision-aware optimizer is used.
2500-
if getattr(self, "use_precision_aware_optimizer", False):
2504+
if getattr(self, "use_precision_aware_optimizer", True):
25012505
setattr(param, "decoupled_grad", grad)
2506+
else:
2507+
# Attach the gradient to the optimizer parameter.
2508+
setattr(param, "grad", grad.to(param.dtype) if grad is not None else None)
25022509

25032510
@property
25042511
def num_buckets(self):

0 commit comments

Comments
 (0)