Skip to content

Commit 124bf39

Browse files
btabacopybara-github
authored andcommitted
Fix enum member py version issue.
PiperOrigin-RevId: 829322405 Change-Id: I471abb204907e0337a53cd74877dd9e7d1ccc446
1 parent 84c7037 commit 124bf39

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

brax/training/agents/es/train.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import enum
2121
import functools
22+
import sys
2223
import time
2324
from typing import Any, Callable, Dict, Optional, Tuple
2425

@@ -40,6 +41,11 @@
4041
Metrics = types.Metrics
4142
InferenceParams = Tuple[running_statistics.NestedMeanStd, Params]
4243

44+
if sys.version_info >= (3, 11):
45+
from enum import member as enum_member
46+
else:
47+
enum_member = lambda x: x
48+
4349

4450
@flax.struct.dataclass
4551
class TrainingState:
@@ -67,9 +73,9 @@ def wierstra(x: jnp.ndarray) -> jnp.ndarray:
6773

6874

6975
class FitnessShaping(enum.Enum):
70-
ORIGINAL = enum.member(functools.partial(lambda x: x))
71-
CENTERED_RANK = enum.member(functools.partial(centered_rank))
72-
WIERSTRA = enum.member(functools.partial(wierstra))
76+
ORIGINAL = enum_member(functools.partial(lambda x: x))
77+
CENTERED_RANK = enum_member(functools.partial(centered_rank))
78+
WIERSTRA = enum_member(functools.partial(wierstra))
7379

7480

7581
# TODO(eorsini): Pass the network as argument.

0 commit comments

Comments
 (0)