Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 145 additions & 35 deletions dreamerv3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import jax
import jax.numpy as jnp
import ruamel.yaml as yaml
from pathlib import Path
tree_map = jax.tree_util.tree_map
sg = lambda x: tree_map(jax.lax.stop_gradient, x)

Expand All @@ -23,12 +24,15 @@ def filter(self, record):
class Agent(nj.Module):

configs = yaml.YAML(typ='safe').load(
(embodied.Path(__file__).parent / 'configs.yaml').read())
(Path(__file__).parent / 'configs.yaml'))

def __init__(self, obs_space, act_space, step, config):
self.config = config
self.obs_space = obs_space
self.act_space = act_space['action']
try:
self.act_space = act_space['action']
except KeyError:
self.act_space = act_space
self.step = step
self.wm = WorldModel(obs_space, act_space, config, name='wm')
self.task_behavior = getattr(behaviors, config.task_behavior)(
Expand Down Expand Up @@ -60,16 +64,28 @@ def policy(self, obs, state, mode='train'):
expl_outs, expl_state = self.expl_behavior.policy(latent, expl_state)
if mode == 'eval':
outs = task_outs
outs['action'] = outs['action'].sample(seed=nj.rng())
outs['log_entropy'] = jnp.zeros(outs['action'].shape[:1])
if type(self.act_space) == dict:
outs['action'] = {k: w.mode() for k, w in outs['action'].items()}
outs['log_entropy'] = {k: w.entropy() for k, w in outs['action'].items()}
else:
outs['action'] = outs['action'].sample(seed=nj.rng())
outs['log_entropy'] = jnp.zeros(outs['action'].shape[:1])
elif mode == 'explore':
outs = expl_outs
outs['log_entropy'] = outs['action'].entropy()
outs['action'] = outs['action'].sample(seed=nj.rng())
if type(self.act_space) == dict:
outs['log_entropy'] = {k: w.entropy() for k, w in outs['action'].items()}
outs['action'] = {k: w.sample(seed=nj.rng()) for k, w in outs['action'].items()}
else:
outs['log_entropy'] = outs['action'].entropy()
outs['action'] = outs['action'].sample(seed=nj.rng())
elif mode == 'train':
outs = task_outs
outs['log_entropy'] = outs['action'].entropy()
outs['action'] = outs['action'].sample(seed=nj.rng())
if type(self.act_space) == dict:
outs['log_entropy'] = {k: w.entropy() for k, w in outs['action'].items()}
outs['action'] = {k: w.sample(seed=nj.rng()) for k, w in outs['action'].items()}
else:
outs['log_entropy'] = outs['action'].entropy()
outs['action'] = outs['action'].sample(seed=nj.rng())
state = ((latent, outs['action']), task_state, expl_state)
return outs, state

Expand Down Expand Up @@ -119,7 +135,11 @@ class WorldModel(nj.Module):

def __init__(self, obs_space, act_space, config):
self.obs_space = obs_space
self.act_space = act_space['action']
try:
self.act_space_shape = act_space['action'].shape
except KeyError:
self.act_space_shape = {k: v.shape for k, v in act_space.items() if k != 'reset'}
#tuple(a+b for a, b in zip(act_space['Continous'].shape, act_space['Discrete'].shape))
self.config = config
shapes = {k: tuple(v.shape) for k, v in obs_space.items()}
shapes = {k: v for k, v in shapes.items() if not k.startswith('log_')}
Expand All @@ -138,11 +158,18 @@ def __init__(self, obs_space, act_space, config):

def initial(self, batch_size):
prev_latent = self.rssm.initial(batch_size)
prev_action = jnp.zeros((batch_size, *self.act_space.shape))
if type(self.act_space_shape) == dict:
prev_action = {k: jnp.zeros((batch_size, *v)) for k, v in self.act_space_shape.items()}
else:
prev_action = jnp.zeros((batch_size, *self.act_space_shape))
return prev_latent, prev_action

def train(self, data, state):
modules = [self.encoder, self.rssm, *self.heads.values()]
if 'action' not in data:
data['action'] = {"Continous": data['Continous'], "Discrete": data['Discrete']}
data.pop('Continous')
data.pop('Discrete')
mets, (state, outs, metrics) = self.opt(
modules, self.loss, data, state, has_aux=True)
metrics.update(mets)
Expand All @@ -151,7 +178,23 @@ def train(self, data, state):
def loss(self, data, state):
embed = self.encoder(data)
prev_latent, prev_action = state
prev_actions = jnp.concatenate([
# Shape state: (prev_latent, prev_action), action can be a Dict if hybrid (Discrete + Continous)
if 'action' not in data:
if len(data['Continous'].shape) == 2:
data['Continous'] = data['Continous'][..., None]
if len(data['Discrete'].shape) == 2:
data['Discrete'] = data['Discrete'][..., None]

data['action'] = {"Continous": data['Continous'], "Discrete": data['Discrete']}
data.pop('Continous')
data.pop('Discrete')

if isinstance(data['action'], dict):
prev_actions = {
k: jnp.concatenate([prev_action[k][:, None],
data['action'][k][:, :-1]], 1) for k in data['action']}
else:
prev_actions = jnp.concatenate([
prev_action[:, None], data['action'][:, :-1]], 1)
post, prior = self.rssm.observe(
embed, prev_actions, data['is_first'], prev_latent)
Expand All @@ -173,7 +216,10 @@ def loss(self, data, state):
out = {'embed': embed, 'post': post, 'prior': prior}
out.update({f'{k}_loss': v for k, v in losses.items()})
last_latent = {k: v[:, -1] for k, v in post.items()}
last_action = data['action'][:, -1]
if isinstance(data['action'], dict):
last_action = {k: v[:, -1] for k, v in data['action'].items()}
else:
last_action = data['action'][:, -1]
state = last_latent, last_action
metrics = self._metrics(data, dists, post, prior, losses, model_loss)
return model_loss.mean(), (state, out, metrics)
Expand All @@ -189,8 +235,13 @@ def step(prev, _):
return {**state, 'action': policy(state)}
traj = jaxutils.scan(
step, jnp.arange(horizon), start, self.config.imag_unroll)
traj = {
k: jnp.concatenate([start[k][None], v], 0) for k, v in traj.items()}
# traj =
traj_ = {
k: jnp.concatenate([start[k][None], v], 0) for k, v in traj.items() if k != "action"}
Continous = jnp.concatenate([start["action"]["Continous"][None], traj["action"]["Continous"]], 0)
Discrete = jnp.concatenate([start["action"]["Discrete"][None], traj["action"]["Discrete"]], 0)
traj_["action"] = {"Continous": Continous, "Discrete": Discrete}
traj = traj_
cont = self.heads['cont'](traj).mode()
traj['cont'] = jnp.concatenate([first_cont[None], cont[1:]], 0)
discount = 1 - 1 / self.config.horizon
Expand All @@ -201,13 +252,21 @@ def report(self, data):
state = self.initial(len(data['is_first']))
report = {}
report.update(self.loss(data, state)[-1][-1])
if isinstance(data['action'], dict):
action_context, action_post = {}, {}
for k, v in data['action'].items():
action_context[k] = v[:6, :5]
action_post[k] = v[:6, 5:]
else:
action_context = data['action'][:6, :5]
action_post = data['action'][:6, 5:]
context, _ = self.rssm.observe(
self.encoder(data)[:6, :5], data['action'][:6, :5],
self.encoder(data)[:6, :5], action_context,
data['is_first'][:6, :5])
start = {k: v[:, -1] for k, v in context.items()}
recon = self.heads['decoder'](context)
openl = self.heads['decoder'](
self.rssm.imagine(data['action'][:6, 5:], start))
self.rssm.imagine(action_post, start))
for key in self.heads['decoder'].cnn_shapes.keys():
truth = data[key][:6].astype(jnp.float32)
model = jnp.concatenate([recon[key].mode()[:, :5], openl[key].mode()], 1)
Expand Down Expand Up @@ -246,11 +305,22 @@ def __init__(self, critics, scales, act_space, config):
self.scales = scales
self.act_space = act_space
self.config = config
disc = act_space.discrete
self.grad = config.actor_grad_disc if disc else config.actor_grad_cont
if type(self.act_space) == dict:
shape = {k: v.shape for k, v in act_space.items() if k != 'reset'}
Discrete = False
self.grad_disc = config.actor_grad_disc
self.grad_cont = config.actor_grad_cont
self.grad = False
else:
shape = act_space.shape
Discrete = act_space.discrete
self.grad = config.actor_grad_disc if Discrete else config.actor_grad_cont

self.actor = nets.MLP(
name='actor', dims='deter', shape=act_space.shape, **config.actor,
dist=config.actor_dist_disc if disc else config.actor_dist_cont)
name='actor', dims='deter', shape=shape, **config.actor,
dist=config.actor_dist_disc if Discrete else config.actor_dist_cont,
dist_cont=config.actor_dist_cont if not Discrete else None,
dist_disc=config.actor_dist_disc if not Discrete else None)
self.retnorms = {
k: jaxutils.Moments(**config.retnorm, name=f'retnorm_{k}')
for k in critics}
Expand All @@ -264,7 +334,10 @@ def policy(self, state, carry):

def train(self, imagine, start, context):
def loss(start):
policy = lambda s: self.actor(sg(s)).sample(seed=nj.rng())
if type(self.act_space) == dict:
policy = lambda s: {k: w.sample(seed=nj.rng()) for k, w in self.actor(s).items()}
else:
policy = lambda s: self.actor(sg(s)).sample(seed=nj.rng())
traj = imagine(policy, start, self.config.imag_horizon)
loss, metrics = self.loss(traj)
return loss, (traj, metrics)
Expand All @@ -291,9 +364,21 @@ def loss(self, traj):
metrics[f'{key}_return_rate'] = (jnp.abs(ret) >= 0.5).mean()
adv = jnp.stack(advs).sum(0)
policy = self.actor(sg(traj))
logpi = policy.log_prob(sg(traj['action']))[:-1]
loss = {'backprop': -adv, 'reinforce': -logpi * sg(adv)}[self.grad]
ent = policy.entropy()[:-1]
if type(self.act_space) == dict:
logpi = {k: w.log_prob(sg(traj['action'][k]))[:-1] for k, w in policy.items()}
logpi = sum(logpi.values())
else:
logpi = policy.log_prob(sg(traj['action']))[:-1]
loss = {'backprop': -adv, 'reinforce': -logpi * sg(adv)}
if self.grad:
loss = loss[self.grad]
else:
loss = sum(loss.values())
if type(self.act_space) == dict:
ent = {k: w.entropy()[:-1] for k, w in policy.items()}
ent = sum(ent.values())
else:
ent = policy.entropy()[:-1]
loss -= self.config.actent * ent
loss *= sg(traj['weight'])[:-1]
loss *= self.config.loss_scales.actor
Expand All @@ -302,14 +387,29 @@ def loss(self, traj):

def _metrics(self, traj, policy, logpi, ent, adv):
metrics = {}
ent = policy.entropy()[:-1]
rand = (ent - policy.minent) / (policy.maxent - policy.minent)
rand = rand.mean(range(2, len(rand.shape)))
act = traj['action']
act = jnp.argmax(act, -1) if self.act_space.discrete else act
metrics.update(jaxutils.tensorstats(act, 'action'))
metrics.update(jaxutils.tensorstats(rand, 'policy_randomness'))
metrics.update(jaxutils.tensorstats(ent, 'policy_entropy'))
if type(self.act_space) == dict:
ent = {k: w.entropy()[:-1] for k, w in policy.items()}
rand = {k: (ent[k] - w.minent) / (w.maxent - w.minent) for k, w in policy.items()}
act = {k: jnp.argmax(traj['action'][k], -1) for k in traj['action']}
# act = jnp.concatenate([act[k] for k in traj['action']], -1)

for k in traj['action']:
metrics.update(jaxutils.tensorstats(act[k], f'action_{k}'))
for k in rand:
metrics.update(jaxutils.tensorstats(rand[k], f'policy_randomness_{k}'))
for k in ent:
metrics.update(jaxutils.tensorstats(ent[k], f'policy_entropy_{k}'))

else:
ent = policy.entropy()[:-1]
rand = (ent - policy.minent) / (policy.maxent - policy.minent)
rand = rand.mean(range(2, len(rand.shape)))
act = traj['action']
act = jnp.argmax(act, -1) if self.act_space.discrete else act

metrics.update(jaxutils.tensorstats(act, 'action'))
metrics.update(jaxutils.tensorstats(rand, 'policy_randomness'))
metrics.update(jaxutils.tensorstats(ent, 'policy_entropy'))
metrics.update(jaxutils.tensorstats(logpi, 'policy_logprob'))
metrics.update(jaxutils.tensorstats(adv, 'adv'))
metrics['imag_weight_dist'] = jaxutils.subsample(traj['weight'])
Expand Down Expand Up @@ -338,7 +438,13 @@ def train(self, traj, actor):

def loss(self, traj, target):
metrics = {}
traj = {k: v[:-1] for k, v in traj.items()}
if type(traj['action']) == dict:
action_continuous = traj['action']["Continous"]
action_discrete = traj['action']["Discrete"]
traj = {k: v[:-1] for k, v in traj.items() if k != 'action'}
traj["action"] = {"Continous": action_continuous, "Discrete": action_discrete}
else:
traj = {k: v[:-1] for k, v in traj.items()}
dist = self.net(traj)
loss = -dist.log_prob(sg(target))
if self.config.critic_slowreg == 'logprob':
Expand All @@ -358,8 +464,12 @@ def loss(self, traj, target):

def score(self, traj, actor=None):
rew = self.rewfn(traj)
assert len(rew) == len(traj['action']) - 1, (
'should provide rewards for all but last action')
if type(traj['action']) == dict:
action_continuous = traj['action']["Continous"]
action_discrete = traj['action']["Discrete"]
else:
assert len(rew) == len(traj['action']) - 1, (
'should provide rewards for all but last action')
discount = 1 - 1 / self.config.horizon
disc = traj['cont'][1:] * discount
value = self.net(traj).mean()
Expand Down
11 changes: 9 additions & 2 deletions dreamerv3/embodied/core/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,18 @@ def __len__(self):
return len(self._envs)

def step(self, action):
assert all(len(v) == len(self._envs) for v in action.values()), (
assert all(len(v) == len(self._envs) for v in action.values() if not isinstance(v, dict)), (
len(self._envs), {k: v.shape for k, v in action.items()})
assert all(len(v2) == len(self._envs) for v in action.values() if isinstance(v, dict) for v2 in v.values()), (
len(self._envs), {k: v.shape for k, v in action.items()})
obs = []
for i, env in enumerate(self._envs):
act = {k: v[i] for k, v in action.items()}
act = {}
for k, v in action.items():
if isinstance(v, dict):
act[k] = {k2: v2[i] for k2, v2 in v.items()}
else:
act[k] = v[i]
obs.append(env.step(act))
if self._parallel:
obs = [ob() for ob in obs]
Expand Down
10 changes: 10 additions & 0 deletions dreamerv3/embodied/core/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ def _batcher(self, sources, output):
if isinstance(elem, Exception):
raise elem
batch = {k: np.stack([x[k] for x in elems], 0) for k in elems[0]}
# batch = {}
# for k in elems[0]:
# out = []
# for x in elems:
# if isinstance(x[k][0], dict):
# out.append({kk: np.stack([xx[kk] for xx in x[k]], 0) for kk in x[k][0]})
# else:
# out.append(x[k])
# batch[k] = np.stack(out, 0)

if self._postprocess:
batch = self._postprocess(batch)
output.put(batch) # Will wait here if the queue is full.
Expand Down
30 changes: 29 additions & 1 deletion dreamerv3/embodied/core/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,21 @@ def __call__(self, policy, steps=0, episodes=0):
step, episode = self._step(policy, step, episode)

def _step(self, policy, step, episode):
assert all(len(x) == len(self._env) for x in self._acts.values())
assert all(len(x) == len(self._env) for x in self._acts.values() if not isinstance(x, dict))
assert all(len(x2) == len(self._env) for x in self._acts.values() if isinstance(x, dict) for x2 in x.values() )
acts = {k: v for k, v in self._acts.items() if not k.startswith('log_')}
obs = self._env.step(acts)
obs = {k: convert(v) for k, v in obs.items()}
assert all(len(x) == len(self._env) for x in obs.values()), obs
acts, self._state = policy(obs, self._state, **self._kwargs)
if "action" in acts and isinstance(acts["action"], dict):
for k, v in acts["action"].items():
acts[k] = v
del acts["action"]
if "log_entropy" in acts:
for k, v in acts["log_entropy"].items():
acts[f"log_entropy_{k}"] = v
del acts["log_entropy"]
acts = {k: convert(v) for k, v in acts.items()}
if obs['is_last'].any():
mask = 1 - obs['is_last']
Expand All @@ -60,6 +69,15 @@ def _step(self, policy, step, episode):
if first:
self._eps[i].clear()
for i in range(len(self._env)):
# if "action" in trns and isinstance(trns["action"], dict):
# out = {}
# for k, v in trns.items():
# if type(v) is not dict:
# out[k] = v[i]
# else:
# out[k] = {k2: v2[i] for k2, v2 in v.items()}
# trn = out
# else:
trn = {k: v[i] for k, v in trns.items()}
[self._eps[i][k].append(v) for k, v in trn.items()]
[fn(trn, i, **self._kwargs) for fn in self._on_steps]
Expand All @@ -68,6 +86,16 @@ def _step(self, policy, step, episode):
for i, done in enumerate(obs['is_last']):
if done:
ep = {k: convert(v) for k, v in self._eps[i].items()}
# ep = {}
# for k, v in self._eps[i].items():
# if isinstance(v[0], dict): # Action is a list of dicts
# if k not in ep:
# ep[k] = []
# for act_dict in v:
# ep[k].append({k2: convert(v2) for k2, v2 in act_dict.items()})
# else:
# ep[k] = convert(v)
# ep = {k: convert(v) for k, v in self._eps[i].items() if not isinstance(v[0], dict)}
[fn(ep.copy(), i, **self._kwargs) for fn in self._on_episodes]
episode += 1
return step, episode
Expand Down
Loading