@@ -145,7 +145,58 @@ def train(
145145 checkpoint_logdir : Optional [str ] = None ,
146146 restore_checkpoint_path : Optional [str ] = None ,
147147):
148- """SAC training."""
148+ """SAC training.
149+
150+ Args:
151+ environment: the environment to train
152+ num_timesteps: the total number of environment steps to use during training
153+ episode_length: the length of an environment episode
154+ wrap_env: If True, wrap the environment for training. Otherwise use the
155+ environment as is.
156+ wrap_env_fn: a custom function that wraps the environment for training. If
157+ not specified, the environment is wrapped with the default training
158+ wrapper.
159+ action_repeat: the number of timesteps to repeat an action
160+ num_envs: the number of parallel environments to use for rollouts
161+ NOTE: `num_envs` must be divisible by the total number of chips since each
162+ chip gets `num_envs // total_number_of_chips` environments to roll out
163+ NOTE: `batch_size * num_minibatches` must be divisible by `num_envs` since
164+ data generated by `num_envs` parallel envs gets used for gradient
165+ updates over `num_minibatches` of data, where each minibatch has a
166+ leading dimension of `batch_size`
167+ num_eval_envs: the number of envs to use for evluation. Each env will run 1
168+ episode, and all envs run in parallel during eval.
169+ learning_rate: learning rate for ppo loss
170+ discounting: discounting rate
171+ seed: random seed
172+ batch_size: the batch size for each minibatch SGD step
173+ num_evals: the number of evals to run during the entire training run.
174+ Increasing the number of evals increases total training time
175+ normalize_observations: whether to normalize observations
176+ max_devices_per_host: maximum number of chips to use per host process
177+ reward_scaling: float scaling for reward
178+ tau: interpolation factor in polyak averaging for target networks
179+ intial_alpha: initial value for the temperature parameter alpha
180+ min_replay_size: the minimum number of samples in the replay buffer before
181+ starting training. This is used to prefill the replay buffer with random
182+ samples before training starts
183+ max_replay_size: the maximum number of samples in the replay buffer. If None,
184+ the replay buffer will be filled with `num_timesteps` samples
185+ grad_updates_per_step: the number of gradient updates to run per actor step.
186+ deterministic_eval: whether to run the eval with a deterministic policy
187+ network_factory: function that generates networks for policy and value
188+ functions
189+ progress_fn: a user-defined callback function for reporting/plotting metrics
190+ eval_env: an optional environment for eval only, defaults to `environment`
191+ randomization_fn: a user-defined callback function that generates randomized
192+ environments
193+ checkpoint_logdir: the path used to save checkpoints. If None, no checkpoints
194+ are saved. The checkpoint will be saved every `num_evals` steps
195+ restore_checkpoint_path: the path used to restore previous model params
196+
197+ Returns:
198+ Tuple of (make_policy function, network params, metrics)
199+ """
149200 process_id = jax .process_index ()
150201 local_devices_to_use = jax .local_device_count ()
151202 if max_devices_per_host is not None :
0 commit comments