@@ -1790,6 +1790,46 @@ def split(
17901790 )
17911791 return result
17921792
1793+ def chunk (self , chunks : int , dim : int = 0 ) -> tuple [TensorDictBase , ...]:
1794+ if chunks < 1 :
1795+ raise ValueError (
1796+ f"chunks must be a strictly positive integer, got { chunks } ."
1797+ )
1798+ # fall back on split, using upper rounding
1799+ batch_size = self .batch_size
1800+ dim = _maybe_correct_neg_dim (dim , batch_size )
1801+ max_size = batch_size [dim ]
1802+ split_size = - (max_size // - chunks )
1803+ segments = _create_segments_from_int (split_size , max_size )
1804+ splits = {k : v .chunk (chunks , dim ) for k , v in self .items ()}
1805+ names = self ._maybe_names ()
1806+ batch_sizes = [
1807+ torch .Size (
1808+ tuple (d if i != dim else end - start for i , d in enumerate (batch_size ))
1809+ )
1810+ for start , end in segments
1811+ ]
1812+ splits = [
1813+ {k : v [ss ] for k , v in splits .items ()} for ss in range (len (batch_sizes ))
1814+ ]
1815+ device = self .device
1816+ is_shared = self ._is_shared
1817+ is_memmap = self ._is_memmap
1818+ is_locked = self .is_locked
1819+ result = tuple (
1820+ self ._new_unsafe (
1821+ source = split ,
1822+ batch_size = bsz ,
1823+ names = names ,
1824+ device = device ,
1825+ lock = is_locked ,
1826+ is_shared = is_shared ,
1827+ is_memmap = is_memmap ,
1828+ )
1829+ for split , bsz in _zip_strict (splits , batch_sizes )
1830+ )
1831+ return result
1832+
17931833 def masked_select (self , mask : Tensor ) -> T :
17941834 d = {}
17951835 mask_expand = mask
@@ -4350,6 +4390,10 @@ def _cast_reduction(
43504390 reshape = TensorDict .reshape
43514391 split = TensorDict .split
43524392
4393+ def chunk (self , chunks : int , dim : int = 0 ) -> tuple [TensorDictBase , ...]:
4394+ splits = - (self .batch_size [dim ] // - chunks )
4395+ return self .split (splits , dim )
4396+
43534397 def _view (self , * args , ** kwargs ):
43544398 raise RuntimeError (
43554399 "Cannot call `view` on a sub-tensordict. Call `reshape` instead."
0 commit comments