Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions lhotse/dataset/sampling/dynamic_bucketing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import concurrent.futures
import random
import warnings
from bisect import bisect_right
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from itertools import islice
from typing import Any, Deque, Dict, Generator, Iterable, List, Optional, Tuple, Union

Expand Down Expand Up @@ -334,6 +336,9 @@ def __init__(
deque() for _ in range(len(duration_bins) + 1)
]

self._cut_reading_thread = ThreadPoolExecutor(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to not use a process pool? Due to the global interpreter lock, there can be only one running thread at any given time in Python, I think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, with some setups that use IterableDatasetWrapper you are placing the sampler in a dataloader worker process, and AFAIK you can't spawn a nested process pool there because that process is daemonic.

Anyway thread should be sufficient here as I expect the CPU to be mostly idle when running forward and backward passes on GPUs... The reason it didn't work for you is likely the thread could not populate the buckets fast enough and sampler thought they are depleted (race condition). This can be solved with a proper synchronization mechanism but unfortunately I don't have the time to add it right now. I'll return to it sometime.

self._cut_reading_future: Optional[concurrent.futures.Future] = None

def __iter__(self) -> Generator[CutSet, None, None]:
# Init: sample `buffer_size` cuts and assign them to the right buckets.
self.cuts_iter = iter(self.cuts)
Expand All @@ -356,6 +361,7 @@ def is_ready(bucket: Deque[Cut]):
# On each step we're sampling a new batch.
try:
while True:
self._wait_for_cut_collection()
ready_buckets = [b for b in self.buckets if is_ready(b)]
if not ready_buckets:
# No bucket has enough data to yield for the last full batch.
Expand Down Expand Up @@ -394,13 +400,21 @@ def is_ready(bucket: Deque[Cut]):
self.cuts_iter = None

def _collect_cuts_in_buckets(self, n_cuts: int):
try:
def collect():
for _ in range(n_cuts):
cuts = next(self.cuts_iter)
duration = (
cuts[0].duration if isinstance(cuts, tuple) else cuts.duration
)
bucket_idx = bisect_right(self.duration_bins, duration)
self.buckets[bucket_idx].append(cuts)
except StopIteration:
pass

assert self._cut_reading_future is None
self._cut_reading_future = self._cut_reading_thread.submit(collect)

def _wait_for_cut_collection(self):
assert self._cut_reading_future is not None
err = self._cut_reading_future.exception()
if err is not None and not isinstance(err, StopIteration):
raise err
self._cut_reading_future = None