Skip to content

Commit 22f86e0

Browse files
committed
add docstring to SAC train function
1 parent c340aa2 commit 22f86e0

File tree

1 file changed

+52
-1
lines changed

1 file changed

+52
-1
lines changed

brax/training/agents/sac/train.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,58 @@ def train(
145145
checkpoint_logdir: Optional[str] = None,
146146
restore_checkpoint_path: Optional[str] = None,
147147
):
148-
"""SAC training."""
148+
"""SAC training.
149+
150+
Args:
151+
environment: the environment to train
152+
num_timesteps: the total number of environment steps to use during training
153+
episode_length: the length of an environment episode
154+
wrap_env: If True, wrap the environment for training. Otherwise use the
155+
environment as is.
156+
wrap_env_fn: a custom function that wraps the environment for training. If
157+
not specified, the environment is wrapped with the default training
158+
wrapper.
159+
action_repeat: the number of timesteps to repeat an action
160+
num_envs: the number of parallel environments to use for rollouts
161+
NOTE: `num_envs` must be divisible by the total number of chips since each
162+
chip gets `num_envs // total_number_of_chips` environments to roll out
163+
NOTE: `batch_size * num_minibatches` must be divisible by `num_envs` since
164+
data generated by `num_envs` parallel envs gets used for gradient
165+
updates over `num_minibatches` of data, where each minibatch has a
166+
leading dimension of `batch_size`
167+
num_eval_envs: the number of envs to use for evluation. Each env will run 1
168+
episode, and all envs run in parallel during eval.
169+
learning_rate: learning rate for ppo loss
170+
discounting: discounting rate
171+
seed: random seed
172+
batch_size: the batch size for each minibatch SGD step
173+
num_evals: the number of evals to run during the entire training run.
174+
Increasing the number of evals increases total training time
175+
normalize_observations: whether to normalize observations
176+
max_devices_per_host: maximum number of chips to use per host process
177+
reward_scaling: float scaling for reward
178+
tau: interpolation factor in polyak averaging for target networks
179+
intial_alpha: initial value for the temperature parameter alpha
180+
min_replay_size: the minimum number of samples in the replay buffer before
181+
starting training. This is used to prefill the replay buffer with random
182+
samples before training starts
183+
max_replay_size: the maximum number of samples in the replay buffer. If None,
184+
the replay buffer will be filled with `num_timesteps` samples
185+
grad_updates_per_step: the number of gradient updates to run per actor step.
186+
deterministic_eval: whether to run the eval with a deterministic policy
187+
network_factory: function that generates networks for policy and value
188+
functions
189+
progress_fn: a user-defined callback function for reporting/plotting metrics
190+
eval_env: an optional environment for eval only, defaults to `environment`
191+
randomization_fn: a user-defined callback function that generates randomized
192+
environments
193+
checkpoint_logdir: the path used to save checkpoints. If None, no checkpoints
194+
are saved. The checkpoint will be saved every `num_evals` steps
195+
restore_checkpoint_path: the path used to restore previous model params
196+
197+
Returns:
198+
Tuple of (make_policy function, network params, metrics)
199+
"""
149200
process_id = jax.process_index()
150201
local_devices_to_use = jax.local_device_count()
151202
if max_devices_per_host is not None:

0 commit comments

Comments
 (0)