Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,359 changes: 3,359 additions & 0 deletions DreamerExperiment.ipynb

Large diffs are not rendered by default.

5,598 changes: 5,598 additions & 0 deletions ablation_results/metrics (1).jsonl

Large diffs are not rendered by default.

5,583 changes: 5,583 additions & 0 deletions ablation_results/metrics (2).jsonl

Large diffs are not rendered by default.

5,462 changes: 5,462 additions & 0 deletions ablation_results/metrics (3).jsonl

Large diffs are not rendered by default.

5,667 changes: 5,667 additions & 0 deletions ablation_results/metrics.jsonl

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions dreamerv3/configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ defaults:
minecraft: {size: [64, 64], break_speed: 100.0, logs: False, length: 36000}
dmc: {size: [64, 64], repeat: 1, proprio: True, image: True, camera: -1}
loconav: {size: [64, 64], repeat: 1, camera: -1}
nethack: {size: [64, 64], max_episode_steps: 5000, use_seed: True}

replay:
size: 5e6
Expand Down Expand Up @@ -195,6 +196,10 @@ loconav:
env.loconav.repeat: 1
run.train_ratio: 256

nethack:
task: nethack_Challenge
run: {steps: 1e8, train_ratio: 64, envs: 8}

multicpu:
batch_size: 12
jax.mock_devices: 8
Expand Down
1 change: 1 addition & 0 deletions dreamerv3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def make_env(config, index, **overrides):
'langroom': 'embodied.envs.langroom:LangRoom',
'procgen': 'embodied.envs.procgen:ProcGen',
'bsuite': 'embodied.envs.bsuite:BSuite',
'nethack': 'embodied.envs.nethack:NetHack',
'memmaze': lambda task, **kw: from_gym.FromGym(
f'MemoryMaze-{task}-v0', **kw),
}[suite]
Expand Down
85 changes: 85 additions & 0 deletions embodied/envs/nethack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import elements
import embodied
import numpy as np


class NetHack(embodied.Env):

def __init__(self, task, size=(64, 64), max_episode_steps=5000, seed=None):
import gymnasium as gym
import nle
from gymnasium.wrappers import TimeLimit

env_name = f'NetHack{task.replace("_", "-")}-v0'
try:
base_env = gym.make(env_name)
except Exception:
base_env = gym.make('NetHackChallenge-v0')

self._env = TimeLimit(base_env, max_episode_steps=max_episode_steps)
self._seed = seed
self._size = size
self._done = True
obs_space = self._env.observation_space

self._blstats_shape = tuple(obs_space["blstats"].shape)

@property
def obs_space(self):
return {
'image': elements.Space(np.uint8, (*self._size, 3)),
'blstats': elements.Space(np.float32, self._blstats_shape),
'reward': elements.Space(np.float32),
'is_first': elements.Space(bool),
'is_last': elements.Space(bool),
'is_terminal': elements.Space(bool),
}

@property
def act_space(self):
return {
'action': elements.Space(np.int32, (), 0, self._env.action_space.n),
'reset': elements.Space(bool),
}

def step(self, action):
if action['reset'] or self._done:
self._done = False
obs, info = self._env.reset(seed=self._seed)
return self._obs(obs, 0.0, is_first=True)

obs, reward, terminated, truncated, info = self._env.step(action['action'])
self._done = bool(terminated or truncated)
return self._obs(
obs, reward,
is_last=self._done,
is_terminal=bool(terminated),
)

def _obs(self, obs, reward, is_first=False, is_last=False, is_terminal=False):
image = self._render_glyphs(obs['glyphs'])
return dict(
image=image,
blstats=obs['blstats'].astype(np.float32),
reward=np.float32(reward),
is_first=is_first,
is_last=is_last,
is_terminal=is_terminal,
)

def _render_glyphs(self, glyphs):
from PIL import Image
h, w = glyphs.shape
rgb = np.zeros((h, w, 3), dtype=np.uint8)
rgb[:, :, 0] = ((glyphs >> 0) & 0xFF).astype(np.uint8)
rgb[:, :, 1] = ((glyphs >> 8) & 0xFF).astype(np.uint8)
rgb[:, :, 2] = ((glyphs >> 4) & 0xFF).astype(np.uint8)
image = Image.fromarray(rgb)
image = image.resize(self._size, Image.BILINEAR)
return np.array(image)

def close(self):
try:
self._env.close()
except Exception:
pass
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ optax
portal>=3.5.0
scope>=0.4.4
tqdm
crafter
gymnasium[atari]