Skip to content

Commit 0ca6e92

Browse files
committed
cherrypick a fix from my develop branch
(not tested in this branch)
1 parent d552c92 commit 0ca6e92

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

pyannote/audio/tasks/segmentation/mixins.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,10 +425,16 @@ def train__iter__helper(self, rng: random.Random, **filters):
425425
value
426426
)
427427
file_ids = np.where(training)[0]
428+
if len(file_ids) == 0:
429+
yield None
428430

429431
# turn annotated duration into a probability distribution
430432
annotated_duration = self.annotated_duration[file_ids]
431-
cum_prob_annotated_duration = np.cumsum(
433+
# in case there is a 0 seconds annotated region somewhere
434+
annotated_duration = np.nan_to_num(annotated_duration, nan=0.0, copy=True)
435+
if np.sum(annotated_duration) == 0:
436+
yield None
437+
cum_prob_annotated_duration: np.ndarray = np.cumsum(
432438
annotated_duration / np.sum(annotated_duration)
433439
)
434440

@@ -498,7 +504,13 @@ def train__iter__(self):
498504
# eg: for balance=["database", "split"], with 2 databases and 2 splits:
499505
# ("DIHARD", "A"), ("DIHARD", "B"), ("REPERE", "A"), ("REPERE", "B")
500506
filters = {key: value for key, value in zip(balance, product)}
501-
subchunks[product] = self.train__iter__helper(rng, **filters)
507+
product_iterator = self.train__iter__helper(rng, **filters)
508+
509+
# This specific product may not exist. For example, if balance=['database']
510+
# and there is a database that's not present in the training set (only in the val).
511+
# Or if a certain filter combination does not have any matching file.
512+
if next(product_iterator) is not None:
513+
subchunks[product] = product_iterator
502514

503515
# Compute the balance weights.
504516
# To get the weights of each subchunk generator,

0 commit comments

Comments
 (0)