@@ -1149,45 +1149,45 @@ def execute(
11491149 # Process conditioning
11501150 positive = _process_conditioning (positive )
11511151
1152- # Setup model and dtype
1153- mp = model .clone ()
1154- use_grad_scaler = False
1155- lora_dtype = node_helpers .string_to_torch_dtype (lora_dtype )
1156- if training_dtype != "none" :
1157- dtype = node_helpers .string_to_torch_dtype (training_dtype )
1158- mp .set_model_compute_dtype (dtype )
1159- else :
1160- # Detect model's native dtype for autocast
1161- model_dtype = mp .model .get_dtype ()
1162- if model_dtype == torch .float16 :
1163- dtype = torch .float16
1164- # GradScaler only supports float16 gradients, not bfloat16.
1165- # Only enable it when lora params will also be in float16.
1166- if lora_dtype != torch .bfloat16 :
1167- use_grad_scaler = True
1168- # Warn about fp16 accumulation instability during training
1169- if PerformanceFeature .Fp16Accumulation in args .fast :
1170- logging .warning (
1171- "WARNING: FP16 model detected with fp16_accumulation enabled. "
1172- "This combination can be numerically unstable during training and may cause NaN values. "
1173- "Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
1174- )
1152+ with torch .inference_mode (False ):
1153+ # Setup model and dtype
1154+ mp = model .clone (force_deepcopy = True )
1155+ use_grad_scaler = False
1156+ lora_dtype = node_helpers .string_to_torch_dtype (lora_dtype )
1157+ if training_dtype != "none" :
1158+ dtype = node_helpers .string_to_torch_dtype (training_dtype )
1159+ mp .set_model_compute_dtype (dtype )
11751160 else :
1176- # For fp8, bf16, or other dtypes, use bf16 autocast
1177- dtype = torch .bfloat16
1161+ # Detect model's native dtype for autocast
1162+ model_dtype = mp .model .get_dtype ()
1163+ if model_dtype == torch .float16 :
1164+ dtype = torch .float16
1165+ # GradScaler only supports float16 gradients, not bfloat16.
1166+ # Only enable it when lora params will also be in float16.
1167+ if lora_dtype != torch .bfloat16 :
1168+ use_grad_scaler = True
1169+ # Warn about fp16 accumulation instability during training
1170+ if PerformanceFeature .Fp16Accumulation in args .fast :
1171+ logging .warning (
1172+ "WARNING: FP16 model detected with fp16_accumulation enabled. "
1173+ "This combination can be numerically unstable during training and may cause NaN values. "
1174+ "Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
1175+ )
1176+ else :
1177+ # For fp8, bf16, or other dtypes, use bf16 autocast
1178+ dtype = torch .bfloat16
11781179
1179- # Prepare latents and compute counts
1180- latents_dtype = dtype if dtype not in (None ,) else torch .bfloat16
1181- latents , num_images , multi_res = _prepare_latents_and_count (
1182- latents , latents_dtype , bucket_mode
1183- )
1180+ # Prepare latents and compute counts
1181+ latents_dtype = dtype if dtype not in (None ,) else torch .bfloat16
1182+ latents , num_images , multi_res = _prepare_latents_and_count (
1183+ latents , latents_dtype , bucket_mode
1184+ )
11841185
1185- # Validate and expand conditioning
1186- positive = _validate_and_expand_conditioning (positive , num_images , bucket_mode )
1186+ # Validate and expand conditioning
1187+ positive = _validate_and_expand_conditioning (positive , num_images , bucket_mode )
11871188
1188- with torch .inference_mode (False ):
11891189 # Setup models for training
1190- mp .model .requires_grad_ (False )
1190+ mp .model .requires_grad_ (False ). train ()
11911191
11921192 # Load existing LoRA weights if provided
11931193 existing_weights , existing_steps = _load_existing_lora (existing_lora )
0 commit comments