From a83fd39a0655e2cc324ec20a2ee2649d33504388 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 18 Jun 2026 16:37:46 +0800 Subject: [PATCH 1/2] fix --- src/twinkle/dataloader/device_mesh_fetcher.py | 7 +++---- src/twinkle/dataloader/device_mesh_sampler.py | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/twinkle/dataloader/device_mesh_fetcher.py b/src/twinkle/dataloader/device_mesh_fetcher.py index f545f7018..65c8e8311 100644 --- a/src/twinkle/dataloader/device_mesh_fetcher.py +++ b/src/twinkle/dataloader/device_mesh_fetcher.py @@ -77,9 +77,8 @@ def fetch(self, _): else: data = next(self.dataset_iter) + if self.min_batch_size is not None and len(data) < self.min_batch_size: + raise StopIteration if self.device_mesh: - if len(data) < self.min_batch_size: - raise StopIteration - else: - data = data[self.device_mesh.get_slice(len(data))] + data = data[self.device_mesh.get_slice(len(data))] return self.collate_fn(data) diff --git a/src/twinkle/dataloader/device_mesh_sampler.py b/src/twinkle/dataloader/device_mesh_sampler.py index 1f649de37..305ebf9c5 100644 --- a/src/twinkle/dataloader/device_mesh_sampler.py +++ b/src/twinkle/dataloader/device_mesh_sampler.py @@ -34,13 +34,12 @@ def __iter__(self): batch = batch[self.skip_samples - skipped:] skipped = self.skip_samples + if self.min_batch_size is not None and len(batch) < self.min_batch_size: + return if not self.device_mesh: yield batch else: - if len(batch) < self.min_batch_size: - return - else: - yield batch[self.device_mesh.get_slice(len(batch))] + yield batch[self.device_mesh.get_slice(len(batch))] def __len__(self): return len(self.original_sampler) From 097a716f4b1f5248d6ec5e1287fed17ac020bfe9 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 18 Jun 2026 16:48:45 +0800 Subject: [PATCH 2/2] fix --- src/twinkle/dataloader/dataloader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py index c392d56cf..975a2b924 100644 --- a/src/twinkle/dataloader/dataloader.py +++ b/src/twinkle/dataloader/dataloader.py @@ -134,6 +134,7 @@ def __iter__(self): _iter._dataset_fetcher.drop_last, self.batch_size, self.device_mesh, + min_batch_size=self.min_batch_size, max_retries=self.max_retries) return self._tracking_iter(_iter)