v0.5.0: `consolidate`, compile compatibility and better non-tensor support
This release is packed with new features and performance improvements.
What's new
TensorDict.consolidate
There is now a TensorDict.consolidate method that will put all the tensors in a single storage. This will greatly speed-up serialization in multiprocessed and distributed settings.
PT2 support
TensorDict common ops (get, set, index, arithmetic ops etc) now work within torch.compile.
The list of supported operations can be found in test/test_compile.py. We encourage users to report any graph break caused by tensordict to us, as we are willing to improve the coverage as much as can be.
Python 3.12 support
#807 enables python 3.12 support, a long awaited feature!
Global reduction for mean, std and other reduction methods
It is now possible to get the grand average of a tensordict content using tensordict.mean(reduce=True).
This applies to mean, nanmean, prod, std, sum, nansum and var.
from_pytree and to_pytree
We made it easy to convert a tensordict to a given pytree structure and build it from any pytree using to_pytree and from_pytree. #832
Similarly, conversion to namedtuple is now made easy thanks to #788.
map_iter
One can now iterate through a TensorDIct batch-dimension and apply a function on a separate process thanks to map_iter.
This should enable the construction of datasets using TensorDict, where the preproc step is executed on a separate process. #847
Using flatten and unflatten, flatten_keys and unflatten_keys as context managers
It is not possible to use flatten_keys and flatten as context managers (#908, #779):
with tensordict.flatten_keys() as flat_td:
flat_td["flat.key"] = 0
assert td["flat", "key"] == 0Building a tensordict using keyword arguments
We made it easy to build tensordicts with simple keyword arguments, like a dict is built in python:
td = TensorDict(a=0, b=1)
assert td["a"] == torch.tensor(0)
assert td["b"] == torch.tensor(1)The batch_size is now optional for both tensordict and tensorclasses. #905
Load tensordicts directly on device
Thanks to #769, it is now possible to load a tensordict directly on a destination device (including "meta" device):
td = TensorDict.load(path, device=device)New features
- [Feature,Performance]
to(device, pin_memory, num_threads)by @vmoens in #846 - [Feature] Allow calls to get_mode, get_mean and get_median in case mode, mean or median is not present by @vmoens in #804
- [Feature] Arithmetic ops for tensorclass by @vmoens in #786
- [Feature] Best attempt to densly stack sub-tds when LazyStacked TDS are passed to maybe_dense_stack by @vmoens in #799
- [Feature] Better dtype coverage by @vmoens in #834
- [Feature] Change default interaction types to DETERMINISTIC by @vmoens in #825
- [Feature] DETERMINISTIC interaction mode by @vmoens in #824
- [Feature] Expose call_on_nested to apply and named_apply by @vmoens in #768
- [Feature] Expose stack / cat as class methods by @vmoens in #793
- [Feature] Load tensordicts on device, incl. meta by @vmoens in #769
- [Feature] Make Probabilistic modules aware of CompositeDistributions out_keys by @vmoens in #810
- [Feature] Memory-mapped nested tensors by @vmoens in #618
- [Feature] Multithreaded apply by @vmoens in #844
- [Feature] Multithreaded pin_memory by @vmoens in #845
- [Feature] Support for non tensor data in h5 by @vmoens in #772
- [Feature] TensorDict.consolidate by @vmoens in #814
- [Feature] TensorDict.numpy() by @vmoens in #787
- [Feature] TensorDict.replace by @vmoens in #774
- [Feature]
outargument in apply by @vmoens in #794 - [Feature]
tofor consolidated TDs by @vmoens in #851 - [Feature]
zero_gradandrequires_grad_by @vmoens in #901 - [Feature] add_custom_mapping and NPE refactors by @vmoens in #910
- [Feature] construct tds with kwargs by @vmoens in #905
- [Feature] determinstic_sample for composite dist by @vmoens in #827
- [Feature] expand_as by @vmoens in #792
- [Feature] flatten and unflatten as decorators by @vmoens in #779
- [Feature] from and to_pytree by @vmoens in #832
- [Feature] from_modules expand_identical kwarg by @vmoens in #911
- [Feature] grad and data for tensorclasses by @vmoens in #904
- [Feature] isfinite, isnan, isreal by @vmoens in #829
- [Feature] map_iter by @vmoens in #847
- [Feature] map_names for composite dists by @vmoens in #809
- [Feature] online edition of memory mapped tensordicts by @vmoens in #775
- [Feature] remove distutils dependency and enable 3.12 support by @GaetanLepage in #807
- [Feature] to_namedtuple and from_namedtuple by @vmoens in #788
- [Feature] view(dtype) by @vmoens in #835
Performance
- [Performance] Faster getattr in TC by @vmoens in #912
- [Performance] Faster lock_/unclock_ when sub-tds are already locked by @vmoens in #816
- [Performance] Faster multithreaded pin_memory by @vmoens in #919
- [Performance] Faster tensorclass by @vmoens in #791
- [Performance] Faster tensorclass set by @vmoens in #880
- [Performance] Faster to-module by @vmoens in #914
Bug Fixes
- [BugFix,CI] Fix storage filename tests by @vmoens in #850
- [BugFix] @Property setter in tensorclass by @vmoens in #813
- [BugFix] Allow any tensorclass to have a data field by @vmoens in #906
- [BugFix] Allow fake-tensor detection pass through in torch 2.0 by @vmoens in #802
- [BugFix] Avoid collapsing NonTensorStack when calling where by @vmoens in #837
- [BugFix] Check if the current user has write access by @MateuszGuzek in #781
- [BugFix] Ensure dtype is preserved with autocast by @vmoens in #773
- [BugFix] FIx non-tensor writing in modules by @vmoens in #822
- [BugFix] Fix (keys, values) in sub by @vmoens in #907
- [BugFix] Fix
_make_dtype_promotionbackward compat by @vmoens in #842 - [BugFix] Fix
pad_sequencebehavior for non-tensor attributes of tensorclass by @kurtamohler in #884 - [BugFix] Fix builds by @vmoens in #849
- [BugFix] Fix compile + vmap by @vmoens in #924
- [BugFix] Fix deterministic fallback when the dist has no support by @vmoens in #830
- [BugFix] Fix device parsing in augmented funcs by @vmoens in #770
- [BugFix] Fix empty tuple index by @vmoens in #811
- [BugFix] Fix fallback of deterministic samples when mean is not available by @vmoens in #828
- [BugFix] Fix functorch dim mock by @vmoens in #777
- [BugFix] Fix gather device by @vmoens in #815
- [BugFix] Fix h5 auto batch size by @vmoens in #798
- [BugFix] Fix key ordering in pointwise ops by @vmoens in #855
- [BugFix] Fix lazy stack features (where and norm) by @vmoens in #795
- [BugFix] Fix map by @vmoens in #862
- [BugFix] Fix map test with fork on cuda by @vmoens in #765
- [BugFix] Fix pad_sequence for non tensors by @vmoens in #784
- [BugFix] Fix setting non-tensors as data in NonTensorData by @vmoens in #864
- [BugFix] Fix stack of tensorclasses (and nontensors) by @vmoens in #820
- [BugFix] Fix storage.filename compat with torch 2.0 by @vmoens in #803
- [BugFix] Fix tensorclass register by @vmoens in #817
- [BugFix] Fix torch version assertion by @vmoens in #917
- [BugFix] Fix vmap compatibility with torch<2.2 by @vmoens in #925
- [BugFix] Fix vmap for tensorclass by @vmoens in #778
- [BugFix] Fix wheels by @vmoens in #856
- [BugFix] Keep stack dim name in LazyStackedTensorDict copy ops by @vmoens in #801
- [BugFix] Read-only compatibility in MemoryMappedTensor by @vmoens in #780
- [BugFix] Refactor map and map_iter by @vmoens in #869
- [BugFix] Sync cuda only if initialized by @vmoens in #767
- [BugFix] fix _expand_to_match_shape for single bool tensor by @vmoens in #902
- [BugFix] fix construction of lazy stacks from tds by @vmoens in #903
- [BugFix] fix tensorclass set by @vmoens in #854
- [BugFix] remove inplace updates when using td as a decorator by @vmoens in #796
- [BugFix] use as_subclass in Buffer by @vmoens in #913
Refactoring and code quality
- [Quality] Better nested detection in numpy() by @vmoens in #800
- [Quality] Better repr of keys by @vmoens in #897
- [Quality] fix c++ binaries formatting by @vmoens in #859
- [Quality] non_blocking_pin instead of pin_memory by @vmoens in #915
- [Quality] zip-strict when possible by @vmoens in #886
- [Refactor] Better tensorclass method registration by @vmoens in #797
- [Refactor] Make all leaves in tensorclass part of
_tensordict, except for NonTensorData by @vmoens in #841 - [Refactor] Refactor c++ binaries location by @vmoens in #860
- [Refactor] Refactor is_dynamo_compile imports by @vmoens in #916
- [Refactor] Remove
_run_checksfrom__init__by @vmoens in #843 - [Refactor] use from_file instead of mmap+from_buffer for readonly files by @vmoens in #808
Others
- Bump jinja2 from 3.1.3 to 3.1.4 in /docs by @dependabot in #840
- [Benchmark] Benchmark tensorclass ops by @vmoens in #790
- [Benchmark] Fix recursion and cache errors in benchmarks by @vmoens in #900
- [CI] Fix nightly build by @vmoens in #861
- [CI] Python 3.12 compatibility by @kurtamohler in #818
- [Doc] Fix symbolic trace reference in doc by @vmoens in #918
- [Formatting] Lint revamp by @vmoens in #890
- [Test] Test FC of memmap save and load by @vmoens in #838
- [Versioning] Allow any torch version for local builds by @vmoens in #764
- [Versioning] Make dependence on uint16 optional for older PT versions by @vmoens in #839
- [Versioning] tree_leaves for pytorch < 2.3 by @vmoens in #806
- [Versioning] v0.5 bump by @vmoens in #848
New Contributors
- @MateuszGuzek made their first contribution in #781
- @GaetanLepage made their first contribution in #807
- @kurtamohler made their first contribution in #818
Full Changelog: v0.4.0...v0.5.0