Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions src/pyannote/audio/tasks/segmentation/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ def train__iter__helper(self, rng: random.Random, **filters):

num_chunks_per_file = getattr(self, "num_chunks_per_file", 1)

# DDP support: Get rank info for deterministic sample assignment
trainer = getattr(self.model, 'trainer', None)
if trainer is not None and hasattr(trainer, 'world_size'):
world_size = trainer.world_size
global_rank = trainer.global_rank
else:
world_size = 1
global_rank = 0

# Counter to track generated samples for DDP distribution
sample_counter = 0

while True:
# select one file at random (with probability proportional to its annotated duration)
file_id = file_ids[cum_prob_annotated_duration.searchsorted(rng.random())]
Expand Down Expand Up @@ -135,7 +147,11 @@ def train__iter__helper(self, rng: random.Random, **filters):
]
start_time = rng.uniform(start, start + region_duration - duration)

yield self.prepare_chunk(file_id, start_time, duration)
# DDP: Only yield samples assigned to this rank (deterministic round-robin)
if sample_counter % world_size == global_rank:
yield self.prepare_chunk(file_id, start_time, duration)

sample_counter += 1

def train__iter__(self):
"""Iterate over training samples
Expand Down Expand Up @@ -171,7 +187,11 @@ def train__iter__(self):
filters = {key: value for key, value in zip(balance, product)}
subchunks[product] = self.train__iter__helper(rng, **filters)

while True:
# Calculate how many batches to generate (accounting for DDP)
num_batches = self.train__len__()

# Generate exactly num_batches samples
for _ in range(num_batches):
# select one subchunk generator at random (with uniform probability)
# so that it is balanced on average
if balance is not None:
Expand Down Expand Up @@ -249,7 +269,25 @@ def train__len__(self):
)[0]

duration = np.sum(self.prepared_data["audio-annotated"][train_file_ids])
return max(self.batch_size, math.ceil(duration / self.duration))

total_batches = max(self.batch_size, math.ceil(duration / self.duration))

# Adjust for DDP: return per-rank batch count
trainer = getattr(self.model, 'trainer', None)
if trainer is not None and hasattr(trainer, 'world_size'):
world_size = trainer.world_size
global_rank = trainer.global_rank

# Calculate batches per rank using modulo distribution
# This matches the assignment logic in train__iter__helper()
# Ranks 0 to (total_batches % world_size - 1) get one extra batch
batches_per_rank = total_batches // world_size
if global_rank < total_batches % world_size:
batches_per_rank += 1

return batches_per_rank

return total_batches

def prepare_validation(self, prepared_data: Dict):
validation_chunks = list()
Expand Down