Skip to content

Commit eaa0248

Browse files
committed
[BugFix] Fix chunk following split fix (#1377)
1 parent 3cfe770 commit eaa0248

File tree

4 files changed

+58
-7
lines changed

4 files changed

+58
-7
lines changed

tensordict/_lazy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3740,6 +3740,10 @@ def iter_across_tds():
37403740
for tds in _zip_strict(*tds)
37413741
)
37423742

3743+
def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]:
3744+
splits = -(self.batch_size[dim] // -chunks)
3745+
return self.split(splits, dim)
3746+
37433747
lock_ = TensorDictBase.lock_
37443748
lock = _renamed_inplace_method(lock_)
37453749

@@ -4410,6 +4414,10 @@ def _cast_reduction(
44104414
**kwargs,
44114415
)
44124416

4417+
def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]:
4418+
splits = -(self.batch_size[dim] // -chunks)
4419+
return self.split(splits, dim)
4420+
44134421
__xor__ = TensorDict.__xor__
44144422
__or__ = TensorDict.__or__
44154423
__eq__ = TensorDict.__eq__

tensordict/_td.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,6 +1790,46 @@ def split(
17901790
)
17911791
return result
17921792

1793+
def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]:
1794+
if chunks < 1:
1795+
raise ValueError(
1796+
f"chunks must be a strictly positive integer, got {chunks}."
1797+
)
1798+
# fall back on split, using upper rounding
1799+
batch_size = self.batch_size
1800+
dim = _maybe_correct_neg_dim(dim, batch_size)
1801+
max_size = batch_size[dim]
1802+
split_size = -(max_size // -chunks)
1803+
segments = _create_segments_from_int(split_size, max_size)
1804+
splits = {k: v.chunk(chunks, dim) for k, v in self.items()}
1805+
names = self._maybe_names()
1806+
batch_sizes = [
1807+
torch.Size(
1808+
tuple(d if i != dim else end - start for i, d in enumerate(batch_size))
1809+
)
1810+
for start, end in segments
1811+
]
1812+
splits = [
1813+
{k: v[ss] for k, v in splits.items()} for ss in range(len(batch_sizes))
1814+
]
1815+
device = self.device
1816+
is_shared = self._is_shared
1817+
is_memmap = self._is_memmap
1818+
is_locked = self.is_locked
1819+
result = tuple(
1820+
self._new_unsafe(
1821+
source=split,
1822+
batch_size=bsz,
1823+
names=names,
1824+
device=device,
1825+
lock=is_locked,
1826+
is_shared=is_shared,
1827+
is_memmap=is_memmap,
1828+
)
1829+
for split, bsz in _zip_strict(splits, batch_sizes)
1830+
)
1831+
return result
1832+
17931833
def masked_select(self, mask: Tensor) -> T:
17941834
d = {}
17951835
mask_expand = mask
@@ -4350,6 +4390,10 @@ def _cast_reduction(
43504390
reshape = TensorDict.reshape
43514391
split = TensorDict.split
43524392

4393+
def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]:
4394+
splits = -(self.batch_size[dim] // -chunks)
4395+
return self.split(splits, dim)
4396+
43534397
def _view(self, *args, **kwargs):
43544398
raise RuntimeError(
43554399
"Cannot call `view` on a sub-tensordict. Call `reshape` instead."

tensordict/base.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3419,6 +3419,7 @@ def unbind(self, dim: int) -> tuple[T, ...]:
34193419
def _unbind(self, dim: int) -> tuple[T, ...]:
34203420
raise NotImplementedError
34213421

3422+
@abc.abstractmethod
34223423
def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]:
34233424
"""Splits a tensordict into the specified number of chunks, if possible.
34243425

@@ -3443,13 +3444,7 @@ def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]:
34433444
[18, 19]]])
34443445

34453446
"""
3446-
if chunks < 1:
3447-
raise ValueError(
3448-
f"chunks must be a strictly positive integer, got {chunks}."
3449-
)
3450-
# fall back on split, using upper rounding
3451-
split_size = -(self.batch_size[dim] // -chunks)
3452-
return self.split(split_size, dim=dim)
3447+
raise NotImplementedError
34533448

34543449
@overload
34553450
def unsqueeze(self, dim: int) -> T: ...

tensordict/persistent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,6 +1429,10 @@ def _unsqueeze(self, dim: int):
14291429
"Cannot call `unsqueeze` on a persistent tensordict. Make it dense before calling this method by calling `to_tensordict`."
14301430
)
14311431

1432+
def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]:
1433+
splits = -(self.batch_size[dim] // -chunks)
1434+
return self.split(splits, dim)
1435+
14321436
__eq__ = TensorDict.__eq__
14331437
__ne__ = TensorDict.__ne__
14341438
__xor__ = TensorDict.__xor__

0 commit comments

Comments
 (0)