@@ -78,11 +78,11 @@ def _init_training_state(
7878 alpha_optimizer : optax .GradientTransformation ,
7979 policy_optimizer : optax .GradientTransformation ,
8080 q_optimizer : optax .GradientTransformation ,
81- initial_alpha : float = 0 .0 ,
81+ initial_alpha : float = 1 .0 ,
8282) -> TrainingState :
8383 """Inits the training state and replicates it over devices."""
8484 key_policy , key_q = jax .random .split (key )
85- log_alpha = jnp .asarray (initial_alpha , dtype = jnp .float32 )
85+ log_alpha = jnp .asarray (jnp . log ( initial_alpha ) , dtype = jnp .float32 )
8686 alpha_optimizer_state = alpha_optimizer .init (log_alpha )
8787
8888 policy_params = sac_network .policy_network .init (key_policy )
@@ -160,10 +160,6 @@ def train(
160160 num_envs: the number of parallel environments to use for rollouts
161161 NOTE: `num_envs` must be divisible by the total number of chips since each
162162 chip gets `num_envs // total_number_of_chips` environments to roll out
163- NOTE: `batch_size * num_minibatches` must be divisible by `num_envs` since
164- data generated by `num_envs` parallel envs gets used for gradient
165- updates over `num_minibatches` of data, where each minibatch has a
166- leading dimension of `batch_size`
167163 num_eval_envs: the number of envs to use for evluation. Each env will run 1
168164 episode, and all envs run in parallel during eval.
169165 learning_rate: learning rate for SAC loss
@@ -176,7 +172,7 @@ def train(
176172 max_devices_per_host: maximum number of chips to use per host process
177173 reward_scaling: float scaling for reward
178174 tau: interpolation factor in polyak averaging for target networks
179- intial_alpha : initial value for the temperature parameter alpha
175+ initial_alpha : initial value for the temperature parameter α
180176 min_replay_size: the minimum number of samples in the replay buffer before
181177 starting training. This is used to prefill the replay buffer with random
182178 samples before training starts
0 commit comments