diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index db6c43a3f..194b687fd 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1265,7 +1265,15 @@ def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): with self.timer('convert_timestep_indices_to_timesteps'): # convert the timestep_indices to a timestep timesteps = self.sd.noise_scheduler.timesteps[timestep_indices.long()] - + + # Clamp timesteps to multistage boundary range to prevent floating point edge cases + if self.sd.is_multistage: + boundaries = [1] + self.sd.multistage_boundaries + boundary_max, boundary_min = boundaries[self.current_boundary_index], boundaries[self.current_boundary_index + 1] + # Ensure timesteps stay strictly within (boundary_min * 1000, boundary_max * 1000) range + # Subtract small epsilon from max to avoid edge case where timestep equals boundary + timesteps = torch.clamp(timesteps, min=boundary_min * 1000, max=(boundary_max * 1000) - 0.01) + with self.timer('prepare_noise'): # get noise noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch, timestep=timesteps)