Skip to content

Commit 7b0fd93

Browse files
author
Vincent Moens
committed
[Feature] capture_non_tensor_stack
ghstack-source-id: 8667892 Pull Request resolved: #1221
1 parent 3ec3be7 commit 7b0fd93

File tree

7 files changed

+233
-42
lines changed

7 files changed

+233
-42
lines changed

docs/source/reference/tensordict.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,13 +231,15 @@ Utils
231231
utils.expand_right
232232
utils.isin
233233
utils.remove_duplicates
234+
capture_non_tensor_stack
235+
dense_stack_tds
234236
is_batchedtensor
235237
is_tensor_collection
238+
lazy_legacy
236239
make_tensordict
237240
merge_tensordicts
238241
pad
239242
pad_sequence
240-
dense_stack_tds
241-
set_lazy_legacy
242-
lazy_legacy
243243
parse_tensor_dict_string
244+
set_capture_non_tensor_stack
245+
set_lazy_legacy

tensordict/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,13 @@
5757
from tensordict.utils import (
5858
assert_allclose_td,
5959
assert_close,
60+
capture_non_tensor_stack,
6061
is_batchedtensor,
6162
is_non_tensor,
6263
is_tensorclass,
6364
lazy_legacy,
6465
parse_tensor_dict_string,
66+
set_capture_non_tensor_stack,
6567
set_lazy_legacy,
6668
unravel_key,
6769
unravel_key_list,

tensordict/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
_unravel_key_to_tuple,
8383
_zip_strict,
8484
cache,
85+
capture_non_tensor_stack,
8586
convert_ellipsis_to_idx,
8687
DeviceType,
8788
erase_cache,
@@ -6229,6 +6230,10 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT):
62296230
value = self._get_str(key, default=default)
62306231

62316232
if is_non_tensor(value):
6233+
from tensordict import NonTensorStack
6234+
6235+
if isinstance(value, NonTensorStack) and not capture_non_tensor_stack():
6236+
return value.tolist()
62326237
data = getattr(value, "data", None)
62336238
if data is None:
62346239
return value.tolist()

tensordict/tensorclass.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@
5656
_TENSORCLASS_MEMO,
5757
_unravel_key_to_tuple,
5858
_zip_strict,
59+
capture_non_tensor_stack,
5960
DeviceType,
6061
IndexType,
6162
is_tensorclass,
6263
KeyDependentDefaultDict,
64+
set_capture_non_tensor_stack,
6365
)
6466
from torch import multiprocessing as mp, Tensor
6567
from torch.multiprocessing import Manager
@@ -3219,7 +3221,7 @@ def _stack_non_tensor(cls, list_of_non_tensor, dim=0, raise_if_non_unique=False)
32193221

32203222
ids = set()
32213223
firstdata = NO_DEFAULT
3222-
return_stack = False
3224+
return_stack = not capture_non_tensor_stack()
32233225
for data in list_of_non_tensor:
32243226
if not isinstance(data, NonTensorData):
32253227
if raise_if_non_unique:
@@ -3242,8 +3244,18 @@ def _stack_non_tensor(cls, list_of_non_tensor, dim=0, raise_if_non_unique=False)
32423244
return_stack = True
32433245
break
32443246
else:
3245-
return_stack = False
3247+
return_stack = not capture_non_tensor_stack()
32463248
if not return_stack:
3249+
if capture_non_tensor_stack(allow_none=True) is None:
3250+
warnings.warn(
3251+
"The default behavior of stacking non-tensor data will change in "
3252+
"version v0.9 and switch from True to False (current default). "
3253+
"To prepare for this change, use set_capture_non_tensor_stack(val: bool) as a decorator or context "
3254+
"manager, or set the environment variable CAPTURE_NONTENSOR_STACK "
3255+
"to 'False'.",
3256+
FutureWarning,
3257+
stacklevel=2,
3258+
)
32473259
batch_size = list(first.batch_size)
32483260
batch_size.insert(dim, len(list_of_non_tensor))
32493261
return NonTensorData(
@@ -3772,9 +3784,13 @@ def data(self):
37723784
Raises a ValueError if there is more than one unique value.
37733785
"""
37743786
try:
3775-
return NonTensorData._stack_non_tensor(
3776-
self.tensordicts, raise_if_non_unique=True
3777-
).data
3787+
with set_capture_non_tensor_stack(True):
3788+
nt = NonTensorData._stack_non_tensor(
3789+
self.tensordicts, raise_if_non_unique=True
3790+
)
3791+
if not isinstance(nt, NonTensorData):
3792+
raise ValueError
3793+
return nt.data
37783794
except ValueError:
37793795
raise AttributeError(
37803796
"Cannot get the non-unique data of a NonTensorStack. Use .tolist() instead."

tensordict/utils.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2098,7 +2098,7 @@ def set(self) -> None:
20982098

20992099
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
21002100
global _LAZY_OP
2101-
_LAZY_OP = bool(self._old_mode)
2101+
_LAZY_OP = self._old_mode
21022102
os.environ["LAZY_LEGACY_OP"] = str(_LAZY_OP)
21032103

21042104

@@ -2121,6 +2121,92 @@ def _legacy_lazy(func):
21212121
return func
21222122

21232123

2124+
# non tensor stack control
2125+
_DEFAULT_CAPTURE_NONTENSOR_STACK = True
2126+
_CAPTURE_NONTENSOR_STACK = os.environ.get("CAPTURE_NONTENSOR_STACK")
2127+
2128+
2129+
class set_capture_non_tensor_stack(_DecoratorContextManager):
2130+
"""A context manager or decorator to control whether identical non-tensor data should be stacked into a single NonTensorData object or a NonTensorStack.
2131+
2132+
Args:
2133+
mode (bool): Whether to capture non-tensor stacks. If ``False``, identical
2134+
non-tensor data will be stacked into a :class:`~tensordict.NonTensorStack`. If ``True``,
2135+
a single :class:`~tensordict.NonTensorData` object will contain the unique value, but with the desired batch-size.
2136+
Defaults to ``True``.
2137+
2138+
.. note:: Until v0.9, this will raise a warning if the same value is encountered and the value is not set
2139+
explicitly (`capture_non_tensor_stack() = True` default behavior).
2140+
You can set the value of :func:`~tensordict.capture_non_tensor_stack` through:
2141+
2142+
- The ``CAPTURE_NON_TENSOR_STACK`` environment variable;
2143+
- By setting ``set_capture_non_tensor_stack(val: bool).set()`` at the beginning of your script;
2144+
- By using ``set_capture_non_tensor_stack(val: bool)`` as a context manager or a decorator.
2145+
2146+
It is recommended to use the `set_capture_non_tensor_stack(False)` behavior.
2147+
2148+
.. seealso:: :class:`~tensordict.capture_non_tensor_stack`
2149+
2150+
Examples:
2151+
>>> with set_capture_non_tensor_stack(False):
2152+
... torch.stack([NonTensorData("a"), NonTensorData("a")])
2153+
NonTensorData("a", batch_size=[2])
2154+
>>> @set_capture_non_tensor_stack(False)
2155+
... def my_function():
2156+
... return torch.stack([NonTensorData("a"), NonTensorData("a")])
2157+
>>> my_function()
2158+
NonTensorStack(["a", "a"], stack_dim=0)
2159+
"""
2160+
2161+
def __init__(self, mode: bool) -> None:
2162+
super().__init__()
2163+
self.mode = mode
2164+
2165+
def clone(self) -> set_capture_non_tensor_stack:
2166+
# override this method if your children class takes __init__ parameters
2167+
return type(self)(self.mode)
2168+
2169+
def __enter__(self) -> None:
2170+
self.set()
2171+
2172+
def set(self) -> None:
2173+
global _CAPTURE_NONTENSOR_STACK
2174+
self._old_mode = _CAPTURE_NONTENSOR_STACK
2175+
_CAPTURE_NONTENSOR_STACK = bool(self.mode)
2176+
# we do this such that sub-processes see the same lazy op than the main one
2177+
os.environ["CAPTURE_NONTENSOR_STACK"] = str(_CAPTURE_NONTENSOR_STACK)
2178+
2179+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
2180+
global _CAPTURE_NONTENSOR_STACK
2181+
_CAPTURE_NONTENSOR_STACK = self._old_mode
2182+
os.environ["CAPTURE_NONTENSOR_STACK"] = str(_CAPTURE_NONTENSOR_STACK)
2183+
2184+
2185+
def capture_non_tensor_stack(allow_none=False):
2186+
"""Get the current setting for capturing non-tensor stacks.
2187+
2188+
Args:
2189+
allow_none (bool, optional): If ``True``, returns ``None`` if no setting has been
2190+
specified. Otherwise, returns the default setting. Defaults to ``False``.
2191+
2192+
seealso: :func:`~tensordict.set_capture_non_tensor_stack`
2193+
2194+
Returns:
2195+
bool or None: The current setting for capturing non-tensor stacks.
2196+
2197+
"""
2198+
global _CAPTURE_NONTENSOR_STACK
2199+
if _CAPTURE_NONTENSOR_STACK is None and allow_none:
2200+
return None
2201+
elif _CAPTURE_NONTENSOR_STACK is None:
2202+
return _DEFAULT_CAPTURE_NONTENSOR_STACK
2203+
return (
2204+
strtobool(_CAPTURE_NONTENSOR_STACK)
2205+
if isinstance(_CAPTURE_NONTENSOR_STACK, str)
2206+
else _CAPTURE_NONTENSOR_STACK
2207+
)
2208+
2209+
21242210
# Process initializer for map
21252211
def _proc_init(base_seed, queue, num_threads):
21262212
worker_id = queue.get(timeout=120)

test/test_tensorclass.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,6 @@
2424
import pytest
2525
import tensordict.utils
2626
import torch
27-
from tensordict import TensorClass
28-
from tensordict.tensorclass import from_dataclass
29-
30-
try:
31-
import torchsnapshot
32-
33-
_has_torchsnapshot = True
34-
TORCHSNAPSHOT_ERR = ""
35-
except ImportError as err:
36-
_has_torchsnapshot = False
37-
TORCHSNAPSHOT_ERR = str(err)
3827

3928
from _utils_internal import get_available_devices
4029

@@ -44,14 +33,26 @@
4433
lazy_legacy,
4534
LazyStackedTensorDict,
4635
MemoryMappedTensor,
36+
set_capture_non_tensor_stack,
4737
tensorclass,
38+
TensorClass,
4839
TensorDict,
4940
TensorDictBase,
5041
)
5142
from tensordict._lazy import _PermutedTensorDict, _ViewedTensorDict
5243
from tensordict.base import _GENERIC_NESTED_ERR
44+
from tensordict.tensorclass import from_dataclass
5345
from torch import Tensor
5446

47+
try:
48+
import torchsnapshot
49+
50+
_has_torchsnapshot = True
51+
TORCHSNAPSHOT_ERR = ""
52+
except ImportError as err:
53+
_has_torchsnapshot = False
54+
TORCHSNAPSHOT_ERR = str(err)
55+
5556
# Capture all warnings
5657
pytestmark = [
5758
pytest.mark.filterwarnings("error"),
@@ -381,7 +382,8 @@ class MyData:
381382
data3 = MyData(D, B, A, C=C, E=E, batch_size=[3, 4])
382383
data4 = MyData(D, B, A, C, E=E, batch_size=[3, 4])
383384
data5 = MyData(D, B, A, C, E, batch_size=[3, 4])
384-
data = torch.stack([data1, data2, data3, data4, data5], 0)
385+
with set_capture_non_tensor_stack(True):
386+
data = torch.stack([data1, data2, data3, data4, data5], 0)
385387
assert (data.A == A).all()
386388
assert (data.B == B).all()
387389
assert (data.C == C).all()
@@ -1857,7 +1859,8 @@ class MyDataNested:
18571859
if lazy:
18581860
stacked_tc = LazyStackedTensorDict.lazy_stack([data1, data2], 0)
18591861
else:
1860-
stacked_tc = torch.stack([data1, data2], 0)
1862+
with set_capture_non_tensor_stack(True):
1863+
stacked_tc = torch.stack([data1, data2], 0)
18611864
assert type(stacked_tc) is type(data1)
18621865
assert isinstance(stacked_tc.y, type(data1.y))
18631866
assert stacked_tc.X.shape == torch.Size([2, 3, 4, 5])
@@ -2145,7 +2148,8 @@ def z(self) -> torch.Tensor:
21452148
y1 = Y(weakref.ref(obj), batch_size=[1])
21462149
y = torch.cat([y0, y1])
21472150
assert y.z.shape == torch.Size(())
2148-
y = torch.stack([y0, y1])
2151+
with set_capture_non_tensor_stack(True):
2152+
y = torch.stack([y0, y1])
21492153
assert y.z.shape == torch.Size(())
21502154

21512155

@@ -2253,9 +2257,13 @@ class TensorClass:
22532257
def get_nested(self):
22542258
c = self.TensorClass(torch.ones(1), ("a", "b", "c"), "Hello", batch_size=[])
22552259

2256-
td = torch.stack(
2257-
[TensorDict({"t": torch.ones(1), "c": c}, batch_size=[]) for _ in range(3)]
2258-
)
2260+
with set_capture_non_tensor_stack(True):
2261+
td = torch.stack(
2262+
[
2263+
TensorDict({"t": torch.ones(1), "c": c}, batch_size=[])
2264+
for _ in range(3)
2265+
]
2266+
)
22592267
return td
22602268

22612269
def test_apply(self):

0 commit comments

Comments
 (0)