@@ -64,9 +64,6 @@ def __init__(
6464 discrete_actions = True
6565 self .action_distribution_type = "categorical"
6666
67- if fuse_keys is None :
68- fuse_keys = []
69-
7067 super ().__init__ (
7168 PointNavResNetNet (
7269 observation_space = observation_space ,
@@ -237,7 +234,7 @@ def __init__(
237234 backbone ,
238235 resnet_baseplanes ,
239236 normalize_visual_inputs : bool ,
240- fuse_keys : List [str ],
237+ fuse_keys : Optional [ List [str ] ],
241238 force_blind_policy : bool = False ,
242239 discrete_actions : bool = True ,
243240 ):
@@ -255,9 +252,15 @@ def __init__(
255252
256253 # Only fuse the 1D state inputs. Other inputs are processed by the
257254 # visual encoder
258- self ._fuse_keys : List [str ] = [
259- k for k in fuse_keys if len (observation_space .spaces [k ].shape ) == 1
260- ]
255+ self ._fuse_keys : List [str ] = (
256+ [
257+ k
258+ for k in fuse_keys
259+ if len (observation_space .spaces [k ].shape ) == 1
260+ ]
261+ if fuse_keys is not None
262+ else []
263+ )
261264 if len (self ._fuse_keys ) != 0 :
262265 rnn_input_size += sum (
263266 [observation_space .spaces [k ].shape [0 ] for k in self ._fuse_keys ]
@@ -355,12 +358,16 @@ def __init__(
355358 if force_blind_policy :
356359 use_obs_space = spaces .Dict ({})
357360 else :
358- use_obs_space = spaces .Dict (
359- {
360- k : observation_space .spaces [k ]
361- for k in fuse_keys
362- if len (observation_space .spaces [k ].shape ) == 3
363- }
361+ use_obs_space = (
362+ spaces .Dict (
363+ {
364+ k : observation_space .spaces [k ]
365+ for k in fuse_keys
366+ if len (observation_space .spaces [k ].shape ) == 3
367+ }
368+ )
369+ if fuse_keys is not None
370+ else observation_space
364371 )
365372
366373 self .visual_encoder = ResNetEncoder (
0 commit comments