Skip to content

Commit ab34392

Browse files
danielsuocopybara-github
authored andcommitted
[brax] Explicitly set jax_pmap_shmap_merge=False.
Part of a larger clean-up to prepare for migrating `jax.pmap` users to `jax.shard_map` or the new `jax.pmap` that's implemented using `jax.shard_map`. PiperOrigin-RevId: 811785897 Change-Id: I32b92af43d1bcf2605e15582c80c8e4f8df2d75d
1 parent 9f7c582 commit ab34392

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

brax/envs/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@
3232
from brax.envs import walker2d
3333
from brax.envs.base import Env, PipelineEnv, State, Wrapper
3434
from brax.envs.wrappers import training
35+
import jax
36+
37+
# NOTE(dsuo): Opt out of using the new `jax.pmap` implementation. The new
38+
# version is implemented using `jax.jit` and `jax.shard_map` and is causing
39+
# environment reset counts to be wrong.
40+
jax.config.update('jax_pmap_shmap_merge', False)
3541

3642
_envs = {
3743
'ant': ant.Ant,

0 commit comments

Comments
 (0)