diff --git a/dreamerv3/agent.py b/dreamerv3/agent.py index 4f2ad92df..f6c183ca3 100644 --- a/dreamerv3/agent.py +++ b/dreamerv3/agent.py @@ -2,6 +2,7 @@ import jax import jax.numpy as jnp import ruamel.yaml as yaml +from pathlib import Path tree_map = jax.tree_util.tree_map sg = lambda x: tree_map(jax.lax.stop_gradient, x) @@ -23,12 +24,15 @@ def filter(self, record): class Agent(nj.Module): configs = yaml.YAML(typ='safe').load( - (embodied.Path(__file__).parent / 'configs.yaml').read()) + (Path(__file__).parent / 'configs.yaml')) def __init__(self, obs_space, act_space, step, config): self.config = config self.obs_space = obs_space - self.act_space = act_space['action'] + try: + self.act_space = act_space['action'] + except KeyError: + self.act_space = act_space self.step = step self.wm = WorldModel(obs_space, act_space, config, name='wm') self.task_behavior = getattr(behaviors, config.task_behavior)( @@ -60,16 +64,28 @@ def policy(self, obs, state, mode='train'): expl_outs, expl_state = self.expl_behavior.policy(latent, expl_state) if mode == 'eval': outs = task_outs - outs['action'] = outs['action'].sample(seed=nj.rng()) - outs['log_entropy'] = jnp.zeros(outs['action'].shape[:1]) + if type(self.act_space) == dict: + outs['action'] = {k: w.mode() for k, w in outs['action'].items()} + outs['log_entropy'] = {k: w.entropy() for k, w in outs['action'].items()} + else: + outs['action'] = outs['action'].sample(seed=nj.rng()) + outs['log_entropy'] = jnp.zeros(outs['action'].shape[:1]) elif mode == 'explore': outs = expl_outs - outs['log_entropy'] = outs['action'].entropy() - outs['action'] = outs['action'].sample(seed=nj.rng()) + if type(self.act_space) == dict: + outs['log_entropy'] = {k: w.entropy() for k, w in outs['action'].items()} + outs['action'] = {k: w.sample(seed=nj.rng()) for k, w in outs['action'].items()} + else: + outs['log_entropy'] = outs['action'].entropy() + outs['action'] = outs['action'].sample(seed=nj.rng()) elif mode == 'train': outs = task_outs - outs['log_entropy'] = outs['action'].entropy() - outs['action'] = outs['action'].sample(seed=nj.rng()) + if type(self.act_space) == dict: + outs['log_entropy'] = {k: w.entropy() for k, w in outs['action'].items()} + outs['action'] = {k: w.sample(seed=nj.rng()) for k, w in outs['action'].items()} + else: + outs['log_entropy'] = outs['action'].entropy() + outs['action'] = outs['action'].sample(seed=nj.rng()) state = ((latent, outs['action']), task_state, expl_state) return outs, state @@ -119,7 +135,11 @@ class WorldModel(nj.Module): def __init__(self, obs_space, act_space, config): self.obs_space = obs_space - self.act_space = act_space['action'] + try: + self.act_space_shape = act_space['action'].shape + except KeyError: + self.act_space_shape = {k: v.shape for k, v in act_space.items() if k != 'reset'} + #tuple(a+b for a, b in zip(act_space['Continous'].shape, act_space['Discrete'].shape)) self.config = config shapes = {k: tuple(v.shape) for k, v in obs_space.items()} shapes = {k: v for k, v in shapes.items() if not k.startswith('log_')} @@ -138,11 +158,18 @@ def __init__(self, obs_space, act_space, config): def initial(self, batch_size): prev_latent = self.rssm.initial(batch_size) - prev_action = jnp.zeros((batch_size, *self.act_space.shape)) + if type(self.act_space_shape) == dict: + prev_action = {k: jnp.zeros((batch_size, *v)) for k, v in self.act_space_shape.items()} + else: + prev_action = jnp.zeros((batch_size, *self.act_space_shape)) return prev_latent, prev_action def train(self, data, state): modules = [self.encoder, self.rssm, *self.heads.values()] + if 'action' not in data: + data['action'] = {"Continous": data['Continous'], "Discrete": data['Discrete']} + data.pop('Continous') + data.pop('Discrete') mets, (state, outs, metrics) = self.opt( modules, self.loss, data, state, has_aux=True) metrics.update(mets) @@ -151,7 +178,23 @@ def train(self, data, state): def loss(self, data, state): embed = self.encoder(data) prev_latent, prev_action = state - prev_actions = jnp.concatenate([ + # Shape state: (prev_latent, prev_action), action can be a Dict if hybrid (Discrete + Continous) + if 'action' not in data: + if len(data['Continous'].shape) == 2: + data['Continous'] = data['Continous'][..., None] + if len(data['Discrete'].shape) == 2: + data['Discrete'] = data['Discrete'][..., None] + + data['action'] = {"Continous": data['Continous'], "Discrete": data['Discrete']} + data.pop('Continous') + data.pop('Discrete') + + if isinstance(data['action'], dict): + prev_actions = { + k: jnp.concatenate([prev_action[k][:, None], + data['action'][k][:, :-1]], 1) for k in data['action']} + else: + prev_actions = jnp.concatenate([ prev_action[:, None], data['action'][:, :-1]], 1) post, prior = self.rssm.observe( embed, prev_actions, data['is_first'], prev_latent) @@ -173,7 +216,10 @@ def loss(self, data, state): out = {'embed': embed, 'post': post, 'prior': prior} out.update({f'{k}_loss': v for k, v in losses.items()}) last_latent = {k: v[:, -1] for k, v in post.items()} - last_action = data['action'][:, -1] + if isinstance(data['action'], dict): + last_action = {k: v[:, -1] for k, v in data['action'].items()} + else: + last_action = data['action'][:, -1] state = last_latent, last_action metrics = self._metrics(data, dists, post, prior, losses, model_loss) return model_loss.mean(), (state, out, metrics) @@ -189,8 +235,13 @@ def step(prev, _): return {**state, 'action': policy(state)} traj = jaxutils.scan( step, jnp.arange(horizon), start, self.config.imag_unroll) - traj = { - k: jnp.concatenate([start[k][None], v], 0) for k, v in traj.items()} + # traj = + traj_ = { + k: jnp.concatenate([start[k][None], v], 0) for k, v in traj.items() if k != "action"} + Continous = jnp.concatenate([start["action"]["Continous"][None], traj["action"]["Continous"]], 0) + Discrete = jnp.concatenate([start["action"]["Discrete"][None], traj["action"]["Discrete"]], 0) + traj_["action"] = {"Continous": Continous, "Discrete": Discrete} + traj = traj_ cont = self.heads['cont'](traj).mode() traj['cont'] = jnp.concatenate([first_cont[None], cont[1:]], 0) discount = 1 - 1 / self.config.horizon @@ -201,13 +252,21 @@ def report(self, data): state = self.initial(len(data['is_first'])) report = {} report.update(self.loss(data, state)[-1][-1]) + if isinstance(data['action'], dict): + action_context, action_post = {}, {} + for k, v in data['action'].items(): + action_context[k] = v[:6, :5] + action_post[k] = v[:6, 5:] + else: + action_context = data['action'][:6, :5] + action_post = data['action'][:6, 5:] context, _ = self.rssm.observe( - self.encoder(data)[:6, :5], data['action'][:6, :5], + self.encoder(data)[:6, :5], action_context, data['is_first'][:6, :5]) start = {k: v[:, -1] for k, v in context.items()} recon = self.heads['decoder'](context) openl = self.heads['decoder']( - self.rssm.imagine(data['action'][:6, 5:], start)) + self.rssm.imagine(action_post, start)) for key in self.heads['decoder'].cnn_shapes.keys(): truth = data[key][:6].astype(jnp.float32) model = jnp.concatenate([recon[key].mode()[:, :5], openl[key].mode()], 1) @@ -246,11 +305,22 @@ def __init__(self, critics, scales, act_space, config): self.scales = scales self.act_space = act_space self.config = config - disc = act_space.discrete - self.grad = config.actor_grad_disc if disc else config.actor_grad_cont + if type(self.act_space) == dict: + shape = {k: v.shape for k, v in act_space.items() if k != 'reset'} + Discrete = False + self.grad_disc = config.actor_grad_disc + self.grad_cont = config.actor_grad_cont + self.grad = False + else: + shape = act_space.shape + Discrete = act_space.discrete + self.grad = config.actor_grad_disc if Discrete else config.actor_grad_cont + self.actor = nets.MLP( - name='actor', dims='deter', shape=act_space.shape, **config.actor, - dist=config.actor_dist_disc if disc else config.actor_dist_cont) + name='actor', dims='deter', shape=shape, **config.actor, + dist=config.actor_dist_disc if Discrete else config.actor_dist_cont, + dist_cont=config.actor_dist_cont if not Discrete else None, + dist_disc=config.actor_dist_disc if not Discrete else None) self.retnorms = { k: jaxutils.Moments(**config.retnorm, name=f'retnorm_{k}') for k in critics} @@ -264,7 +334,10 @@ def policy(self, state, carry): def train(self, imagine, start, context): def loss(start): - policy = lambda s: self.actor(sg(s)).sample(seed=nj.rng()) + if type(self.act_space) == dict: + policy = lambda s: {k: w.sample(seed=nj.rng()) for k, w in self.actor(s).items()} + else: + policy = lambda s: self.actor(sg(s)).sample(seed=nj.rng()) traj = imagine(policy, start, self.config.imag_horizon) loss, metrics = self.loss(traj) return loss, (traj, metrics) @@ -291,9 +364,21 @@ def loss(self, traj): metrics[f'{key}_return_rate'] = (jnp.abs(ret) >= 0.5).mean() adv = jnp.stack(advs).sum(0) policy = self.actor(sg(traj)) - logpi = policy.log_prob(sg(traj['action']))[:-1] - loss = {'backprop': -adv, 'reinforce': -logpi * sg(adv)}[self.grad] - ent = policy.entropy()[:-1] + if type(self.act_space) == dict: + logpi = {k: w.log_prob(sg(traj['action'][k]))[:-1] for k, w in policy.items()} + logpi = sum(logpi.values()) + else: + logpi = policy.log_prob(sg(traj['action']))[:-1] + loss = {'backprop': -adv, 'reinforce': -logpi * sg(adv)} + if self.grad: + loss = loss[self.grad] + else: + loss = sum(loss.values()) + if type(self.act_space) == dict: + ent = {k: w.entropy()[:-1] for k, w in policy.items()} + ent = sum(ent.values()) + else: + ent = policy.entropy()[:-1] loss -= self.config.actent * ent loss *= sg(traj['weight'])[:-1] loss *= self.config.loss_scales.actor @@ -302,14 +387,29 @@ def loss(self, traj): def _metrics(self, traj, policy, logpi, ent, adv): metrics = {} - ent = policy.entropy()[:-1] - rand = (ent - policy.minent) / (policy.maxent - policy.minent) - rand = rand.mean(range(2, len(rand.shape))) - act = traj['action'] - act = jnp.argmax(act, -1) if self.act_space.discrete else act - metrics.update(jaxutils.tensorstats(act, 'action')) - metrics.update(jaxutils.tensorstats(rand, 'policy_randomness')) - metrics.update(jaxutils.tensorstats(ent, 'policy_entropy')) + if type(self.act_space) == dict: + ent = {k: w.entropy()[:-1] for k, w in policy.items()} + rand = {k: (ent[k] - w.minent) / (w.maxent - w.minent) for k, w in policy.items()} + act = {k: jnp.argmax(traj['action'][k], -1) for k in traj['action']} + # act = jnp.concatenate([act[k] for k in traj['action']], -1) + + for k in traj['action']: + metrics.update(jaxutils.tensorstats(act[k], f'action_{k}')) + for k in rand: + metrics.update(jaxutils.tensorstats(rand[k], f'policy_randomness_{k}')) + for k in ent: + metrics.update(jaxutils.tensorstats(ent[k], f'policy_entropy_{k}')) + + else: + ent = policy.entropy()[:-1] + rand = (ent - policy.minent) / (policy.maxent - policy.minent) + rand = rand.mean(range(2, len(rand.shape))) + act = traj['action'] + act = jnp.argmax(act, -1) if self.act_space.discrete else act + + metrics.update(jaxutils.tensorstats(act, 'action')) + metrics.update(jaxutils.tensorstats(rand, 'policy_randomness')) + metrics.update(jaxutils.tensorstats(ent, 'policy_entropy')) metrics.update(jaxutils.tensorstats(logpi, 'policy_logprob')) metrics.update(jaxutils.tensorstats(adv, 'adv')) metrics['imag_weight_dist'] = jaxutils.subsample(traj['weight']) @@ -338,7 +438,13 @@ def train(self, traj, actor): def loss(self, traj, target): metrics = {} - traj = {k: v[:-1] for k, v in traj.items()} + if type(traj['action']) == dict: + action_continuous = traj['action']["Continous"] + action_discrete = traj['action']["Discrete"] + traj = {k: v[:-1] for k, v in traj.items() if k != 'action'} + traj["action"] = {"Continous": action_continuous, "Discrete": action_discrete} + else: + traj = {k: v[:-1] for k, v in traj.items()} dist = self.net(traj) loss = -dist.log_prob(sg(target)) if self.config.critic_slowreg == 'logprob': @@ -358,8 +464,12 @@ def loss(self, traj, target): def score(self, traj, actor=None): rew = self.rewfn(traj) - assert len(rew) == len(traj['action']) - 1, ( - 'should provide rewards for all but last action') + if type(traj['action']) == dict: + action_continuous = traj['action']["Continous"] + action_discrete = traj['action']["Discrete"] + else: + assert len(rew) == len(traj['action']) - 1, ( + 'should provide rewards for all but last action') discount = 1 - 1 / self.config.horizon disc = traj['cont'][1:] * discount value = self.net(traj).mean() diff --git a/dreamerv3/embodied/core/batch.py b/dreamerv3/embodied/core/batch.py index 1c76e3092..974b99a70 100644 --- a/dreamerv3/embodied/core/batch.py +++ b/dreamerv3/embodied/core/batch.py @@ -24,11 +24,18 @@ def __len__(self): return len(self._envs) def step(self, action): - assert all(len(v) == len(self._envs) for v in action.values()), ( + assert all(len(v) == len(self._envs) for v in action.values() if not isinstance(v, dict)), ( + len(self._envs), {k: v.shape for k, v in action.items()}) + assert all(len(v2) == len(self._envs) for v in action.values() if isinstance(v, dict) for v2 in v.values()), ( len(self._envs), {k: v.shape for k, v in action.items()}) obs = [] for i, env in enumerate(self._envs): - act = {k: v[i] for k, v in action.items()} + act = {} + for k, v in action.items(): + if isinstance(v, dict): + act[k] = {k2: v2[i] for k2, v2 in v.items()} + else: + act[k] = v[i] obs.append(env.step(act)) if self._parallel: obs = [ob() for ob in obs] diff --git a/dreamerv3/embodied/core/batcher.py b/dreamerv3/embodied/core/batcher.py index 8f93f9f9a..6dd22e37c 100644 --- a/dreamerv3/embodied/core/batcher.py +++ b/dreamerv3/embodied/core/batcher.py @@ -92,6 +92,16 @@ def _batcher(self, sources, output): if isinstance(elem, Exception): raise elem batch = {k: np.stack([x[k] for x in elems], 0) for k in elems[0]} + # batch = {} + # for k in elems[0]: + # out = [] + # for x in elems: + # if isinstance(x[k][0], dict): + # out.append({kk: np.stack([xx[kk] for xx in x[k]], 0) for kk in x[k][0]}) + # else: + # out.append(x[k]) + # batch[k] = np.stack(out, 0) + if self._postprocess: batch = self._postprocess(batch) output.put(batch) # Will wait here if the queue is full. diff --git a/dreamerv3/embodied/core/driver.py b/dreamerv3/embodied/core/driver.py index aa4d66e63..6c7261063 100644 --- a/dreamerv3/embodied/core/driver.py +++ b/dreamerv3/embodied/core/driver.py @@ -42,12 +42,21 @@ def __call__(self, policy, steps=0, episodes=0): step, episode = self._step(policy, step, episode) def _step(self, policy, step, episode): - assert all(len(x) == len(self._env) for x in self._acts.values()) + assert all(len(x) == len(self._env) for x in self._acts.values() if not isinstance(x, dict)) + assert all(len(x2) == len(self._env) for x in self._acts.values() if isinstance(x, dict) for x2 in x.values() ) acts = {k: v for k, v in self._acts.items() if not k.startswith('log_')} obs = self._env.step(acts) obs = {k: convert(v) for k, v in obs.items()} assert all(len(x) == len(self._env) for x in obs.values()), obs acts, self._state = policy(obs, self._state, **self._kwargs) + if "action" in acts and isinstance(acts["action"], dict): + for k, v in acts["action"].items(): + acts[k] = v + del acts["action"] + if "log_entropy" in acts: + for k, v in acts["log_entropy"].items(): + acts[f"log_entropy_{k}"] = v + del acts["log_entropy"] acts = {k: convert(v) for k, v in acts.items()} if obs['is_last'].any(): mask = 1 - obs['is_last'] @@ -60,6 +69,15 @@ def _step(self, policy, step, episode): if first: self._eps[i].clear() for i in range(len(self._env)): + # if "action" in trns and isinstance(trns["action"], dict): + # out = {} + # for k, v in trns.items(): + # if type(v) is not dict: + # out[k] = v[i] + # else: + # out[k] = {k2: v2[i] for k2, v2 in v.items()} + # trn = out + # else: trn = {k: v[i] for k, v in trns.items()} [self._eps[i][k].append(v) for k, v in trn.items()] [fn(trn, i, **self._kwargs) for fn in self._on_steps] @@ -68,6 +86,16 @@ def _step(self, policy, step, episode): for i, done in enumerate(obs['is_last']): if done: ep = {k: convert(v) for k, v in self._eps[i].items()} + # ep = {} + # for k, v in self._eps[i].items(): + # if isinstance(v[0], dict): # Action is a list of dicts + # if k not in ep: + # ep[k] = [] + # for act_dict in v: + # ep[k].append({k2: convert(v2) for k2, v2 in act_dict.items()}) + # else: + # ep[k] = convert(v) + # ep = {k: convert(v) for k, v in self._eps[i].items() if not isinstance(v[0], dict)} [fn(ep.copy(), i, **self._kwargs) for fn in self._on_episodes] episode += 1 return step, episode diff --git a/dreamerv3/embodied/core/wrappers.py b/dreamerv3/embodied/core/wrappers.py index 63ed52be7..aa53f4b84 100644 --- a/dreamerv3/embodied/core/wrappers.py +++ b/dreamerv3/embodied/core/wrappers.py @@ -62,10 +62,20 @@ def __init__(self, env, key='action', low=-1, high=1): self._key = key self._low = low self._high = high + print(key, low, high, self.env.act_space[key].low, self.env.act_space[key].high) def step(self, action): - clipped = np.clip(action[self._key], self._low, self._high) - return self.env.step({**action, self._key: clipped}) + # Ugly ass way if action has multiple distributions + if "action" in action and isinstance(action["action"], dict): + clipped = np.clip(action["action"][self._key], self._low, self._high) + unpacked_action = {**action["action"], + **{k: v for k, v in action.items() if k != "action"}, + self._key: clipped} + return self.env.step(unpacked_action) + # return self.env.step({**action, "action": {**action["action"], self._key: clipped}}) + else: + clipped = np.clip(action[self._key], self._low, self._high) + return self.env.step({**action, self._key: clipped}) class NormalizeAction(base.Wrapper): @@ -111,7 +121,7 @@ def step(self, action): assert action[self._key].min() == 0.0, action assert action[self._key].max() == 1.0, action assert action[self._key].sum() == 1.0, action - index = np.argmax(action[self._key]) + index = np.argmax(action[self._key]).astype(np.int32) return self.env.step({**action, self._key: index}) @staticmethod diff --git a/dreamerv3/embodied/envs/dummy.py b/dreamerv3/embodied/envs/dummy.py index 5165a386b..95936d47e 100644 --- a/dreamerv3/embodied/envs/dummy.py +++ b/dreamerv3/embodied/envs/dummy.py @@ -5,7 +5,7 @@ class Dummy(embodied.Env): def __init__(self, task, size=(64, 64), length=100): - assert task in ('cont', 'disc') + assert task in ('cont', 'Discrete') self._task = task self._size = size self._length = length diff --git a/dreamerv3/embodied/replay/chunk.py b/dreamerv3/embodied/replay/chunk.py index d9e14a8d0..7a86e3e86 100644 --- a/dreamerv3/embodied/replay/chunk.py +++ b/dreamerv3/embodied/replay/chunk.py @@ -32,12 +32,38 @@ def __bool__(self): def append(self, step): if not self.data: - example = {k: embodied.convert(v) for k, v in step.items()} - self.data = { - k: np.empty((self.size,) + v.shape, v.dtype) - for k, v in example.items()} - for key, value in step.items(): - self.data[key][self.length] = value + if "action" in step and isinstance(step["action"], dict): + example = {} + for k, v in step.items(): + if isinstance(v, dict): + example[k] = {k2: embodied.convert(v2) for k2, v2 in v.items()} + else: + example[k] = embodied.convert(v) + + self.data = {} + for k, v in example.items(): + if isinstance(v, dict): + self.data[k] = { + k2: np.empty((self.size,) + v2.shape, v2.dtype) + for k2, v2 in v.items()} + else: + self.data[k] = np.empty((self.size,) + v.shape, v.dtype) + + for key, value in step.items(): + if isinstance(value, dict): + for k2, v2 in value.items(): + self.data[key][k2][self.length] = v2 + else: + self.data[key][self.length] = value + + else: + example = {k: embodied.convert(v) for k, v in step.items()} + self.data = { + k: np.empty((self.size,) + v.shape, v.dtype) + for k, v in example.items()} + for key, value in step.items(): + self.data[key][self.length] = value + self.length += 1 def save(self, directory): diff --git a/dreamerv3/embodied/replay/generic.py b/dreamerv3/embodied/replay/generic.py index 3ae64d53c..0c787b9d4 100644 --- a/dreamerv3/embodied/replay/generic.py +++ b/dreamerv3/embodied/replay/generic.py @@ -102,8 +102,28 @@ def _sample(self): seq = self.table[self.sampler()] else: seq = self.table[self.sampler()] - seq = {k: [step[k] for step in seq] for k in seq[0]} - seq = {k: embodied.convert(v) for k, v in seq.items()} + # seq = {k: [step[k] for step in seq] for k in seq[0]} + seq_ = {} + for k in seq[0]: + seq_[k] = [] + for step in seq: + if k in step: + seq_[k].append(step[k]) + elif "action" in step: + # I basically unpack the action dictionary into the sequence + seq_[k].append(step["action"][k]) + seq = seq_ + out = {} + for k, v in seq.items(): + if isinstance(v[0], dict): + for act_dict in v: + if k not in out: + out[k] = [] + out[k].append({k2: embodied.convert(v2) for k2, v2 in act_dict.items()}) + else: + out[k] = embodied.convert(v) + seq = out + # seq = {k: embodied.convert(v) for k, v in seq.items()} if 'is_first' in seq: seq['is_first'][0] = True return seq diff --git a/dreamerv3/embodied/run/train.py b/dreamerv3/embodied/run/train.py index 7b0cc5ab2..588d9188a 100644 --- a/dreamerv3/embodied/run/train.py +++ b/dreamerv3/embodied/run/train.py @@ -42,6 +42,21 @@ def per_episode(ep): for key in args.log_keys_video: if key in ep: stats[f'policy_{key}'] = ep[key] + + # Flip from Dict of Lists of Dicts to Dict of Dicts of Lists to Dict of Lists + ep_ = ep.copy() + for key, value in ep.items(): + if isinstance(value[0], dict): + collect = {} + for action in value: + for k2, v2 in action.items(): + if k2 not in collect: + collect[k2] = [] + collect[k2].append(list(v2) if v2.shape else np.float32(v2)) + for k2, v2 in collect.items(): + ep_[f"{key}_{k2}"] = np.array(collect[k2]) + del ep_[key] + ep = ep_ for key, value in ep.items(): if not args.log_zeros and key not in nonzeros and (value == 0).all(): continue @@ -109,3 +124,17 @@ def train_step(tran, worker): if should_save(step): checkpoint.save() logger.write() + + +def dict_inverter(d): + for key, value in d.items(): + if isinstance(value[0], dict): + collect = {} + for action in value: + for k2, v2 in action.items(): + if k2 not in collect: + collect[k2] = [] + collect[k2].append(v2) + print(f'Collecting {key} with keys {collect.keys()}') + d[key] = collect + value = collect \ No newline at end of file diff --git a/dreamerv3/jaxagent.py b/dreamerv3/jaxagent.py index 096da0c96..960c69638 100644 --- a/dreamerv3/jaxagent.py +++ b/dreamerv3/jaxagent.py @@ -225,6 +225,10 @@ def _init_varibs(self, obs_space, act_space): data = self._dummy_batch({**obs_space, **act_space}, dims) data = self._convert_inps(data, self.train_devices) state, varibs = self._init_train(varibs, rng, data['is_first']) + # if "action" not in data: + # data["action"] = {"Continous": data["Continous"], "Discrete": data["Discrete"]} + # data.pop("Continous") + # data.pop("Discrete") varibs = self._train(varibs, rng, data, state, init_only=True) # obs = self._dummy_batch(obs_space, (1,)) # state, varibs = self._init_policy(varibs, rng, obs['is_first']) diff --git a/dreamerv3/nets.py b/dreamerv3/nets.py index e47031abb..c43d79172 100644 --- a/dreamerv3/nets.py +++ b/dreamerv3/nets.py @@ -53,9 +53,21 @@ def initial(self, bs): def observe(self, embed, action, is_first, state=None): swap = lambda x: x.transpose([1, 0] + list(range(2, len(x.shape)))) if state is None: - state = self.initial(action.shape[0]) + if isinstance(action, dict): + if "Discrete" in action: + state = self.initial(action["Discrete"].shape[0]) + elif "Continuous" in action: + state = self.initial(action["Continuous"].shape[0]) + else: + raise ValueError(action) + else: + state = self.initial(action.shape[0]) + if isinstance(action, dict): + action = {k: swap(v) for k, v in action.items()} + else: + action = swap(action) step = lambda prev, inputs: self.obs_step(prev[0], *inputs) - inputs = swap(action), swap(embed), swap(is_first) + inputs = action, swap(embed), swap(is_first) start = state, state post, prior = jaxutils.scan(step, inputs, start, self._unroll) post = {k: swap(v) for k, v in post.items()} @@ -64,9 +76,19 @@ def observe(self, embed, action, is_first, state=None): def imagine(self, action, state=None): swap = lambda x: x.transpose([1, 0] + list(range(2, len(x.shape)))) - state = self.initial(action.shape[0]) if state is None else state + if isinstance(action, dict): #TODO: Make clearer + if state is None: + if "Discrete" in action: + state = self.initial(action["Discrete"].shape[0]) + elif "Continuous" in action: + state = self.initial(action["Continuous"].shape[0]) + else: + raise ValueError(action) + action = {k: swap(v) for k, v in action.items()} + else: + state = self.initial(action.shape[0]) if state is None else state + action = swap(action) assert isinstance(state, dict), state - action = swap(action) prior = jaxutils.scan(self.img_step, action, state, self._unroll) prior = {k: swap(v) for k, v in prior.items()} return prior @@ -82,6 +104,9 @@ def get_dist(self, state, argmax=False): def obs_step(self, prev_state, prev_action, embed, is_first): is_first = cast(is_first) + if type(prev_action) == dict: + # Here continous and discrete actions are concatenated to be passed in the world model + prev_action = jnp.concatenate([v for k, v in prev_action.items()], -1) prev_action = cast(prev_action) if self._action_clip > 0.0: prev_action *= sg(self._action_clip / jnp.maximum( @@ -101,10 +126,17 @@ def obs_step(self, prev_state, prev_action, embed, is_first): return cast(post), cast(prior) def img_step(self, prev_state, prev_action): + if isinstance(prev_action, dict): + prev_action = jnp.concatenate([v for k, v in prev_action.items()], -1) prev_stoch = prev_state['stoch'] prev_action = cast(prev_action) if self._action_clip > 0.0: - prev_action *= sg(self._action_clip / jnp.maximum( + if type(prev_action) == dict: + prev_action = jax.tree_util.tree_map( + lambda x: x*sg(self._action_clip / jnp.maximum( + self._action_clip, jnp.abs(x))), prev_action) + else: + prev_action *= sg(self._action_clip / jnp.maximum( self._action_clip, jnp.abs(prev_action))) if self._classes: shape = prev_stoch.shape[:-2] + (self._stoch * self._classes,) @@ -216,9 +248,14 @@ def __init__( def __call__(self, data): some_key, some_shape = list(self.shapes.items())[0] batch_dims = data[some_key].shape[:-len(some_shape)] - data = { - k: v.reshape((-1,) + v.shape[len(batch_dims):]) - for k, v in data.items()} + data_ = {} + for k, v in data.items(): + if isinstance(v, dict): + data_[k] = {k2: v2.reshape((-1,) + v2.shape[len(batch_dims):]) + for k2, v2 in v.items()} + else: + data_[k] = v.reshape((-1,) + v.shape[len(batch_dims):]) + data = data_ outputs = [] if self.cnn_shapes: inputs = jnp.concatenate([data[k] for k in self.cnn_shapes], -1) @@ -415,9 +452,14 @@ def __init__( self._inputs = Input(inputs, dims=dims) self._symlog_inputs = symlog_inputs distkeys = ( - 'dist', 'outscale', 'minstd', 'maxstd', 'outnorm', 'unimix', 'bins') + 'dist', 'outscale', 'minstd', 'maxstd', 'outnorm', 'unimix', 'bins', + "dist_cont", "dist_disc") self._dense = {k: v for k, v in kw.items() if k not in distkeys} self._dist = {k: v for k, v in kw.items() if k in distkeys} + from collections import defaultdict + self._dist_dict = defaultdict(lambda: self._dist["dist"]) + if 'dist_cont' in kw and "dist_disc" in kw: + self._dist_dict.update({"Continous": kw["dist_cont"], "Discrete": kw["dist_disc"]}) def __call__(self, inputs): feat = self._inputs(inputs) @@ -438,7 +480,7 @@ def __call__(self, inputs): raise ValueError(self._shape) def _out(self, name, shape, x): - return self.get(f'dist_{name}', Dist, shape, **self._dist)(x) + return self.get(f'dist_{name}', Dist, shape, dist=self._dist_dict[name])(x) # , **self._dist class Dist(nj.Module): diff --git a/example.py b/example.py index e9009a477..b90aabd31 100644 --- a/example.py +++ b/example.py @@ -4,21 +4,23 @@ def main(): import dreamerv3 from dreamerv3 import embodied warnings.filterwarnings('ignore', '.*truncated to dtype int32.*') + import gym +# from gym.wrappers import StepAPICompatibility # See configs.yaml for all options. config = embodied.Config(dreamerv3.configs['defaults']) - config = config.update(dreamerv3.configs['medium']) + config = config.update(dreamerv3.configs['small']) config = config.update({ - 'logdir': '~/logdir/run1', + 'logdir': '~/logdir/run_1', 'run.train_ratio': 64, 'run.log_every': 30, # Seconds 'batch_size': 16, 'jax.prealloc': False, - 'encoder.mlp_keys': '$^', - 'decoder.mlp_keys': '$^', - 'encoder.cnn_keys': 'image', - 'decoder.cnn_keys': 'image', - # 'jax.platform': 'cpu', + 'encoder.mlp_keys': '.*', + 'decoder.mlp_keys': '.*', + 'encoder.cnn_keys': '$^', + 'decoder.cnn_keys': '$^', + 'jax.platform': 'gpu', }) config = embodied.Flags(config).parse() @@ -32,10 +34,11 @@ def main(): # embodied.logger.MLFlowOutput(logdir.name), ]) - import crafter - from embodied.envs import from_gym - env = crafter.Env() # Replace this with your Gym env. - env = from_gym.FromGym(env, obs_key='image') # Or obs_key='vector'. + + from dreamerv3.embodied.envs import from_gym + env = gym.make('CartPole-v1') +# env.render(mode = "human") + env = from_gym.FromGym(env, obs_key='vector') # Or obs_key='vector'. env = dreamerv3.wrap_env(env, config) env = embodied.BatchEnv([env], parallel=False) diff --git a/requirements.txt b/requirements.txt index d5a83fe38..89170fac4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ cloudpickle crafter +# wheel==0.38.4 # For gym +# setuptools==65.5.0 # Needed for gym gym==0.19.0 jax jaxlib