@@ -425,10 +425,16 @@ def train__iter__helper(self, rng: random.Random, **filters):
425425 value
426426 )
427427 file_ids = np .where (training )[0 ]
428+ if len (file_ids ) == 0 :
429+ yield None
428430
429431 # turn annotated duration into a probability distribution
430432 annotated_duration = self .annotated_duration [file_ids ]
431- cum_prob_annotated_duration = np .cumsum (
433+ # in case there is a 0 seconds annotated region somewhere
434+ annotated_duration = np .nan_to_num (annotated_duration , nan = 0.0 , copy = True )
435+ if np .sum (annotated_duration ) == 0 :
436+ yield None
437+ cum_prob_annotated_duration : np .ndarray = np .cumsum (
432438 annotated_duration / np .sum (annotated_duration )
433439 )
434440
@@ -498,7 +504,13 @@ def train__iter__(self):
498504 # eg: for balance=["database", "split"], with 2 databases and 2 splits:
499505 # ("DIHARD", "A"), ("DIHARD", "B"), ("REPERE", "A"), ("REPERE", "B")
500506 filters = {key : value for key , value in zip (balance , product )}
501- subchunks [product ] = self .train__iter__helper (rng , ** filters )
507+ product_iterator = self .train__iter__helper (rng , ** filters )
508+
509+ # This specific product may not exist. For example, if balance=['database']
510+ # and there is a database that's not present in the training set (only in the val).
511+ # Or if a certain filter combination does not have any matching file.
512+ if next (product_iterator ) is not None :
513+ subchunks [product ] = product_iterator
502514
503515 # Compute the balance weights.
504516 # To get the weights of each subchunk generator,
0 commit comments