@@ -2098,7 +2098,7 @@ def set(self) -> None:
20982098
20992099 def __exit__ (self , exc_type : Any , exc_value : Any , traceback : Any ) -> None :
21002100 global _LAZY_OP
2101- _LAZY_OP = bool ( self ._old_mode )
2101+ _LAZY_OP = self ._old_mode
21022102 os .environ ["LAZY_LEGACY_OP" ] = str (_LAZY_OP )
21032103
21042104
@@ -2121,6 +2121,92 @@ def _legacy_lazy(func):
21212121 return func
21222122
21232123
2124+ # non tensor stack control
2125+ _DEFAULT_CAPTURE_NONTENSOR_STACK = True
2126+ _CAPTURE_NONTENSOR_STACK = os .environ .get ("CAPTURE_NONTENSOR_STACK" )
2127+
2128+
2129+ class set_capture_non_tensor_stack (_DecoratorContextManager ):
2130+ """A context manager or decorator to control whether identical non-tensor data should be stacked into a single NonTensorData object or a NonTensorStack.
2131+
2132+ Args:
2133+ mode (bool): Whether to capture non-tensor stacks. If ``False``, identical
2134+ non-tensor data will be stacked into a :class:`~tensordict.NonTensorStack`. If ``True``,
2135+ a single :class:`~tensordict.NonTensorData` object will contain the unique value, but with the desired batch-size.
2136+ Defaults to ``True``.
2137+
2138+ .. note:: Until v0.9, this will raise a warning if the same value is encountered and the value is not set
2139+ explicitly (`capture_non_tensor_stack() = True` default behavior).
2140+ You can set the value of :func:`~tensordict.capture_non_tensor_stack` through:
2141+
2142+ - The ``CAPTURE_NON_TENSOR_STACK`` environment variable;
2143+ - By setting ``set_capture_non_tensor_stack(val: bool).set()`` at the beginning of your script;
2144+ - By using ``set_capture_non_tensor_stack(val: bool)`` as a context manager or a decorator.
2145+
2146+ It is recommended to use the `set_capture_non_tensor_stack(False)` behavior.
2147+
2148+ .. seealso:: :class:`~tensordict.capture_non_tensor_stack`
2149+
2150+ Examples:
2151+ >>> with set_capture_non_tensor_stack(False):
2152+ ... torch.stack([NonTensorData("a"), NonTensorData("a")])
2153+ NonTensorData("a", batch_size=[2])
2154+ >>> @set_capture_non_tensor_stack(False)
2155+ ... def my_function():
2156+ ... return torch.stack([NonTensorData("a"), NonTensorData("a")])
2157+ >>> my_function()
2158+ NonTensorStack(["a", "a"], stack_dim=0)
2159+ """
2160+
2161+ def __init__ (self , mode : bool ) -> None :
2162+ super ().__init__ ()
2163+ self .mode = mode
2164+
2165+ def clone (self ) -> set_capture_non_tensor_stack :
2166+ # override this method if your children class takes __init__ parameters
2167+ return type (self )(self .mode )
2168+
2169+ def __enter__ (self ) -> None :
2170+ self .set ()
2171+
2172+ def set (self ) -> None :
2173+ global _CAPTURE_NONTENSOR_STACK
2174+ self ._old_mode = _CAPTURE_NONTENSOR_STACK
2175+ _CAPTURE_NONTENSOR_STACK = bool (self .mode )
2176+ # we do this such that sub-processes see the same lazy op than the main one
2177+ os .environ ["CAPTURE_NONTENSOR_STACK" ] = str (_CAPTURE_NONTENSOR_STACK )
2178+
2179+ def __exit__ (self , exc_type : Any , exc_value : Any , traceback : Any ) -> None :
2180+ global _CAPTURE_NONTENSOR_STACK
2181+ _CAPTURE_NONTENSOR_STACK = self ._old_mode
2182+ os .environ ["CAPTURE_NONTENSOR_STACK" ] = str (_CAPTURE_NONTENSOR_STACK )
2183+
2184+
2185+ def capture_non_tensor_stack (allow_none = False ):
2186+ """Get the current setting for capturing non-tensor stacks.
2187+
2188+ Args:
2189+ allow_none (bool, optional): If ``True``, returns ``None`` if no setting has been
2190+ specified. Otherwise, returns the default setting. Defaults to ``False``.
2191+
2192+ seealso: :func:`~tensordict.set_capture_non_tensor_stack`
2193+
2194+ Returns:
2195+ bool or None: The current setting for capturing non-tensor stacks.
2196+
2197+ """
2198+ global _CAPTURE_NONTENSOR_STACK
2199+ if _CAPTURE_NONTENSOR_STACK is None and allow_none :
2200+ return None
2201+ elif _CAPTURE_NONTENSOR_STACK is None :
2202+ return _DEFAULT_CAPTURE_NONTENSOR_STACK
2203+ return (
2204+ strtobool (_CAPTURE_NONTENSOR_STACK )
2205+ if isinstance (_CAPTURE_NONTENSOR_STACK , str )
2206+ else _CAPTURE_NONTENSOR_STACK
2207+ )
2208+
2209+
21242210# Process initializer for map
21252211def _proc_init (base_seed , queue , num_threads ):
21262212 worker_id = queue .get (timeout = 120 )
0 commit comments