Skip to content

Commit dcab2b4

Browse files
btabacopybara-github
authored andcommitted
Add donate_argnums to brax PPO. Avoids graph recaptures with MJX-Warp.
PiperOrigin-RevId: 816256474 Change-Id: I2c5573ca8c3ee7a279fbe160bbf30a62cfe7ff88
1 parent 3085074 commit dcab2b4

File tree

6 files changed

+53
-19
lines changed

6 files changed

+53
-19
lines changed

brax/training/acting.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def generate_unroll(
6161
) -> Tuple[State, Transition]:
6262
"""Collect trajectories of given unroll_length."""
6363

64-
@jax.jit
6564
def f(carry, unused_t):
6665
state, current_key = carry
6766
current_key, next_key = jax.random.split(current_key)
@@ -70,8 +69,9 @@ def f(carry, unused_t):
7069
)
7170
return (nstate, next_key), transition
7271

72+
f_jit = jax.jit(f, donate_argnums=(0,))
7373
(final_state, _), data = jax.lax.scan(
74-
f, (env_state, key), (), length=unroll_length
74+
f_jit, (env_state, key), (), length=unroll_length
7575
)
7676
return final_state, data
7777

@@ -111,9 +111,12 @@ def __init__(
111111
self._eval_walltime = 0.0
112112

113113
eval_env = envs.training.EvalWrapper(eval_env)
114+
self._eval_state_to_donate = jax.jit(eval_env.reset)(
115+
jax.random.split(key, num_eval_envs)
116+
)
114117

115118
def generate_eval_unroll(
116-
policy_params: PolicyParams, key: PRNGKey
119+
eval_env_state_donated: State, policy_params: PolicyParams, key: PRNGKey
117120
) -> State:
118121
reset_keys = jax.random.split(key, num_eval_envs)
119122
eval_first_state = eval_env.reset(reset_keys)
@@ -125,7 +128,9 @@ def generate_eval_unroll(
125128
unroll_length=episode_length // action_repeat,
126129
)[0]
127130

128-
self._generate_eval_unroll = jax.jit(generate_eval_unroll)
131+
self._generate_eval_unroll = jax.jit(
132+
generate_eval_unroll, donate_argnums=(0,), keep_unused=True
133+
)
129134
self._steps_per_unroll = episode_length * num_eval_envs
130135

131136
def run_evaluation(
@@ -138,7 +143,11 @@ def run_evaluation(
138143
self._key, unroll_key = jax.random.split(self._key)
139144

140145
t = time.time()
141-
eval_state = self._generate_eval_unroll(policy_params, unroll_key)
146+
eval_state = self._generate_eval_unroll(
147+
self._eval_state_to_donate, policy_params, unroll_key
148+
)
149+
self._eval_state_to_donate = eval_state
150+
142151
eval_metrics = eval_state.info['eval_metrics']
143152
eval_metrics.active_episodes.block_until_ready()
144153
epoch_eval_time = time.time() - t

brax/training/agents/ppo/train.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -394,16 +394,28 @@ def train(
394394
randomization_fn,
395395
)
396396

397-
if local_devices_to_use > 1 or use_pmap_on_reset:
398-
reset_fn = jax.pmap(env.reset, axis_name=_PMAP_AXIS_NAME)
399-
else:
400-
reset_fn = jax.jit(jax.vmap(env.reset))
397+
def reset_fn_donated_env_state(env_state_donated, key_envs):
398+
return env.reset(key_envs)
401399

402400
key_envs = jax.random.split(key_env, num_envs // process_count)
403401
key_envs = jnp.reshape(
404402
key_envs, (local_devices_to_use, -1) + key_envs.shape[1:]
405403
)
406-
env_state = reset_fn(key_envs)
404+
if local_devices_to_use > 1 or use_pmap_on_reset:
405+
reset_fn_ = jax.pmap(env.reset, axis_name=_PMAP_AXIS_NAME)
406+
env_state = reset_fn_(key_envs)
407+
reset_fn = jax.pmap(
408+
reset_fn_donated_env_state,
409+
axis_name=_PMAP_AXIS_NAME,
410+
donate_argnums=(0,),
411+
)
412+
else:
413+
reset_fn_ = jax.jit(jax.vmap(env.reset))
414+
env_state = reset_fn_(key_envs)
415+
reset_fn = jax.jit(
416+
reset_fn_donated_env_state, donate_argnums=(0,), keep_unused=True
417+
)
418+
407419
# Discard the batch axes over devices and envs.
408420
obs_shape = jax.tree_util.tree_map(lambda x: x.shape[2:], env_state.obs)
409421

@@ -611,7 +623,14 @@ def training_epoch(
611623
loss_metrics = jax.tree_util.tree_map(jnp.mean, loss_metrics)
612624
return training_state, state, loss_metrics
613625

614-
training_epoch = jax.pmap(training_epoch, axis_name=_PMAP_AXIS_NAME)
626+
training_epoch = jax.pmap(
627+
training_epoch,
628+
axis_name=_PMAP_AXIS_NAME,
629+
donate_argnums=(
630+
0,
631+
1,
632+
),
633+
)
615634

616635
# Note that this is NOT a pure jittable method.
617636
def training_epoch_with_timing(
@@ -755,7 +774,8 @@ def training_epoch_with_timing(
755774
lambda x, s: jax.random.split(x[0], s), in_axes=(0, None)
756775
)(key_envs, key_envs.shape[1])
757776
# TODO(brax-team): move extra reset logic to the AutoResetWrapper.
758-
env_state = reset_fn(key_envs) if num_resets_per_eval > 0 else env_state
777+
if num_resets_per_eval > 0:
778+
env_state = reset_fn((training_state, env_state), key_envs)
759779

760780
if process_id != 0:
761781
continue

brax/training/agents/ppo/train_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def testTrain(self, obs_mode):
5454
normalize_advantage=False,
5555
)
5656
self.assertGreater(metrics['eval/episode_reward'], 135)
57-
self.assertEqual(fast.reset_count, 2) # type: ignore
58-
self.assertEqual(fast.step_count, 2) # type: ignore
57+
self.assertEqual(fast.reset_count, 4) # type: ignore
58+
self.assertEqual(fast.step_count, 3) # type: ignore
5959

6060
@parameterized.parameters(
6161
('normal', 'scalar'),

brax/training/agents/sac/train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,8 @@ def f(carry, unused):
416416
)[0]
417417

418418
prefill_replay_buffer = jax.pmap(
419-
prefill_replay_buffer, axis_name=_PMAP_AXIS_NAME
419+
prefill_replay_buffer, axis_name=_PMAP_AXIS_NAME,
420+
donate_argnums=(0, 1, 2)
420421
)
421422

422423
def training_epoch(
@@ -441,7 +442,9 @@ def f(carry, unused_t):
441442
metrics = jax.tree_util.tree_map(jnp.mean, metrics)
442443
return training_state, env_state, buffer_state, metrics
443444

444-
training_epoch = jax.pmap(training_epoch, axis_name=_PMAP_AXIS_NAME)
445+
training_epoch = jax.pmap(
446+
training_epoch, axis_name=_PMAP_AXIS_NAME, donate_argnums=(0, 1, 2)
447+
)
445448

446449
# Note that this is NOT a pure jittable method.
447450
def training_epoch_with_timing(

brax/training/agents/sac/train_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,12 @@ def testTrain(self):
4545
grad_updates_per_step=64,
4646
num_evals=3,
4747
seed=0,
48+
eval_env=envs.get_environment('fast'),
4849
)
4950
self.assertGreater(metrics['eval/episode_reward'], 140 * 0.995)
50-
self.assertEqual(fast.reset_count, 3) # type: ignore
51-
# once for prefill, once for train, once for eval
52-
self.assertEqual(fast.step_count, 3) # type: ignore
51+
self.assertEqual(fast.reset_count, 2) # type: ignore
52+
# once for prefill, once for train
53+
self.assertEqual(fast.step_count, 2) # type: ignore
5354

5455
@parameterized.parameters(True, False)
5556
def testNetworkEncoding(self, normalize_observations):

docs/release-notes/next-release.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
* Allow episode metrics during eval to be normalized by the episode length, as long as the metric name ends with "per_step".
77
* Add adaptive learning rate to PPO. Desired KL is sensitive to network initialization weights and entropy cost and may require some tuning for your environment.
88
* Add loss metrics to the PPO training logger.
9+
* Add `donate_argnums` to brax PPO to somewhat mitigate repeated graph captures when using MJX-Warp.

0 commit comments

Comments
 (0)