Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions brax/training/agents/sac/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,11 @@ def _init_training_state(
alpha_optimizer: optax.GradientTransformation,
policy_optimizer: optax.GradientTransformation,
q_optimizer: optax.GradientTransformation,
initial_alpha: float = 1.0,
) -> TrainingState:
"""Inits the training state and replicates it over devices."""
key_policy, key_q = jax.random.split(key)
log_alpha = jnp.asarray(0.0, dtype=jnp.float32)
log_alpha = jnp.asarray(jnp.log(initial_alpha), dtype=jnp.float32)
alpha_optimizer_state = alpha_optimizer.init(log_alpha)

policy_params = sac_network.policy_network.init(key_policy)
Expand Down Expand Up @@ -128,6 +129,7 @@ def train(
max_devices_per_host: Optional[int] = None,
reward_scaling: float = 1.0,
tau: float = 0.005,
initial_alpha: float = 1.0,
min_replay_size: int = 0,
max_replay_size: Optional[int] = None,
grad_updates_per_step: int = 1,
Expand All @@ -143,7 +145,54 @@ def train(
checkpoint_logdir: Optional[str] = None,
restore_checkpoint_path: Optional[str] = None,
):
"""SAC training."""
"""SAC training.

Args:
environment: the environment to train
num_timesteps: the total number of environment steps to use during training
episode_length: the length of an environment episode
wrap_env: If True, wrap the environment for training. Otherwise use the
environment as is.
wrap_env_fn: a custom function that wraps the environment for training. If
not specified, the environment is wrapped with the default training
wrapper.
action_repeat: the number of timesteps to repeat an action
num_envs: the number of parallel environments to use for rollouts
NOTE: `num_envs` must be divisible by the total number of chips since each
chip gets `num_envs // total_number_of_chips` environments to roll out
num_eval_envs: the number of envs to use for evluation. Each env will run 1
episode, and all envs run in parallel during eval.
learning_rate: learning rate for SAC loss
discounting: discounting rate
seed: random seed
batch_size: the batch size for each minibatch SGD step
num_evals: the number of evals to run during the entire training run.
Increasing the number of evals increases total training time
normalize_observations: whether to normalize observations
max_devices_per_host: maximum number of chips to use per host process
reward_scaling: float scaling for reward
tau: interpolation factor in polyak averaging for target networks
initial_alpha: initial value for the temperature parameter α
min_replay_size: the minimum number of samples in the replay buffer before
starting training. This is used to prefill the replay buffer with random
samples before training starts
max_replay_size: the maximum number of samples in the replay buffer. If None,
the replay buffer will be filled with `num_timesteps` samples
grad_updates_per_step: the number of gradient updates to run per actor step.
deterministic_eval: whether to run the eval with a deterministic policy
network_factory: function that generates networks for policy and value
functions
progress_fn: a user-defined callback function for reporting/plotting metrics
eval_env: an optional environment for eval only, defaults to `environment`
randomization_fn: a user-defined callback function that generates randomized
environments
checkpoint_logdir: the path used to save checkpoints. If None, no checkpoints
are saved. The checkpoint will be saved every `num_evals` steps
restore_checkpoint_path: the path used to restore previous model params

Returns:
Tuple of (make_policy function, network params, metrics)
"""
process_id = jax.process_index()
local_devices_to_use = jax.local_device_count()
if max_devices_per_host is not None:
Expand Down Expand Up @@ -485,6 +534,7 @@ def training_epoch_with_timing(
alpha_optimizer=alpha_optimizer,
policy_optimizer=policy_optimizer,
q_optimizer=q_optimizer,
initial_alpha=initial_alpha,
)
del global_key

Expand Down