File tree Expand file tree Collapse file tree 2 files changed +15
-2
lines changed
Expand file tree Collapse file tree 2 files changed +15
-2
lines changed Original file line number Diff line number Diff line change @@ -5043,7 +5043,8 @@ def assign(
50435043 cls = type(value)
50445044 if issubclass(cls, torch.Tensor):
50455045 pass
5046- elif _is_non_tensor(cls):
5046+ # We want to skip NonTensorStacks
5047+ elif _is_non_tensor(cls) and not issubclass(cls, TensorDictBase):
50475048 if requires_metadata:
50485049 metadata_dict["non_tensors"][key] = (
50495050 value.data,
@@ -5411,7 +5412,8 @@ def _view_and_pad(tensor):
54115412 if non_blocking and device.type != "cuda":
54125413 # sync if needed
54135414 self._sync_all()
5414- torch.cat(items, out=storage)
5415+ if items:
5416+ torch.cat(items, out=storage)
54155417 for v, (k, oldv) in _zip_strict(
54165418 storage.split(flat_size), list(flat_dict.items())
54175419 ):
Original file line number Diff line number Diff line change @@ -11404,6 +11404,17 @@ def test_stack(self, non_tensor_data):
1140411404 LazyStackedTensorDict ,
1140511405 )
1140611406
11407+ def test_stack_consolidate (self ):
11408+ td = torch .stack (
11409+ [
11410+ TensorDict (a = "a string" , b = "b string" ),
11411+ TensorDict (a = "another string" , b = "bnother string" ),
11412+ ]
11413+ )
11414+ tdc = td .consolidate ()
11415+ assert (tdc == td ).all ()
11416+ assert tdc ["a" ] == ["a string" , "another string" ]
11417+
1140711418 def test_assign_non_tensor (self ):
1140811419 data = TensorDict ({}, [1 , 10 ])
1140911420
You can’t perform that action at this time.
0 commit comments