Skip to content

Commit b966e84

Browse files
author
Vincent Moens
committed
[Feature] OrderedDict for TensorDictSequential
ghstack-source-id: b2a0a12 Pull Request resolved: #1142
1 parent 2aea3dd commit b966e84

File tree

2 files changed

+116
-14
lines changed

2 files changed

+116
-14
lines changed

tensordict/nn/sequence.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
from __future__ import annotations
77

8+
import collections
89
import logging
910
from copy import deepcopy
10-
from typing import Any, Iterable, List
11+
from typing import Any, Callable, Iterable, List, OrderedDict, overload
1112

1213
from tensordict._nestedkey import NestedKey
1314

@@ -170,19 +171,57 @@ class TensorDictSequential(TensorDictModule):
170171
module: nn.ModuleList
171172
_select_before_return = False
172173

174+
@overload
173175
def __init__(
174176
self,
175-
*modules: TensorDictModuleBase,
177+
modules: OrderedDict[str, Callable[[TensorDictBase], TensorDictBase]],
178+
*,
179+
partial_tolerant: bool = False,
180+
selected_out_keys: List[NestedKey] | None = None,
181+
) -> None: ...
182+
183+
@overload
184+
def __init__(
185+
self,
186+
modules: List[Callable[[TensorDictBase], TensorDictBase]],
187+
*,
188+
partial_tolerant: bool = False,
189+
selected_out_keys: List[NestedKey] | None = None,
190+
) -> None: ...
191+
192+
def __init__(
193+
self,
194+
*modules: Callable[[TensorDictBase], TensorDictBase],
176195
partial_tolerant: bool = False,
177196
selected_out_keys: List[NestedKey] | None = None,
178197
) -> None:
179-
modules = self._convert_modules(modules)
180-
in_keys, out_keys = self._compute_in_and_out_keys(modules)
181-
self._complete_out_keys = list(out_keys)
182198

183-
super().__init__(
184-
module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys
185-
)
199+
if len(modules) == 1 and isinstance(modules[0], collections.OrderedDict):
200+
modules_vals = self._convert_modules(modules[0].values())
201+
in_keys, out_keys = self._compute_in_and_out_keys(modules_vals)
202+
self._complete_out_keys = list(out_keys)
203+
modules = collections.OrderedDict(
204+
**{key: val for key, val in zip(modules[0], modules_vals)}
205+
)
206+
super().__init__(
207+
module=nn.ModuleDict(modules), in_keys=in_keys, out_keys=out_keys
208+
)
209+
elif len(modules) == 1 and isinstance(
210+
modules[0], collections.abc.MutableSequence
211+
):
212+
modules = self._convert_modules(modules[0])
213+
in_keys, out_keys = self._compute_in_and_out_keys(modules)
214+
self._complete_out_keys = list(out_keys)
215+
super().__init__(
216+
module=nn.ModuleList(modules), in_keys=in_keys, out_keys=out_keys
217+
)
218+
else:
219+
modules = self._convert_modules(modules)
220+
in_keys, out_keys = self._compute_in_and_out_keys(modules)
221+
self._complete_out_keys = list(out_keys)
222+
super().__init__(
223+
module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys
224+
)
186225

187226
self.partial_tolerant = partial_tolerant
188227
if selected_out_keys:
@@ -408,7 +447,7 @@ def select_subsequence(
408447
out_keys = deepcopy(self.out_keys)
409448
out_keys = unravel_key_list(out_keys)
410449

411-
module_list = list(self.module)
450+
module_list = list(self._module_iter())
412451
id_to_keep = set(range(len(module_list)))
413452
for i, module in enumerate(module_list):
414453
if (
@@ -445,8 +484,12 @@ def select_subsequence(
445484
raise ValueError(
446485
"No modules left after selection. Make sure that in_keys and out_keys are coherent."
447486
)
448-
449-
return type(self)(*modules)
487+
if isinstance(self.module, nn.ModuleList):
488+
return type(self)(*modules)
489+
else:
490+
keys = [key for key in self.module if self.module[key] in modules]
491+
modules_dict = OrderedDict(**{key: val for key, val in zip(keys, modules)})
492+
return type(self)(modules_dict)
450493

451494
def _run_module(
452495
self,
@@ -466,6 +509,12 @@ def _run_module(
466509
module(sub_td, **kwargs)
467510
return tensordict
468511

512+
def _module_iter(self):
513+
if isinstance(self.module, nn.ModuleDict):
514+
yield from self.module.children()
515+
else:
516+
yield from self.module
517+
469518
@dispatch(auto_batch_size=False)
470519
@_set_skip_existing_None()
471520
def forward(
@@ -481,7 +530,7 @@ def forward(
481530
else:
482531
tensordict_exec = tensordict
483532
if not len(kwargs):
484-
for module in self.module:
533+
for module in self._module_iter():
485534
tensordict_exec = self._run_module(module, tensordict_exec, **kwargs)
486535
else:
487536
raise RuntimeError(
@@ -510,8 +559,8 @@ def forward(
510559
def __len__(self) -> int:
511560
return len(self.module)
512561

513-
def __getitem__(self, index: int | slice) -> TensorDictModuleBase:
514-
if isinstance(index, int):
562+
def __getitem__(self, index: int | slice | str) -> TensorDictModuleBase:
563+
if isinstance(index, (int, str)):
515564
return self.module.__getitem__(index)
516565
else:
517566
return type(self)(*self.module.__getitem__(index))

test/test_nn.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pickle
1111
import unittest
1212
import weakref
13+
from collections import OrderedDict
1314

1415
import pytest
1516
import torch
@@ -797,6 +798,58 @@ def test_tdmodule_inplace(self):
797798

798799

799800
class TestTDSequence:
801+
def test_ordered_dict(self):
802+
linear = nn.Linear(3, 4)
803+
linear.weight.data.fill_(0)
804+
linear.bias.data.fill_(1)
805+
layer0 = TensorDictModule(linear, in_keys=["x"], out_keys=["y"])
806+
ordered_dict = OrderedDict(
807+
layer0=layer0,
808+
layer1=lambda x: x + 1,
809+
)
810+
seq = TensorDictSequential(ordered_dict)
811+
td = seq(TensorDict(x=torch.ones(3)))
812+
assert (td["x"] == 2).all()
813+
assert (td["y"] == 2).all()
814+
assert seq["layer0"] is layer0
815+
816+
def test_ordered_dict_select_subsequence(self):
817+
ordered_dict = OrderedDict(
818+
layer0=TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["y"]),
819+
layer1=TensorDictModule(lambda x: x - 1, in_keys=["y"], out_keys=["z"]),
820+
layer2=TensorDictModule(
821+
lambda x, y: x + y, in_keys=["x", "y"], out_keys=["a"]
822+
),
823+
)
824+
seq = TensorDictSequential(ordered_dict)
825+
assert len(seq) == 3
826+
assert isinstance(seq.module, nn.ModuleDict)
827+
seq_select = seq.select_subsequence(out_keys=["a"])
828+
assert len(seq_select) == 2
829+
assert isinstance(seq_select.module, nn.ModuleDict)
830+
assert list(seq_select.module) == ["layer0", "layer2"]
831+
832+
def test_ordered_dict_select_outkeys(self):
833+
ordered_dict = OrderedDict(
834+
layer0=TensorDictModule(
835+
lambda x: x + 1, in_keys=["x"], out_keys=["intermediate"]
836+
),
837+
layer1=TensorDictModule(
838+
lambda x: x - 1, in_keys=["intermediate"], out_keys=["z"]
839+
),
840+
layer2=TensorDictModule(
841+
lambda x, y: x + y, in_keys=["x", "z"], out_keys=["a"]
842+
),
843+
)
844+
seq = TensorDictSequential(ordered_dict)
845+
assert len(seq) == 3
846+
assert isinstance(seq.module, nn.ModuleDict)
847+
seq.select_out_keys("z", "a")
848+
td = seq(TensorDict(x=0))
849+
assert "intermediate" not in td
850+
assert "z" in td
851+
assert "a" in td
852+
800853
@pytest.mark.parametrize("args", [True, False])
801854
def test_input_keys(self, args):
802855
module0 = TensorDictModule(lambda x: x + 0, in_keys=["input"], out_keys=["1"])

0 commit comments

Comments
 (0)