11import json
2- import random
32from importlib .resources import files
43
54import torch
@@ -170,6 +169,7 @@ class DynamicBatchSampler(Sampler[list[int]]):
170169 in a batch to ensure that the total number of frames are less
171170 than a certain threshold.
172171 2. Make sure the padding efficiency in the batch is high.
172+ 3. Shuffle batches each epoch while maintaining reproducibility.
173173 """
174174
175175 def __init__ (
@@ -178,6 +178,8 @@ def __init__(
178178 self .sampler = sampler
179179 self .frames_threshold = frames_threshold
180180 self .max_samples = max_samples
181+ self .random_seed = random_seed
182+ self .epoch = 0
181183
182184 indices , batches = [], []
183185 data_source = self .sampler .data_source
@@ -210,17 +212,23 @@ def __init__(
210212 batches .append (batch )
211213
212214 del indices
213-
214- # if want to have different batches between epochs, may just set a seed and log it in ckpt
215- # cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
216- # e.g. for epoch n, use (random_seed + n)
217- random .seed (random_seed )
218- random .shuffle (batches )
219-
220215 self .batches = batches
221216
217+ def set_epoch (self , epoch : int ) -> None :
218+ """Sets the epoch for this sampler."""
219+ self .epoch = epoch
220+
222221 def __iter__ (self ):
223- return iter (self .batches )
222+ # Use both random_seed and epoch for deterministic but different shuffling per epoch
223+ if self .random_seed is not None :
224+ g = torch .Generator ()
225+ g .manual_seed (self .random_seed + self .epoch )
226+ # Use PyTorch's random permutation for better reproducibility across PyTorch versions
227+ indices = torch .randperm (len (self .batches ), generator = g ).tolist ()
228+ batches = [self .batches [i ] for i in indices ]
229+ else :
230+ batches = self .batches
231+ return iter (batches )
224232
225233 def __len__ (self ):
226234 return len (self .batches )
0 commit comments