Skip to content

Conversation

@relaxis
Copy link

@relaxis relaxis commented Nov 23, 2025

Summary

Fixes a critical bug in multistage training (e.g., Wan 2.2 LoRAs) that causes training to crash with RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn when timesteps land at or slightly beyond the boundary value due to floating point precision.

The Problem

When training multistage models like Wan 2.2 with separate high/low noise experts:

  • Low noise expert trains on timesteps 0-900 (boundary 0.0-0.9)
  • Timestep sampling can occasionally produce values like 900.223 due to floating point precision
  • This timestep falls outside the valid range for the low noise expert
  • The model's forward pass fails because the wrong expert is selected based on the timestep

The Fix

After converting timestep indices to actual timestep values (line 1267), the fix adds clamping to ensure timesteps stay strictly within the multistage boundary range:

if self.sd.is_multistage:
    boundaries = [1] + self.sd.multistage_boundaries
    boundary_max, boundary_min = boundaries[self.current_boundary_index], boundaries[self.current_boundary_index + 1]
    timesteps = torch.clamp(timesteps, min=boundary_min * 1000, max=(boundary_max * 1000) - 0.01)

For low noise training (boundary 0.0-0.9), timesteps are now clamped to [0, 899.99] max.

Testing

  • Tested on Wan 2.2 14B I2V LoRA training (low noise expert, boundary 0-0.9)
  • Training previously crashed at step 9602 with timestep t=900.223
  • With the fix, training continues past step 9602 without issues
  • No performance impact observed

Related Issues

This may be related to other reports of training crashes during multistage LoRA training with similar error messages.

🤖 Generated with Claude Code

Co-Authored-By: Claude [email protected]

Fixes an issue where multistage training (e.g., Wan 2.2 LoRAs) would crash
with "element 0 of tensors does not require grad and does not have a grad_fn"
error when timesteps landed exactly at or slightly beyond the boundary value
due to floating point precision.

The bug occurred when:
- Training low noise expert (timesteps 0-900)
- Scheduler returned a timestep like 900.223 (slightly above 900)
- This timestep was outside the valid boundary range for the expert
- Model forward pass failed because wrong expert was selected

The fix clamps timesteps to stay strictly within the multistage boundary
range after sampling. For low noise training with boundary 0.0-0.9, timesteps
are now clamped to [0, 899.99] max.

Tested on Wan 2.2 14B I2V LoRA training (low noise expert, boundary 0-0.9).
Training previously crashed at step 9602 with t=900.223, now continues without
issues.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant