Skip to content

Commit f350acd

Browse files
[Trainer/bug] Ensure model is not inference mode (CORE-72) (Comfy-Org#13400)
* Ensure model is not inference mode * force clone inside training mode to avoid inference tensor * Allow force deepcopy for model patcher
1 parent 46d45aa commit f350acd

2 files changed

Lines changed: 38 additions & 37 deletions

File tree

comfy/model_patcher.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,10 +379,11 @@ def get_free_memory(self, device):
379379
def get_clone_model_override(self):
380380
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
381381

382-
def clone(self, disable_dynamic=False, model_override=None):
382+
def clone(self, disable_dynamic=False, model_override=None, force_deepcopy=False):
383383
class_ = self.__class__
384-
if self.is_dynamic() and disable_dynamic:
385-
class_ = ModelPatcher
384+
if self.is_dynamic() and disable_dynamic or force_deepcopy:
385+
if self.is_dynamic() and disable_dynamic:
386+
class_ = ModelPatcher
386387
if model_override is None:
387388
if self.cached_patcher_init is None:
388389
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")

comfy_extras/nodes_train.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,45 +1149,45 @@ def execute(
11491149
# Process conditioning
11501150
positive = _process_conditioning(positive)
11511151

1152-
# Setup model and dtype
1153-
mp = model.clone()
1154-
use_grad_scaler = False
1155-
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
1156-
if training_dtype != "none":
1157-
dtype = node_helpers.string_to_torch_dtype(training_dtype)
1158-
mp.set_model_compute_dtype(dtype)
1159-
else:
1160-
# Detect model's native dtype for autocast
1161-
model_dtype = mp.model.get_dtype()
1162-
if model_dtype == torch.float16:
1163-
dtype = torch.float16
1164-
# GradScaler only supports float16 gradients, not bfloat16.
1165-
# Only enable it when lora params will also be in float16.
1166-
if lora_dtype != torch.bfloat16:
1167-
use_grad_scaler = True
1168-
# Warn about fp16 accumulation instability during training
1169-
if PerformanceFeature.Fp16Accumulation in args.fast:
1170-
logging.warning(
1171-
"WARNING: FP16 model detected with fp16_accumulation enabled. "
1172-
"This combination can be numerically unstable during training and may cause NaN values. "
1173-
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
1174-
)
1152+
with torch.inference_mode(False):
1153+
# Setup model and dtype
1154+
mp = model.clone(force_deepcopy=True)
1155+
use_grad_scaler = False
1156+
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
1157+
if training_dtype != "none":
1158+
dtype = node_helpers.string_to_torch_dtype(training_dtype)
1159+
mp.set_model_compute_dtype(dtype)
11751160
else:
1176-
# For fp8, bf16, or other dtypes, use bf16 autocast
1177-
dtype = torch.bfloat16
1161+
# Detect model's native dtype for autocast
1162+
model_dtype = mp.model.get_dtype()
1163+
if model_dtype == torch.float16:
1164+
dtype = torch.float16
1165+
# GradScaler only supports float16 gradients, not bfloat16.
1166+
# Only enable it when lora params will also be in float16.
1167+
if lora_dtype != torch.bfloat16:
1168+
use_grad_scaler = True
1169+
# Warn about fp16 accumulation instability during training
1170+
if PerformanceFeature.Fp16Accumulation in args.fast:
1171+
logging.warning(
1172+
"WARNING: FP16 model detected with fp16_accumulation enabled. "
1173+
"This combination can be numerically unstable during training and may cause NaN values. "
1174+
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
1175+
)
1176+
else:
1177+
# For fp8, bf16, or other dtypes, use bf16 autocast
1178+
dtype = torch.bfloat16
11781179

1179-
# Prepare latents and compute counts
1180-
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
1181-
latents, num_images, multi_res = _prepare_latents_and_count(
1182-
latents, latents_dtype, bucket_mode
1183-
)
1180+
# Prepare latents and compute counts
1181+
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
1182+
latents, num_images, multi_res = _prepare_latents_and_count(
1183+
latents, latents_dtype, bucket_mode
1184+
)
11841185

1185-
# Validate and expand conditioning
1186-
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
1186+
# Validate and expand conditioning
1187+
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
11871188

1188-
with torch.inference_mode(False):
11891189
# Setup models for training
1190-
mp.model.requires_grad_(False)
1190+
mp.model.requires_grad_(False).train()
11911191

11921192
# Load existing LoRA weights if provided
11931193
existing_weights, existing_steps = _load_existing_lora(existing_lora)

0 commit comments

Comments
 (0)