From 75ccb931fc292f20e6e8980a25b3cd97caac2a69 Mon Sep 17 00:00:00 2001 From: tatp-yf Date: Sun, 21 Jun 2026 22:58:02 +0800 Subject: [PATCH] fix: support SAC multi-GPU replay IPC --- scripts/train_offpolicy.py | 93 ++- src/unilab/algos/torch/fast_sac/learner.py | 9 + .../algos/torch/offpolicy/multi_gpu_runner.py | 137 +++- src/unilab/algos/torch/offpolicy/worker.py | 734 +++++++++++------- .../replay_pipelines/multi_gpu_cpu_pinned.py | 360 +++++++++ .../algos/test_fast_sac_symmetry_contract.py | 15 + .../test_offpolicy_double_buffer_runner.py | 6 +- tests/algos/test_offpolicy_runner_unit.py | 136 ++-- tests/ipc/test_multi_gpu_replay_pack.py | 313 ++++++++ tests/scripts/test_train_scripts.py | 12 +- 10 files changed, 1391 insertions(+), 424 deletions(-) create mode 100644 src/unilab/ipc/replay_pipelines/multi_gpu_cpu_pinned.py create mode 100644 tests/ipc/test_multi_gpu_replay_pack.py diff --git a/scripts/train_offpolicy.py b/scripts/train_offpolicy.py index c71de8035..64b2cba4e 100644 --- a/scripts/train_offpolicy.py +++ b/scripts/train_offpolicy.py @@ -168,10 +168,9 @@ def build_runner(algo_name: str, cfg: DictConfig): "expected 'one_tick'" ) verbose_metrics = bool(getattr(cfg.training, "verbose_metrics", False)) - if cfg.training.num_gpus > 1: - if algo_name == "flashsac": - raise ValueError("FlashSAC does not support training.num_gpus > 1") - raise ValueError("cpu_pinned_double_buffer is currently single-GPU only") + num_gpus = int(getattr(cfg.training, "num_gpus", 1)) + if num_gpus > 1 and algo_name != "sac": + raise ValueError("Only SAC supports training.num_gpus > 1 in this validation round") if cfg.training.no_sync_collection: raise ValueError("cpu_pinned_double_buffer requires synchronized collection") @@ -239,30 +238,70 @@ def build_runner(algo_name: str, cfg: DictConfig): _algo_type = str(_custom_runtime.algo_type) _actor_kwargs = dict(_learner_extra_kwargs) - _learner = _learner_cls( - obs_dim=_obs_dim, - action_dim=_action_dim, - device=_device, - gamma=cfg.algo.gamma, - tau=cfg.algo.tau, - actor_lr=cfg.algo.actor_lr, - critic_lr=cfg.algo.critic_lr, - alpha_lr=cfg.algo.algo_params.alpha_lr, - alpha_init=cfg.algo.algo_params.alpha_init, - target_entropy_ratio=cfg.algo.algo_params.target_entropy_ratio, - actor_hidden_dim=cfg.algo.actor_hidden_dim, - critic_hidden_dim=cfg.algo.critic_hidden_dim, - num_atoms=cfg.algo.num_atoms, - use_layer_norm=cfg.algo.use_layer_norm, - max_grad_norm=cfg.algo.algo_params.max_grad_norm, - use_amp=cfg.training.use_amp, - amp_dtype=cfg.algo.algo_params.amp_dtype, - use_compile=cfg.algo.algo_params.use_compile, - use_symmetry=cfg.algo.use_symmetry, - symmetry_augmentation=_symmetry_aug, - critic_obs_dim=_critic_dim, + _learner_kwargs = { + "obs_dim": _obs_dim, + "action_dim": _action_dim, + "gamma": cfg.algo.gamma, + "tau": cfg.algo.tau, + "actor_lr": cfg.algo.actor_lr, + "critic_lr": cfg.algo.critic_lr, + "alpha_lr": cfg.algo.algo_params.alpha_lr, + "alpha_init": cfg.algo.algo_params.alpha_init, + "target_entropy_ratio": cfg.algo.algo_params.target_entropy_ratio, + "actor_hidden_dim": cfg.algo.actor_hidden_dim, + "critic_hidden_dim": cfg.algo.critic_hidden_dim, + "num_atoms": cfg.algo.num_atoms, + "use_layer_norm": cfg.algo.use_layer_norm, + "max_grad_norm": cfg.algo.algo_params.max_grad_norm, + "use_amp": cfg.training.use_amp, + "amp_dtype": cfg.algo.algo_params.amp_dtype, + "use_compile": cfg.algo.algo_params.use_compile, + "use_symmetry": cfg.algo.use_symmetry, + "symmetry_augmentation": _symmetry_aug, + "critic_obs_dim": _critic_dim, **_learner_extra_kwargs, - ) + } + _learner = _learner_cls(device=_device, **_learner_kwargs) + + if num_gpus > 1: + from unilab.algos.torch.offpolicy.multi_gpu_runner import MultiGPUOffPolicyRunner + + if cfg.algo.obs_normalization: + raise ValueError( + "SAC multi-GPU validation currently requires algo.obs_normalization=false" + ) + if not str(_device).startswith("cuda"): + raise ValueError("SAC multi-GPU training requires a CUDA device") + return MultiGPUOffPolicyRunner( + learner=_learner, + env_name=cfg.training.task_name, + algo_type=_algo_type, + learner_cls=_learner_cls, + learner_kwargs=_learner_kwargs, + num_gpus=num_gpus, + distributed_backend="nccl", + num_envs=cfg.algo.num_envs, + replay_buffer_n=cfg.algo.replay_buffer_n, + batch_size=_batch_size, + learning_starts=cfg.algo.learning_starts, + updates_per_step=cfg.algo.updates_per_step, + policy_frequency=cfg.algo.policy_frequency, + sync_collection=True, + env_steps_per_sync=cfg.training.env_steps_per_sync, + device=_device, + actor_hidden_dim=cfg.algo.actor_hidden_dim, + use_layer_norm=cfg.algo.use_layer_norm, + obs_normalization=False, + sim_backend=cfg.training.sim_backend, + env_cfg_override=env_cfg_override, + actor_kwargs=_actor_kwargs, + trace_enabled=cfg.training.trace_enabled, + trace_output_dir=cfg.training.trace_output_dir, + trace_thread_time=cfg.training.trace_thread_time, + trace_cuda_events=cfg.training.trace_cuda_events, + seed=cfg.algo.seed, + nan_guard_cfg=_nan_guard_cfg, + ) return DoubleBufferOffPolicyRunner( learner=_learner, diff --git a/src/unilab/algos/torch/fast_sac/learner.py b/src/unilab/algos/torch/fast_sac/learner.py index 7792bdb36..d4c743c5b 100644 --- a/src/unilab/algos/torch/fast_sac/learner.py +++ b/src/unilab/algos/torch/fast_sac/learner.py @@ -559,6 +559,15 @@ def _reduce_gradients(self, model: nn.Module) -> None: p.grad.copy_(flat[offset : offset + n].view_as(p.grad)) offset += n + def sync_initial_parameters(self, src: int = 0) -> None: + """Broadcast initial learner state for distributed off-policy training.""" + if self.world_size <= 1: + return + for module in (self.actor, self.qnet, self.qnet_target): + for parameter in module.parameters(): + dist.broadcast(parameter.data, src=src) + dist.broadcast(self.log_alpha.data, src=src) + def _get_actions_and_log_probs_for_critic( self, actor_obs: torch.Tensor, diff --git a/src/unilab/algos/torch/offpolicy/multi_gpu_runner.py b/src/unilab/algos/torch/offpolicy/multi_gpu_runner.py index a8d1ad486..9531b8fad 100644 --- a/src/unilab/algos/torch/offpolicy/multi_gpu_runner.py +++ b/src/unilab/algos/torch/offpolicy/multi_gpu_runner.py @@ -1,11 +1,12 @@ -"""Multi-GPU off-policy runner using NCCL all-reduce for FastSAC. +"""Multi-GPU off-policy runner using distributed gradient averaging. Architecture: Main process → creates ReplayBuffer (host-only), WeightSync, queues → spawns Collector subprocess (CPU, env simulation) → spawns N Learner workers via mp.spawn (one per GPU) - Learner rank i → samples packed CPU replay rows to its rank device, then - communicates via NCCL all_reduce + Learner rank i → samples packed CPU replay rows to its rank device through + a rank-local H2D pipeline, then communicates gradients + through the configured distributed backend. Collector → talks only to rank 0 via collection_ready_queue / trainer_done_queue """ @@ -18,13 +19,12 @@ import time from collections import defaultdict, deque from datetime import timedelta -from typing import Any, Dict, Optional, cast +from typing import Any, Dict, Optional import torch import torch.distributed as dist import torch.multiprocessing as tmp # torch.multiprocessing for spawn -from unilab.algos.torch.fast_sac.learner import FastSACLearner from unilab.algos.torch.offpolicy.runner import ( OffPolicyRunner, build_reward_comparison_metrics, @@ -35,6 +35,7 @@ from unilab.ipc import SharedWeightSync from unilab.ipc.async_runner import _SPAWN_CTX from unilab.ipc.replay_buffer import ReplayBuffer +from unilab.ipc.replay_pipelines.multi_gpu_cpu_pinned import MultiGPUCPUPinnedReplayPipeline from unilab.logging import OffPolicyLogger from unilab.training.seed import apply_training_seed, derive_worker_seed @@ -45,17 +46,6 @@ def _find_free_port() -> int: return int(s.getsockname()[1]) -def _broadcast_initial_params(learner: FastSACLearner, rank: int) -> None: - """Broadcast rank-0 initial parameters to all workers for consistent starting point.""" - for model in ( - cast(torch.nn.Module, learner.actor), - cast(torch.nn.Module, learner.qnet), - ): - for p in model.parameters(): - dist.broadcast(p.data, src=0) - dist.broadcast(learner.log_alpha.data, src=0) - - def _drain_metrics( metrics_queue: Any, reward_history: deque, @@ -98,6 +88,7 @@ def _drain_metrics( def _learner_worker( rank: int, world_size: int, + learner_cls: Any, learner_kwargs: Dict[str, Any], runner_kwargs: Dict[str, Any], replay_buffer: ReplayBuffer, @@ -108,6 +99,9 @@ def _learner_worker( collection_ready_queue: Any, trainer_done_queue: Any, metrics_queue: Any, + collector_pack_request_queue: Any, + collector_pack_ready_queue: Any, + collector_pack_shared_slots: Any, master_port: int, ) -> None: """Worker function executed on each GPU (called via torch.multiprocessing.spawn).""" @@ -115,12 +109,14 @@ def _learner_worker( os.environ["MASTER_PORT"] = str(master_port) device = f"cuda:{rank}" torch.cuda.set_device(rank) + backend = str(runner_kwargs.get("distributed_backend", "nccl")) dist.init_process_group( - "nccl", rank=rank, world_size=world_size, timeout=timedelta(seconds=120) + backend, rank=rank, world_size=world_size, timeout=timedelta(seconds=120) ) logger: Optional[OffPolicyLogger] = None weight_sync: SharedWeightSync | None = None + replay_pipeline: MultiGPUCPUPinnedReplayPipeline | None = None try: apply_training_seed( derive_worker_seed(runner_kwargs.get("seed"), worker_index=rank + 1000), @@ -131,10 +127,15 @@ def _learner_worker( replay_buffer.device = device # 2. Create learner on this device - learner = FastSACLearner(device=device, world_size=world_size, **learner_kwargs) + learner = learner_cls(device=device, world_size=world_size, **learner_kwargs) - # 3. Broadcast rank-0 params so all workers start identically - _broadcast_initial_params(learner, rank) + # 3. Broadcast rank-0 params so all workers start identically. + sync_initial_parameters = getattr(learner, "sync_initial_parameters", None) + if not callable(sync_initial_parameters): + raise ValueError( + "Multi-GPU off-policy learner must implement sync_initial_parameters(src=0)" + ) + sync_initial_parameters(src=0) # 4. Reconnect to the shared weight-sync buffer weight_sync = SharedWeightSync( @@ -157,12 +158,25 @@ def _learner_worker( logger_type: str = runner_kwargs.get("logger_type", "tensorboard") learning_starts = max(int(runner_kwargs.get("learning_starts", 0)), 0) train_start_threshold = compute_train_start_threshold(batch_size, learning_starts, num_envs) + sample_count = batch_size * updates_per_step + + replay_pipeline = MultiGPUCPUPinnedReplayPipeline( + replay_buffer, + rank=rank, + world_size=world_size, + device=device, + sample_count=sample_count, + base_seed=int(runner_kwargs.get("seed") or 0), + collector_pack_request_queue=collector_pack_request_queue[rank], + collector_pack_ready_queue=collector_pack_ready_queue[rank], + collector_pack_shared_slots=collector_pack_shared_slots[rank], + ) # 6. Logger (rank 0 only) if rank == 0: os.makedirs(log_dir, exist_ok=True) logger = OffPolicyLogger( - algo_name=f"FastSAC_x{world_size}GPU", + algo_name=f"Fast{str(runner_kwargs.get('algo_type', 'offpolicy')).upper()}_x{world_size}GPU", max_iterations=max_iterations, num_envs=num_envs, env_name=env_name, @@ -172,15 +186,19 @@ def _learner_worker( log_backend=logger_type, ) logger.set_collection_sync(sync_collection, env_steps_per_sync) + logger.log_status("Replay pipeline: multi_gpu_cpu_pinned") + logger.log_status(f"Replay batch semantics: per-rank batch_size={batch_size}") logger.start() reward_history: deque = deque(maxlen=100) latest_reward_components: dict = {} write_read_ema = 0.0 last_buf_log = 0 + prepared_tick: int | None = None # 7. Training loop for it in range(1, max_iterations + 1): + collector_released_for_next = False # --- Wait for data (rank 0 only, then barrier syncs everyone) --- wait_start = time.time() if rank == 0: @@ -228,12 +246,32 @@ def _learner_worker( iter_metrics: dict = defaultdict(list) ptr_before = int(replay_buffer.ptr[0]) if rank == 0 else 0 - large_batch = replay_buffer.sample(batch_size * updates_per_step) + if prepared_tick != it: + replay_pipeline.start_prepare(it, sample_count) + prepared_tick = it + while not replay_pipeline.batch_ready(it, sample_count): + if stop_event.is_set(): + return + time.sleep(0.001) + large_batch = replay_pipeline.sample_large_batch(it, sample_count) learner_incremental_h2d_time = ( - float(getattr(replay_buffer, "last_incremental_h2d_time_s", 0.0)) + float(getattr(replay_pipeline, "last_incremental_h2d_time_s", 0.0)) if rank == 0 else 0.0 ) + + if it < max_iterations: + min_snapshot_ptr = int(replay_buffer.ptr[0]) + (num_envs * env_steps_per_sync) + replay_pipeline.start_prepare( + it + 1, + sample_count, + min_snapshot_ptr=min_snapshot_ptr, + ) + prepared_tick = it + 1 + if rank == 0 and sync_collection and trainer_done_queue is not None: + trainer_done_queue.put(1) + collector_released_for_next = True + train_start = time.time() for update_idx in range(updates_per_step): @@ -245,13 +283,15 @@ def _learner_worker( for k, v in critic_metrics.items(): iter_metrics[k].append(v) - if update_idx % policy_frequency == 1: + if update_idx % policy_frequency == 0: actor_metrics = learner.update_actor(batch) for k, v in actor_metrics.items(): iter_metrics[k].append(v) learner.soft_update_target() + replay_pipeline.after_tick() + # Barrier: all ranks must finish this iteration before rank 0 proceeds dist.barrier() train_time = time.time() - train_start if rank == 0 else 0.0 @@ -263,7 +303,11 @@ def _learner_worker( weight_sync.write_weights(learner.actor.state_dict()) weight_sync_time = time.perf_counter() - weight_sync_start - if sync_collection and trainer_done_queue is not None: + if ( + sync_collection + and trainer_done_queue is not None + and not collector_released_for_next + ): trainer_done_queue.put(1) write_delta = int(replay_buffer.ptr[0]) - ptr_before @@ -289,6 +333,8 @@ def _learner_worker( weight_sync_time=weight_sync_time, extra_info={ "throughput_steps": num_envs * env_steps_per_sync, + "world_size": world_size, + "effective_batch_size": batch_size * world_size, }, ) @@ -306,12 +352,17 @@ def _learner_worker( logger.log_save(ckpt_path) logger.finish() + if replay_pipeline is not None: + replay_pipeline.close() + replay_pipeline = None weight_sync.close() weight_sync = None finally: if logger is not None: logger.close() + if replay_pipeline is not None: + replay_pipeline.close() if weight_sync is not None: weight_sync.close() dist.destroy_process_group() @@ -321,10 +372,9 @@ class MultiGPUOffPolicyRunner(OffPolicyRunner): """Multi-GPU off-policy runner. Keeps a single Collector on CPU and spawns *num_gpus* Learner workers via - ``torch.multiprocessing.spawn``. Each worker processes independent - mini-batches from the same shared ReplayBuffer; gradients are averaged - with NCCL all_reduce — equivalent to training on a *num_gpus× larger* - effective batch size per wall-clock second. + ``torch.multiprocessing.spawn``. Each worker processes an independent + mini-batch from the same shared ReplayBuffer through a rank-local H2D + pipeline; learner-owned distributed gradient reduction averages gradients. Falls back transparently to single-GPU when ``num_gpus <= 1``. """ @@ -349,8 +399,10 @@ def __init__( learner: Any, env_name: str, algo_type: str, + learner_cls: Any, learner_kwargs: Dict[str, Any], num_gpus: int = 1, + distributed_backend: str = "nccl", **kwargs: Any, ) -> None: self.validate_capabilities( @@ -361,7 +413,9 @@ def __init__( super().__init__(learner=learner, env_name=env_name, algo_type=algo_type, **kwargs) self.num_gpus = num_gpus self.world_size = num_gpus + self._learner_cls = learner_cls self._learner_kwargs = learner_kwargs + self.distributed_backend = distributed_backend def learn( self, @@ -378,6 +432,8 @@ def learn( logger_type=logger_type, ) return + if not self.sync_collection: + raise ValueError("Multi-GPU off-policy replay requires synchronized collection") self._learn_multi_gpu( max_iterations=max_iterations, save_interval=save_interval, @@ -422,6 +478,17 @@ def _learn_multi_gpu( ) metrics_queue = _SPAWN_CTX.Queue(maxsize=100) + collector_pack_request_queues = [_SPAWN_CTX.Queue(maxsize=2) for _ in range(self.num_gpus)] + collector_pack_ready_queues = [_SPAWN_CTX.Queue(maxsize=2) for _ in range(self.num_gpus)] + sample_count = self.batch_size * self.updates_per_step + packed_width = int(replay_buffer._storage.shape[1]) + collector_pack_shared_slots = [ + [ + torch.empty((sample_count, packed_width), dtype=torch.float32).share_memory_() + for _ in range(2) + ] + for _ in range(self.num_gpus) + ] # --- Start Collector (CPU, single process, unchanged) --- weight_param_shapes = {k: v.shape for k, v in self.learner.actor.state_dict().items()} @@ -445,7 +512,13 @@ def _learn_multi_gpu( "shared_obs_normalizer_stats": None, "sim_backend": self.sim_backend, "env_cfg_override": self.env_cfg_override, + "obs_dim": self.obs_dim, + "action_dim": self.action_dim, + "actor_kwargs": self.actor_kwargs, "seed": derive_worker_seed(self.seed, worker_index=0), + "collector_pack_request_queue": collector_pack_request_queues, + "collector_pack_ready_queue": collector_pack_ready_queues, + "collector_pack_shared_slots": collector_pack_shared_slots, } self._start_collector( target_fn=off_policy_collector_fn, @@ -476,6 +549,8 @@ def _learn_multi_gpu( "action_dim": self.action_dim, "logger_type": logger_type, "seed": self.seed, + "distributed_backend": self.distributed_backend, + "algo_type": self.algo_type, } try: @@ -483,6 +558,7 @@ def _learn_multi_gpu( _learner_worker, args=( self.num_gpus, + self._learner_cls, self._learner_kwargs, runner_kwargs, replay_buffer, @@ -493,6 +569,9 @@ def _learn_multi_gpu( collection_ready_queue, trainer_done_queue, metrics_queue, + collector_pack_request_queues, + collector_pack_ready_queues, + collector_pack_shared_slots, master_port, ), nprocs=self.num_gpus, diff --git a/src/unilab/algos/torch/offpolicy/worker.py b/src/unilab/algos/torch/offpolicy/worker.py index 0c9192c7d..cccc72626 100644 --- a/src/unilab/algos/torch/offpolicy/worker.py +++ b/src/unilab/algos/torch/offpolicy/worker.py @@ -6,6 +6,7 @@ import queue import sys +import threading import time from typing import Any, cast @@ -111,8 +112,18 @@ def _record_phase_ms(cycle_timing_ms: dict[str, float], key: str, start_ns: int) return end_ns +def _ranked_entry(collection, rank: int, world_size: int = 1): + if collection is None: + return None + if int(world_size) <= 1: + return collection + return collection[rank] + + def _collector_pack_shared_batch(replay_buffer, request: dict, shared_slots) -> dict: tick_id = int(request["tick_id"]) + rank = int(request.get("rank", 0)) + world_size = int(request.get("world_size", 1)) snapshot_ptr = int(replay_buffer.ptr[0]) snapshot_size = int(replay_buffer.size[0]) sample_seed = int(request["sample_seed"]) @@ -128,11 +139,16 @@ def _collector_pack_shared_batch(replay_buffer, request: dict, shared_slots) -> gen = torch.Generator(device="cpu") gen.manual_seed(sample_seed) indices = torch.randint(0, snapshot_size, (sample_count,), generator=gen) - dst = shared_slots[shared_slot] + rank_shared_slots = _ranked_entry(shared_slots, rank, world_size) + if rank_shared_slots is None: + raise RuntimeError("collector replay pack request is missing shared slots") + dst = rank_shared_slots[shared_slot] torch.index_select(replay_buffer._storage, 0, indices, out=dst) pack_end_ns = time.perf_counter_ns() return { "tick_id": tick_id, + "rank": rank, + "world_size": world_size, "snapshot_ptr": snapshot_ptr, "snapshot_size": snapshot_size, "sample_seed": sample_seed, @@ -193,10 +209,119 @@ def _service_collector_pack_requests( "pinned_memory": False, }, ) - ready_queue.put(ready) + target_ready_queue = _ranked_entry( + ready_queue, + int(ready.get("rank", 0)), + int(ready.get("world_size", 1)), + ) + if target_ready_queue is None: + raise RuntimeError("collector replay pack request is missing a ready queue") + target_ready_queue.put(ready) return True, None +def _drain_collector_pack_requests( + replay_buffer, + request_queue, + ready_queue, + shared_slots, + trace_recorder=None, + *, + pending_request: dict | None = None, + max_requests: int = 0, +) -> dict | None: + """Service currently available replay pack requests without blocking env progress.""" + serviced_count = 0 + pending = pending_request + while True: + if max_requests > 0 and serviced_count >= max_requests: + return pending + serviced, pending = _service_collector_pack_requests( + replay_buffer, + request_queue, + ready_queue, + shared_slots, + trace_recorder, + block_timeout=0.0, + pending_request=pending, + ) + if not serviced: + return pending + serviced_count += 1 + + +class _CollectorPackService: + """Background replay pack service for multi-rank off-policy learners.""" + + def __init__( + self, + replay_buffer, + request_queue, + ready_queue, + shared_slots, + trace_recorder=None, + *, + stop_event=None, + ) -> None: + self._replay_buffer = replay_buffer + self._request_queue = request_queue + self._ready_queue = ready_queue + self._shared_slots = shared_slots + self._trace_recorder = trace_recorder + self._stop_event = stop_event + self._threads: list[threading.Thread] = [] + self._started = False + + @staticmethod + def should_start(request_queue, ready_queue, shared_slots) -> bool: + return ( + isinstance(request_queue, list) + and isinstance(ready_queue, list) + and isinstance(shared_slots, list) + and len(request_queue) > 1 + ) + + def start(self) -> None: + if self._started: + return + self._started = True + world_size = len(self._request_queue) + for rank in range(world_size): + thread = threading.Thread( + target=self._rank_worker, + args=(rank, world_size), + name=f"collector_replay_pack_rank{rank}", + daemon=True, + ) + thread.start() + self._threads.append(thread) + + def _rank_worker(self, rank: int, world_size: int) -> None: + request_queue = self._request_queue[rank] + ready_queue = self._ready_queue + shared_slots = self._shared_slots + pending_request = None + while True: + if self._stop_event is not None and self._stop_event.is_set(): + return + serviced, pending_request = _service_collector_pack_requests( + self._replay_buffer, + request_queue, + ready_queue, + shared_slots, + self._trace_recorder, + block_timeout=0.001, + pending_request=pending_request, + ) + if not serviced and pending_request is not None: + time.sleep(0.0005) + + def close(self) -> None: + for thread in self._threads: + thread.join(timeout=1.0) + self._threads.clear() + + def off_policy_collector_fn( stop_event, env_name: str, @@ -394,315 +519,346 @@ def _run_collector( # Track env.step calls collected since the last learner phase. env_steps_since_sync = 0 pending_collector_pack_request = None - - # Collection loop - while not stop_event.is_set(): - cycle_timing_ms: dict[str, float] = dict.fromkeys(COLLECTOR_TIMING_KEYS, 0.0) - phase_start_ns = _time.perf_counter_ns() - - # Check for weight updates - if weight_sync.version > local_weight_version: - _wt_ns = _time.perf_counter_ns() - sd = dict(actor.state_dict()) - local_weight_version = weight_sync.read_weights_into(sd) - actor.load_state_dict(sd) - if trace_recorder: - trace_recorder.add_slice( - "collector/check_weight_update", - category="collector", - start_ns=_wt_ns, - end_ns=_time.perf_counter_ns(), - ) - - # Update normalizer stats - if obs_normalization and shared_obs_normalizer_stats is not None: - stats = shared_obs_normalizer_stats.get() - if stats is not None: - # Apply stats to a local normalizer if needed, or directly to actor - pass # Handled by EmpiricalNormalization in learner if actor possesses it. We need a local normalizer. - phase_start_ns = _record_phase_ms(cycle_timing_ms, "weight_sync_ms", phase_start_ns) - - # Normalize obs_np - obs_np_input = obs_np - if obs_normalization and shared_obs_normalizer_stats is not None: - stats = shared_obs_normalizer_stats.get() - if stats is not None: - mean, std = stats - obs_np_input = (obs_np - mean) / (std + 1e-8) - - # Select action - with torch.no_grad(): - _t_infer_ns = _time.perf_counter_ns() - obs_torch = torch.from_numpy(obs_np_input) - dones_torch = torch.from_numpy(prev_dones_np) - priv_info_np = resolve_offpolicy_actor_priv_info( - algo_type=algo_type, - obs_np=obs_np, - critic_np=critic_np, - info=info_dict, - ) - priv_info_torch = torch.from_numpy(priv_info_np) if priv_info_np is not None else None - actions_torch = sample_offpolicy_actions( - actor=actor, - algo_type=algo_type, - obs_torch=obs_torch, - prev_dones_torch=dones_torch, - priv_info_torch=priv_info_torch, - ) - actions_np = actions_torch.numpy() - if trace_recorder: - trace_recorder.add_slice( - "collector/actor_infer_cpu", - category="collector", - start_ns=_t_infer_ns, - end_ns=_time.perf_counter_ns(), - ) - phase_start_ns = _record_phase_ms(cycle_timing_ms, "action_select_ms", phase_start_ns) - - # Step environment - _env_ns = _time.perf_counter_ns() - state = env.step(actions_np) - if trace_recorder: - trace_recorder.add_slice( - "collector/env_step", - category="collector", - start_ns=_env_ns, - end_ns=_time.perf_counter_ns(), - args={"num_envs": num_envs}, - ) - phase_start_ns = _record_phase_ms(cycle_timing_ms, "env_step_ms", phase_start_ns) - - # Extract data as numpy - next_obs_np, next_critic_np = split_obs_dict(state.obs) - next_obs_np = np.asarray(next_obs_np, dtype=np.float32) - next_critic_np = np.asarray(next_critic_np, dtype=np.float32) - rewards_np = np.asarray(state.reward, dtype=np.float32).ravel() - - terminated_np = state.terminated.astype(np.float32, copy=False).ravel() - truncated_np = state.truncated.astype(np.float32, copy=False).ravel() - combined_dones = (state.terminated | state.truncated).astype(np.float32, copy=False).ravel() - prev_dones_np = combined_dones - done_mask_np = combined_dones > 0.5 - timeout_mask_np = truncated_np > 0.5 - terminated_mask_np = np.logical_and(terminated_np > 0.5, ~timeout_mask_np) - - done_count_window += int(np.count_nonzero(done_mask_np)) - timeout_count_window += int(np.count_nonzero(timeout_mask_np)) - terminated_count_window += int(np.count_nonzero(terminated_mask_np)) - - terminal_contract = resolve_terminal_observation_contract( - next_obs_batch_size=next_obs_np.shape[0], - final_observation=state.final_observation, - done=done_mask_np, - info=state.info, - truncated=truncated_np, - ) - phase_start_ns = _record_phase_ms(cycle_timing_ms, "replay_ms", phase_start_ns) - - # ReplayBuffer `dones` follows the UniLab env lifecycle contract: - # done = terminated | truncated. Learners use `truncated` to keep - # bootstrap enabled for timeout/truncation rows. - _rb_ns = _time.perf_counter_ns() - replay_buffer.add( - torch.from_numpy(obs_np), - torch.from_numpy(actions_np), - torch.from_numpy(rewards_np), - torch.from_numpy(next_obs_np), - torch.from_numpy(combined_dones), - torch.from_numpy(truncated_np), - terminal_mask=torch.from_numpy(terminal_contract.terminal_mask), - terminal_next_obs=( - torch.from_numpy(terminal_contract.terminal_obs) - if terminal_contract.terminal_obs is not None - else None - ), - critic=torch.from_numpy(critic_np), - next_critic=torch.from_numpy(next_critic_np), - terminal_next_critic=( - torch.from_numpy(terminal_contract.terminal_critic) - if terminal_contract.terminal_critic is not None - else None - ), - ) - if trace_recorder: - trace_recorder.add_slice( - "collector/replay_add", - category="collector", - start_ns=_rb_ns, - end_ns=_time.perf_counter_ns(), - ) - phase_start_ns = _record_phase_ms(cycle_timing_ms, "replay_ms", phase_start_ns) - _, pending_collector_pack_request = _service_collector_pack_requests( + collector_pack_service = None + if _CollectorPackService.should_start( + collector_pack_request_queue, + collector_pack_ready_queue, + collector_pack_shared_slots, + ): + collector_pack_service = _CollectorPackService( replay_buffer, collector_pack_request_queue, collector_pack_ready_queue, collector_pack_shared_slots, trace_recorder, - block_timeout=0.0, - pending_request=pending_collector_pack_request, + stop_event=stop_event, ) - phase_start_ns = _record_phase_ms(cycle_timing_ms, "replay_ms", phase_start_ns) - - # Track episode rewards - vectorized - current_ep_rewards += rewards_np - current_ep_lengths += 1 - reset_mask = combined_dones > 0.5 - reset_indices = np.where(reset_mask)[0] - if len(reset_indices) > 0: - ep_rewards.extend(current_ep_rewards[reset_indices].tolist()) - ep_lengths.extend(current_ep_lengths[reset_indices].tolist()) - current_ep_rewards[reset_indices] = 0.0 - current_ep_lengths[reset_indices] = 0 - - obs_np = next_obs_np - critic_np = next_critic_np - info_dict = state.info - total_steps += num_envs - env_steps_since_sync += 1 - phase_start_ns = _record_phase_ms(cycle_timing_ms, "sync_coordination_ms", phase_start_ns) - - # Signal the learner once this collection chunk is ready. - if ( - sync_collection - and collection_ready_queue is not None - and trainer_done_queue is not None - ): - if env_steps_since_sync >= env_steps_per_sync: - _sig_ns = _time.perf_counter_ns() - collection_ready_queue.put(1) + collector_pack_service.start() + + # Collection loop + try: + while not stop_event.is_set(): + cycle_timing_ms: dict[str, float] = dict.fromkeys(COLLECTOR_TIMING_KEYS, 0.0) + phase_start_ns = _time.perf_counter_ns() + + # Check for weight updates + if weight_sync.version > local_weight_version: + _wt_ns = _time.perf_counter_ns() + sd = dict(actor.state_dict()) + local_weight_version = weight_sync.read_weights_into(sd) + actor.load_state_dict(sd) if trace_recorder: trace_recorder.add_slice( - "collector/signal_ready", + "collector/check_weight_update", category="collector", - start_ns=_sig_ns, + start_ns=_wt_ns, end_ns=_time.perf_counter_ns(), ) - phase_start_ns = _record_phase_ms( - cycle_timing_ms, "sync_coordination_ms", phase_start_ns + + # Update normalizer stats + if obs_normalization and shared_obs_normalizer_stats is not None: + stats = shared_obs_normalizer_stats.get() + if stats is not None: + # Apply stats to a local normalizer if needed, or directly to actor + pass # Handled by EmpiricalNormalization in learner if actor possesses it. We need a local normalizer. + phase_start_ns = _record_phase_ms(cycle_timing_ms, "weight_sync_ms", phase_start_ns) + + # Normalize obs_np + obs_np_input = obs_np + if obs_normalization and shared_obs_normalizer_stats is not None: + stats = shared_obs_normalizer_stats.get() + if stats is not None: + mean, std = stats + obs_np_input = (obs_np - mean) / (std + 1e-8) + + # Select action + with torch.no_grad(): + _t_infer_ns = _time.perf_counter_ns() + obs_torch = torch.from_numpy(obs_np_input) + dones_torch = torch.from_numpy(prev_dones_np) + priv_info_np = resolve_offpolicy_actor_priv_info( + algo_type=algo_type, + obs_np=obs_np, + critic_np=critic_np, + info=info_dict, ) - _wait_ns = _time.perf_counter_ns() - while not stop_event.is_set(): - _, pending_collector_pack_request = _service_collector_pack_requests( - replay_buffer, - collector_pack_request_queue, - collector_pack_ready_queue, - collector_pack_shared_slots, - trace_recorder, - block_timeout=0.0, - pending_request=pending_collector_pack_request, - ) - phase_start_ns = _record_phase_ms(cycle_timing_ms, "replay_ms", phase_start_ns) - try: - trainer_done_queue.get(timeout=0.001) - phase_start_ns = _record_phase_ms( - cycle_timing_ms, "sync_coordination_ms", phase_start_ns - ) - _, pending_collector_pack_request = _service_collector_pack_requests( - replay_buffer, - collector_pack_request_queue, - collector_pack_ready_queue, - collector_pack_shared_slots, - trace_recorder, - block_timeout=0.0, - pending_request=pending_collector_pack_request, - ) - phase_start_ns = _record_phase_ms( - cycle_timing_ms, "replay_ms", phase_start_ns - ) - break - except queue.Empty: - phase_start_ns = _record_phase_ms( - cycle_timing_ms, "sync_coordination_ms", phase_start_ns - ) - continue + priv_info_torch = ( + torch.from_numpy(priv_info_np) if priv_info_np is not None else None + ) + actions_torch = sample_offpolicy_actions( + actor=actor, + algo_type=algo_type, + obs_torch=obs_torch, + prev_dones_torch=dones_torch, + priv_info_torch=priv_info_torch, + ) + actions_np = actions_torch.numpy() if trace_recorder: trace_recorder.add_slice( - "collector/wait_trainer_done", + "collector/actor_infer_cpu", category="collector", - start_ns=_wait_ns, + start_ns=_t_infer_ns, end_ns=_time.perf_counter_ns(), ) - if metrics_queue is not None: - try: - metrics_queue.put_nowait( - {"trace_events": trace_recorder.drain_events()} - ) - except Exception: - pass + phase_start_ns = _record_phase_ms(cycle_timing_ms, "action_select_ms", phase_start_ns) + + # Step environment + _env_ns = _time.perf_counter_ns() + state = env.step(actions_np) + if trace_recorder: + trace_recorder.add_slice( + "collector/env_step", + category="collector", + start_ns=_env_ns, + end_ns=_time.perf_counter_ns(), + args={"num_envs": num_envs}, + ) + phase_start_ns = _record_phase_ms(cycle_timing_ms, "env_step_ms", phase_start_ns) + + # Extract data as numpy + next_obs_np, next_critic_np = split_obs_dict(state.obs) + next_obs_np = np.asarray(next_obs_np, dtype=np.float32) + next_critic_np = np.asarray(next_critic_np, dtype=np.float32) + rewards_np = np.asarray(state.reward, dtype=np.float32).ravel() + + terminated_np = state.terminated.astype(np.float32, copy=False).ravel() + truncated_np = state.truncated.astype(np.float32, copy=False).ravel() + combined_dones = ( + (state.terminated | state.truncated).astype(np.float32, copy=False).ravel() + ) + prev_dones_np = combined_dones + done_mask_np = combined_dones > 0.5 + timeout_mask_np = truncated_np > 0.5 + terminated_mask_np = np.logical_and(terminated_np > 0.5, ~timeout_mask_np) + + done_count_window += int(np.count_nonzero(done_mask_np)) + timeout_count_window += int(np.count_nonzero(timeout_mask_np)) + terminated_count_window += int(np.count_nonzero(terminated_mask_np)) + + terminal_contract = resolve_terminal_observation_contract( + next_obs_batch_size=next_obs_np.shape[0], + final_observation=state.final_observation, + done=done_mask_np, + info=state.info, + truncated=truncated_np, + ) + phase_start_ns = _record_phase_ms(cycle_timing_ms, "replay_ms", phase_start_ns) + + # ReplayBuffer `dones` follows the UniLab env lifecycle contract: + # done = terminated | truncated. Learners use `truncated` to keep + # bootstrap enabled for timeout/truncation rows. + _rb_ns = _time.perf_counter_ns() + replay_buffer.add( + torch.from_numpy(obs_np), + torch.from_numpy(actions_np), + torch.from_numpy(rewards_np), + torch.from_numpy(next_obs_np), + torch.from_numpy(combined_dones), + torch.from_numpy(truncated_np), + terminal_mask=torch.from_numpy(terminal_contract.terminal_mask), + terminal_next_obs=( + torch.from_numpy(terminal_contract.terminal_obs) + if terminal_contract.terminal_obs is not None + else None + ), + critic=torch.from_numpy(critic_np), + next_critic=torch.from_numpy(next_critic_np), + terminal_next_critic=( + torch.from_numpy(terminal_contract.terminal_critic) + if terminal_contract.terminal_critic is not None + else None + ), + ) + if trace_recorder: + trace_recorder.add_slice( + "collector/replay_add", + category="collector", + start_ns=_rb_ns, + end_ns=_time.perf_counter_ns(), + ) + phase_start_ns = _record_phase_ms(cycle_timing_ms, "replay_ms", phase_start_ns) + if collector_pack_service is None: + pending_collector_pack_request = _drain_collector_pack_requests( + replay_buffer, + collector_pack_request_queue, + collector_pack_ready_queue, + collector_pack_shared_slots, + trace_recorder, + pending_request=pending_collector_pack_request, + ) + phase_start_ns = _record_phase_ms(cycle_timing_ms, "replay_ms", phase_start_ns) + + # Track episode rewards - vectorized + current_ep_rewards += rewards_np + current_ep_lengths += 1 + reset_mask = combined_dones > 0.5 + reset_indices = np.where(reset_mask)[0] + if len(reset_indices) > 0: + ep_rewards.extend(current_ep_rewards[reset_indices].tolist()) + ep_lengths.extend(current_ep_lengths[reset_indices].tolist()) + current_ep_rewards[reset_indices] = 0.0 + current_ep_lengths[reset_indices] = 0 + + obs_np = next_obs_np + critic_np = next_critic_np + info_dict = state.info + total_steps += num_envs + env_steps_since_sync += 1 + phase_start_ns = _record_phase_ms( + cycle_timing_ms, "sync_coordination_ms", phase_start_ns + ) + + # Signal the learner once this collection chunk is ready. + if ( + sync_collection + and collection_ready_queue is not None + and trainer_done_queue is not None + ): + if env_steps_since_sync >= env_steps_per_sync: + _sig_ns = _time.perf_counter_ns() + collection_ready_queue.put(1) + if trace_recorder: + trace_recorder.add_slice( + "collector/signal_ready", + category="collector", + start_ns=_sig_ns, + end_ns=_time.perf_counter_ns(), + ) phase_start_ns = _record_phase_ms( cycle_timing_ms, "sync_coordination_ms", phase_start_ns ) + _wait_ns = _time.perf_counter_ns() + while not stop_event.is_set(): + if collector_pack_service is None: + pending_collector_pack_request = _drain_collector_pack_requests( + replay_buffer, + collector_pack_request_queue, + collector_pack_ready_queue, + collector_pack_shared_slots, + trace_recorder, + pending_request=pending_collector_pack_request, + ) + phase_start_ns = _record_phase_ms( + cycle_timing_ms, "replay_ms", phase_start_ns + ) + try: + trainer_done_queue.get(timeout=0.001) + phase_start_ns = _record_phase_ms( + cycle_timing_ms, "sync_coordination_ms", phase_start_ns + ) + if collector_pack_service is None: + pending_collector_pack_request = _drain_collector_pack_requests( + replay_buffer, + collector_pack_request_queue, + collector_pack_ready_queue, + collector_pack_shared_slots, + trace_recorder, + pending_request=pending_collector_pack_request, + ) + phase_start_ns = _record_phase_ms( + cycle_timing_ms, "replay_ms", phase_start_ns + ) + break + except queue.Empty: + phase_start_ns = _record_phase_ms( + cycle_timing_ms, "sync_coordination_ms", phase_start_ns + ) + continue + if trace_recorder: + trace_recorder.add_slice( + "collector/wait_trainer_done", + category="collector", + start_ns=_wait_ns, + end_ns=_time.perf_counter_ns(), + ) + if metrics_queue is not None: + try: + metrics_queue.put_nowait( + {"trace_events": trace_recorder.drain_events()} + ) + except Exception: + pass + phase_start_ns = _record_phase_ms( + cycle_timing_ms, "sync_coordination_ms", phase_start_ns + ) + env_steps_since_sync = 0 + elif env_steps_since_sync >= env_steps_per_sync: env_steps_since_sync = 0 - elif env_steps_since_sync >= env_steps_per_sync: - env_steps_since_sync = 0 - phase_start_ns = _record_phase_ms(cycle_timing_ms, "sync_coordination_ms", phase_start_ns) - - # Progress log every 2 seconds - now = _time.time() - if now - _last_log_time > 2.0: - _last_log_time = now - - # Extract reward components from env info - log_info = state.info.get("log", {}) - if log_info: - for k, v in log_info.items(): - if k.startswith("reward/"): - ep_reward_components[k].append(v) - - # Send metrics periodically - if metrics_queue is not None and total_steps % (num_envs * 10) == 0: - import statistics + phase_start_ns = _record_phase_ms( + cycle_timing_ms, "sync_coordination_ms", phase_start_ns + ) - try: - msg = { - "total_steps": total_steps, - "buffer_size": int(replay_buffer.size[0]), - } - if ep_rewards: - msg["mean_ep_reward"] = statistics.mean(ep_rewards[-100:]) - msg["mean_ep_length"] = ( - statistics.mean(ep_lengths[-100:]) if ep_lengths else 0.0 - ) - # Add mean reward components - if ep_reward_components: - components_mean = {} - for k, vals in ep_reward_components.items(): - if vals: - components_mean[k] = statistics.mean(vals) - msg["reward_components"] = components_mean - ep_reward_components.clear() # reset after sending - - if timing_counts: - msg["collector_timing_ms"] = { - k: (v / timing_counts[k]) - for k, v in timing_accum_ms.items() - if timing_counts[k] > 0 + # Progress log every 2 seconds + now = _time.time() + if now - _last_log_time > 2.0: + _last_log_time = now + + # Extract reward components from env info + log_info = state.info.get("log", {}) + if log_info: + for k, v in log_info.items(): + if k.startswith("reward/"): + ep_reward_components[k].append(v) + + # Send metrics periodically + if metrics_queue is not None and total_steps % (num_envs * 10) == 0: + import statistics + + try: + msg = { + "total_steps": total_steps, + "buffer_size": int(replay_buffer.size[0]), } + if ep_rewards: + msg["mean_ep_reward"] = statistics.mean(ep_rewards[-100:]) + msg["mean_ep_length"] = ( + statistics.mean(ep_lengths[-100:]) if ep_lengths else 0.0 + ) + # Add mean reward components + if ep_reward_components: + components_mean = {} + for k, vals in ep_reward_components.items(): + if vals: + components_mean[k] = statistics.mean(vals) + msg["reward_components"] = components_mean + ep_reward_components.clear() # reset after sending + + if timing_counts: + msg["collector_timing_ms"] = { + k: (v / timing_counts[k]) + for k, v in timing_accum_ms.items() + if timing_counts[k] > 0 + } + + if done_count_window > 0: + msg["timeout_rate"] = timeout_count_window / done_count_window + msg["terminated_rate"] = terminated_count_window / done_count_window + done_count_window = 0 + timeout_count_window = 0 + terminated_count_window = 0 + + if trace_recorder: + msg["trace_events"] = trace_recorder.drain_events() + + metrics_queue.put_nowait(msg) + if "collector_timing_ms" in msg: + timing_accum_ms.clear() + timing_counts.clear() + except Exception as e: + print(f"[OffPolicyWorker] metrics enqueue error: {e}", file=sys.stderr) + phase_start_ns = _record_phase_ms( + cycle_timing_ms, "sync_coordination_ms", phase_start_ns + ) - if done_count_window > 0: - msg["timeout_rate"] = timeout_count_window / done_count_window - msg["terminated_rate"] = terminated_count_window / done_count_window - done_count_window = 0 - timeout_count_window = 0 - terminated_count_window = 0 - - if trace_recorder: - msg["trace_events"] = trace_recorder.drain_events() - - metrics_queue.put_nowait(msg) - if "collector_timing_ms" in msg: - timing_accum_ms.clear() - timing_counts.clear() - except Exception as e: - print(f"[OffPolicyWorker] metrics enqueue error: {e}", file=sys.stderr) - phase_start_ns = _record_phase_ms(cycle_timing_ms, "sync_coordination_ms", phase_start_ns) - - for key in COLLECTOR_TIMING_KEYS: - _record_timing_ms(timing_accum_ms, timing_counts, key, cycle_timing_ms[key]) + for key in COLLECTOR_TIMING_KEYS: + _record_timing_ms(timing_accum_ms, timing_counts, key, cycle_timing_ms[key]) - if metrics_queue is not None and trace_recorder: - try: - metrics_queue.put_nowait({"trace_events": trace_recorder.drain_events()}) - except Exception: - pass - weight_sync.close() + finally: + if collector_pack_service is not None: + collector_pack_service.close() + if metrics_queue is not None and trace_recorder: + try: + metrics_queue.put_nowait({"trace_events": trace_recorder.drain_events()}) + except Exception: + pass + weight_sync.close() diff --git a/src/unilab/ipc/replay_pipelines/multi_gpu_cpu_pinned.py b/src/unilab/ipc/replay_pipelines/multi_gpu_cpu_pinned.py new file mode 100644 index 000000000..ad85d229a --- /dev/null +++ b/src/unilab/ipc/replay_pipelines/multi_gpu_cpu_pinned.py @@ -0,0 +1,360 @@ +"""Rank-local multi-GPU replay pipeline for packed CPU replay samples.""" + +from __future__ import annotations + +import queue +import threading +import time +from typing import Dict + +import torch + +from unilab.ipc.replay_buffer import ReplayBuffer +from unilab.ipc.replay_pipelines.base import ReplayTickMetadata +from unilab.ipc.replay_pipelines.transfer import build_replay_transfer_backend + + +class MultiGPUCPUPinnedReplayPipeline: + """Per-rank replay pipeline with independent host slots and H2D stream. + + The replay buffer remains authoritative CPU shared storage. Each learner + rank owns its host staging slots and device slots, so H2D submission happens + concurrently in the rank-local worker process instead of funnelling through + rank 0. + """ + + def __init__( + self, + replay_buffer: ReplayBuffer, + *, + rank: int, + world_size: int, + device: str, + sample_count: int, + base_seed: int = 0, + trace_recorder=None, + trace_cuda_events: bool = True, + collector_pack_request_queue=None, + collector_pack_ready_queue=None, + collector_pack_shared_slots=None, + ) -> None: + if int(world_size) <= 1: + raise ValueError("MultiGPUCPUPinnedReplayPipeline requires world_size > 1") + if not getattr(replay_buffer, "_packed_cpu_storage", False): + raise ValueError("Multi-GPU replay pipeline requires packed ReplayBuffer storage") + if ( + collector_pack_request_queue is None + or collector_pack_ready_queue is None + or collector_pack_shared_slots is None + ): + raise ValueError("Multi-GPU replay pipeline requires collector pack IPC objects") + + self._replay_buffer = replay_buffer + self._rank = int(rank) + self._world_size = int(world_size) + self._device = torch.device(device) + self._sample_count = int(sample_count) + self._base_seed = int(base_seed) + self._trace_recorder = trace_recorder + self._pack_layout = "packed" + self._pack_executor = "collector_thread" + self._ring_depth = 2 + self._transfer_backend = build_replay_transfer_backend( + device=self._device, + ring_depth=self._ring_depth, + ) + self._trace_cuda_events = bool(trace_cuda_events) and ( + self._transfer_backend.supports_timing_events + ) + self._collector_pack_request_queue = collector_pack_request_queue + self._collector_pack_ready_queue = collector_pack_ready_queue + self._collector_pack_shared_slots = collector_pack_shared_slots + if len(self._collector_pack_shared_slots) != self._ring_depth or not all( + isinstance(slot, torch.Tensor) for slot in self._collector_pack_shared_slots + ): + raise ValueError( + "Multi-GPU replay pipeline expects rank-local shared slots with " + f"ring_depth={self._ring_depth}" + ) + self._transfer_backend.register_host_slots(self._collector_pack_shared_slots) + self._packed_width = int(replay_buffer._storage.shape[1]) + self._gpu_packed = self._transfer_backend.allocate_device_slots( + count=self._ring_depth, + shape=(self._sample_count, self._packed_width), + dtype=torch.float32, + ) + + self._hot = 0 + self._cold = 1 + self._has_hot_batch = False + self._hot_metadata: ReplayTickMetadata | None = None + self._prepared_metadata: ReplayTickMetadata | None = None + self._prepare_tick_id: int | None = None + self._prepare_state = "idle" + self._prepare_error: BaseException | None = None + self._prepare_condition = threading.Condition() + self._closed = False + self.last_incremental_h2d_time_s = 0.0 + self._collector_h2d_thread = threading.Thread( + target=self._collector_h2d_worker, + name=f"replay_rank{self._rank}_h2d", + daemon=True, + ) + self._collector_h2d_thread.start() + + @property + def h2d_submitter(self) -> str: + return self._transfer_backend.h2d_submitter + + @property + def transfer_manifest(self) -> dict[str, object]: + return { + "backend": type(self._transfer_backend).__name__, + "device": str(self._device), + "device_family": self._transfer_backend.device_family, + "host_memory_kind": self._transfer_backend.host_memory_kind, + "host_pinned": self._transfer_backend.host_pinned, + "direct_pinned_shared": self._transfer_backend.direct_pinned_shared, + "supports_async_submit": self._transfer_backend.supports_async_submit, + "supports_timing_events": self._transfer_backend.supports_timing_events, + "h2d_submitter": self._transfer_backend.h2d_submitter, + "rank": self._rank, + "world_size": self._world_size, + "ring_depth": self._ring_depth, + } + + def _snapshot(self) -> tuple[int, int]: + return int(self._replay_buffer.ptr[0]), int(self._replay_buffer.size[0]) + + def _packed_h2d_source(self, slot: int) -> torch.Tensor: + return self._collector_pack_shared_slots[slot] + + def _h2d_bytes(self) -> int: + source = self._packed_h2d_source(0) + return int(source.numel() * source.element_size()) + + def _packed_batch_view(self, packed: torch.Tensor) -> Dict[str, torch.Tensor]: + rb = self._replay_buffer + batch = { + "obs": packed[:, rb._obs_sl], + "next_obs": packed[:, rb._nobs_sl], + "actions": packed[:, rb._act_sl], + "rewards": packed[:, rb._rew_col], + "dones": packed[:, rb._done_col], + "truncated": packed[:, rb._trunc_col], + } + if rb._critic_dim > 0: + batch["critic"] = packed[:, rb._critic_sl] + batch["next_critic"] = packed[:, rb._ncritic_sl] + return batch + + def _submit_h2d(self, slot: int, metadata: ReplayTickMetadata) -> float: + self._transfer_backend.clear_ready(slot) + return self._transfer_backend.submit_h2d( + slot=slot, + dst=self._gpu_packed[slot], + src=self._packed_h2d_source(slot), + metadata=metadata, + trace_recorder=self._trace_recorder, + trace_cuda_events=self._trace_cuda_events, + h2d_bytes=self._h2d_bytes(), + pack_layout=self._pack_layout, + pack_executor=self._pack_executor, + ) + + def _collector_h2d_worker(self) -> None: + while True: + if self._closed: + return + try: + ready = self._collector_pack_ready_queue.get(timeout=0.1) + except queue.Empty: + continue + if ready is None: + return + try: + if int(ready.get("rank", self._rank)) != self._rank: + raise RuntimeError( + f"Rank {self._rank} received replay batch for rank {ready.get('rank')}" + ) + metadata = ReplayTickMetadata( + tick_id=int(ready["tick_id"]), + snapshot_ptr=int(ready["snapshot_ptr"]), + snapshot_size=int(ready["snapshot_size"]), + sample_seed=int(ready["sample_seed"]), + sample_count=int(ready["sample_count"]), + batch_host_slot=int(ready["shared_slot"]), + batch_gpu_slot=int(ready["target_gpu_slot"]), + ) + slot = metadata.batch_gpu_slot + assert slot is not None + self.last_incremental_h2d_time_s = self._submit_h2d(slot, metadata) + with self._prepare_condition: + if self._prepare_tick_id != metadata.tick_id: + raise RuntimeError( + f"Rank {self._rank} packed tick {metadata.tick_id} " + f"does not match pending tick {self._prepare_tick_id}" + ) + self._prepared_metadata = metadata + self._prepare_state = "h2d_submitted" + self._prepare_error = None + self._prepare_condition.notify_all() + except BaseException as exc: + with self._prepare_condition: + self._prepare_error = exc + self._prepare_condition.notify_all() + + def _refresh_prepare_state(self) -> None: + if self._prepare_error is not None: + raise self._prepare_error + if self._prepared_metadata is not None: + slot = self._prepared_metadata.batch_gpu_slot + if slot is not None and self._transfer_backend.ready_query(slot): + self._prepare_state = "ready" + + def start_prepare( + self, + tick_id: int, + sample_count: int, + min_snapshot_ptr: int | None = None, + ) -> bool: + if int(sample_count) != self._sample_count: + raise ValueError("sample_count must match the allocated multi-GPU replay slots") + if self._closed: + raise RuntimeError("Cannot prepare replay batch after pipeline.close()") + self._refresh_prepare_state() + active_tick = self._prepare_tick_id + if self._prepared_metadata is not None or self._prepare_state not in {"idle", "ready"}: + prepared_tick = ( + self._prepared_metadata.tick_id + if self._prepared_metadata is not None + else active_tick + ) + if prepared_tick == int(tick_id): + return False + raise RuntimeError( + "Cannot prepare a new replay batch before consuming the previous one" + ) + + slot = self._cold + self._transfer_backend.clear_ready(slot) + self._prepare_tick_id = int(tick_id) + self._prepare_error = None + snapshot_ptr, snapshot_size = self._snapshot() + sample_seed = self._base_seed + int(tick_id) * self._world_size + self._rank + min_snapshot_ptr = snapshot_ptr if min_snapshot_ptr is None else int(min_snapshot_ptr) + request = { + "tick_id": int(tick_id), + "rank": self._rank, + "world_size": self._world_size, + "snapshot_ptr": snapshot_ptr, + "snapshot_size": snapshot_size, + "min_snapshot_ptr": min_snapshot_ptr, + "sample_seed": sample_seed, + "sample_count": self._sample_count, + "shared_slot": slot, + "learner_hot_gpu_slot": self._hot, + "target_gpu_slot": slot, + "pack_layout": self._pack_layout, + "pack_executor": self._pack_executor, + } + self._prepare_state = "collector_pack_requested" + self._collector_pack_request_queue.put(request) + return True + + def batch_ready(self, tick_id: int, sample_count: int) -> bool: + if int(sample_count) != self._sample_count: + raise ValueError("sample_count must match the allocated multi-GPU replay slots") + if self._has_hot_batch: + if self._hot_metadata is not None and self._hot_metadata.tick_id != int(tick_id): + return False + return True + self._refresh_prepare_state() + if self._prepared_metadata is None: + return False + if self._prepared_metadata.tick_id != int(tick_id): + return False + return self._prepare_state == "ready" + + def wait_ready(self) -> None: + return None + + def wait_until_ready(self, tick_id: int, sample_count: int) -> bool: + if int(sample_count) != self._sample_count: + raise ValueError("sample_count must match the allocated multi-GPU replay slots") + self._refresh_prepare_state() + if self._prepared_metadata is None: + if self._prepare_tick_id is None: + self.start_prepare(tick_id, sample_count) + with self._prepare_condition: + while self._prepared_metadata is None and self._prepare_error is None: + self._prepare_condition.wait(timeout=0.1) + if self._prepare_error is not None: + raise self._prepare_error + assert self._prepared_metadata is not None + if self._prepared_metadata.tick_id != int(tick_id): + raise RuntimeError( + f"Rank {self._rank} prepared tick {self._prepared_metadata.tick_id} " + f"does not match requested tick {tick_id}" + ) + slot = self._prepared_metadata.batch_gpu_slot + assert slot is not None + self._transfer_backend.synchronize_ready(slot) + self._prepare_state = "ready" + return True + + def sample_large_batch(self, tick_id: int, sample_count: int) -> Dict[str, torch.Tensor]: + if int(sample_count) != self._sample_count: + raise ValueError("sample_count must match the allocated multi-GPU replay slots") + if self._has_hot_batch: + if self._hot_metadata is not None and self._hot_metadata.tick_id != int(tick_id): + raise RuntimeError( + f"Rank {self._rank} hot tick {self._hot_metadata.tick_id} " + f"does not match requested tick {tick_id}" + ) + return self._packed_batch_view(self._gpu_packed[self._hot]) + if not self.batch_ready(tick_id, sample_count): + self.wait_until_ready(tick_id, sample_count) + assert self._prepared_metadata is not None + slot = self._prepared_metadata.batch_gpu_slot + assert slot is not None + wait_begin_ns = time.perf_counter_ns() + self._transfer_backend.wait_current_stream_for_ready(slot) + wait_copy_time_s = float(getattr(self._transfer_backend, "last_wait_copy_time_s", 0.0)) + if wait_copy_time_s > 0.0: + self.last_incremental_h2d_time_s = wait_copy_time_s + if self._trace_recorder is not None: + self._trace_recorder.add_slice( + "replay_pipeline/rank_batch_h2d_wait", + category="replay_pipeline", + start_ns=wait_begin_ns, + end_ns=time.perf_counter_ns(), + args={"tick_id": tick_id, "rank": self._rank, "batch_gpu_slot": slot}, + ) + if slot != self._cold: + raise RuntimeError("Prepared multi-GPU replay batch is not in the current cold slot") + self._hot, self._cold = self._cold, self._hot + self._has_hot_batch = True + self._hot_metadata = self._prepared_metadata + self._prepared_metadata = None + self._prepare_tick_id = None + self._prepare_state = "idle" + return self._packed_batch_view(self._gpu_packed[self._hot]) + + def after_tick(self) -> None: + self._has_hot_batch = False + self._hot_metadata = None + + def close(self) -> None: + self._closed = True + try: + self._collector_pack_ready_queue.put_nowait(None) + except Exception: + pass + self._collector_h2d_thread.join(timeout=2.0) + if self._prepared_metadata is not None: + slot = self._prepared_metadata.batch_gpu_slot + if slot is not None: + self._transfer_backend.synchronize_ready(slot) + self._transfer_backend.close() + self._gpu_packed.clear() diff --git a/tests/algos/test_fast_sac_symmetry_contract.py b/tests/algos/test_fast_sac_symmetry_contract.py index b757a4661..65f7e77d5 100644 --- a/tests/algos/test_fast_sac_symmetry_contract.py +++ b/tests/algos/test_fast_sac_symmetry_contract.py @@ -119,6 +119,21 @@ def test_fast_sac_learner_rejects_symmetry_without_augmentation(): ) +def test_fast_sac_learner_exposes_multi_gpu_initial_sync_contract(): + from unilab.algos.torch.fast_sac.learner import FastSACLearner + + learner = FastSACLearner( + obs_dim=4, + action_dim=2, + critic_obs_dim=4, + device="cpu", + world_size=1, + ) + + assert callable(getattr(learner, "sync_initial_parameters", None)) + learner.sync_initial_parameters(src=0) + + def test_multi_gpu_offpolicy_runner_rejects_sac_symmetry_capability(): from unilab.algos.torch.offpolicy.multi_gpu_runner import MultiGPUOffPolicyRunner diff --git a/tests/algos/test_offpolicy_double_buffer_runner.py b/tests/algos/test_offpolicy_double_buffer_runner.py index 0aad883c3..af9af0127 100644 --- a/tests/algos/test_offpolicy_double_buffer_runner.py +++ b/tests/algos/test_offpolicy_double_buffer_runner.py @@ -113,14 +113,14 @@ def __init__(self, *args, **kwargs): assert runner.kwargs["learner"].kwargs["critic_obs_dim"] == 6 -def test_sac_multi_gpu_rejects_cpu_pinned_double_buffer(): +def test_sac_multi_gpu_rejects_obs_normalization_until_synchronized(): cfg = _offpolicy_cfg( [ "algo=sac", "training.num_gpus=2", ] ) - with pytest.raises(ValueError, match="currently single-GPU only"): + with pytest.raises(ValueError, match="requires algo.obs_normalization=false"): _offpolicy().build_runner("sac", cfg) @@ -300,7 +300,7 @@ def test_flashsac_double_buffer_multi_gpu_rejected(): "training.num_gpus=2", ] ) - with pytest.raises(ValueError, match="FlashSAC does not support training.num_gpus > 1"): + with pytest.raises(ValueError, match="Only SAC supports training.num_gpus > 1"): _offpolicy().build_runner("flashsac", cfg) diff --git a/tests/algos/test_offpolicy_runner_unit.py b/tests/algos/test_offpolicy_runner_unit.py index b5ce929b2..35832d685 100644 --- a/tests/algos/test_offpolicy_runner_unit.py +++ b/tests/algos/test_offpolicy_runner_unit.py @@ -54,6 +54,9 @@ def update_actor(self, batch: dict[str, torch.Tensor]) -> dict[str, float]: def soft_update_target(self) -> None: self.target_updates += 1 + def sync_initial_parameters(self, src: int = 0) -> None: + del src + def get_state_dict(self) -> dict[str, int]: return {"update_count": self.update_count} @@ -885,6 +888,7 @@ def test_multi_gpu_runner_passes_explicit_runtime_context_to_collector( learner=learner, env_name="DummyEnv", algo_type="sac", + learner_cls=_FakeLearner, learner_kwargs={}, num_gpus=2, num_envs=2, @@ -893,7 +897,7 @@ def test_multi_gpu_runner_passes_explicit_runtime_context_to_collector( learning_starts=6, updates_per_step=1, policy_frequency=1, - sync_collection=False, + sync_collection=True, env_steps_per_sync=1, device="cpu", sim_backend="motrix", @@ -929,6 +933,7 @@ def test_multi_gpu_runner_allocates_replay_critic_storage( learner=_FakeLearner(), env_name="DummyEnv", algo_type="sac", + learner_cls=_FakeLearner, learner_kwargs={}, num_gpus=2, num_envs=2, @@ -937,7 +942,7 @@ def test_multi_gpu_runner_allocates_replay_critic_storage( learning_starts=6, updates_per_step=1, policy_frequency=1, - sync_collection=False, + sync_collection=True, env_steps_per_sync=1, device="cpu", ) @@ -950,83 +955,72 @@ def test_multi_gpu_runner_allocates_replay_critic_storage( assert replay_buffer.critic_dim == 7 -def test_multi_gpu_worker_rank0_propagates_learner_timing_and_extra_info( +def test_multi_gpu_runner_spawn_receives_algorithm_agnostic_learner_and_rank_ipc( monkeypatch: pytest.MonkeyPatch, tmp_path ) -> None: monkeypatch.setattr(multi_gpu_runner_module, "ReplayBuffer", _FakeReplayBuffer) monkeypatch.setattr(multi_gpu_runner_module, "SharedWeightSync", _FakeWeightSync) - monkeypatch.setattr(multi_gpu_runner_module, "OffPolicyLogger", _FakeLogger) - monkeypatch.setattr(multi_gpu_runner_module, "FastSACLearner", _FakeLearner) - monkeypatch.setattr(multi_gpu_runner_module.torch.cuda, "set_device", lambda rank: None) - monkeypatch.setattr( - multi_gpu_runner_module.dist, "init_process_group", lambda *args, **kwargs: None - ) - monkeypatch.setattr(multi_gpu_runner_module.dist, "broadcast", lambda *args, **kwargs: None) - monkeypatch.setattr(multi_gpu_runner_module.dist, "barrier", lambda *args, **kwargs: None) - monkeypatch.setattr( - multi_gpu_runner_module.dist, "destroy_process_group", lambda *args, **kwargs: None - ) - monkeypatch.setattr(multi_gpu_runner_module.torch, "save", lambda *args, **kwargs: None) + monkeypatch.setattr(multi_gpu_runner_module.time, "sleep", lambda seconds: None) + monkeypatch.setattr(runner_module, "get_env_dims", lambda *args, **kwargs: (4, 2, 0)) + captured: dict[str, object] = {} - fake_clock = _FakeClock([300.1, 310.1, 310.2, 310.7]) - monkeypatch.setattr(multi_gpu_runner_module.time, "time", fake_clock.time) + def fake_spawn(fn, args, nprocs, join): + del fn, join + captured["args"] = args + captured["nprocs"] = nprocs - sleep_sizes = iter([4, 8, 12]) + monkeypatch.setattr(multi_gpu_runner_module.tmp, "spawn", fake_spawn) - def fake_sleep(seconds: float) -> None: - if seconds < 0.5: - next_size = next(sleep_sizes, 12) - replay_buffer = _FakeReplayBuffer.last_instance - assert replay_buffer is not None - replay_buffer.size[0] = next_size - replay_buffer.ptr[0] = next_size + runner = multi_gpu_runner_module.MultiGPUOffPolicyRunner( + learner=_FakeLearner(), + env_name="DummyEnv", + algo_type="sac", + learner_cls=_FakeLearner, + learner_kwargs={"obs_dim": 4, "action_dim": 2}, + num_gpus=2, + num_envs=2, + replay_buffer_n=8, + batch_size=8, + learning_starts=6, + updates_per_step=1, + policy_frequency=1, + sync_collection=True, + env_steps_per_sync=1, + device="cpu", + ) + monkeypatch.setattr(runner, "_start_collector", lambda *args, **kwargs: None) - monkeypatch.setattr(multi_gpu_runner_module.time, "sleep", fake_sleep) + runner.learn(max_iterations=0, save_interval=0, log_dir=str(tmp_path)) - replay_buffer = _FakeReplayBuffer(capacity=64, obs_dim=4, action_dim=2, device="cpu") - weight_sync = _FakeWeightSync() - metrics_queue = queue.Queue() - stop_event = type("_Stop", (), {"is_set": staticmethod(lambda: False)})() - - runner_kwargs = { - "max_iterations": 1, - "save_interval": 0, - "log_dir": str(tmp_path), - "batch_size": 8, - "learning_starts": 6, - "updates_per_step": 1, - "policy_frequency": 1, - "sync_collection": False, - "env_steps_per_sync": 1, - "env_name": "DummyEnv", - "num_envs": 2, - "obs_dim": 4, - "action_dim": 2, - "logger_type": "wandb", - } - - multi_gpu_runner_module._learner_worker( - rank=0, - world_size=2, - learner_kwargs={}, - runner_kwargs=runner_kwargs, - replay_buffer=cast(ReplayBuffer, replay_buffer), - weight_sync_name=weight_sync.name, - weight_sync_lock=weight_sync._lock, - weight_param_shapes={"weight": torch.Size([1])}, - stop_event=stop_event, - collection_ready_queue=None, - trainer_done_queue=None, - metrics_queue=metrics_queue, - master_port=12345, + args = cast(tuple, captured["args"]) + assert captured["nprocs"] == 2 + assert args[1] is _FakeLearner + assert args[2] == {"obs_dim": 4, "action_dim": 2} + assert len(cast(list, args[13])) == 2 + assert len(cast(list, args[14])) == 2 + + +def test_multi_gpu_runner_rejects_unsynchronized_collection_on_learn( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: + monkeypatch.setattr(runner_module, "get_env_dims", lambda *args, **kwargs: (4, 2, 0)) + runner = multi_gpu_runner_module.MultiGPUOffPolicyRunner( + learner=_FakeLearner(), + env_name="DummyEnv", + algo_type="sac", + learner_cls=_FakeLearner, + learner_kwargs={"obs_dim": 4, "action_dim": 2}, + num_gpus=2, + num_envs=2, + replay_buffer_n=8, + batch_size=8, + learning_starts=6, + updates_per_step=1, + policy_frequency=1, + sync_collection=False, + env_steps_per_sync=1, + device="cpu", ) - logger = _FakeLogger.last_instance - assert logger is not None - assert logger.step_calls - step = logger.step_calls[0] - assert "collect_time" not in step - assert step["learner_incremental_h2d_time"] == pytest.approx(0.004) - assert step["weight_sync_time"] >= 0.0 - assert step["extra_info"] == {"throughput_steps": 2} - assert step["extra_info"]["throughput_steps"] == 2 + with pytest.raises(ValueError, match="requires synchronized collection"): + runner.learn(max_iterations=0, save_interval=0, log_dir=str(tmp_path)) diff --git a/tests/ipc/test_multi_gpu_replay_pack.py b/tests/ipc/test_multi_gpu_replay_pack.py new file mode 100644 index 000000000..0ddd64e14 --- /dev/null +++ b/tests/ipc/test_multi_gpu_replay_pack.py @@ -0,0 +1,313 @@ +"""Tests for multi-rank replay pack IPC contract.""" + +from __future__ import annotations + +import queue + +import pytest +import torch + +from unilab.algos.torch.offpolicy.worker import ( + _drain_collector_pack_requests, + _service_collector_pack_requests, +) +from unilab.ipc.replay_buffer import ReplayBuffer +from unilab.ipc.replay_pipelines.multi_gpu_cpu_pinned import MultiGPUCPUPinnedReplayPipeline + + +def _make_replay_buffer() -> ReplayBuffer: + buf = ReplayBuffer( + capacity=16, + obs_dim=2, + action_dim=1, + device="cpu", + critic_dim=0, + packed_cpu_storage=True, + ) + obs = torch.arange(32, dtype=torch.float32).reshape(16, 2) + actions = torch.arange(16, dtype=torch.float32).reshape(16, 1) + rewards = torch.arange(16, dtype=torch.float32) + next_obs = obs + 100 + dones = torch.zeros(16) + truncated = torch.zeros(16) + buf.add(obs, actions, rewards, next_obs, dones, truncated) + return buf + + +def test_collector_pack_routes_ranked_requests_to_rank_slots_and_ready_queues() -> None: + replay_buffer = _make_replay_buffer() + packed_width = int(replay_buffer._storage.shape[1]) + sample_count = 4 + request_queue: queue.Queue = queue.Queue() + ready_queues = [queue.Queue(), queue.Queue()] + shared_slots = [ + [torch.empty((sample_count, packed_width), dtype=torch.float32) for _ in range(2)] + for _ in range(2) + ] + + request_queue.put( + { + "tick_id": 7, + "rank": 1, + "world_size": 2, + "sample_seed": 123, + "sample_count": sample_count, + "shared_slot": 0, + "target_gpu_slot": 0, + "learner_hot_gpu_slot": 1, + "min_snapshot_ptr": 0, + } + ) + + serviced, pending = _service_collector_pack_requests( + replay_buffer, + request_queue, + ready_queues, + shared_slots, + ) + + assert serviced is True + assert pending is None + assert ready_queues[0].empty() + ready = ready_queues[1].get_nowait() + assert ready["rank"] == 1 + assert ready["world_size"] == 2 + assert ready["sample_seed"] == 123 + assert ready["sample_count"] == sample_count + assert not torch.isnan(shared_slots[1][0]).any() + + +def test_collector_pack_defers_until_min_snapshot_ptr_for_ranked_request() -> None: + replay_buffer = _make_replay_buffer() + packed_width = int(replay_buffer._storage.shape[1]) + sample_count = 2 + request = { + "tick_id": 8, + "rank": 0, + "world_size": 2, + "sample_seed": 456, + "sample_count": sample_count, + "shared_slot": 1, + "target_gpu_slot": 1, + "learner_hot_gpu_slot": 0, + "min_snapshot_ptr": int(replay_buffer.ptr[0]) + 1, + } + request_queue: queue.Queue = queue.Queue() + request_queue.put(request) + ready_queues = [queue.Queue(), queue.Queue()] + shared_slots = [ + [torch.empty((sample_count, packed_width), dtype=torch.float32) for _ in range(2)] + for _ in range(2) + ] + + serviced, pending = _service_collector_pack_requests( + replay_buffer, + request_queue, + ready_queues, + shared_slots, + ) + + assert serviced is False + assert pending == request + assert ready_queues[0].empty() + assert ready_queues[1].empty() + + +def test_rank_local_pipeline_requests_rank_seed_and_consumes_cpu_batch() -> None: + replay_buffer = _make_replay_buffer() + packed_width = int(replay_buffer._storage.shape[1]) + sample_count = 3 + request_queue: queue.Queue = queue.Queue() + ready_queue: queue.Queue = queue.Queue() + shared_slots = [ + torch.empty((sample_count, packed_width), dtype=torch.float32) for _ in range(2) + ] + pipeline = MultiGPUCPUPinnedReplayPipeline( + replay_buffer, + rank=1, + world_size=2, + device="cpu", + sample_count=sample_count, + base_seed=100, + collector_pack_request_queue=request_queue, + collector_pack_ready_queue=ready_queue, + collector_pack_shared_slots=shared_slots, + ) + try: + assert pipeline.start_prepare(5, sample_count) + request = request_queue.get_nowait() + assert request["rank"] == 1 + assert request["world_size"] == 2 + assert request["sample_seed"] == 111 + + ready = { + "tick_id": request["tick_id"], + "rank": request["rank"], + "world_size": request["world_size"], + "snapshot_ptr": int(replay_buffer.ptr[0]), + "snapshot_size": int(replay_buffer.size[0]), + "sample_seed": request["sample_seed"], + "sample_count": sample_count, + "shared_slot": request["shared_slot"], + "target_gpu_slot": request["target_gpu_slot"], + } + torch.manual_seed(0) + shared_slots[request["shared_slot"]].copy_(replay_buffer._storage[:sample_count]) + ready_queue.put(ready) + + assert pipeline.wait_until_ready(5, sample_count) + batch = pipeline.sample_large_batch(5, sample_count) + assert batch["obs"].shape == (sample_count, 2) + assert torch.equal( + batch["obs"], replay_buffer._storage[:sample_count, replay_buffer._obs_sl] + ) + pipeline.after_tick() + finally: + pipeline.close() + + +def test_rank_local_pipeline_rejects_runner_rank_matrix_slots() -> None: + replay_buffer = _make_replay_buffer() + packed_width = int(replay_buffer._storage.shape[1]) + sample_count = 3 + rank_matrix_slots = [ + [torch.empty((sample_count, packed_width), dtype=torch.float32) for _ in range(2)] + for _ in range(2) + ] + + with pytest.raises(ValueError, match="rank-local shared slots"): + MultiGPUCPUPinnedReplayPipeline( + replay_buffer, + rank=1, + world_size=2, + device="cpu", + sample_count=sample_count, + collector_pack_request_queue=queue.Queue(), + collector_pack_ready_queue=queue.Queue(), + collector_pack_shared_slots=rank_matrix_slots, + ) + + +def test_multi_rank_pipeline_uses_runner_rank_local_ipc_shape() -> None: + replay_buffer = _make_replay_buffer() + packed_width = int(replay_buffer._storage.shape[1]) + sample_count = 3 + request_queues = [queue.Queue(), queue.Queue()] + ready_queues = [queue.Queue(), queue.Queue()] + shared_slots = [ + [torch.empty((sample_count, packed_width), dtype=torch.float32) for _ in range(2)] + for _ in range(2) + ] + pipeline = MultiGPUCPUPinnedReplayPipeline( + replay_buffer, + rank=1, + world_size=2, + device="cpu", + sample_count=sample_count, + base_seed=100, + collector_pack_request_queue=request_queues[1], + collector_pack_ready_queue=ready_queues[1], + collector_pack_shared_slots=shared_slots[1], + ) + try: + assert pipeline.start_prepare(5, sample_count) + request = request_queues[1].get_nowait() + pending = _drain_collector_pack_requests( + replay_buffer, + request_queues[1], + ready_queues, + shared_slots, + pending_request=request, + ) + assert pending is None + assert ready_queues[0].empty() + + assert pipeline.wait_until_ready(5, sample_count) + batch = pipeline.sample_large_batch(5, sample_count) + assert batch["obs"].shape == (sample_count, 2) + assert batch["obs"].device.type == "cpu" + finally: + pipeline.close() + + +def test_collector_pack_drain_services_available_multi_rank_requests() -> None: + replay_buffer = _make_replay_buffer() + packed_width = int(replay_buffer._storage.shape[1]) + sample_count = 2 + request_queue: queue.Queue = queue.Queue() + ready_queues = [queue.Queue(), queue.Queue()] + shared_slots = [ + [torch.empty((sample_count, packed_width), dtype=torch.float32) for _ in range(2)] + for _ in range(2) + ] + for rank in range(2): + request_queue.put( + { + "tick_id": 9, + "rank": rank, + "world_size": 2, + "sample_seed": 900 + rank, + "sample_count": sample_count, + "shared_slot": 0, + "target_gpu_slot": 0, + "learner_hot_gpu_slot": 1, + "min_snapshot_ptr": 0, + } + ) + + pending = _drain_collector_pack_requests( + replay_buffer, + request_queue, + ready_queues, + shared_slots, + ) + + assert pending is None + assert request_queue.empty() + assert ready_queues[0].get_nowait()["rank"] == 0 + assert ready_queues[1].get_nowait()["rank"] == 1 + + +def test_collector_pack_service_parallel_rank_queues() -> None: + from unilab.algos.torch.offpolicy.worker import _CollectorPackService + + replay_buffer = _make_replay_buffer() + packed_width = int(replay_buffer._storage.shape[1]) + sample_count = 2 + request_queues = [queue.Queue(), queue.Queue()] + ready_queues = [queue.Queue(), queue.Queue()] + shared_slots = [ + [torch.empty((sample_count, packed_width), dtype=torch.float32) for _ in range(2)] + for _ in range(2) + ] + stop_event = type("_Stop", (), {"_set": False})() + stop_event.is_set = lambda: bool(stop_event._set) + + service = _CollectorPackService( + replay_buffer, + request_queues, + ready_queues, + shared_slots, + stop_event=stop_event, + ) + try: + service.start() + for rank in range(2): + request_queues[rank].put( + { + "tick_id": 10, + "rank": rank, + "world_size": 2, + "sample_seed": 1000 + rank, + "sample_count": sample_count, + "shared_slot": 0, + "target_gpu_slot": 0, + "learner_hot_gpu_slot": 1, + "min_snapshot_ptr": 0, + } + ) + assert ready_queues[0].get(timeout=1.0)["sample_seed"] == 1000 + assert ready_queues[1].get(timeout=1.0)["sample_seed"] == 1001 + finally: + stop_event._set = True + service.close() diff --git a/tests/scripts/test_train_scripts.py b/tests/scripts/test_train_scripts.py index 1c787e42e..4e0d82843 100644 --- a/tests/scripts/test_train_scripts.py +++ b/tests/scripts/test_train_scripts.py @@ -2423,25 +2423,26 @@ def test_offpolicy_flashsac_rejects_multi_gpu(): ] ) - with pytest.raises(ValueError, match="FlashSAC does not support training.num_gpus > 1"): + with pytest.raises(ValueError, match="Only SAC supports training.num_gpus > 1"): _offpolicy().build_runner("flashsac", cfg) -def test_offpolicy_sac_multi_gpu_rejected_by_double_buffer(): +def test_offpolicy_sac_multi_gpu_requires_cuda_device(): cfg = _offpolicy_cfg( [ "algo=sac", "task=sac/g1_walk_flat/mujoco", "training.num_gpus=2", "training.device=cpu", + "algo.obs_normalization=false", ] ) - with pytest.raises(ValueError, match="currently single-GPU only"): + with pytest.raises(ValueError, match="requires a CUDA device"): _offpolicy().build_runner("sac", cfg) -def test_offpolicy_sac_multi_gpu_rejects_even_with_explicit_symmetry_disable(): +def test_offpolicy_sac_multi_gpu_requires_cuda_even_with_explicit_symmetry_disable(): cfg = _offpolicy_cfg( [ "algo=sac", @@ -2449,10 +2450,11 @@ def test_offpolicy_sac_multi_gpu_rejects_even_with_explicit_symmetry_disable(): "training.num_gpus=2", "training.device=cpu", "algo.use_symmetry=false", + "algo.obs_normalization=false", ] ) - with pytest.raises(ValueError, match="currently single-GPU only"): + with pytest.raises(ValueError, match="requires a CUDA device"): _offpolicy().build_runner("sac", cfg)