Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions brax/envs/wrappers/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def reset(self, rng: jax.Array) -> State:
state = self.env.reset(rng)
state.info['first_pipeline_state'] = state.pipeline_state
state.info['first_obs'] = state.obs
state.info["obs_st"] = state.obs
return state

def step(self, state: State, action: jax.Array) -> State:
Expand All @@ -143,6 +144,9 @@ def step(self, state: State, action: jax.Array) -> State:
state = state.replace(done=jp.zeros_like(state.done))
state = self.env.step(state, action)

# Store next_obs before reset
obs_st = state.obs

def where_done(x, y):
done = state.done
if done.shape and done.shape[0] != x.shape[0]:
Expand All @@ -155,6 +159,7 @@ def where_done(x, y):
where_done, state.info['first_pipeline_state'], state.pipeline_state
)
obs = jax.tree.map(where_done, state.info['first_obs'], state.obs)
state.info["obs_st"] = obs_st
return state.replace(pipeline_state=pipeline_state, obs=obs)


Expand Down
34 changes: 34 additions & 0 deletions brax/envs/wrappers/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,40 @@

class TrainingTest(absltest.TestCase):

def test_autoreset_termination(self):
for env_id in ["ant", "halfcheetah"]:
with self.subTest(env_id=env_id):
self._run_termination(env_id)

def _run_termination(self, env_id):
env = envs.create(env_id)
key = jax.random.PRNGKey(42)
max_steps_in_episode = env.episode_length

state = jax.jit(env.reset)(key)
action = jp.zeros(env.sys.act_size())

env_step_fn = jax.jit(env.step)

def step_fn(state, _):
next_state = env_step_fn(state, action)
return next_state, (next_state.obs, next_state.done, next_state.info)

_, (observations, dones, infos) = jax.lax.scan(
f=step_fn, init=state, xs=None, length=max_steps_in_episode + 1
)

observations_step = infos["obs_st"]
# Should have at least finished once
assert sum(dones) >= 1
for i, (obs, done, obs_st) in enumerate(zip(observations, dones, observations_step)):
if done:
# Ensure we stored the last obs from finished episode, \\
# which differs from first obs of new episode
assert not jp.array_equal(obs_st, obs)
else:
assert jp.array_equal(obs_st, obs)

def test_domain_randomization_wrapper(self):
def rand(sys, rng):
@jax.vmap
Expand Down