Skip to content

LoRA training QoL improvements: UI progress bar, deterministic seeding, make gradient checkpointing optional #8668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions comfy_extras/nodes_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,13 @@ def INPUT_TYPES(s):
["bf16", "fp32"],
{"default": "bf16", "tooltip": "The dtype to use for lora."},
),
"gradient_checkpointing": (
IO.BOOLEAN,
{
"default": True,
"tooltip": "Use gradient checkpointing to reduce memory usage at the cost of speed)",
},
),
"existing_lora": (
folder_paths.get_filename_list("loras") + ["[None]"],
{
Expand Down Expand Up @@ -372,9 +379,11 @@ def train(
seed,
training_dtype,
lora_dtype,
gradient_checkpointing,
existing_lora,
):
mp = model.clone()
device = comfy.model_management.get_torch_device()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
Expand All @@ -384,8 +393,9 @@ def train(

with torch.inference_mode(False):
lora_sd = {}
generator = torch.Generator()
generator.manual_seed(seed)
old_cpu_rng_state = torch.get_rng_state()
old_device_rng_state = torch.cuda.get_rng_state(device)
torch.manual_seed(seed)

# Load existing LoRA weights if provided
existing_weights = {}
Expand Down Expand Up @@ -472,8 +482,12 @@ def train(
criterion = torch.nn.SmoothL1Loss()

# setup models
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
patch(m)
if gradient_checkpointing:
modules_to_patch = find_all_highest_child_module_with_forward(mp.model.diffusion_model)
for m in modules_to_patch:
patch(m)
logging.info(f"Added gradient checkpoints to {len(modules_to_patch)} modules")

comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)

# Setup sampler and guider like in test script
Expand All @@ -493,6 +507,9 @@ def loss_callback(loss):
# Training loop
torch.cuda.empty_cache()
try:
if comfy.utils.PROGRESS_BAR_ENABLED:
ui_pbar = comfy.utils.ProgressBar(steps)

for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
# Generate random sigma
sigma = mp.model.model_sampling.percent_to_sigma(
Expand All @@ -506,6 +523,7 @@ def loss_callback(loss):
ss.sample(
noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()}
)
ui_pbar.update(1)
finally:
for m in mp.model.modules():
unpatch(m)
Expand All @@ -518,6 +536,9 @@ def loss_callback(loss):
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)

torch.set_rng_state(old_cpu_rng_state)
torch.cuda.set_rng_state(old_device_rng_state, device)

return (mp, lora_sd, loss_map, steps + existing_steps)


Expand Down