@@ -218,7 +218,9 @@ def __init__(
218218                ** {key : val  for  key , val  in  _zip_strict (modules [0 ], modules_vals )}
219219            )
220220            super ().__init__ (
221-                 module = nn .ModuleDict (modules ), in_keys = in_keys , out_keys = out_keys 
221+                 module = nn .ModuleDict (modules ),
222+                 in_keys = in_keys ,
223+                 out_keys = out_keys ,
222224            )
223225        elif  len (modules ) ==  1  and  isinstance (
224226            modules [0 ], collections .abc .MutableSequence 
@@ -227,20 +229,25 @@ def __init__(
227229            in_keys , out_keys  =  self ._compute_in_and_out_keys (modules )
228230            self ._complete_out_keys  =  list (out_keys )
229231            super ().__init__ (
230-                 module = nn .ModuleList (modules ), in_keys = in_keys , out_keys = out_keys 
232+                 module = nn .ModuleList (modules ),
233+                 in_keys = in_keys ,
234+                 out_keys = out_keys ,
231235            )
232236        elif  len (modules ) ==  1  and  isinstance (modules [0 ], dict ):
233237            return  self .__init__ (
234238                collections .OrderedDict (modules [0 ]),
235239                partial_tolerant = partial_tolerant ,
236240                selected_out_keys = selected_out_keys ,
241+                 inplace = inplace ,
237242            )
238243        else :
239244            modules  =  self ._convert_modules (modules )
240245            in_keys , out_keys  =  self ._compute_in_and_out_keys (modules )
241246            self ._complete_out_keys  =  list (out_keys )
242247            super ().__init__ (
243-                 module = nn .ModuleList (list (modules )), in_keys = in_keys , out_keys = out_keys 
248+                 module = nn .ModuleList (list (modules )),
249+                 in_keys = in_keys ,
250+                 out_keys = out_keys ,
244251            )
245252
246253        self .inplace  =  inplace 
@@ -628,6 +635,7 @@ def forward(
628635            )
629636        if  tensordict_out  is  not None :
630637            result  =  tensordict_out 
638+             print ('here! update' )
631639            result .update (tensordict_exec , keys_to_update = self .out_keys )
632640        else :
633641            result  =  tensordict_exec 
0 commit comments