-
-
Notifications
You must be signed in to change notification settings - Fork 744
Open
Labels
BugSomething isn't workingSomething isn't workingContributions WelcomeWe welcome contributions to fix this issue!We welcome contributions to fix this issue!Medium Priority(will be worked on after all high priority issues)(will be worked on after all high priority issues)OptimizersIssues or feature requests relating to optimizersIssues or feature requests relating to optimizers
Description
System Info
Running a standard training loop where I save the optimizer state_dict using opt.state_dict().
Upon loading using opt.load_state_dict() to resume, the model immediately NaNs after the first backprop step.
This only occurs using the AdEMA optimizer:
bnb.optim.AdEMAMix8bit(model.parameters(), lr=lr, t_alpha=T, t_beta3=T)
AdamW and others load state dict perfectly fine. Any ideas?
Reproduction
`
opt = bnb.optim.AdEMAMix8bit(model.parameters())
#run training loop
torch.save(opt.state_dict(), "dt.pt")
#try resuming opt from state_dict later
opt.load_state_dict("dt.pt")
#run training loop again
`
Expected behavior
Optimizer should resume training without NaNning
Metadata
Metadata
Assignees
Labels
BugSomething isn't workingSomething isn't workingContributions WelcomeWe welcome contributions to fix this issue!We welcome contributions to fix this issue!Medium Priority(will be worked on after all high priority issues)(will be worked on after all high priority issues)OptimizersIssues or feature requests relating to optimizersIssues or feature requests relating to optimizers