Skip to content

Commit 0b901a7

Browse files
author
Vincent Moens
committed
[BugFix] Consolidate lazy stacks of non-tensors
ghstack-source-id: afb1480 Pull Request resolved: #1222
1 parent 7b0fd93 commit 0b901a7

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

tensordict/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5043,7 +5043,8 @@ def assign(
50435043
cls = type(value)
50445044
if issubclass(cls, torch.Tensor):
50455045
pass
5046-
elif _is_non_tensor(cls):
5046+
# We want to skip NonTensorStacks
5047+
elif _is_non_tensor(cls) and not issubclass(cls, TensorDictBase):
50475048
if requires_metadata:
50485049
metadata_dict["non_tensors"][key] = (
50495050
value.data,
@@ -5411,7 +5412,8 @@ def _view_and_pad(tensor):
54115412
if non_blocking and device.type != "cuda":
54125413
# sync if needed
54135414
self._sync_all()
5414-
torch.cat(items, out=storage)
5415+
if items:
5416+
torch.cat(items, out=storage)
54155417
for v, (k, oldv) in _zip_strict(
54165418
storage.split(flat_size), list(flat_dict.items())
54175419
):

test/test_tensordict.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11404,6 +11404,17 @@ def test_stack(self, non_tensor_data):
1140411404
LazyStackedTensorDict,
1140511405
)
1140611406

11407+
def test_stack_consolidate(self):
11408+
td = torch.stack(
11409+
[
11410+
TensorDict(a="a string", b="b string"),
11411+
TensorDict(a="another string", b="bnother string"),
11412+
]
11413+
)
11414+
tdc = td.consolidate()
11415+
assert (tdc == td).all()
11416+
assert tdc["a"] == ["a string", "another string"]
11417+
1140711418
def test_assign_non_tensor(self):
1140811419
data = TensorDict({}, [1, 10])
1140911420

0 commit comments

Comments
 (0)