Skip to content

Commit a9134b9

Browse files
committed
PoC for reading cuts in background thread in dynamic bucketing
1 parent b3b96a1 commit a9134b9

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

lhotse/dataset/sampling/dynamic_bucketing.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import warnings
33
from bisect import bisect_right
44
from collections import deque
5+
from concurrent.futures import ThreadPoolExecutor
56
from itertools import islice
7+
from threading import Lock
68
from typing import (
79
Deque,
810
Generator,
@@ -284,6 +286,9 @@ def __init__(
284286
deque() for _ in range(len(duration_bins) + 1)
285287
]
286288

289+
self._cut_reading_thread = ThreadPoolExecutor(1)
290+
self._bucket_mutex = Lock()
291+
287292
def __iter__(self) -> Generator[CutSet, None, None]:
288293
# Init: sample `buffer_size` cuts and assign them to the right buckets.
289294
self.cuts_iter = iter(self.cuts)
@@ -302,19 +307,20 @@ def is_ready(bucket: Deque[Cut]):
302307
# On each step we're sampling a new batch.
303308
try:
304309
while True:
305-
ready_buckets = [b for b in self.buckets if is_ready(b)]
306-
if not ready_buckets:
307-
# No bucket has enough data to yield for the last full batch.
308-
non_empty_buckets = [b for b in self.buckets if b]
309-
if self.drop_last or len(non_empty_buckets) == 0:
310-
# Either the user requested only full batches, or we have nothing left.
311-
raise StopIteration()
312-
else:
313-
# Sample from partial batches that are left.
314-
ready_buckets = non_empty_buckets
315-
# Choose a bucket to sample from.
316-
# We'll only select from the buckets that have a full batch available.
317-
sampling_bucket = self.rng.choice(ready_buckets)
310+
with self._bucket_mutex:
311+
ready_buckets = [b for b in self.buckets if is_ready(b)]
312+
if not ready_buckets:
313+
# No bucket has enough data to yield for the last full batch.
314+
non_empty_buckets = [b for b in self.buckets if b]
315+
if self.drop_last or len(non_empty_buckets) == 0:
316+
# Either the user requested only full batches, or we have nothing left.
317+
raise StopIteration()
318+
else:
319+
# Sample from partial batches that are left.
320+
ready_buckets = non_empty_buckets
321+
# Choose a bucket to sample from.
322+
# We'll only select from the buckets that have a full batch available.
323+
sampling_bucket = self.rng.choice(ready_buckets)
318324
# Sample one batch from that bucket and yield it to the caller.
319325
batcher = DurationBatcher(
320326
sampling_bucket,
@@ -339,13 +345,17 @@ def is_ready(bucket: Deque[Cut]):
339345
self.cuts_iter = None
340346

341347
def _collect_cuts_in_buckets(self, n_cuts: int):
342-
try:
343-
for _ in range(n_cuts):
348+
def collect_one():
349+
with self._bucket_mutex:
344350
cuts = next(self.cuts_iter)
345351
duration = (
346352
cuts[0].duration if isinstance(cuts, tuple) else cuts.duration
347353
)
348354
bucket_idx = bisect_right(self.duration_bins, duration)
349355
self.buckets[bucket_idx].append(cuts)
356+
357+
try:
358+
for _ in range(n_cuts):
359+
self._cut_reading_thread.submit(collect_one)
350360
except StopIteration:
351361
pass

0 commit comments

Comments
 (0)