Skip to content

Commit 5ce36a6

Browse files
authored
Fix backwards compatability for Hab1 tasks and Hab2 gym-style tasks (#878)
1 parent 69dd8b8 commit 5ce36a6

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

habitat_baselines/rl/ddppo/policy/resnet_policy.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,6 @@ def __init__(
6464
discrete_actions = True
6565
self.action_distribution_type = "categorical"
6666

67-
if fuse_keys is None:
68-
fuse_keys = []
69-
7067
super().__init__(
7168
PointNavResNetNet(
7269
observation_space=observation_space,
@@ -237,7 +234,7 @@ def __init__(
237234
backbone,
238235
resnet_baseplanes,
239236
normalize_visual_inputs: bool,
240-
fuse_keys: List[str],
237+
fuse_keys: Optional[List[str]],
241238
force_blind_policy: bool = False,
242239
discrete_actions: bool = True,
243240
):
@@ -255,9 +252,15 @@ def __init__(
255252

256253
# Only fuse the 1D state inputs. Other inputs are processed by the
257254
# visual encoder
258-
self._fuse_keys: List[str] = [
259-
k for k in fuse_keys if len(observation_space.spaces[k].shape) == 1
260-
]
255+
self._fuse_keys: List[str] = (
256+
[
257+
k
258+
for k in fuse_keys
259+
if len(observation_space.spaces[k].shape) == 1
260+
]
261+
if fuse_keys is not None
262+
else []
263+
)
261264
if len(self._fuse_keys) != 0:
262265
rnn_input_size += sum(
263266
[observation_space.spaces[k].shape[0] for k in self._fuse_keys]
@@ -355,12 +358,16 @@ def __init__(
355358
if force_blind_policy:
356359
use_obs_space = spaces.Dict({})
357360
else:
358-
use_obs_space = spaces.Dict(
359-
{
360-
k: observation_space.spaces[k]
361-
for k in fuse_keys
362-
if len(observation_space.spaces[k].shape) == 3
363-
}
361+
use_obs_space = (
362+
spaces.Dict(
363+
{
364+
k: observation_space.spaces[k]
365+
for k in fuse_keys
366+
if len(observation_space.spaces[k].shape) == 3
367+
}
368+
)
369+
if fuse_keys is not None
370+
else observation_space
364371
)
365372

366373
self.visual_encoder = ResNetEncoder(

habitat_baselines/rl/ddppo/policy/running_mean_and_var.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
class RunningMeanAndVar(nn.Module):
1414
def __init__(self, n_channels: int) -> None:
1515
super().__init__()
16+
assert n_channels > 0
1617
self.register_buffer("_mean", torch.zeros(1, n_channels, 1, 1))
1718
self.register_buffer("_var", torch.zeros(1, n_channels, 1, 1))
1819
self.register_buffer("_count", torch.zeros(()))

0 commit comments

Comments
 (0)