Skip to content

Commit ea9014e

Browse files
authored
Merge pull request #670 from lhotse-speech/feature/dynamic-bucketing-supports-small-data
DynamicBucketingSampler supports very small data
2 parents d32d3f0 + ba01e09 commit ea9014e

File tree

2 files changed

+47
-6
lines changed

2 files changed

+47
-6
lines changed

lhotse/dataset/sampling/dynamic_bucketing.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def __iter__(self) -> "DynamicBucketingSampler":
176176
buffer_size=self.buffer_size,
177177
strict=self.strict,
178178
rng=self.rng,
179+
diagnostics=self.diagnostics,
179180
)
180-
self.cuts_iter.diagnostics = self.diagnostics
181181
self.cuts_iter = iter(self.cuts_iter)
182182
return self
183183

@@ -224,9 +224,9 @@ def estimate_duration_buckets(cuts: Iterable[Cut], num_buckets: int) -> List[Sec
224224

225225
durs = np.array([c.duration for c in cuts])
226226
durs.sort()
227-
assert num_buckets < durs.shape[0], (
228-
f"The number of buckets ({num_buckets}) must be smaller "
229-
f"than the number of cuts ({durs.shape[0]})."
227+
assert num_buckets <= durs.shape[0], (
228+
f"The number of buckets ({num_buckets}) must be smaller than "
229+
f"or equal to the number of cuts ({durs.shape[0]})."
230230
)
231231
bucket_duration = durs.sum() / num_buckets
232232

@@ -298,8 +298,6 @@ def is_ready(bucket: Deque[Cut]):
298298
return True
299299
return False
300300

301-
assert any(is_ready(bucket) for bucket in self.buckets)
302-
303301
# The iteration code starts here.
304302
# On each step we're sampling a new batch.
305303
try:

test/dataset/sampling/test_dynamic_bucketing.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,49 @@ def test_dynamic_bucketing_sampler():
143143
assert sum(c.duration for c in batches[3]) == 2
144144

145145

146+
def test_dynamic_bucketing_sampler_too_small_data_can_be_sampled():
147+
cuts = DummyManifest(CutSet, begin_id=0, end_id=10)
148+
for i, c in enumerate(cuts):
149+
if i < 5:
150+
c.duration = 1
151+
else:
152+
c.duration = 2
153+
154+
# 10 cuts with 30s total are not enough to satisfy max_duration of 100 with 2 buckets
155+
sampler = DynamicBucketingSampler(cuts, max_duration=100, num_buckets=2, seed=0)
156+
batches = [b for b in sampler]
157+
sampled_cuts = [c for b in batches for c in b]
158+
159+
# Invariant: no duplicated cut IDs
160+
assert len(set(c.id for b in batches for c in b)) == len(sampled_cuts)
161+
162+
# Same number of sampled and source cuts.
163+
assert len(sampled_cuts) == len(cuts)
164+
165+
# We sampled 10 batches
166+
assert len(batches) == 2
167+
168+
# Each batch has five cuts
169+
for b in batches:
170+
assert len(b) == 5
171+
172+
173+
def test_dynamic_bucketing_sampler_too_small_data_drop_last_true_results_in_no_batches():
174+
cuts = DummyManifest(CutSet, begin_id=0, end_id=10)
175+
for i, c in enumerate(cuts):
176+
if i < 5:
177+
c.duration = 1
178+
else:
179+
c.duration = 2
180+
181+
# 10 cuts with 30s total are not enough to satisfy max_duration of 100 with 2 buckets
182+
sampler = DynamicBucketingSampler(
183+
cuts, max_duration=100, num_buckets=2, seed=0, drop_last=True
184+
)
185+
batches = [b for b in sampler]
186+
assert len(batches) == 0
187+
188+
146189
def test_dynamic_bucketing_sampler_filter():
147190
cuts = DummyManifest(CutSet, begin_id=0, end_id=10)
148191
for i, c in enumerate(cuts):

0 commit comments

Comments
 (0)