Skip to content

Commit 9453c7c

Browse files
committed
convert alpha to log value
1 parent 74aefc9 commit 9453c7c

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

brax/training/agents/sac/train.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ def _init_training_state(
7878
alpha_optimizer: optax.GradientTransformation,
7979
policy_optimizer: optax.GradientTransformation,
8080
q_optimizer: optax.GradientTransformation,
81-
initial_alpha: float = 0.0,
81+
initial_alpha: float = 1.0,
8282
) -> TrainingState:
8383
"""Inits the training state and replicates it over devices."""
8484
key_policy, key_q = jax.random.split(key)
85-
log_alpha = jnp.asarray(initial_alpha, dtype=jnp.float32)
85+
log_alpha = jnp.asarray(jnp.log(initial_alpha), dtype=jnp.float32)
8686
alpha_optimizer_state = alpha_optimizer.init(log_alpha)
8787

8888
policy_params = sac_network.policy_network.init(key_policy)
@@ -160,10 +160,6 @@ def train(
160160
num_envs: the number of parallel environments to use for rollouts
161161
NOTE: `num_envs` must be divisible by the total number of chips since each
162162
chip gets `num_envs // total_number_of_chips` environments to roll out
163-
NOTE: `batch_size * num_minibatches` must be divisible by `num_envs` since
164-
data generated by `num_envs` parallel envs gets used for gradient
165-
updates over `num_minibatches` of data, where each minibatch has a
166-
leading dimension of `batch_size`
167163
num_eval_envs: the number of envs to use for evluation. Each env will run 1
168164
episode, and all envs run in parallel during eval.
169165
learning_rate: learning rate for SAC loss
@@ -176,7 +172,7 @@ def train(
176172
max_devices_per_host: maximum number of chips to use per host process
177173
reward_scaling: float scaling for reward
178174
tau: interpolation factor in polyak averaging for target networks
179-
intial_alpha: initial value for the temperature parameter alpha
175+
initial_alpha: initial value for the temperature parameter α
180176
min_replay_size: the minimum number of samples in the replay buffer before
181177
starting training. This is used to prefill the replay buffer with random
182178
samples before training starts

0 commit comments

Comments
 (0)