@@ -143,6 +143,49 @@ def test_dynamic_bucketing_sampler():
143143 assert sum (c .duration for c in batches [3 ]) == 2
144144
145145
146+ def test_dynamic_bucketing_sampler_too_small_data_can_be_sampled ():
147+ cuts = DummyManifest (CutSet , begin_id = 0 , end_id = 10 )
148+ for i , c in enumerate (cuts ):
149+ if i < 5 :
150+ c .duration = 1
151+ else :
152+ c .duration = 2
153+
154+ # 10 cuts with 30s total are not enough to satisfy max_duration of 100 with 2 buckets
155+ sampler = DynamicBucketingSampler (cuts , max_duration = 100 , num_buckets = 2 , seed = 0 )
156+ batches = [b for b in sampler ]
157+ sampled_cuts = [c for b in batches for c in b ]
158+
159+ # Invariant: no duplicated cut IDs
160+ assert len (set (c .id for b in batches for c in b )) == len (sampled_cuts )
161+
162+ # Same number of sampled and source cuts.
163+ assert len (sampled_cuts ) == len (cuts )
164+
165+ # We sampled 10 batches
166+ assert len (batches ) == 2
167+
168+ # Each batch has five cuts
169+ for b in batches :
170+ assert len (b ) == 5
171+
172+
173+ def test_dynamic_bucketing_sampler_too_small_data_drop_last_true_results_in_no_batches ():
174+ cuts = DummyManifest (CutSet , begin_id = 0 , end_id = 10 )
175+ for i , c in enumerate (cuts ):
176+ if i < 5 :
177+ c .duration = 1
178+ else :
179+ c .duration = 2
180+
181+ # 10 cuts with 30s total are not enough to satisfy max_duration of 100 with 2 buckets
182+ sampler = DynamicBucketingSampler (
183+ cuts , max_duration = 100 , num_buckets = 2 , seed = 0 , drop_last = True
184+ )
185+ batches = [b for b in sampler ]
186+ assert len (batches ) == 0
187+
188+
146189def test_dynamic_bucketing_sampler_filter ():
147190 cuts = DummyManifest (CutSet , begin_id = 0 , end_id = 10 )
148191 for i , c in enumerate (cuts ):
0 commit comments