Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 66 additions & 27 deletions scripts/train_offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,9 @@ def build_runner(algo_name: str, cfg: DictConfig):
"expected 'one_tick'"
)
verbose_metrics = bool(getattr(cfg.training, "verbose_metrics", False))
if cfg.training.num_gpus > 1:
if algo_name == "flashsac":
raise ValueError("FlashSAC does not support training.num_gpus > 1")
raise ValueError("cpu_pinned_double_buffer is currently single-GPU only")
num_gpus = int(getattr(cfg.training, "num_gpus", 1))
if num_gpus > 1 and algo_name != "sac":
raise ValueError("Only SAC supports training.num_gpus > 1 in this validation round")

if cfg.training.no_sync_collection:
raise ValueError("cpu_pinned_double_buffer requires synchronized collection")
Expand Down Expand Up @@ -239,30 +238,70 @@ def build_runner(algo_name: str, cfg: DictConfig):
_algo_type = str(_custom_runtime.algo_type)
_actor_kwargs = dict(_learner_extra_kwargs)

_learner = _learner_cls(
obs_dim=_obs_dim,
action_dim=_action_dim,
device=_device,
gamma=cfg.algo.gamma,
tau=cfg.algo.tau,
actor_lr=cfg.algo.actor_lr,
critic_lr=cfg.algo.critic_lr,
alpha_lr=cfg.algo.algo_params.alpha_lr,
alpha_init=cfg.algo.algo_params.alpha_init,
target_entropy_ratio=cfg.algo.algo_params.target_entropy_ratio,
actor_hidden_dim=cfg.algo.actor_hidden_dim,
critic_hidden_dim=cfg.algo.critic_hidden_dim,
num_atoms=cfg.algo.num_atoms,
use_layer_norm=cfg.algo.use_layer_norm,
max_grad_norm=cfg.algo.algo_params.max_grad_norm,
use_amp=cfg.training.use_amp,
amp_dtype=cfg.algo.algo_params.amp_dtype,
use_compile=cfg.algo.algo_params.use_compile,
use_symmetry=cfg.algo.use_symmetry,
symmetry_augmentation=_symmetry_aug,
critic_obs_dim=_critic_dim,
_learner_kwargs = {
"obs_dim": _obs_dim,
"action_dim": _action_dim,
"gamma": cfg.algo.gamma,
"tau": cfg.algo.tau,
"actor_lr": cfg.algo.actor_lr,
"critic_lr": cfg.algo.critic_lr,
"alpha_lr": cfg.algo.algo_params.alpha_lr,
"alpha_init": cfg.algo.algo_params.alpha_init,
"target_entropy_ratio": cfg.algo.algo_params.target_entropy_ratio,
"actor_hidden_dim": cfg.algo.actor_hidden_dim,
"critic_hidden_dim": cfg.algo.critic_hidden_dim,
"num_atoms": cfg.algo.num_atoms,
"use_layer_norm": cfg.algo.use_layer_norm,
"max_grad_norm": cfg.algo.algo_params.max_grad_norm,
"use_amp": cfg.training.use_amp,
"amp_dtype": cfg.algo.algo_params.amp_dtype,
"use_compile": cfg.algo.algo_params.use_compile,
"use_symmetry": cfg.algo.use_symmetry,
"symmetry_augmentation": _symmetry_aug,
"critic_obs_dim": _critic_dim,
**_learner_extra_kwargs,
)
}
_learner = _learner_cls(device=_device, **_learner_kwargs)

if num_gpus > 1:
from unilab.algos.torch.offpolicy.multi_gpu_runner import MultiGPUOffPolicyRunner

if cfg.algo.obs_normalization:
raise ValueError(
"SAC multi-GPU validation currently requires algo.obs_normalization=false"
)
if not str(_device).startswith("cuda"):
raise ValueError("SAC multi-GPU training requires a CUDA device")
return MultiGPUOffPolicyRunner(
learner=_learner,
env_name=cfg.training.task_name,
algo_type=_algo_type,
learner_cls=_learner_cls,
learner_kwargs=_learner_kwargs,
num_gpus=num_gpus,
distributed_backend="nccl",
num_envs=cfg.algo.num_envs,
replay_buffer_n=cfg.algo.replay_buffer_n,
batch_size=_batch_size,
learning_starts=cfg.algo.learning_starts,
updates_per_step=cfg.algo.updates_per_step,
policy_frequency=cfg.algo.policy_frequency,
sync_collection=True,
env_steps_per_sync=cfg.training.env_steps_per_sync,
device=_device,
actor_hidden_dim=cfg.algo.actor_hidden_dim,
use_layer_norm=cfg.algo.use_layer_norm,
obs_normalization=False,
sim_backend=cfg.training.sim_backend,
env_cfg_override=env_cfg_override,
actor_kwargs=_actor_kwargs,
trace_enabled=cfg.training.trace_enabled,
trace_output_dir=cfg.training.trace_output_dir,
trace_thread_time=cfg.training.trace_thread_time,
trace_cuda_events=cfg.training.trace_cuda_events,
seed=cfg.algo.seed,
nan_guard_cfg=_nan_guard_cfg,
)

return DoubleBufferOffPolicyRunner(
learner=_learner,
Expand Down
9 changes: 9 additions & 0 deletions src/unilab/algos/torch/fast_sac/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,15 @@ def _reduce_gradients(self, model: nn.Module) -> None:
p.grad.copy_(flat[offset : offset + n].view_as(p.grad))
offset += n

def sync_initial_parameters(self, src: int = 0) -> None:
"""Broadcast initial learner state for distributed off-policy training."""
if self.world_size <= 1:
return
for module in (self.actor, self.qnet, self.qnet_target):
for parameter in module.parameters():
dist.broadcast(parameter.data, src=src)
dist.broadcast(self.log_alpha.data, src=src)

def _get_actions_and_log_probs_for_critic(
self,
actor_obs: torch.Tensor,
Expand Down
Loading
Loading