LoRA training QoL improvements: UI progress bar, deterministic seeding, make gradient checkpointing optional #8668
+25
−4
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Adding the UI progress bar allows users to see the training progress in the UI (obviously) but also makes it possible to cancel training.
Gradient checkpointing, especially with so many checkpoints, is computationally expensive and not necessary if memory isn't a constraint. I left it enabled by default but disabling it is a free speed boost:
As for seeding, I replaced the unused generator and instead temporarily store the global RNG states, seed everything, then restore after training is finished. This seeds the weight initialization without needing to pass a generator function all over the place. The RNG of weight initialization is pretty significant, if it's allowed to be random then workflows which directly incorporate lora training instead of loading a trained file would be impossible to reproduce. It also seeds timestep sampling, which is the main factor driving training loss at small batch sizes.
With this change, fp32 training is now fully deterministic, although bf16 training is still partially nondeterministic, and I wasn't able to track down the cause of that. I'm guessing it could be related to stochastic rounding?