Skip to content

Commit 906b1af

Browse files
authored
Merge pull request #765 from hcsolakoglu/dynbatchsampler-epoch-shuffle
Add Per-Epoch Batch Shuffling to DynamicBatchSampler
2 parents bebbfbb + 33e8651 commit 906b1af

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

src/f5_tts/model/dataset.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import json
2-
import random
32
from importlib.resources import files
43

54
import torch
@@ -170,6 +169,7 @@ class DynamicBatchSampler(Sampler[list[int]]):
170169
in a batch to ensure that the total number of frames are less
171170
than a certain threshold.
172171
2. Make sure the padding efficiency in the batch is high.
172+
3. Shuffle batches each epoch while maintaining reproducibility.
173173
"""
174174

175175
def __init__(
@@ -178,6 +178,8 @@ def __init__(
178178
self.sampler = sampler
179179
self.frames_threshold = frames_threshold
180180
self.max_samples = max_samples
181+
self.random_seed = random_seed
182+
self.epoch = 0
181183

182184
indices, batches = [], []
183185
data_source = self.sampler.data_source
@@ -210,17 +212,23 @@ def __init__(
210212
batches.append(batch)
211213

212214
del indices
213-
214-
# if want to have different batches between epochs, may just set a seed and log it in ckpt
215-
# cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
216-
# e.g. for epoch n, use (random_seed + n)
217-
random.seed(random_seed)
218-
random.shuffle(batches)
219-
220215
self.batches = batches
221216

217+
def set_epoch(self, epoch: int) -> None:
218+
"""Sets the epoch for this sampler."""
219+
self.epoch = epoch
220+
222221
def __iter__(self):
223-
return iter(self.batches)
222+
# Use both random_seed and epoch for deterministic but different shuffling per epoch
223+
if self.random_seed is not None:
224+
g = torch.Generator()
225+
g.manual_seed(self.random_seed + self.epoch)
226+
# Use PyTorch's random permutation for better reproducibility across PyTorch versions
227+
indices = torch.randperm(len(self.batches), generator=g).tolist()
228+
batches = [self.batches[i] for i in indices]
229+
else:
230+
batches = self.batches
231+
return iter(batches)
224232

225233
def __len__(self):
226234
return len(self.batches)

src/f5_tts/model/trainer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,11 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
279279
self.accelerator.even_batches = False
280280
sampler = SequentialSampler(train_dataset)
281281
batch_sampler = DynamicBatchSampler(
282-
sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
282+
sampler,
283+
self.batch_size,
284+
max_samples=self.max_samples,
285+
random_seed=resumable_with_seed, # This enables reproducible shuffling
286+
drop_last=False,
283287
)
284288
train_dataloader = DataLoader(
285289
train_dataset,
@@ -329,6 +333,10 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
329333
progress_bar_initial = 0
330334
current_dataloader = train_dataloader
331335

336+
# Set epoch for the batch sampler if it exists
337+
if hasattr(train_dataloader, "batch_sampler") and hasattr(train_dataloader.batch_sampler, "set_epoch"):
338+
train_dataloader.batch_sampler.set_epoch(epoch)
339+
332340
progress_bar = tqdm(
333341
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
334342
desc=f"Epoch {epoch+1}/{self.epochs}",

0 commit comments

Comments
 (0)