Skip to content

Commit 3085074

Browse files
btabacopybara-github
authored andcommitted
Update for new pmap.
PiperOrigin-RevId: 815939422 Change-Id: Ife3b7328864884512183301d9685ad6d9bc69dcb
1 parent ab34392 commit 3085074

File tree

1 file changed

+0
-6
lines changed

1 file changed

+0
-6
lines changed

brax/envs/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,6 @@
3232
from brax.envs import walker2d
3333
from brax.envs.base import Env, PipelineEnv, State, Wrapper
3434
from brax.envs.wrappers import training
35-
import jax
36-
37-
# NOTE(dsuo): Opt out of using the new `jax.pmap` implementation. The new
38-
# version is implemented using `jax.jit` and `jax.shard_map` and is causing
39-
# environment reset counts to be wrong.
40-
jax.config.update('jax_pmap_shmap_merge', False)
4135

4236
_envs = {
4337
'ant': ant.Ant,

0 commit comments

Comments
 (0)