22import warnings
33from bisect import bisect_right
44from collections import deque
5+ from concurrent .futures import ThreadPoolExecutor
56from itertools import islice
7+ from threading import Lock
68from typing import (
79 Deque ,
810 Generator ,
@@ -284,6 +286,9 @@ def __init__(
284286 deque () for _ in range (len (duration_bins ) + 1 )
285287 ]
286288
289+ self ._cut_reading_thread = ThreadPoolExecutor (1 )
290+ self ._bucket_mutex = Lock ()
291+
287292 def __iter__ (self ) -> Generator [CutSet , None , None ]:
288293 # Init: sample `buffer_size` cuts and assign them to the right buckets.
289294 self .cuts_iter = iter (self .cuts )
@@ -302,19 +307,20 @@ def is_ready(bucket: Deque[Cut]):
302307 # On each step we're sampling a new batch.
303308 try :
304309 while True :
305- ready_buckets = [b for b in self .buckets if is_ready (b )]
306- if not ready_buckets :
307- # No bucket has enough data to yield for the last full batch.
308- non_empty_buckets = [b for b in self .buckets if b ]
309- if self .drop_last or len (non_empty_buckets ) == 0 :
310- # Either the user requested only full batches, or we have nothing left.
311- raise StopIteration ()
312- else :
313- # Sample from partial batches that are left.
314- ready_buckets = non_empty_buckets
315- # Choose a bucket to sample from.
316- # We'll only select from the buckets that have a full batch available.
317- sampling_bucket = self .rng .choice (ready_buckets )
310+ with self ._bucket_mutex :
311+ ready_buckets = [b for b in self .buckets if is_ready (b )]
312+ if not ready_buckets :
313+ # No bucket has enough data to yield for the last full batch.
314+ non_empty_buckets = [b for b in self .buckets if b ]
315+ if self .drop_last or len (non_empty_buckets ) == 0 :
316+ # Either the user requested only full batches, or we have nothing left.
317+ raise StopIteration ()
318+ else :
319+ # Sample from partial batches that are left.
320+ ready_buckets = non_empty_buckets
321+ # Choose a bucket to sample from.
322+ # We'll only select from the buckets that have a full batch available.
323+ sampling_bucket = self .rng .choice (ready_buckets )
318324 # Sample one batch from that bucket and yield it to the caller.
319325 batcher = DurationBatcher (
320326 sampling_bucket ,
@@ -339,13 +345,17 @@ def is_ready(bucket: Deque[Cut]):
339345 self .cuts_iter = None
340346
341347 def _collect_cuts_in_buckets (self , n_cuts : int ):
342- try :
343- for _ in range ( n_cuts ) :
348+ def collect_one () :
349+ with self . _bucket_mutex :
344350 cuts = next (self .cuts_iter )
345351 duration = (
346352 cuts [0 ].duration if isinstance (cuts , tuple ) else cuts .duration
347353 )
348354 bucket_idx = bisect_right (self .duration_bins , duration )
349355 self .buckets [bucket_idx ].append (cuts )
356+
357+ try :
358+ for _ in range (n_cuts ):
359+ self ._cut_reading_thread .submit (collect_one )
350360 except StopIteration :
351361 pass
0 commit comments