55
66from __future__ import annotations
77
8+ import collections
89import logging
910from copy import deepcopy
10- from typing import Any , Iterable , List
11+ from typing import Any , Callable , Iterable , List , OrderedDict , overload
1112
1213from tensordict ._nestedkey import NestedKey
1314
@@ -170,19 +171,57 @@ class TensorDictSequential(TensorDictModule):
170171 module : nn .ModuleList
171172 _select_before_return = False
172173
174+ @overload
173175 def __init__ (
174176 self ,
175- * modules : TensorDictModuleBase ,
177+ modules : OrderedDict [str , Callable [[TensorDictBase ], TensorDictBase ]],
178+ * ,
179+ partial_tolerant : bool = False ,
180+ selected_out_keys : List [NestedKey ] | None = None ,
181+ ) -> None : ...
182+
183+ @overload
184+ def __init__ (
185+ self ,
186+ modules : List [Callable [[TensorDictBase ], TensorDictBase ]],
187+ * ,
188+ partial_tolerant : bool = False ,
189+ selected_out_keys : List [NestedKey ] | None = None ,
190+ ) -> None : ...
191+
192+ def __init__ (
193+ self ,
194+ * modules : Callable [[TensorDictBase ], TensorDictBase ],
176195 partial_tolerant : bool = False ,
177196 selected_out_keys : List [NestedKey ] | None = None ,
178197 ) -> None :
179- modules = self ._convert_modules (modules )
180- in_keys , out_keys = self ._compute_in_and_out_keys (modules )
181- self ._complete_out_keys = list (out_keys )
182198
183- super ().__init__ (
184- module = nn .ModuleList (list (modules )), in_keys = in_keys , out_keys = out_keys
185- )
199+ if len (modules ) == 1 and isinstance (modules [0 ], collections .OrderedDict ):
200+ modules_vals = self ._convert_modules (modules [0 ].values ())
201+ in_keys , out_keys = self ._compute_in_and_out_keys (modules_vals )
202+ self ._complete_out_keys = list (out_keys )
203+ modules = collections .OrderedDict (
204+ ** {key : val for key , val in zip (modules [0 ], modules_vals )}
205+ )
206+ super ().__init__ (
207+ module = nn .ModuleDict (modules ), in_keys = in_keys , out_keys = out_keys
208+ )
209+ elif len (modules ) == 1 and isinstance (
210+ modules [0 ], collections .abc .MutableSequence
211+ ):
212+ modules = self ._convert_modules (modules [0 ])
213+ in_keys , out_keys = self ._compute_in_and_out_keys (modules )
214+ self ._complete_out_keys = list (out_keys )
215+ super ().__init__ (
216+ module = nn .ModuleList (modules ), in_keys = in_keys , out_keys = out_keys
217+ )
218+ else :
219+ modules = self ._convert_modules (modules )
220+ in_keys , out_keys = self ._compute_in_and_out_keys (modules )
221+ self ._complete_out_keys = list (out_keys )
222+ super ().__init__ (
223+ module = nn .ModuleList (list (modules )), in_keys = in_keys , out_keys = out_keys
224+ )
186225
187226 self .partial_tolerant = partial_tolerant
188227 if selected_out_keys :
@@ -408,7 +447,7 @@ def select_subsequence(
408447 out_keys = deepcopy (self .out_keys )
409448 out_keys = unravel_key_list (out_keys )
410449
411- module_list = list (self .module )
450+ module_list = list (self ._module_iter () )
412451 id_to_keep = set (range (len (module_list )))
413452 for i , module in enumerate (module_list ):
414453 if (
@@ -445,8 +484,12 @@ def select_subsequence(
445484 raise ValueError (
446485 "No modules left after selection. Make sure that in_keys and out_keys are coherent."
447486 )
448-
449- return type (self )(* modules )
487+ if isinstance (self .module , nn .ModuleList ):
488+ return type (self )(* modules )
489+ else :
490+ keys = [key for key in self .module if self .module [key ] in modules ]
491+ modules_dict = OrderedDict (** {key : val for key , val in zip (keys , modules )})
492+ return type (self )(modules_dict )
450493
451494 def _run_module (
452495 self ,
@@ -466,6 +509,12 @@ def _run_module(
466509 module (sub_td , ** kwargs )
467510 return tensordict
468511
512+ def _module_iter (self ):
513+ if isinstance (self .module , nn .ModuleDict ):
514+ yield from self .module .children ()
515+ else :
516+ yield from self .module
517+
469518 @dispatch (auto_batch_size = False )
470519 @_set_skip_existing_None ()
471520 def forward (
@@ -481,7 +530,7 @@ def forward(
481530 else :
482531 tensordict_exec = tensordict
483532 if not len (kwargs ):
484- for module in self .module :
533+ for module in self ._module_iter () :
485534 tensordict_exec = self ._run_module (module , tensordict_exec , ** kwargs )
486535 else :
487536 raise RuntimeError (
@@ -510,8 +559,8 @@ def forward(
510559 def __len__ (self ) -> int :
511560 return len (self .module )
512561
513- def __getitem__ (self , index : int | slice ) -> TensorDictModuleBase :
514- if isinstance (index , int ):
562+ def __getitem__ (self , index : int | slice | str ) -> TensorDictModuleBase :
563+ if isinstance (index , ( int , str ) ):
515564 return self .module .__getitem__ (index )
516565 else :
517566 return type (self )(* self .module .__getitem__ (index ))
0 commit comments